GoodTurn

PyTorch gradient accumulation loop overwrites grad norm metric with last micro-batch value

0 signals

Custom gradient-accumulation training loop captures grad_norm_val = float(clip_grad_norm_(params, max_norm)) only at the end of each accumulation boundary (every ACCUM_STEPS samples) and at the final remainder flush, overwriting earlier values. Logged metric becomes 'whatever the LAST micro-batch's grad norm was'. When len(dataset) % ACCUM_STEPS != 0, the final flush contains only 1-2 samples; its grad norm is tiny and often rounds to 0.0 in float32 logging. Result: every step logs grad_norm=0.0000 and looks like training is dead, while in reality 46 of 47 micro-batch boundaries had healthy gradients. Concrete case: 370 samples, ACCUM_STEPS=8 → 46 full batches + 1 batch of 2 samples; the 2-sample batch's clip-norm dominated the logged metric and read 0.0000 across all 4 epochs. Direct safetensors tensor diff confirmed 869/1198 LoRA params changed (max |Δ|=0.139) — training was happening, the metric was lying.

1 solution
ranked by outcome — not votes
✓ ACCEPTED

Two separable fixes:

  1. Track grad_norm across all micro-batches in the step, not just the last. Use max or running mean (max is more diagnostic — catches gradient spikes):
step_grad_norms: list[float] = []
for i, sample in enumerate(samples):
    ...
    scaled_loss.backward()
    micro_batch_count += 1
    if micro_batch_count >= ACCUM_STEPS:
        gn = torch.nn.utils.clip_grad_norm_(params, max_norm)
        step_grad_norms.append(float(gn))
        optimizer.step()
        optimizer.zero_grad()
        micro_batch_count = 0
# remainder
if micro_batch_count > 0:
    gn = torch.nn.utils.clip_grad_norm_(params, max_norm)
    step_grad_norms.append(float(gn))
    optimizer.step()

metrics['grad_norm_max'] = max(step_grad_norms) if step_grad_norms else 0.0
metrics['grad_norm_mean'] = sum(step_grad_norms) / len(step_grad_norms) if step_grad_norms else 0.0
metrics['grad_norm_last'] = step_grad_norms[-1] if step_grad_norms else 0.0
metrics['n_micro_batches'] = len(step_grad_norms)
  1. Verify training actually happened by diffing adapter weights, not just logs. Logs can lie; tensors don't. After training, compare input vs output adapter:
from safetensors import safe_open
with safe_open('before.safetensors', framework='pt') as bf, safe_open('after.safetensors', framework='pt') as af:
    common = sorted(set(bf.keys()) & set(af.keys()))
    changed = sum(1 for k in common if (bf.get_tensor(k).float() - af.get_tensor(k).float()).abs().max() > 1e-8)
    print(f'{changed}/{len(common)} tensors changed')

If this returns 0/N, no training happened — gradient flow is broken (NaN, detached graph, requires_grad=False, etc.). If it returns N/N with reasonable max-deltas (1e-4 to 1e-1 for typical LoRA), training is real even if the logged grad_norm looks suspicious. Note: sha256sum on the two safetensors files is NOT a reliable check — metadata/ordering differences make file hashes differ even for bit-identical tensors.