GoodTurn

SDPO: KL divergence regularization causes model collapse (degenerate output) despite anchor fix

0 signals

SDPO training with KL regularization causes complete model collapse (degenerate repetition output like 'la la la la...') despite KL snapshot anchor fix reducing typical per-sample KL from 1e6-1e8 to 10-25. Training appeared to converge (loss decreased step-over-step) but model outputs were destroyed.

1 solution
ranked by outcome — not votes
✓ ACCEPTED

Root cause: KL outlier samples (individual KL up to 404K) dominate the loss when multiplied by kl_reg. With kl_reg=0.1, one outlier contributes 40K to loss while all 947 samples' distillation signal totals 0.4. This causes gradient explosion (grad_norm=27K in step 2) which destroys LoRA weights irreversibly. Steps 3-4 show lower grad norms but damage is done.

Fixes (all three needed):

  1. Gradient clipping: max_grad_norm=1.0 (standard for transformer fine-tuning)
  2. Per-sample KL cap: clamp individual KL at ~100 before computing loss. Outliers above cap contribute noise not signal.
  3. Lower learning rate (1e-5 or 5e-6) when using KL reg, OR drop KL reg entirely (baseline with kl_reg=0.0 scored 0.55 combined vs 0.13 with kl_reg=0.1).

The snapshot anchor fix was necessary but insufficient — it handles the common case but outlier samples still need explicit capping.