Device
Unable to load system info.
README
Build an LLM
This repository contains an educational training workflow for a transformer-based autoregressive, decoder-only language model. It is optimized not for speed or cost, but rather for learning.
Users can:
- Pre-train an LLM from scratch using a simple, intuitive interface, with a diagram that visualizes the user's specific configuration, gives visibility to every token and every training row, as well as attention heatmaps.
- Fine-tune a pre-trained model on prompt/response pairs using supervised fine-tuning (SFT), with support for both full-parameter and parameter-efficient fine-tuning via LoRA.
- Explore the code to understand the modularized building blocks of transformer models, with multiple implementation variants for each component. The code shown dynamically adapts to configuration choices.
- Work through equations to understand the math behind the code, with equations dynamically displayed based on configuration.
I built this as I wanted to properly understand LLMs. A great way to learn is to write code yourself; an even better way to learn is to write code in a general, modular manner that's clean enough for others to read.
I'm incredibly grateful to all those from whom I learned and borrowed ideas (see Resources). I hope others find this repository helpful!
Contents
- Usage Guide
- Key Concepts
- Pre-Training Pipeline
- Fine-Tuning Pipeline
- Inference and Sampling
- Core Components Deep Dive
- Resources
Usage Guide
Getting Started
The following command will install uv and run the FastAPI + Next.js app:
./run.sh
The app will open in your browser with the following pages:
- Overview: System details and this README
- Pre-Training Page: Configure and pre-train models with a visual interface
- Fine-Tuning Page: Fine-tune pre-trained models on prompt/response pairs
- Inference Page: Generate text from trained models (pre-trained or fine-tuned)
- Playground
Pre-Training
- Loads training text file
- Creates tokenizer and dataset
- Initializes model based on selected architecture and configuration
- Trains for specified epochs with real-time loss visualization
- Saves checkpoints to
checkpoints/YYYYMMDDHHMMSS/(timestamped folders)
How To
- Upload training data or use the default
training.txtfile - Choose custom parameters or select a preset:
- š GPT-2: Learned positional embeddings, LayerNorm, GELU activation
- š¦ LLaMA: RoPE positional encoding, RMSNorm, SwiGLU activation
- š¬ OLMo: ALiBi positional encoding, LayerNorm, SwiGLU activation
- š· DeepSeek V2: LLaMA-style with MoE (64 experts, top-6, 2 shared experts)
- šÆ Mixtral: LLaMA-style with MoE (8 experts, top-2 routing). Sparse MoE architecture matching Mixtral 8x7B design
- Configure model dimensions (or use size presets: small, medium, full)
- Optionally enable MoE (Mixture of Experts) and configure expert settings
- Set training hyperparameters (batch size, learning rate, epochs, etc.)
- Click "Start Training" to enter the Interactive Training Mode:
- Glass Box Training: Watch the model learn step-by-step.
- Colored Tokens: See exactly how the tokenizer splits your input text with color-coded tokens (hover for IDs).
- Real-time Controls: Start/Pause/Resume and āļø Step to Next Batch buttons for manual stepping.
- Attention Heatmaps: Visualize how the model attends to different tokens in the sequence (visible when paused).
- Live Metrics: Monitor Loss, Gradient Norms, and Validation performance in real-time.
Inference
How To
- Select a checkpoint from the dropdown (auto-scans
checkpoints/directory)- Shows both pre-trained and fine-tuned checkpoints
- Clearly labeled: "š Final Model (Pre-trained)" vs "š Final Model (Fine-tuned)"
- Fine-tuned models are in
checkpoints/{timestamp}/sft/subdirectories
- Enter a prompt
- Configure sampling parameters (temperature, top-k, top-p)
- Click "Generate" to create text
- Glass Box Internals: Scroll down to inspect the model's internal state during generation:
- Attention Maps: Visualize where the model is looking.
- Logit Lens: See what the model "thinks" the next token is after every layer.
- Layer Norms: Track signal propagation through the network.
Key Concepts
Before diving into the implementation details, here are the key concepts that underpin this codebase:
1. Autoregressive Language Modeling
What it is: Predicting the next token given previous tokens in a sequence.
Example:
Input: "The cat sat on the"
Target: "cat sat on the mat"
At each position, the model predicts what comes next. This is how language models learn to generate coherent text.
Why it works:
- Language has structure and patterns
- Given context, the next token is somewhat predictable
- The model learns these patterns from data
2. Causal Masking
What it is: Preventing the model from seeing future tokens during training.
Why it's needed:
- During inference, we generate one token at a time
- The model should only use past context
- Training must match inference conditions
How it works:
- Attention scores for future positions are set to
-inf - After softmax, these become 0 probability
- The model can't attend to future tokens
3. Attention Mechanisms & Heatmaps
What it is: The core mechanism that allows the model to "look back" at previous tokens to inform the current prediction.
The Intuition: Imagine reading a sentence. When you see the word "bank", you need to know if the context is "river" or "money" to understand it. Attention allows the model to look back at words like "river" or "money" earlier in the sentence to disambiguate "bank".
Heatmap Visualization: In the Glass Box view (Pre-Training and Inference), you can see this process happening via Attention Heatmaps:
- X-axis (Key): The tokens the model is looking at (source).
- Y-axis (Query): The token the model is currently generating (destination).
- Color Intensity: Darker/Brighter colors indicate stronger focus.
- Diagonal Pattern: A strong diagonal line means the model is mostly looking at the immediate previous token (common in early layers).
4. Residual Connections
What it is: Adding input to output: output = input + transformation(input)
Why it helps:
- Allows gradients to flow directly through
- Enables training of very deep networks
- The model can learn the identity function if the transformation isn't needed
5. Layer Normalization
What it is: Normalizing activations across the feature dimension.
Why it helps:
- Stabilizes training
- Allows higher learning rates
- Reduces internal covariate shift
Variants:
- LayerNorm (GPT/OLMo): Normalizes by subtracting mean, then scaling
- RMSNorm (LLaMA): Only scales (no mean subtraction, no bias)
6. Positional Encoding
The Problem: Transformers have no inherent notion of sequence order.
Solutions:
- Learned Embeddings (GPT): Fixed embeddings for each position
- RoPE (LLaMA): Rotates query/key vectors by position-dependent angles
- ALiBi (OLMo): Adds distance-based bias to attention scores
7. Pre-training vs Fine-tuning
Pre-training:
- Train on large, diverse text corpus
- Learn general language patterns
- Unsupervised (no labels needed)
- Example: Train on Wikipedia, books, web text
Supervised Fine-Tuning (SFT):
- Take pre-trained model
- Train further on prompt/response pairs
- Supervised (needs labeled data: prompt ā response)
- Lower learning rate (typically 10-100x lower)
- Shorter training (1-5 epochs vs 10+)
- Loss masking: Only compute loss on response tokens, not prompt tokens
- Example: Train on instruction-following datasets
Why Fine-Tune?
- Pre-trained models learn general language but may not follow instructions well
- Fine-tuning teaches the model to respond appropriately to prompts
- Makes the model more useful for specific tasks (Q&A, instruction following, etc.)
Pre-Training Pipeline
1. Data Loading (pretraining/data/dataset.py)
Purpose: Load text, tokenize, and create training sequences.
Process
- Load Text: Read raw text file
- Tokenize: Convert text to token IDs
- Create Sequences: Sliding window approach
- Input:
[token_0, token_1, ..., token_n-1] - Target:
[token_1, token_2, ..., token_n](shifted by 1)
- Input:
- Split: Train/validation split (default 90/10)
Code Flow
# Load text
text = "Hello world..."
# Tokenize
data = tokenizer.encode_tensor(text) # [total_tokens]
# Create sequences (sliding window)
for i in range(len(data) - block_size):
X.append(data[i:i+block_size]) # Input
Y.append(data[i+1:i+block_size+1]) # Target (shifted)
# X: [num_sequences, block_size]
# Y: [num_sequences, block_size]
Why Shift by 1?
- We predict the next token
- At position
i, we predict token at positioni+1 - This is autoregressive language modeling
2. Training Loop (pretraining/training/trainer.py)
Purpose: Train the model using gradient descent.
Key Components
- Optimizer: AdamW (adaptive learning rate with weight decay)
- Loss Function: Cross-entropy (next-token prediction)
- Evaluation: Periodic evaluation on train/val sets
- Checkpointing: Save model periodically
Training Step
# 1. Get random batch
idx = torch.randint(0, len(X_train), (batch_size,))
x_batch = X_train[idx] # [batch_size, seq_len]
y_batch = Y_train[idx] # [batch_size, seq_len]
# 2. Forward pass
logits = model(x_batch) # [batch_size, seq_len, vocab_size]
# 3. Compute loss
# Reshape for cross-entropy
logits_flat = logits.view(-1, vocab_size) # [batch*seq, vocab]
targets_flat = y_batch.view(-1) # [batch*seq]
loss = F.cross_entropy(logits_flat, targets_flat)
# 4. Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
Loss Explanation:
- Cross-entropy measures how well predicted probabilities match true token
- Lower loss = better predictions
- Typical range: starts high (2-4), decreases during training
Evaluation:
- Run model in
eval()mode (no dropout, no gradient computation) - Average loss over multiple random batches
- Compare train vs val loss to detect overfitting
3. Configuration (config.py, pretraining/training/training_args.py)
ModelConfig - Model Architecture
d_model: Hidden dimension (e.g., 256, 768)n_layers: Number of transformer blocks (e.g., 4, 12)n_heads: Number of attention heads (e.g., 4, 12)d_head: Dimension per head (typicallyd_model / n_heads)d_mlp: MLP hidden dimension (typically4 * d_model)n_ctx: Context length (max sequence length)d_vocab: Vocabulary size- MoE Configuration (when
use_moe=True):num_experts: Number of expert MLPs (e.g., 8, 64)num_experts_per_tok: Top-k experts to activate per token (e.g., 2, 6)use_shared_experts: Enable shared experts (DeepSeek-style)num_shared_experts: Number of always-active shared expertsrouter_type: Routing strategy (top_kortop_k_with_shared)load_balancing_loss_weight: Weight for load balancing auxiliary lossexpert_capacity_factor: Capacity factor for expert load balancing
TransformerTrainingArgs - Training Hyperparameters
batch_size: Number of sequences per batchepochs: Number of training epochslr: Learning rateweight_decay: L2 regularization strengtheval_iters: Number of batches for evaluation
Fine-Tuning Pipeline
1. Data Loading (finetuning/data/sft_dataset.py)
Purpose: Load prompt/response pairs from CSV and create training sequences with loss masks.
Process
- Load CSV: Read CSV file with
promptandresponsecolumns - Tokenize: Convert prompts and responses to token IDs
- Create Sequences: Concatenate prompt + response
- Create Masks: 0 for prompt tokens, 1 for response tokens
- Shift by 1: Create input/target pairs (same as pre-training)
- Split: Train/validation split (default 90/10)
Code Flow
# Load CSV
df = pd.read_csv("finetuning.csv")
prompts = df['prompt'].tolist()
responses = df['response'].tolist()
# For each pair
for prompt, response in zip(prompts, responses):
# Tokenize
prompt_tokens = tokenizer.encode(prompt) # [prompt_len]
response_tokens = tokenizer.encode(response) # [response_len]
# Concatenate
full_tokens = prompt_tokens + response_tokens # [prompt_len + response_len]
# Create mask: 0 for prompt, 1 for response
mask = [0] * len(prompt_tokens) + [1] * len(response_tokens)
# Shift by 1 (same as pre-training)
input_seq = full_tokens[:-1] # [seq_len-1]
target_seq = full_tokens[1:] # [seq_len-1]
mask_seq = mask[1:] # [seq_len-1]
# X: [num_sequences, seq_len-1] - input sequences
# Y: [num_sequences, seq_len-1] - target sequences
# masks: [num_sequences, seq_len-1] - loss masks (1 for response, 0 for prompt)
Why Mask Prompt Tokens?
- We want the model to learn to generate responses, not repeat prompts
- Computing loss on prompt tokens would teach the model to copy the prompt
- Masking ensures we only learn from the response tokens
2. Training Loop (finetuning/training/sft_trainer.py)
Purpose: Fine-tune the pre-trained model using masked loss.
Key Differences from Pre-Training
- Masked Loss: Only compute loss on response tokens
- Lower Learning Rate: Typically 1e-5 (vs 1e-3 for pre-training)
- Shorter Training: 1-5 epochs (vs 10+ for pre-training)
- Structured Data: Prompt/response pairs instead of raw text
Training Step
# 1. Get random batch
idx = torch.randint(0, len(X_train), (batch_size,))
x_batch = X_train[idx] # [batch_size, seq_len]
y_batch = Y_train[idx] # [batch_size, seq_len]
masks_batch = masks_train[idx] # [batch_size, seq_len]
# 2. Forward pass
logits = model(x_batch) # [batch_size, seq_len, vocab_size]
# 3. Compute masked loss
logits_flat = logits.view(-1, vocab_size) # [batch*seq, vocab]
targets_flat = y_batch.view(-1) # [batch*seq]
masks_flat = masks_batch.view(-1) # [batch*seq]
# Compute loss per token
loss_unmasked = F.cross_entropy(
logits_flat, targets_flat, reduction='none'
) # [batch*seq]
# Apply mask: only average over response tokens (where mask == 1)
loss = (loss_unmasked * masks_flat).sum() / masks_flat.sum().clamp(min=1)
# 4. Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
Loss Explanation:
loss_unmasked: Loss for every token positionmasks_flat: 1 for response tokens, 0 for prompt tokensloss_unmasked * masks_flat: Zero out prompt token lossessum() / masks_flat.sum(): Average only over response tokens
Why Lower Learning Rate?
- Pre-trained weights are already good
- We want small adjustments, not large changes
- Prevents catastrophic forgetting of pre-trained knowledge
3. Checkpoint Organization
Pre-trained checkpoints: checkpoints/{timestamp}/final_model.pt
Fine-tuned checkpoints: checkpoints/{timestamp}/sft/final_model.pt
Both checkpoints are visible in the inference page, clearly labeled:
- "š Final Model (Pre-trained)"
- "š Final Model (Fine-tuned)"
Inference and Sampling
Text Generation (inference/sampler.py)
Note: The inference page supports both pre-trained and fine-tuned models. Fine-tuned models are often better at following instructions and generating appropriate responses to prompts.
Purpose: Generate text from a trained model.
Autoregressive Generation
Basic Approach (without KV cache):
# Start with prompt
tokens = tokenizer.encode(prompt) # [seq_len]
# Generate tokens one by one
for _ in range(max_new_tokens):
# Get model predictions (recomputes everything each time)
logits = model(tokens) # [1, seq_len, vocab_size]
# Get logits for last position
next_token_logits = logits[0, -1, :] # [vocab_size]
# Sample next token
next_token = sample(next_token_logits)
# Append to sequence
tokens.append(next_token)
Optimized Approach (with KV cache):
# Start with prompt
tokens = tokenizer.encode(prompt) # [seq_len]
tokens_tensor = torch.tensor([tokens], device=device)
# Process prompt and initialize KV cache
logits, kv_cache = model(tokens_tensor, cache=None, start_pos=0)
next_token_logits = logits[0, -1, :] / temperature
start_pos = tokens_tensor.shape[1]
# Generate tokens one by one
for _ in range(max_new_tokens):
# Sample next token
next_token = sample(next_token_logits)
# Append to sequence
tokens_tensor = torch.cat([tokens_tensor, next_token.unsqueeze(0)], dim=1)
# Process only the new token using KV cache
# This is much faster - only computes Q, K, V for the new token
new_token_tensor = next_token.unsqueeze(0) # [1, 1]
logits, kv_cache = model(new_token_tensor, cache=kv_cache, start_pos=start_pos)
next_token_logits = logits[0, -1, :] / temperature
start_pos += 1
Why KV Cache?
- Without cache: Each generation step recomputes Q, K, V for all previous tokens (O(n²) complexity per step)
- With cache: Each generation step only computes Q, K, V for the new token, reusing cached K, V from previous tokens (O(n) complexity per step)
- Speedup: Dramatically faster for long sequences - can be 10-100x faster depending on sequence length
- Memory tradeoff: Uses more memory to store cached K, V tensors, but the speedup is usually worth it
Sampling Strategies
-
Temperature Sampling:
logits = logits / temperature probs = softmax(logits)temperature < 1: More focused (deterministic)temperature = 1: Balancedtemperature > 1: More creative (random)
-
Top-k Sampling:
# Only consider top k most likely tokens top_k_logits = topk(logits, k)- Prevents sampling very unlikely tokens
k=40is common
-
Top-p (Nucleus) Sampling:
# Consider tokens until cumulative probability > p sorted_probs = sort(softmax(logits)) cumulative = cumsum(sorted_probs) mask = cumulative <= p- Adaptive: considers more tokens when distribution is flat
p=0.9is common
-
Combined: Top-k + Top-p + Temperature
- Most common in practice
- Provides good balance of quality and diversity
Core Components Deep Dive
1. Normalization Layers
Layer Normalization (pretraining/normalization/layernorm.py) - GPT/OLMo Style
Purpose: Normalize activations across the feature dimension to stabilize training.
Unified Implementation:
The LayerNorm class supports both einops and PyTorch implementations via the use_einops flag:
With Einops (use_einops=True)
Uses einops.reduce for mean and variance computation:
# Compute mean: [batch, posn, d_model] -> [batch, posn, 1]
residual_mean = einops.reduce(
residual, 'batch posn d_model -> batch posn 1', 'mean'
)
Why this works: Einops makes it explicit that we're reducing over d_model while keeping batch and posn.
Without Einops (use_einops=False)
Uses PyTorch's built-in operations:
# Compute mean: [batch, posn, d_model] -> [batch, posn, 1]
residual_mean = residual.mean(dim=-1, keepdim=True)
Why this works: dim=-1 means "last dimension" (d_model), keepdim=True preserves the dimension.
PyTorch Built-in (LayerNormWithTorch)
Uses PyTorch's nn.LayerNorm:
self.ln = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
Why this works: PyTorch's implementation is optimized and handles edge cases.
Mathematical Formula:
LayerNorm(x) = γ * (x - μ) / (Ļ + ε) + β
Where:
μ= mean over d_model dimensionĻ= standard deviation over d_model dimensionγ(w) = learnable scale parameterβ(b) = learnable shift parameterε= small constant for numerical stability
Shape Flow:
Input: [batch, posn, d_model]
Mean: [batch, posn, 1] # Mean over d_model
Std: [batch, posn, 1] # Std over d_model
Normalized: [batch, posn, d_model]
Output: [batch, posn, d_model] # After scale and shift
RMS Normalization (pretraining/normalization/rmsnorm.py) - LLaMA Style
Purpose: Simpler normalization used in LLaMA (no mean subtraction, no bias).
Key Difference from LayerNorm:
- LayerNorm:
(x - mean) / std * γ + β(centers then scales) - RMSNorm:
x / rms * γ(only scales, no centering, no bias)
Implementation:
# Compute RMS (Root Mean Square)
rms = sqrt(mean(x²) + eps) # [batch, posn, 1]
# Normalize
norm = x / rms # [batch, posn, d_model]
# Apply scale (no bias)
output = norm * w # [batch, posn, d_model]
Why RMSNorm?
- Simpler (fewer operations)
- No bias term needed
- Works well in practice
- Used in LLaMA, PaLM, and other modern models
Shape Flow:
Input: [batch, posn, d_model]
RMS: [batch, posn, 1] # RMS over d_model
Normalized: [batch, posn, d_model]
Output: [batch, posn, d_model] # After scale only
2. Token Embeddings (pretraining/embeddings/embed.py)
Purpose: Convert token IDs (integers) into dense vector representations.
EmbedWithoutTorch
Manual implementation using nn.Parameter:
# W_E: [d_vocab, d_model] - embedding matrix
self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
# Forward: Index into embedding matrix
# tokens: [batch, position] -> [batch, position, d_model]
return self.W_E[tokens]
How it works:
- Each token ID (0 to d_vocab-1) maps to a row in
W_E W_E[tokens]uses advanced indexing to look up embeddings- Example:
tokens = [[5, 10, 3]]ā[W_E[5], W_E[10], W_E[3]]
EmbedWithTorch
Uses PyTorch's nn.Embedding:
self.embedding = nn.Embedding(cfg.d_vocab, cfg.d_model)
return self.embedding(tokens)
Why both versions:
- Manual version shows what's happening under the hood
- PyTorch version is optimized and handles edge cases
Shape Flow:
Token IDs: [batch, position] # e.g., [[5, 10, 3]]
Embedding Matrix: [d_vocab, d_model] # e.g., [50257, 256]
Output: [batch, position, d_model] # e.g., [[emb[5], emb[10], emb[3]]]
3. Positional Encoding
Learned Positional Embeddings (pretraining/positional_embeddings/positional_embedding.py) - GPT Style
Purpose: Add information about token positions in the sequence.
The PosEmbed class supports both einops and PyTorch implementations via the use_einops flag:
With Einops (use_einops=True)
# W_pos: [n_ctx, d_model] - one embedding per position
# Get embeddings for current sequence length
W_pos[:seq_len] # [seq_len, d_model]
# Repeat for each item in batch
einops.repeat(
W_pos[:seq_len],
"seq d_model -> batch seq d_model",
batch=batch
)
How it works:
W_pos[i]is the embedding for positionieinops.repeatbroadcasts[seq_len, d_model]to[batch, seq_len, d_model]
Without Einops (use_einops=False)
# Manual broadcasting
position_embeddings_we_need = self.W_pos[:sequence_length] # [seq_len, d_model]
position_embeddings_with_batch_dim = position_embeddings_we_need.unsqueeze(0) # [1, seq_len, d_model]
position_embeddings_repeated = position_embeddings_with_batch_dim.expand(batch_size, -1, -1) # [batch, seq_len, d_model]
How it works:
unsqueeze(0)adds a batch dimension at position 0expand()repeats along batch dimension (memory-efficient, no copying)
Shape Flow:
W_pos: [n_ctx, d_model] # e.g., [1024, 256]
Slice: [seq_len, d_model] # e.g., [128, 256]
Repeat: [batch, seq_len, d_model] # e.g., [32, 128, 256]
Why positional embeddings?
- Transformers have no inherent notion of sequence order
- Positional embeddings encode "this token is at position 5"
- Added to token embeddings:
final_emb = token_emb + pos_emb
Rotary Position Embedding (pretraining/positional_embeddings/rope.py) - LLaMA Style
Purpose: Encode positions through rotations of query and key vectors (not learned, computed on-the-fly).
Key Concepts:
- Not added to embeddings: Applied directly to Q and K in attention
- Rotation-based: Rotates each dimension pair by position-dependent angles
- Relative positions: Encodes relative distances between tokens
How it works:
- Split Q/K into pairs:
[d_head] ā [d_head/2 pairs] - For each pair
(x_i, x_i+1), compute rotation angle:Īø_i * position - Apply rotation matrix:
[cos(Īø) -sin(Īø)] [x_i ] [sin(Īø) cos(Īø)] [x_i+1] - Different frequencies for different dimensions:
Īø_i = 10000^(-2i/d_head)
Implementation:
# Pre-compute frequencies
freqs = 1.0 / (theta ** (arange(0, d_head, 2) / d_head))
# For each position, compute rotation angles
angles = positions * freqs # [seq, d_head/2]
# Rotate Q and K
q_rotated, k_rotated = apply_rotation(q, k, angles)
Why RoPE?
- Encodes relative positions naturally
- Extrapolates to longer sequences better than learned embeddings
- Applied in attention (not embeddings), so more flexible
- Used in LLaMA, PaLM, and other modern models
Shape Flow:
Q/K: [batch, seq, n_heads, d_head]
ā
Reshape to pairs: [batch, seq, n_heads, d_head/2, 2]
ā
Rotate each pair: [batch, seq, n_heads, d_head/2, 2]
ā
Reshape back: [batch, seq, n_heads, d_head]
ALiBi - Attention with Linear Biases (pretraining/positional_embeddings/alibi.py) - OLMo Style
Purpose: Encode positions through linear biases added to attention scores (not learned, computed on-the-fly).
Key Concepts:
- Not added to embeddings: Applied directly to attention scores
- Distance-based: Bias depends on distance between positions
- Per-head slopes: Each attention head gets a different slope
- No learned parameters: All computed from fixed formulas
How it works:
- Compute distance matrix:
distance[i, j] = |i - j| - For each head
h, compute slope:slope[h] = 2^(-8/n_heads * h) - Apply bias:
bias[h, i, j] = -slope[h] * distance[i, j]for future positions - Add bias to attention scores before softmax
Implementation:
# Pre-compute slopes for each head
slopes = 2^(-8/n_heads * [1, 2, ..., n_heads]) # [n_heads]
# Compute distance matrix
distance = |pos_i - pos_j| # [seq_len, seq_len]
# Apply slopes
bias = -slopes.unsqueeze(-1).unsqueeze(-1) * distance.unsqueeze(0) # [n_heads, seq_len, seq_len]
# Add to attention scores
attn_scores = attn_scores + bias
Why ALiBi?
- Simpler than RoPE (no rotations, just addition)
- Extrapolates extremely well to longer sequences
- Each head learns different distance preferences
- Used in OLMo and other modern models
- No learned parameters, so more efficient
Shape Flow:
Attention scores: [batch, n_heads, seq_len, seq_len]
ā
ALiBi bias: [n_heads, seq_len, seq_len]
ā
Add bias: [batch, n_heads, seq_len, seq_len]
ā
Softmax: [batch, n_heads, seq_len, seq_len]
Key Insight:
- Closer positions get less negative bias (can attend more)
- Farther positions get more negative bias (attend less)
- This naturally implements causal attention without explicit masking (though we still mask for numerical stability)
4. Multi-Head Self-Attention (pretraining/attention/attention.py)
Purpose: Allow tokens to attend to other tokens in the sequence, learning relationships.
Key Concepts
Query (Q), Key (K), Value (V):
- Query: "What am I looking for?"
- Key: "What do I represent?"
- Value: "What information do I contain?"
Attention Mechanism:
- Compute attention scores:
Q @ K^T(how much each token attends to others) - Apply causal mask (prevent looking at future tokens)
- Softmax to get attention probabilities
- Weighted sum of values:
Attention(Q, K, V) = softmax(QK^T / ād_head) @ V
Attention Types: MHA, GQA, and MQA
The codebase supports three attention variants:
1. Multi-Head Attention (MHA) - Standard attention:
- Each head has separate Q, K, V projections
n_kv_heads = n_heads(default)- Used in: GPT-2, original LLaMA
- KV cache size:
[batch, seq_len, n_heads, d_head]
2. Grouped Query Attention (GQA) - Efficient attention:
- Groups of Q heads share the same K/V heads
n_kv_heads < n_heads(e.g., 32 Q heads, 8 KV heads = 4:1 ratio)- Used in: LLaMA 2, Mistral, Mixtral
- KV cache size:
[batch, seq_len, n_kv_heads, d_head](smaller!) - Benefits: ~75% smaller KV cache, faster inference, minimal quality loss
3. Multi-Query Attention (MQA) - Most efficient:
- All Q heads share a single K/V head
n_kv_heads = 1- Used in: PaLM, some optimized models
- KV cache size:
[batch, seq_len, 1, d_head](much smaller!) - Benefits: Maximum memory efficiency, faster inference, slight quality trade-off
How GQA/MQA Works:
- Compute Q with
n_headsprojections:Q: [batch, seq, n_heads, d_head] - Compute K/V with
n_kv_headsprojections:K, V: [batch, seq, n_kv_heads, d_head] - Broadcast K/V to match Q: repeat each KV head
n_heads / n_kv_headstimes - Attention computation proceeds identically to MHA after broadcasting
- Cache stores original (non-broadcasted) K/V to save memory
Attention - Unified Implementation
The Attention class supports both einops and PyTorch implementations via the use_einops flag:
With Einops (use_einops=True):
Step 1: Compute Q, K, V
# residual: [batch, posn, d_model]
# W_Q: [n_heads, d_head, d_model]
# W_K, W_V: [n_kv_heads, d_head, d_model] (for GQA/MQA)
# q: [batch, posn, n_heads, d_head]
q = einops.einsum(
residual, self.W_Q,
"batch posn d_model, n_heads d_head d_model -> batch posn n_heads d_head"
)
# k: [batch, posn, n_kv_heads, d_head] (may be different from n_heads)
k = einops.einsum(
residual, self.W_K,
"batch posn d_model, n_kv_heads d_head d_model -> batch posn n_kv_heads d_head"
)
# v: [batch, posn, n_kv_heads, d_head]
v = einops.einsum(
residual, self.W_V,
"batch posn d_model, n_kv_heads d_head d_model -> batch posn n_kv_heads d_head"
)
What's happening:
- Q is projected with
n_headsprojections (one per Q head) - K/V are projected with
n_kv_headsprojections (may be fewer thann_headsfor GQA/MQA) - For MHA:
n_kv_heads = n_heads(standard behavior) - For GQA/MQA:
n_kv_heads < n_heads(memory efficient)
Step 1b: Broadcast K/V for GQA/MQA (if n_kv_heads < n_heads)
# Broadcast K/V to match Q heads
if self.n_kv_heads < self.n_heads:
repeat_factor = self.n_heads // self.n_kv_heads
k = k.repeat_interleave(repeat_factor, dim=2) # [batch, posn_k, n_heads, d_head]
v = v.repeat_interleave(repeat_factor, dim=2) # [batch, posn_k, n_heads, d_head]
Step 2: Compute Attention Scores
# After broadcasting, k and v have n_heads dimension
# q: [batch, posn_q, n_heads, d_head]
# k: [batch, posn_k, n_heads, d_head] (broadcasted if GQA/MQA)
# attn_scores: [batch, n_heads, posn_q, posn_k]
attn_scores = einops.einsum(
q, k,
"batch posn_q n_heads d_head, batch posn_k n_heads d_head -> batch n_heads posn_q posn_k"
) / (self.cfg.d_head ** 0.5) # Scale by ād_head
What's happening:
- For each position
iand headh, compute dot product with all positionsj attn_scores[b, h, i, j]= how much positioniattends to positionjin headh- Scaling by
ād_headprevents softmax from saturating
Step 3: Causal Masking
# mask: [seq_len, seq_len] - lower triangular matrix
mask = torch.tril(torch.ones((seq_len, seq_len), device=residual.device))
# Set future positions to -inf (so softmax makes them 0)
attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
What's happening:
- Lower triangular matrix:
mask[i, j] = 1ifj <= i, else0 - Prevents token at position
ifrom seeing tokens at positions> i - Essential for autoregressive generation
Step 4: Softmax and Apply to Values
# attn_pattern: [batch, n_heads, posn_q, posn_k] - probabilities
attn_pattern = torch.softmax(attn_scores, dim=-1)
# Weighted sum of values
# attn_output: [batch, posn_q, n_heads, d_head]
attn_output = einops.einsum(
attn_pattern, v,
"batch n_heads posn_q posn_k, batch posn_k n_heads d_head -> batch posn_q n_heads d_head"
)
What's happening:
attn_pattern[b, h, i, j]= probability that positioniattends to positionj- Weighted sum:
output[i] = Σ_j attn_pattern[i, j] * v[j]
Step 5: Project Back
# attn_output: [batch, posn, n_heads, d_head]
# W_O: [n_heads, d_head, d_model]
# output: [batch, posn, d_model]
output = einops.einsum(
attn_output, self.W_O,
"batch posn n_heads d_head, n_heads d_head d_model -> batch posn d_model"
)
What's happening: Combine all heads and project back to d_model dimensions.
Without Einops (use_einops=False):
Same logic, but using PyTorch operations:
# Compute Q, K, V using einsum
q = torch.einsum("bpd,nhd->bpnh", residual, self.W_Q) # [batch, seq, n_heads, d_head]
k = torch.einsum("bpd,nkd->bpnk", residual, self.W_K) # [batch, seq, n_kv_heads, d_head]
v = torch.einsum("bpd,nkd->bpnk", residual, self.W_V) # [batch, seq, n_kv_heads, d_head]
# Broadcast K/V for GQA/MQA (if needed)
if self.n_kv_heads < self.n_heads:
repeat_factor = self.n_heads // self.n_kv_heads
k = k.repeat_interleave(repeat_factor, dim=2) # [batch, seq, n_heads, d_head]
v = v.repeat_interleave(repeat_factor, dim=2) # [batch, seq, n_heads, d_head]
# Transpose for matmul: [batch, seq, n_heads, d_head] -> [batch, n_heads, seq, d_head]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Attention scores
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.cfg.d_head ** 0.5)
Shape Flow (Full Attention):
Input: [batch, seq, d_model]
Q/K/V: [batch, seq, n_heads, d_head] (per head)
Scores: [batch, n_heads, seq, seq] (attention weights)
Pattern: [batch, n_heads, seq, seq] (after softmax)
Output: [batch, seq, n_heads, d_head] (weighted values)
Final: [batch, seq, d_model] (after projection)
Why Multi-Head?
- Different heads can learn different relationships
- Head 1 might learn subject-verb relationships
- Head 2 might learn long-range dependencies
- Head 3 might learn local patterns
- Combining them gives richer representations
KV Caching for Efficient Inference
The Problem: During autoregressive generation, we generate tokens one at a time. Without caching, each step recomputes Q, K, V for all previous tokens, which is wasteful since K and V for previous tokens don't change.
The Solution: Cache K and V tensors from previous tokens, and only compute Q, K, V for the new token.
How it works:
# First forward pass (processing prompt)
# tokens: [batch, prompt_len]
logits, kv_cache = model(tokens, cache=None, start_pos=0)
# kv_cache: List of (K_cache, V_cache) tuples, one per layer
# For MHA: K_cache, V_cache: [batch, prompt_len, n_heads, d_head]
# For GQA/MQA: K_cache, V_cache: [batch, prompt_len, n_kv_heads, d_head] (smaller!)
# Subsequent forward passes (generating new tokens)
# new_token: [batch, 1] - only the new token
logits, kv_cache = model(new_token, cache=kv_cache, start_pos=prompt_len)
# K, V are concatenated: [cached_K, new_K] and [cached_V, new_V]
# Only new_K and new_V are computed, cached ones are reused
# For GQA/MQA: Cache stores original (non-broadcasted) K/V to save memory
Implementation Details:
- Attention Layer: Accepts optional
cacheparameter containing cached K, V from previous tokens - Concatenation: New K, V are concatenated with cached K, V along sequence dimension
- RoPE Handling: Cached K already has RoPE applied, so only new K gets rotated
- ALiBi Handling: Bias matrix is computed for full sequence length (cached + new tokens)
- Cache Update: Returns updated cache containing K, V for all tokens (cached + new)
Benefits:
- Speed: 10-100x faster for long sequences
- Efficiency: Only computes Q, K, V for new tokens (O(n) instead of O(n²) per step)
- Memory: Trades memory for speed (stores cached K, V tensors)
Backward Compatibility: When cache=None, the model works exactly as before (used during training).
5. MLP / Feedforward Network (pretraining/mlp/mlp.py)
Purpose: Apply pointwise non-linear transformations to each position independently.
GPT Architecture (GELU)
Input ā Linear(d_model ā d_mlp) ā GELU ā Linear(d_mlp ā d_model) ā Output
Uses 2 weight matrices: W_in and W_out
LLaMA/OLMo Architecture (SwiGLU)
Input ā [Gate Branch: Linear ā Swish] Ć [Up Branch: Linear] ā Linear(d_mlp ā d_model) ā Output
Uses 3 weight matrices: W_gate, W_up, and W_out
MLP - Unified Implementation
The MLP class supports both einops and PyTorch implementations via the use_einops flag:
With Einops (use_einops=True):
# First linear layer
# residual: [batch, posn, d_model]
# W_in: [d_model, d_mlp]
# hidden: [batch, posn, d_mlp]
hidden = einops.einsum(
residual, self.W_in,
"batch posn d_model, d_model d_mlp -> batch posn d_mlp"
) + self.b_in
# GELU activation (element-wise)
hidden = torch.nn.functional.gelu(hidden)
# Second linear layer
# hidden: [batch, posn, d_mlp]
# W_out: [d_mlp, d_model]
# output: [batch, posn, d_model]
output = einops.einsum(
hidden, self.W_out,
"batch posn d_mlp, d_mlp d_model -> batch posn d_model"
) + self.b_out
What's happening:
- Each position is processed independently (no interaction between positions)
- Expands to
d_mlp(typically 4xd_model) for more capacity - GELU provides non-linearity
- Projects back to
d_model
Without Einops (use_einops=False):
Same logic using PyTorch operations (typically using torch.einsum or torch.matmul).
Why GELU?
- GELU (Gaussian Error Linear Unit) is smoother than ReLU
- Used in GPT, BERT, and modern transformers
- Formula:
GELU(x) = x * Φ(x)where Φ is CDF of standard normal
SwiGLU (LLaMA/OLMo):
- Swish activation:
Swish(x) = x * sigmoid(x)(also called SiLU) - Gated architecture:
SwiGLU(x) = Swish(W_gate @ x) * (W_up @ x) - Element-wise multiplication gates the information flow
- More expressive than GELU, allows model to control information flow
- Used in LLaMA, PaLM, OLMo, and other modern models
Shape Flow (GELU):
Input: [batch, posn, d_model] # e.g., [32, 128, 256]
Expand: [batch, posn, d_mlp] # e.g., [32, 128, 1024]
Activate: [batch, posn, d_mlp] # GELU (element-wise)
Project: [batch, posn, d_model] # e.g., [32, 128, 256]
Shape Flow (SwiGLU):
Input: [batch, posn, d_model] # e.g., [32, 128, 256]
Gate: [batch, posn, d_mlp] # Swish(W_gate @ x)
Up: [batch, posn, d_mlp] # W_up @ x
Hidden: [batch, posn, d_mlp] # gate * up (element-wise)
Project: [batch, posn, d_model] # e.g., [32, 128, 256]
Attention Types: MHA, GQA, and MQA
Multi-Head Attention (MHA):
- Standard attention: each head has separate Q, K, V
- KV cache:
[batch, seq_len, n_heads, d_head] - Used in: GPT-2, original LLaMA
Grouped Query Attention (GQA):
- Groups of Q heads share K/V heads (e.g., 32 Q heads, 8 KV heads)
- KV cache:
[batch, seq_len, n_kv_heads, d_head](75% smaller for 4:1 ratio!) - Used in: LLaMA 2, Mistral, Mixtral
- Benefits: Smaller KV cache, faster inference, minimal quality loss
Multi-Query Attention (MQA):
- All Q heads share single K/V head
- KV cache:
[batch, seq_len, 1, d_head](much smaller!) - Used in: PaLM, optimized models
- Benefits: Maximum memory efficiency, faster inference
Implementation: K/V are computed with n_kv_heads projections, then broadcast to match Q heads before attention computation. Cache stores original (non-broadcasted) K/V.
MoE Architecture (Mixture of Experts)
Purpose: Scale model capacity efficiently by using multiple expert MLPs and routing tokens to a subset of experts.
Key Concept: Instead of one large MLP, MoE uses multiple smaller expert MLPs. For each token, a router selects the top-k experts to activate, allowing the model to have more parameters while keeping computation per token similar.
Architecture Flow:
Input: [batch, posn, d_model]
ā
Router Network ā [batch, posn, num_experts] (logits)
ā
Softmax ā Router Probabilities
ā
Top-k Selection ā Select k experts per token
ā
Expert MLPs (only selected experts compute)
ā
Weighted Combination ā [batch, posn, d_model]
ā
(+ Shared Experts if enabled)
ā
Output: [batch, posn, d_model]
Routing Strategies:
-
Top-k Routing (Mixtral-style):
- Router computes logits for all experts
- Selects top-k experts per token
- Combines expert outputs with routing weights
- Only k experts compute per token (sparse activation)
-
Top-k with Shared Experts (DeepSeek-style):
- Same as top-k, but some experts are always active
- Shared experts handle general knowledge
- Routed experts specialize based on input
- Final output = shared_experts(x) + routed_experts(x)
Load Balancing Loss:
- Auxiliary loss encourages uniform expert usage
- Prevents expert collapse (where only a few experts are used)
- Formula:
aux_loss = num_experts * sum(P_i * f_i)where:P_i= average routing probability for expert if_i= fraction of tokens routed to expert i
- Added to main loss:
total_loss = loss + aux_loss * load_balancing_loss_weight
Benefits:
- Scalability: Can have many experts (e.g., 64) while only activating a few per token
- Efficiency: Computation scales with activated experts, not total experts
- Specialization: Different experts can specialize in different patterns
Implementation (pretraining/mlp/mlp.py):
class MoEMLPBase(nn.Module):
def forward(self, residual):
# Router: [batch, seq_len, d_model] -> [batch, seq_len, num_experts]
router_logits = self.router(residual)
router_probs = F.softmax(router_logits, dim=-1)
# Select top-k experts
top_k_probs, top_k_indices = torch.topk(
router_probs, k=self.num_experts_per_tok, dim=-1
)
# Process each expert
output = torch.zeros_like(residual)
for expert_idx in range(self.num_experts):
# Get expert output and weight by routing probability
expert_output = self.experts[expert_idx](residual)
expert_weights = ... # Extract from top_k_probs
output += expert_weights.unsqueeze(-1) * expert_output
# Add shared experts if enabled
if self.use_shared_experts:
shared_output = sum(expert(residual) for expert in self.shared_experts)
output += shared_output / len(self.shared_experts)
# Compute load balancing loss
aux_loss = self._compute_load_balancing_loss(...)
return output, aux_loss
Shape Flow:
Input: [batch, posn, d_model] # e.g., [32, 128, 256]
Router: [batch, posn, num_experts] # e.g., [32, 128, 8]
Top-k: [batch, posn, k] # e.g., [32, 128, 2] (indices and probs)
Expert Outputs: [batch, posn, d_model] # from each selected expert
Weighted Sum: [batch, posn, d_model] # e.g., [32, 128, 256]
6. Transformer Block (pretraining/transformer_blocks/transformer_block.py)
Purpose: Combine attention and MLP with residual connections and layer normalization.
Architecture (Pre-Norm)
Input
ā
LayerNorm ā Attention ā + (residual)
ā ā
āāāāāāāāāāāāāāāāāāāāāāāā
ā
LayerNorm ā MLP ā + (residual)
ā ā
āāāāāāāāāāāāāāāāāā
ā
Output
Implementation
def forward(self, residual):
# Pre-norm attention with residual connection
residual = residual + self.attn(self.ln1(residual))
# Pre-norm MLP with residual connection
residual = residual + self.mlp(self.ln2(residual))
return residual
Key Concepts:
-
Pre-Norm vs Post-Norm:
- Pre-Norm (what we use):
x + f(LN(x)) - Post-Norm:
LN(x + f(x)) - Pre-norm is more stable for deep networks
- Pre-Norm (what we use):
-
Residual Connections:
- Allow gradients to flow directly through
- Enable training of very deep networks
- Help model learn identity function if needed
-
Why Two LayerNorms?
- One before attention (stabilizes attention)
- One before MLP (stabilizes MLP)
- Each sub-block gets normalized inputs
Shape Flow:
Input: [batch, posn, d_model]
ā
LN1: [batch, posn, d_model]
ā
Attention: [batch, posn, d_model]
ā
Add: [batch, posn, d_model] (residual connection)
ā
LN2: [batch, posn, d_model]
ā
MLP: [batch, posn, d_model]
ā
Add: [batch, posn, d_model] (residual connection)
ā
Output: [batch, posn, d_model]
7. Full Transformer Model (pretraining/model/model.py)
Purpose: Stack all components into a complete language model supporting GPT, LLaMA, and OLMo architectures.
Architecture Flow
GPT Architecture:
Tokens [batch, position]
ā
Token Embeddings ā [batch, position, d_model]
ā
+ Learned Positional Embeddings ā [batch, position, d_model]
ā
Transformer Block 1 ā [batch, position, d_model]
ā
...
ā
Transformer Block N ā [batch, position, d_model]
ā
Final LayerNorm ā [batch, position, d_model]
ā
Unembedding ā [batch, position, d_vocab] (logits)
LLaMA Architecture:
Tokens [batch, position]
ā
Token Embeddings ā [batch, position, d_model]
ā
(No positional embedding layer - RoPE applied in attention)
ā
Transformer Block 1 (with RoPE) ā [batch, position, d_model]
ā
...
ā
Transformer Block N (with RoPE) ā [batch, position, d_model]
ā
Final RMSNorm ā [batch, position, d_model]
ā
Unembedding ā [batch, position, d_vocab] (logits)
OLMo Architecture:
Tokens [batch, position]
ā
Token Embeddings ā [batch, position, d_model]
ā
(No positional embedding layer - ALiBi applied in attention)
ā
Transformer Block 1 (with ALiBi) ā [batch, position, d_model]
ā
...
ā
Transformer Block N (with ALiBi) ā [batch, position, d_model]
ā
Final LayerNorm ā [batch, position, d_model]
ā
Unembedding ā [batch, position, d_vocab] (logits)
MoE Architecture (DeepSeek V2 / Mixtral-style):
Tokens [batch, position]
ā
Token Embeddings ā [batch, position, d_model]
ā
(No positional embedding layer - RoPE applied in attention)
ā
Transformer Block 1 (with RoPE + MoE MLP) ā [batch, position, d_model]
ā
...
ā
Transformer Block N (with RoPE + MoE MLP) ā [batch, position, d_model]
ā
Final RMSNorm ā [batch, position, d_model]
ā
Unembedding ā [batch, position, d_vocab] (logits)
MoE Transformer Block:
Input: [batch, posn, d_model]
ā
LayerNorm ā Attention ā + (residual)
ā
LayerNorm ā MoE MLP ā + (residual)
ā ā
ā āā Router ā Top-k Selection
ā āā Expert MLPs (sparse activation)
ā āā Load Balancing Loss (auxiliary)
ā
Output: [batch, posn, d_model]
Implementation
The model automatically selects components based on cfg.positional_encoding and related settings:
def forward(self, tokens):
# tokens: [batch, position]
# Token embeddings
residual = self.embed(tokens) # [batch, position, d_model]
# Positional embeddings (GPT only)
if self.pos_embed is not None: # GPT
residual = residual + self.pos_embed(tokens)
# LLaMA: RoPE is applied inside attention blocks
# OLMo: ALiBi is applied inside attention blocks
# Pass through transformer blocks
aux_losses = []
for block in self.blocks:
residual, aux_loss = block(residual) # [batch, position, d_model]
if aux_loss is not None: # MoE auxiliary loss
aux_losses.append(aux_loss)
# Aggregate MoE auxiliary losses
total_aux_loss = sum(aux_losses) if aux_losses else None
# Final normalization (LayerNorm for GPT/OLMo, RMSNorm for LLaMA)
residual = self.ln_f(residual) # [batch, position, d_model]
# Unembedding to logits
logits = torch.matmul(residual, self.unembed)
# Return logits and auxiliary loss if MoE is enabled
if total_aux_loss is not None:
return logits, total_aux_loss
return logits
Tests
./test.sh
See tests/README.md and frontend/test/README.md for details.
Resources
LLM from Scratch
- ARENA's Transformers from Scratch
- Andrej Karpathy on building an LLM from scratch
- Sebastian Raschka's Build an LLM from Scratch
- Standord CS336: Languge Modeling From Scratch
- Neel Nanda on building an LLM from scratch
Background
- 3Blue1Brown on LLMs
- Luis Serrano on LLMs
- The Illustrated Transformer
- John Hewitt on pre-training (Stanford CS224N)
- Tom Yeh's AI By Hand
Architectures
Papers
- Attention Is All You Need - Original transformer paper
- Language Models are Unsupervised Multitask Learners - GPT-2 paper
- LLaMA: Open and Efficient Foundation Language Models - LLaMA paper
- OLMo: Accelerating the Science of Language Models - OLMo paper
- Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer - Original MoE paper
- Mixtral of Experts - Mixtral MoE paper
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model - DeepSeek V2 MoE paper