ai · level 7

Training

Gradient descent plus billions of tokens equals a world model.

200 XP

Training

Training is the process of adjusting a model's parameters to minimize a loss function. For language models, the loss is next-token prediction error averaged over billions of examples. The optimizer that does this is almost always a variant of stochastic gradient descent.

Analogy

Imagine a blindfolded hiker standing somewhere on a mountain range, trying to reach the lowest valley. They cannot see the map. All they can do is tap the ground around their feet with a walking stick, feel which direction slopes down the steepest, take a small step that way, and repeat. A million steps later they are in a low basin — not guaranteed the deepest, but a lot lower than where they started. The terrain is the loss landscape; each step is one weight update.

The loss function

Language model pre-training uses cross-entropy loss on next-token prediction:

loss = -log( p(true_next_token | context) )

If the model assigns high probability to the correct next token, the loss is low. If it spreads probability across many wrong tokens, loss is high. The average loss over a training batch drives every parameter update.

Perplexity is exp(loss). It is the geometric mean of the inverse probability assigned to each correct token — intuitively, how many tokens the model is "choosing between" on average. A perplexity of 10 means the model behaves like it is uniformly picking from 10 plausible tokens at each step. Lower is better.

Gradient descent

The gradient of the loss with respect to every parameter tells you which direction increases the loss. Moving in the opposite direction decreases it.

θ ← θ - η × ∇θ L

η (eta) is the learning rate — how large a step to take. Too large: training diverges. Too small: training is slow or gets stuck.

Stochastic gradient descent uses a random mini-batch (not the whole dataset) per step. Gradient estimates are noisy but computation per step is fast. In practice, batches of thousands of tokens are used.

The Adam optimizer

Raw SGD is rarely used for transformers. Adam (Adaptive Moment Estimation) maintains per-parameter estimates of the gradient's first moment (mean) and second moment (variance). Parameters with high gradient variance get smaller effective learning rates. Parameters with consistent gradients get larger effective steps.

This makes Adam robust across the wildly different gradient magnitudes in a transformer (attention weights vs. embedding table vs. FFN bias). AdamW adds weight decay — a penalty that pulls parameters toward zero — which improves generalization.

Learning rate schedules

A fixed learning rate is rarely optimal. Standard practice:

Warmup — start with a tiny learning rate, linearly ramp to the target over a few thousand steps. Prevents large early gradients from corrupting random initialization.

Cosine decay — after warmup, decrease the learning rate following a cosine curve down to near zero. Most tokens are seen with a decaying rate; the model fine-tunes itself during the descent.

Epochs, steps, and tokens

Step — one forward pass plus one backward pass on one mini-batch.

Epoch — one full pass over the training dataset.

Language models rarely complete even one epoch. GPT-3 trained on ~300 billion tokens; each token appeared roughly once. At that scale, you don't overfit by repeating data — you run out of compute before you run out of data.

Token budget: for pre-training, Chinchilla scaling laws suggest using ~20 tokens per parameter. A 7B-parameter model should see ~140B tokens. Using fewer tokens undertrains the model; using far more produces diminishing returns.

Batch size and gradient accumulation

Larger batches produce more accurate gradient estimates and enable higher learning rates. But large batches require large GPU memory. Gradient accumulation runs multiple small batches and sums their gradients before applying one update — simulating a larger batch without the memory cost.

Why loss curves look the way they do

A healthy training run shows:

  1. Rapid initial descent — the model quickly learns obvious patterns
  2. Slower but steady decrease — incremental improvements on harder examples
  3. Slight valley near the end of cosine decay — as the learning rate drops, fine-tuning tightens

Spikes indicate instability, often from bad batches or too-high learning rates. A spike that does not recover is a crashed run.

Compute budget

Training cost scales as:

FLOPs ≈ 6 × N × D

Where N is parameter count and D is tokens seen. A 7B-parameter model trained on 1T tokens requires roughly 42 × 10^21 floating-point operations — typically costing millions of dollars in A100/H100 time.