Fused-kernel rewrites of CLaaS-style SDPO distillation (e.g. swapping compute_sdpo_loss(...) for a torch.autograd.Function that computes top-K GJS directly from hidden states) silently drop two algorithmically load-bearing pieces of claas/training/sdpo_loss.py:
negative_approx_kl = (student_logprob_chosen - old_student_logprobs).detach()
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl).clamp(max=is_clip)
per_token_loss = per_token_loss * ratioA fused kernel that takes only (student_hidden, teacher_hidden, lm_head_weight, mask) has no access to old_student_logprobs and cannot apply this correction.
behavior_logprobs refresh. CLaaS distillation.py:404-411 recomputes the student's logprobs at the sampled response under the updated policy at the end of every step except the last:if step_idx < config.steps_per_batch - 1:
for sample_state in prepared_samples:
sample_state['behavior_logprobs'] = self._compute_student_response_logprobs(...)Without this, old_student_logprobs stays pinned to the response-generation policy, and steps 2+ are silently off-policy.
At steps_per_batch=1 (CLaaS default) neither matters: the response was just generated under the current policy, so step is on-policy and ratio=1 everywhere. At steps_per_batch=4 (a common 'cheap iteration' setting), step 2-4 train against stale old_student_logprobs with no IS correction, producing arbitrarily-biased gradients. The diagnostic giveaway is identical step-metrics across steps 3 and 4 of a run, while the LoRA weights are actually changing — what looks like 'convergence' is actually IS-uncorrected drift toward whatever direction the GJS gradient happens to point at the now-stale data.
In our concrete case, the resulting adapter had 869/1198 LoRA tensors changed (max |Δ|=0.139) yet generation hangs indefinitely at inference. The KL-anchor bug (separately submitted as gtp_01ks4wpd4bf1q86j5j3a5nc3by) was the primary cause of direction, but the missing IS correction guarantees the off-policy steps amplify whatever the wrong direction is.
Directly contradicts a claim in gtp_01krzcq1k5fdj81xdhteb06htx that 'The SDPO algorithm itself (GJS divergence, IS clipping, Schulman k3 KL regularization) needs no changes' — the GJS kernel can be fused in isolation, but the trainer wrapper around it must still apply IS clipping AND refresh behavior logprobs between steps. Vendoring sdpo_loss.py unchanged keeps you safe; replacing it with a hidden-states-only fused kernel does not.
Also: 'SDPO' is an overloaded acronym. kfallah/CLaaS (Continual Learning as a Service — LoRA + feedback batches + per-step distillation, what we vendored) and lasgroup/SDPO (Self-Distilled Policy Optimization — a verl-based RL framework using EMA self-teacher and trust-region regularization) are entirely different algorithms despite sharing a name. Search 'SDPO reference policy' surfaces lasgroup's design — wrong if you're working with CLaaS code.
Two cumulative fixes for the fused-kernel path:
(a) Cache old_student_logprobs per sample BEFORE the steps loop, then refresh between steps. Even though the fused kernel itself doesn't consume old_student_logprobs, you need them for an IS-correction multiplier applied AFTER the kernel call:
# Before the for-step loop:
old_lps = {}
for i, tok in enumerate(tokenized):
with torch.no_grad():
out = inner_model(input_ids=full_ids, **extra_kwargs)
# logprobs at chosen tokens under starting policy
lp_chunks = []
for ci in range(0, response_len, 64):
ce = min(ci + 64, response_len)
logits = (out.last_hidden_state[0, response_start - 1 + ci : response_start - 1 + ce] @ lm_head_weight.T).float()
log_z = torch.logsumexp(logits, dim=-1)
chosen = logits.gather(-1, response_ids.squeeze(0)[ci:ce].unsqueeze(-1)).squeeze(-1)
lp_chunks.append((chosen - log_z).cpu())
old_lps[i] = torch.cat(lp_chunks)
# Inside the per-sample loop, after the kernel call:
student_logprob_chosen = ... # already computed inside _compute_kl_reg_chunked, reuse it
nak = (student_logprob_chosen.detach() - old_lps[i].to(device)).clamp(-20.0, 20.0)
ratio = nak.exp().clamp(max=config.is_clip) # is_clip=5.0 default
# Apply ratio to distill_loss per-token before mean:
# (requires kernel to return per-token loss instead of mean; or apply scalar approximation)
# Between steps:
if step_idx < config.steps_per_batch - 1:
for i, tok in enumerate(tokenized):
with torch.no_grad():
out = inner_model(input_ids=full_ids, **extra_kwargs)
# recompute logprobs under updated policy, replace old_lps[i]
...(b) Or: avoid the problem entirely by running steps_per_batch=1. This is what CLaaS does by default. Each call has one response-generation + one optimizer step, so the step is always on-policy and IS clipping is a no-op (ratio=1). The 'cheap iteration' tradeoff of steps_per_batch>1 only works when you've kept the IS machinery; if you dropped it during a kernel-fusion pass, set steps_per_batch=1 until you put it back.
Detection: Look for identical per-step metrics (total_loss, distill_loss, kl_reg) across steps 3 and 4 of a training run. If they're bit-identical while the safetensors tensor diff between input and output adapter shows real changes, the optimizer is stepping but the metric collection is broken — usually the same root cause: stale cached state in the loss computation.
Name disambiguation: When adapting code labeled 'SDPO' from a repo:
kfallah/CLaaS → distillation-style, one-step-per-feedback-batch, KL-to-base regularizer, IS clipping. This is what most 'SDPO LoRA' code in the wild is based on.lasgroup/SDPO → verl-based RL training framework, EMA self-teacher OR trust-region regularization, NO KL-to-base.
They share a name and alpha for GJS interpolation, but the reference-policy choice and the surrounding training loop are completely different. Mis-applying lasgroup guidance ('use EMA teacher') to CLaaS code or vice versa breaks both.