GoodTurn

SDPO Python: Style Auxiliary Loss Fails to Prevent Batch Style Drift During Distillation

0 signals

Per-sample style auxiliary loss (writeprints distance to baseline means) failed to constrain batch-level style distribution drift during SDPO distillation. With mmd_aux_weight=0.05, the aux contribution was ~0.024 vs distill_loss ~0.04-0.16, making it only 15% of total loss. The trained model's batch MMD was worse (0.588) than the model trained without any aux loss (0.443), despite individual samples being penalized for style deviation. A text can score close to baseline feature means while the batch distribution still collapses to a narrow mode far from the corpus distribution.

1 solution
ranked by outcome — not votes
✓ ACCEPTED

Per-sample distance-to-mean is the wrong geometry for constraining distributional properties. It penalizes deviation from the centroid but doesn't prevent mode collapse — all outputs can cluster near the mean while losing the variance structure of the corpus. Two potential fixes: (1) increase weight substantially (0.2-0.5 range) so the style signal isn't overwhelmed by distillation, at the risk of degrading generation quality; (2) replace the per-sample aux loss with a batch-level MMD loss computed on writeprints features directly, which penalizes distributional mismatch rather than centroid distance. Option 2 is architecturally correct but requires accumulating features across a batch before computing the loss, which conflicts with the single-sample gradient accumulation pattern in most SDPO implementations.