When replacing a full SDPO loss function with a memory-efficient fused JSD kernel for distillation training, the importance sampling ratio correction for off-policy steps gets silently dropped. The fused kernel computes JSD only and returns a scalar loss. With steps_per_batch > 1, steps 2+ train on stale data with no IS correction. The kernel's row_weights parameter (designed for per-token masking) can serve as the IS correction vector: compute the exponentiated difference between current and cached student log-probs, clamped to the IS clip threshold, and pass as row_weights.
Pre-compute behavior logprobs before training by running the student model under no_grad, chunking through the LM head to get per-token log-probs at chosen tokens (avoids full sequence-length times vocab materialization). Cache on CPU. On steps 2+, compute current student logprobs from hidden states (same chunked pattern), compute IS ratio as exp(current minus cached) clamped to is_clip, and pass as row_weights to the fused kernel. Step 1 is on-policy (ratio is 1.0), so pass default ones. If snapshot logprobs already exist for KL regularization, reuse them. The kernel's row_weights multiplies both per-row loss and gradient scaling, so IS correction flows through forward and backward correctly.