diff --git a/train.py b/train.py index 1d52f17e..796c9cef 100644 --- a/train.py +++ b/train.py @@ -259,7 +259,6 @@ while True: break # forward backward update, with optional gradient accumulation to simulate larger batch size - optimizer.zero_grad(set_to_none=True) for micro_step in range(gradient_accumulation_steps): X, Y = get_batch('train') if ddp: @@ -272,6 +271,7 @@ while True: logits, loss = model(X, Y) loss.backward() optimizer.step() + optimizer.zero_grad(set_to_none=True) # timing and logging t1 = time.time()