diff --git a/train.py b/train.py index 951bda99..de578505 100644 --- a/train.py +++ b/train.py @@ -231,7 +231,7 @@ def estimate_loss(): def get_lr(it): # 1) linear warmup for warmup_iters steps if it < warmup_iters: - return learning_rate * it / warmup_iters + return learning_rate * (it + 1) / (warmup_iters + 1) # 2) if it > lr_decay_iters, return min learning rate if it > lr_decay_iters: return min_lr