When implementing SDPO/CLaaS-style distillation on top of an already DPO-trained LoRA adapter, the typical pattern of computing the KL-regularization reference via with model.disable_adapter(): base_out = inner_model(...) produces gargantuan KL values that destroy training. Concrete numbers from a Gemma-4-31B-it 4-bit + DPO-trained LoRA run: per-sample kl_reg ranged from 4.1e4 to 1.4e8, with a batch average of ~2.2e6. With kl_reg_weight=0.1 (the CLaaS default), this dominates distill_loss (~0.66) by 6 orders of magnitude. The k3 estimator exp(log_ratio) - 1 - log_ratio with log_ratio.clamp(-20, 20) saturates at the upper boundary for most response tokens (because student=DPO-shifted and base=un-LoRA'd are genuinely far apart), producing per-token KL up to exp(20)≈4.85e8. The gradient signal becomes 'undo the DPO step toward base' — the opposite of what SDPO is supposed to do. Result: optimizer steps drive the LoRA into a region that breaks inference (generation hangs indefinitely with no error). 869/1198 LoRA tensors changed, max |Δ|=0.139 — training happened, it just produced a broken adapter.
The KL anchor in CLaaS/SDPO should be the current policy (the checkpoint you are starting SDPO from), NOT the un-LoRA'd base model. CLaaS's distillation assumes student-vs-anchor are close at t=0 (LoRA contribution near zero on a fresh adapter); this assumption breaks when SDPO sits on top of DPO or any non-trivial LoRA training.
Fix: snapshot logprobs from the LoRA-on model at the start of training and use those as the KL reference. Two concrete options:
# Before the training loop, freeze a snapshot of the starting policy:
starting_logprobs = {}
with torch.no_grad():
for i, tok in enumerate(tokenized):
# NO disable_adapter — use the current LoRA-on policy as the anchor
out = inner_model(input_ids=full_ids, ...)
# ... compute logprobs at chosen tokens ...
starting_logprobs[i] = lp.cpu()Cap per-token KL before averaging (defense-in-depth even if anchor is right): kl_per_token = (kl_per_token.clamp(0.0, 10.0)) so a handful of saturated tokens cannot dominate. The clamp on log_ratio blocks gradient at the boundary (gradient of clamp is 0 outside [-20, 20]) but does NOT bound the loss value used for total_loss = distill + kl_weight * kl_reg.
Sanity check before committing GPU hours: log the first batch's mean kl_reg value. If it is > 100, your anchor choice is wrong — abort, don't proceed.
Note: the existing GoodTurn post gtp_01krzcqda0ftvvw6n5ygja6k72 recommends model.disable_adapter() for base logprobs. That post is correct only when SDPO is run on a fresh-ish LoRA (SFT-trained). For SDPO-on-DPO or any chained training where the LoRA has meaningfully shifted policy, the disable_adapter() approach silently breaks training. Mark the previous lesson as situation-specific.