| 1 | Perfect! Now I have read all the essential files. Let me create a comprehensive summary document of my findings: |
| 2 | |
| 3 | ## Comprehensive Exploration Report: esp32gpt Repository |
| 4 | |
| 5 | I have completed a thorough exploration of the esp32gpt repository. Here's everything you need to understand the implementation: |
| 6 | |
| 7 | --- |
| 8 | |
| 9 | ### 1. PROJECT OVERVIEW |
| 10 | |
| 11 | **esp32gpt** is a Rust port of Karpathy's microgpt — a 1-layer GPT transformer with only 4,192 parameters that trains from scratch and generates names, running entirely on an ESP32 microcontroller with ~300KB SRAM. |
| 12 | |
| 13 | **Key Stats:** |
| 14 | - 4,192 total parameters (16.4 KB) |
| 15 | - Embeds dim: 16, Heads: 4, FFN dim: 64 |
| 16 | - Block size: 16 tokens |
| 17 | - Vocab: 27 tokens (BOS + a-z) |
| 18 | - Training: 1,000 steps on 32K names dataset |
| 19 | - ~1,000 lines of platform-independent Rust code |
| 20 | |
| 21 | --- |
| 22 | |
| 23 | ### 2. FILE STRUCTURE & SIZES |
| 24 | |
| 25 | ``` |
| 26 | /Users/[REDACTED]/Projects/esp32gpt/ |
| 27 | ├── Cargo.toml (30 lines) - Package config, dependencies |
| 28 | ├── Cargo.lock (updated) |
| 29 | ├── build.rs (7 lines) - ESP-IDF build setup |
| 30 | ├── Makefile (30 lines) - Build/flash/test commands |
| 31 | ├── sdkconfig.defaults (24 lines) - ESP32 configuration |
| 32 | ├── partitions.csv (6 lines) - Flash partition layout |
| 33 | ├── README.md (102 lines) - Project documentation |
| 34 | ├── CLAUDE.md (project instructions) |
| 35 | ├── LICENSE - MIT License |
| 36 | ├── src/ |
| 37 | │ ├── main.rs (101 lines) - Training loop & inference |
| 38 | │ ├── model.rs (339 lines) - Forward pass, parameter layout, KV cache |
| 39 | │ ├── backward.rs (245 lines) - Manual backward pass, gradient accumulation |
| 40 | │ ├── tensor.rs (145 lines) - Vec-mat ops, softmax, ReLU, etc. |
| 41 | │ ├── optimizer.rs (58 lines) - Adam optimizer with bias correction |
| 42 | │ ├── tokenizer.rs (55 lines) - Character-level encode/decode |
| 43 | │ └── rng.rs (108 lines) - Xorshift32 PRNG + Box-Muller Gaussian |
| 44 | └── data/ |
| 45 | └── names.txt (32,032 lines) - Training dataset of 32K names |
| 46 | ``` |
| 47 | |
| 48 | --- |
| 49 | |
| 50 | ### 3. COMPLETE SOURCE CODE ANALYSIS |
| 51 | |
| 52 | #### **main.rs** (Training & Inference Entry Point) |
| 53 | |
| 54 | ```rust |
| 55 | // Constants |
| 56 | const TRAINING_STEPS: usize = 1000; |
| 57 | const LEARNING_RATE: f32 = 0.01; |
| 58 | const TEMPERATURE: f32 = 0.8; |
| 59 | |
| 60 | // Main flow: |
| 61 | // 1. Initialize logger (ESP-IDF or env_logger) |
| 62 | // 2. Load dataset from embedded flash (include_str!) |
| 63 | // 3. Initialize RNG with hardware seed (ESP32) or fixed seed (host) |
| 64 | // 4. Initialize weights from Gaussian distribution |
| 65 | // 5. Create Adam optimizer and gradient buffers |
| 66 | // 6. Training loop: |
| 67 | // - Pick random name from dataset |
| 68 | // - Encode tokens [BOS, char1, ..., charN, BOS] |
| 69 | // - Forward pass through transformer |
| 70 | // - Backward pass to compute gradients |
| 71 | // - Adam step with learning rate decay: lr = 0.01 * (1 - step/1000) |
| 72 | // - Log loss every 100 steps |
| 73 | // 7. Infinite inference loop (never exits): |
| 74 | // - Generate name starting from BOS token |
| 75 | // - Temperature-scaled softmax sampling |
| 76 | // - Stop at BOS or block size limit |
| 77 | ``` |
| 78 | |
| 79 | **Key design:** The training loop picks one name per step and uses the full backprop. Linear learning rate decay improves convergence. |
| 80 | |
| 81 | --- |
| 82 | |
| 83 | #### **model.rs** (Forward Pass, Parameter Layout) |
| 84 | |
| 85 | **Parameter Layout in Flat Buffer (4,192 total parameters, row-major indexing):** |
| 86 | |
| 87 | ``` |
| 88 | WTE: [0..432) 27×16 Token embeddings (vocab×embedding_dim) |
| 89 | WPE: [432..688) 16×16 Position embeddings (block_size×embedding_dim) |
| 90 | WQ: [688..944) 16×16 Query projection (embed_dim×embed_dim) |
| 91 | WK: [944..1200) 16×16 Key projection |
| 92 | WV: [1200..1456) 16×16 Value projection |
| 93 | WO: [1456..1712) 16×16 Output/attention projection |
| 94 | W1: [1712..2736) 16×64 FFN up-projection (embed_dim×ffn_dim) |
| 95 | W2: [2736..3760) 64×16 FFN down-projection (ffn_dim×embed_dim) |
| 96 | WOUT: [3760..4192) 16×27 Output logits head (embed_dim×vocab_size) |
| 97 | ``` |
| 98 | |
| 99 | **ForwardCache Structure:** |
| 100 | Stores ALL intermediate activations needed by backward pass: |
| 101 | - `tokens`: Sequence of input token indices |
| 102 | - `seq_len`: Current sequence length |
| 103 | - `embeddings[BLOCK_SIZE × EMBED_DIM]`: Token + position embeddings |
| 104 | - `queries[BLOCK_SIZE × EMBED_DIM]`: Q projections |
| 105 | - `k_cache[BLOCK_SIZE × EMBED_DIM]`: K values (full history for causal attn) |
| 106 | - `v_cache[BLOCK_SIZE × EMBED_DIM]`: V values (full history) |
| 107 | - `att_outs[BLOCK_SIZE × EMBED_DIM]`: Multi-head attention outputs |
| 108 | - `proj_outs[BLOCK_SIZE × EMBED_DIM]`: Attention output projections |
| 109 | - `residual1[BLOCK_SIZE × EMBED_DIM]`: After attention (emb + proj) |
| 110 | - `ffn_hidden[BLOCK_SIZE × FFN_DIM]`: Before ReLU |
| 111 | - `ffn_relu[BLOCK_SIZE × FFN_DIM]`: After ReLU |
| 112 | - `ffn_out[BLOCK_SIZE × EMBED_DIM]`: After down-projection |
| 113 | - `residual2[BLOCK_SIZE × EMBED_DIM]`: Final hidden state (res1 + ffn_out) |
| 114 | - `logits[BLOCK_SIZE × VOCAB_SIZE]`: Output logits |
| 115 | - `att_probs[BLOCK_SIZE × NUM_HEADS × BLOCK_SIZE]`: Attention probability matrices |
| 116 | |
| 117 | **Indexing Convention for 2D flat data:** |
| 118 | ``` |
| 119 | For position-based data [POS × DIM]: |
| 120 | data[pos * DIM + i] = element at (pos, i) |
| 121 | |
| 122 | For attention probs [POS × HEADS × BLOCK_SIZE]: |
| 123 | att_probs[pos * NUM_HEADS * BLOCK_SIZE + h * BLOCK_SIZE + i] |
| 124 | = probability that position `pos` head `h` attends to position `i` |
| 125 | ``` |
| 126 | |
| 127 | **Forward Pass Algorithm (per position t):** |
| 128 | |
| 129 | ```rust |
| 130 | for t in 0..seq_len { |
| 131 | token = tokens[t] |
| 132 | |
| 133 | // 1. Token + position embeddings |
| 134 | emb = wte[token] + wpe[t] // Both are 16-dim vectors |
| 135 | |
| 136 | // 2. Compute Q, K, V projections (16×16 matrix multiplies) |
| 137 | q = emb @ Wq |
| 138 | k = emb @ Wk |
| 139 | v = emb @ Wv |
| 140 | cache.k_cache[t] = k // Append to KV cache |
| 141 | cache.v_cache[t] = v |
| 142 | |
| 143 | // 3. Multi-head causal self-attention |
| 144 | att_out = [0; 16] // Accumulate across all 4 heads |
| 145 | for h in 0..4 { |
| 146 | q_h = q[h*4..h*4+4] // Head h gets 4 dims (16/4 heads) |
| 147 | |
| 148 | // Compute attention scores for positions 0..=t (causal mask) |
| 149 | scores = [] |
| 150 | for i in 0..=t { |
| 151 | k_i_h = k_cache[i][h*4..h*4+4] |
| 152 | scores[i] = dot(q_h, k_i_h) / sqrt(4) // Scale by HEAD_DIM |
| 153 | } |
| 154 | |
| 155 | // Softmax over causal positions |
| 156 | probs = softmax(scores[0..=t]) |
| 157 | cache.att_probs[t][h] = probs |
| 158 | |
| 159 | // Weighted sum of values |
| 160 | for i in 0..=t { |
| 161 | v_i_h = v_cache[i][h*4..h*4+4] |
| 162 | att_out[h*4..h*4+4] += probs[i] * v_i_h |
| 163 | } |
| 164 | } |
| 165 | |
| 166 | // 4. Output projection + residual |
| 167 | proj = att_out @ Wo |
| 168 | res1 = emb + proj // Residual connection |
| 169 | |
| 170 | // 5. FFN: ReLU(x @ W1) @ W2 |
| 171 | hidden = res1 @ W1 // 16×64 multiply -> 64-dim |
| 172 | activated = ReLU(hidden) |
| 173 | ffn_out = activated @ W2 // 64×16 multiply -> 16-dim |
| 174 | |
| 175 | // 6. FFN residual |
| 176 | res2 = res1 + ffn_out |
| 177 | |
| 178 | // 7. Output logits |
| 179 | logits = res2 @ Wout // 16×27 multiply -> 27-dim (one logit per token) |
| 180 | } |
| 181 | ``` |
| 182 | |
| 183 | **Generate Function:** |
| 184 | ```rust |
| 185 | pub fn generate(params, cache, rng, temperature) -> String { |
| 186 | tokens = [BOS] |
| 187 | for _ in 0..BLOCK_SIZE-1 { |
| 188 | forward(params, tokens, cache) |
| 189 | logits = cache.logits_at(tokens.len()-1) |
| 190 | |
| 191 | // Temperature scaling: lower T = more greedy, higher T = more random |
| 192 | scaled = logits / temperature |
| 193 | probs = softmax(scaled) |
| 194 | next = sample_categorical(probs, rng) |
| 195 | |
| 196 | if next == BOS: |
| 197 | break |
| 198 | tokens.push(next) |
| 199 | } |
| 200 | decode(tokens[1..]) |
| 201 | } |
| 202 | ``` |
| 203 | |
| 204 | --- |
| 205 | |
| 206 | #### **backward.rs** (Manual Gradient Computation) |
| 207 | |
| 208 | **Key Insight:** The backward pass processes positions in REVERSE order because: |
| 209 | - Position t's query attends to keys/values at positions 0..=t |
| 210 | - So key/value gradients at position i accumulate contributions from ALL positions t >= i |
| 211 | - By reversing, when we reach position t, d_k_cache[t] and d_v_cache[t] are complete |
| 212 | |
| 213 | **Backward Algorithm Overview:** |
| 214 | |
| 215 | ```rust |
| 216 | // Accumulated gradient buffers |
| 217 | d_k_cache = zeros([seq_len, EMBED_DIM]) // Accumulates from all future positions |
| 218 | d_v_cache = zeros([seq_len, EMBED_DIM]) |
| 219 | |
| 220 | // Process in REVERSE |
| 221 | for t in (0..seq_len).rev() { |
| 222 | // 1. CROSS-ENTROPY LOSS GRADIENT |
| 223 | target = targets[t] |
| 224 | probs = softmax(logits[t]) |
| 225 | total_loss += -ln(probs[target]) |
| 226 | |
| 227 | d_logits = probs.clone() |
| 228 | d_logits[target] -= 1.0 |
| 229 | d_logits /= seq_len // Normalize by sequence length |
| 230 | |
| 231 | // 2. OUTPUT HEAD BACKWARD: logits = res2 @ Wout |
| 232 | d_res2 = vec_mat_mul_backward_x(d_logits, Wout) |
| 233 | grads[WOUT] += vec_mat_mul_backward_w(res2, d_logits) |
| 234 | |
| 235 | // 3. FFN RESIDUAL: res2 = res1 + ffn_out |
| 236 | d_res1 = d_res2.clone() |
| 237 | d_ffn_out = d_res2 |
| 238 | |
| 239 | // 4. FFN DOWN BACKWARD: ffn_out = ffn_relu @ W2 |
| 240 | d_ffn_relu = vec_mat_mul_backward_x(d_ffn_out, W2) |
| 241 | grads[W2] += vec_mat_mul_backward_w(ffn_relu, d_ffn_out) |
| 242 | |
| 243 | // 5. RELU BACKWARD: d_hidden = d_relu * (hidden > 0) |
| 244 | d_ffn_hidden = d_ffn_relu .* (ffn_hidden > 0) |
| 245 | |
| 246 | // 6. FFN UP BACKWARD: ffn_hidden = res1 @ W1 |
| 247 | d_res1_from_ffn = vec_mat_mul_backward_x(d_ffn_hidden, W1) |
| 248 | grads[W1] += vec_mat_mul_backward_w(res1, d_ffn_hidden) |
| 249 | d_res1 += d_res1_from_ffn |
| 250 | |
| 251 | // 7. ATTENTION RESIDUAL: res1 = emb + proj |
| 252 | d_proj = d_res1.clone() |
| 253 | d_emb = d_res1 |
| 254 | |
| 255 | // 8. ATTENTION OUTPUT PROJECTION BACKWARD: proj = att_out @ Wo |
| 256 | d_att_out = vec_mat_mul_backward_x(d_proj, Wo) |
| 257 | grads[Wo] += vec_mat_mul_backward_w(att_out, d_proj) |
| 258 | |
| 259 | // 9. MULTI-HEAD ATTENTION BACKWARD (THE TRICKY PART) |
| 260 | d_q = zeros([EMBED_DIM]) |
| 261 | |
| 262 | for h in 0..NUM_HEADS { |
| 263 | d_out_h = d_att_out[h*HEAD_DIM..(h+1)*HEAD_DIM] |
| 264 | q_h = query_at(t)[h*HEAD_DIM..(h+1)*HEAD_DIM] |
| 265 | |
| 266 | // For each position i that position t attended to (0..=t): |
| 267 | for i in 0..=t { |
| 268 | v_i_h = v_cache[i][h*HEAD_DIM..(h+1)*HEAD_DIM] |
| 269 | prob = att_prob[t][h][i] |
| 270 | |
| 271 | // Gradient w.r.t. value: d_v[i] += prob * d_out_h |
| 272 | d_v_cache[i][h*HEAD_DIM..(h+1)*HEAD_DIM] += prob * d_out_h |
| 273 | } |
| 274 | |
| 275 | // Softmax backward (for each position the model attended to) |
| 276 | // d_score[i] = prob[i] * (d_prob[i] - sum(prob[j] * d_prob[j])) |
| 277 | d_probs = [dot(d_out_h, v_cache[i][h*HEAD_DIM..]) for i in 0..=t] |
| 278 | dot_sum = sum(prob[i] * d_prob[i] for i in 0..=t) |
| 279 | |
| 280 | for i in 0..=t { |
| 281 | prob = att_prob[t][h][i] |
| 282 | d_score = prob * (d_prob[i] - dot_sum) / sqrt(HEAD_DIM) |
| 283 | |
| 284 | // Backprop through scores: score[i] = dot(q_h, k_i_h) |
| 285 | k_i_h = k_cache[i][h*HEAD_DIM..(h+1)*HEAD_DIM] |
| 286 | d_q[h*HEAD_DIM..(h+1)*HEAD_DIM] += d_score * k_i_h |
| 287 | d_k_cache[i][h*HEAD_DIM..(h+1)*HEAD_DIM] += d_score * q_h |
| 288 | } |
| 289 | } |
| 290 | |
| 291 | // 10. Q PROJECTION BACKWARD: q = emb @ Wq |
| 292 | d_emb_from_q = vec_mat_mul_backward_x(d_q, Wq) |
| 293 | grads[Wq] += vec_mat_mul_backward_w(emb, d_q) |
| 294 | d_emb += d_emb_from_q |
| 295 | |
| 296 | // 11. K, V PROJECTIONS BACKWARD |
| 297 | // Note: d_k_cache[t] and d_v_cache[t] are now complete |
| 298 | d_k_t = d_k_cache[t] |
| 299 | d_emb_from_k = vec_mat_mul_backward_x(d_k_t, Wk) |
| 300 | grads[Wk] += vec_mat_mul_backward_w(emb, d_k_t) |
| 301 | d_emb += d_emb_from_k |
| 302 | |
| 303 | d_v_t = d_v_cache[t] |
| 304 | d_emb_from_v = vec_mat_mul_backward_x(d_v_t, Wv) |
| 305 | grads[Wv] += vec_mat_mul_backward_w(emb, d_v_t) |
| 306 | d_emb += d_emb_from_v |
| 307 | |
| 308 | // 12. EMBEDDING BACKWARD: emb = wte[token] + wpe[t] |
| 309 | token = tokens[t] |
| 310 | grads[wte[token]] += d_emb |
| 311 | grads[wpe[t]] += d_emb |
| 312 | } |
| 313 | |
| 314 | return total_loss / seq_len |
| 315 | ``` |
| 316 | |
| 317 | **Critical Detail - Softmax Backward:** |
| 318 | For softmax followed by loss, the derivative is: |
| 319 | ``` |
| 320 | d_score[i] = prob[i] * (d_prob[i] - sum_j(prob[j] * d_prob[j])) |
| 321 | ``` |
| 322 | This encodes the Jacobian of softmax. The key insight: the second term ensures probabilities stay normalized. |
| 323 | |
| 324 | **Gradient Accumulation Pattern:** |
| 325 | The backward pass accumulates gradients across positions: |
| 326 | ```rust |
| 327 | grads[offset..offset+size] += contribution |
| 328 | ``` |
| 329 | This is done in-place via slices, avoiding extra allocations. |
| 330 | |
| 331 | --- |
| 332 | |
| 333 | #### **tensor.rs** (Core Math Primitives) |
| 334 | |
| 335 | ```rust |
| 336 | // vec_mat_mul: y = x @ W (x: [n], W: [n×m], y: [m]) |
| 337 | pub fn vec_mat_mul(x: &[f32], w: &[f32], n: usize, m: usize) -> Vec<f32> { |
| 338 | let mut y = vec![0.0; m]; |
| 339 | for i in 0..n { |
| 340 | for j in 0..m { |
| 341 | y[j] += x[i] * w[i * m + j]; // Row-major indexing |
| 342 | } |
| 343 | } |
| 344 | y |
| 345 | } |
| 346 | |
| 347 | // vec_mat_mul_backward_x: d_x = d_y @ W.T |
| 348 | pub fn vec_mat_mul_backward_x(d_y: &[f32], w: &[f32], n: usize, m: usize) -> Vec<f32> { |
| 349 | let mut d_x = vec![0.0; n]; |
| 350 | for i in 0..n { |
| 351 | for j in 0..m { |
| 352 | d_x[i] += d_y[j] * w[i * m + j]; |
| 353 | } |
| 354 | } |
| 355 | d_x |
| 356 | } |
| 357 | |
| 358 | // vec_mat_mul_backward_w: d_W[i][j] += x[i] * d_y[j] (outer product, accumulates) |
| 359 | pub fn vec_mat_mul_backward_w(x: &[f32], d_y: &[f32], d_w: &mut [f32], n: usize, m: usize) { |
| 360 | for i in 0..n { |
| 361 | for j in 0..m { |
| 362 | d_w[i * m + j] += x[i] * d_y[j]; |
| 363 | } |
| 364 | } |
| 365 | } |
| 366 | |
| 367 | // vec_add: element-wise addition |
| 368 | pub fn vec_add(a: &[f32], b: &[f32]) -> Vec<f32> { |
| 369 | a.iter().zip(b.iter()).map(|(x, y)| x + y).collect() |
| 370 | } |
| 371 | |
| 372 | // vec_add_inplace: in-place addition (used to accumulate gradients) |
| 373 | pub fn vec_add_inplace(a: &mut [f32], b: &[f32]) { |
| 374 | for (ai, bi) in a.iter_mut().zip(b.iter()) { |
| 375 | *ai += bi; |
| 376 | } |
| 377 | } |
| 378 | |
| 379 | // vec_dot: dot product |
| 380 | pub fn vec_dot(a: &[f32], b: &[f32]) -> f32 { |
| 381 | a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() |
| 382 | } |
| 383 | |
| 384 | // softmax: numerically stable with max subtraction |
| 385 | pub fn softmax(v: &[f32]) -> Vec<f32> { |
| 386 | let max = v.iter().cloned().fold(f32::NEG_INFINITY, f32::max); |
| 387 | let exps: Vec<f32> = v.iter().map(|&x| (x - max).exp()).collect(); |
| 388 | let sum: f32 = exps.iter().sum(); |
| 389 | exps.iter().map(|&e| e / sum).collect() |
| 390 | } |
| 391 | |
| 392 | // relu: max(0, x) |
| 393 | pub fn relu(v: &[f32]) -> Vec<f32> { |
| 394 | v.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect() |
| 395 | } |
| 396 | ``` |
| 397 | |
| 398 | **Design Notes:** |
| 399 | - All operations allocate new vectors (functional style) |
| 400 | - `vec_add_inplace` is used specifically for gradient accumulation to avoid extra allocs |
| 401 | - Softmax subtracts max for numerical stability (prevents overflow on large logits) |
| 402 | - All functions use row-major memory layout consistently |
| 403 | |
| 404 | --- |
| 405 | |
| 406 | #### **optimizer.rs** (Adam Optimizer) |
| 407 | |
| 408 | ```rust |
| 409 | pub struct Adam { |
| 410 | pub m: Vec<f32>, // First moment: running mean of gradients |
| 411 | pub v: Vec<f32>, // Second moment: running mean of squared gradients |
| 412 | pub beta1: f32, // Decay for first moment (default 0.9) |
| 413 | pub beta2: f32, // Decay for second moment (default 0.999) |
| 414 | pub epsilon: f32, // Small constant to prevent division by zero |
| 415 | } |
| 416 | |
| 417 | impl Adam { |
| 418 | pub fn step(&mut self, params: &mut [f32], grads: &[f32], lr: f32, t: usize) { |
| 419 | // Bias correction terms (account for warmup) |
| 420 | let bc1 = 1.0 - beta1^t; |
| 421 | let bc2 = 1.0 - beta2^t; |
| 422 | |
| 423 | for i in 0..params.len() { |
| 424 | let g = grads[i]; |
| 425 | |
| 426 | // Update biased first moment (exponential moving average of gradient) |
| 427 | self.m[i] = beta1 * self.m[i] + (1 - beta1) * g; |
| 428 | |
| 429 | // Update biased second moment (exponential moving average of squared gradient) |
| 430 | self.v[i] = beta2 * self.v[i] + (1 - beta2) * g^2; |
| 431 | |
| 432 | // Bias-corrected estimates |
| 433 | let m_hat = self.m[i] / bc1; |
| 434 | let v_hat = self.v[i] / bc2; |
| 435 | |
| 436 | // Parameter update |
| 437 | params[i] -= lr * m_hat / (sqrt(v_hat) + epsilon); |
| 438 | } |
| 439 | } |
| 440 | } |
| 441 | ``` |
| 442 | |
| 443 | **Configuration in main.rs:** |
| 444 | ```rust |
| 445 | Adam::new(NUM_PARAMS, beta1=0.85, beta2=0.99) |
| 446 | // Note: beta1=0.85 is more aggressive than the standard 0.9 |
| 447 | // This accelerates learning early on |
| 448 | ``` |
| 449 | |
| 450 | **Learning Rate Schedule:** |
| 451 | ```rust |
| 452 | // Linear decay from 0.01 to 0 over training |
| 453 | lr = 0.01 * (1.0 - step / TRAINING_STEPS) |
| 454 | ``` |
| 455 | |
| 456 | --- |
| 457 | |
| 458 | #### **tokenizer.rs** (Character-Level Encoding) |
| 459 | |
| 460 | **Vocabulary:** |
| 461 | - Token 0: BOS (Begin/End Of Sequence marker, displayed as '.') |
| 462 | - Tokens 1-26: Characters 'a'-'z' |
| 463 | - Total: 27 tokens (VOCAB_SIZE) |
| 464 | |
| 465 | ```rust |
| 466 | pub fn encode(name: &str) -> Vec<usize> { |
| 467 | let mut tokens = vec![0]; // Start with BOS |
| 468 | for c in name.chars() { |
| 469 | let c = c.to_ascii_lowercase(); |
| 470 | if c >= 'a' && c <= 'z' { |
| 471 | tokens.push((c as usize) - ('a' as usize) + 1); // Maps 'a'->1, 'z'->26 |
| 472 | } |
| 473 | // Non-alphabetic characters are silently dropped |
| 474 | } |
| 475 | tokens.push(0); // End with BOS |
| 476 | tokens |
| 477 | } |
| 478 | |
| 479 | pub fn decode(token: usize) -> char { |
| 480 | if token == 0 { |
| 481 | '.' // BOS displays as dot |
| 482 | } else { |
| 483 | (b'a' + (token - 1) as u8) as char |
| 484 | } |
| 485 | } |
| 486 | |
| 487 | // Example: "emma" -> [0, 5, 13, 13, 1, 0] |
| 488 | // BOS, e, m, m, a, BOS |
| 489 | ``` |
| 490 | |
| 491 | **Training Usage:** |
| 492 | ```rust |
| 493 | let encoded = tokenizer::encode(name); |
| 494 | // encoded = [BOS, char1, ..., charN, BOS] |
| 495 | |
| 496 | // Input for forward pass: all but last token |
| 497 | let input_tokens = &encoded[..encoded.len() - 1]; |
| 498 | |
| 499 | // Targets for loss: all but first token |
| 500 | let targets = encoded[1..].to_vec(); |
| 501 | ``` |
| 502 | |
| 503 | --- |
| 504 | |
| 505 | #### **rng.rs** (Random Number Generation) |
| 506 | |
| 507 | **Xorshift32 PRNG:** |
| 508 | ```rust |
| 509 | pub fn next_u32(&mut self) -> u32 { |
| 510 | self.state ^= self.state << 13; |
| 511 | self.state ^= self.state >> 17; |
| 512 | self.state ^= self.state << 5; |
| 513 | self.state |
| 514 | } |
| 515 | ``` |
| 516 | - Period: ~2^32 - 1 |
| 517 | - Fast, no multiplication, adequate quality for training initialization |
| 518 | |
| 519 | **Uniform Float [0, 1):** |
| 520 | ```rust |
| 521 | pub fn next_f32(&mut self) -> f32 { |
| 522 | (self.next_u32() as f32) / (u32::MAX as f32 + 1.0) |
| 523 | } |
| 524 | ``` |
| 525 | |
| 526 | **Gaussian via Box-Muller:** |
| 527 | ```rust |
| 528 | pub fn next_gaussian(&mut self) -> f32 { |
| 529 | // Box-Muller produces TWO Gaussian samples per iteration |
| 530 | // Cache the second one for next call |
| 531 | if let Some(spare) = self.spare.take() { |
| 532 | return spare; |
| 533 | } |
| 534 | loop { |
| 535 | let u1 = self.next_f32(); |
| 536 | let u2 = self.next_f32(); |
| 537 | if u1 > 1e-10 { // Avoid log(0) |
| 538 | let r = sqrt(-2.0 * ln(u1)); |
| 539 | let theta = 2.0 * PI * u2; |
| 540 | self.spare = Some(r * sin(theta)); |
| 541 | return r * cos(theta); |
| 542 | } |
| 543 | } |
| 544 | } |
| 545 | ``` |
| 546 | |
| 547 | **Weight Initialization:** |
| 548 | ```rust |
| 549 | // In model.rs |
| 550 | pub fn init_weights(params: &mut [f32], rng: &mut Rng) { |
| 551 | for p in params.iter_mut() { |
| 552 | *p = rng.next_gaussian() * 0.1; // N(0, 0.01) for stable initialization |
| 553 | } |
| 554 | } |
| 555 | ``` |
| 556 | |
| 557 | **RNG Seed Sources:** |
| 558 | ```rust |
| 559 | #[cfg(target_arch = "xtensa")] |
| 560 | let seed = unsafe { esp_idf_svc::sys::esp_random() }; // Hardware RNG on ESP32 |
| 561 | |
| 562 | #[cfg(not(target_arch = "xtensa"))] |
| 563 | let seed = 42u32; // Fixed seed for reproducible host testing |
| 564 | ``` |
| 565 | |
| 566 | **Sampling for Generation:** |
| 567 | ```rust |
| 568 | pub fn sample_from_probs(&mut self, probs: &[f32]) -> usize { |
| 569 | let r = self.next_f32(); |
| 570 | let mut cumulative = 0.0; |
| 571 | for (i, &p) in probs.iter().enumerate() { |
| 572 | cumulative += p; |
| 573 | if r < cumulative { |
| 574 | return i; |
| 575 | } |
| 576 | } |
| 577 | probs.len() - 1 // Fallback (shouldn't reach if probs sum to 1) |
| 578 | } |
| 579 | ``` |
| 580 | |
| 581 | --- |
| 582 | |
| 583 | ### 4. BUILD & CONFIGURATION FILES |
| 584 | |
| 585 | **Cargo.toml:** |
| 586 | ```toml |
| 587 | [dependencies] |
| 588 | log = "0.4" |
| 589 | anyhow = "1" |
| 590 | |
| 591 | # ESP32-only dependencies |
| 592 | [target.'cfg(target_arch = "xtensa")'.dependencies] |
| 593 | esp-idf-hal = "0.45" |
| 594 | esp-idf-svc = "0.51" |
| 595 | esp-idf-sys = "0.36" |
| 596 | |
| 597 | # Host-only dependencies |
| 598 | [target.'cfg(not(target_arch = "xtensa"))'.dependencies] |
| 599 | env_logger = "0.11" |
| 600 | |
| 601 | [profile.dev] |
| 602 | opt-level = 2 # Optimization level 2 on both dev and release |
| 603 | debug = true |
| 604 | |
| 605 | [profile.release] |
| 606 | opt-level = 2 # Required for ESP-IDF compatibility |
| 607 | lto = "off" # Link-time optimization OFF |
| 608 | ``` |
| 609 | |
| 610 | **sdkconfig.defaults (ESP32 Configuration):** |
| 611 | ``` |
| 612 | # SRAM allocation: disable WiFi to free ~50-80KB for training |
| 613 | CONFIG_ESP_WIFI_ENABLED=n |
| 614 | |
| 615 | # Dual-core setup: pin main task to CPU1, leave CPU0 for FreeRTOS |
| 616 | CONFIG_ESP_MAIN_TASK_AFFINITY_CPU1=y |
| 617 | CONFIG_ESP_MAIN_TASK_STACK_SIZE=32768 # 32KB stack for matrix ops |
| 618 | |
| 619 | # Watchdog: don't monitor IDLE1 because training monopolizes CPU1 |
| 620 | CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU1=n |
| 621 | |
| 622 | # Flash: 4MB total |
| 623 | CONFIG_ESPTOOLPY_FLASHSIZE_4MB=y |
| 624 | |
| 625 | # Custom partition table |
| 626 | CONFIG_PARTITION_TABLE_CUSTOM=y |
| 627 | CONFIG_PARTITION_TABLE_CUSTOM_FILENAME="partitions.csv" |
| 628 | |
| 629 | # Logging |
| 630 | CONFIG_LOG_DEFAULT_LEVEL_INFO=y |
| 631 | ``` |
| 632 | |
| 633 | **partitions.csv (Flash Layout):** |
| 634 | ``` |
| 635 | Name, Type, SubType, Offset, Size |
| 636 | nvs, data, nvs, 0x9000, 0x6000, # NVS storage |
| 637 | phy_init, data, phy, 0xf000, 0x1000, # PHY init data |
| 638 | factory, app, factory, 0x10000, 0x3F0000, # Main app (4MB - 64KB) |
| 639 | ``` |
| 640 | |
| 641 | **Makefile:** |
| 642 | ```makefile |
| 643 | # Environment setup for ESP32 cross-compilation |
| 644 | RUSTUP_TOOLCHAIN := esp |
| 645 | LIBCLANG_PATH := ~/.rustup/toolchains/esp/xtensa-esp32-elf-clang/... |
| 646 | |
| 647 | # Commands: |
| 648 | # make test - Run host tests (stable toolchain) |
| 649 | # make build - Cross-compile for ESP32 (esp toolchain) |
| 650 | # make flash - Build + flash + monitor serial |
| 651 | # make monitor - Serial monitor only |
| 652 | # make clean - Clean build artifacts |
| 653 | ``` |
| 654 | |
| 655 | --- |
| 656 | |
| 657 | ### 5. TRAINING DATASET |
| 658 | |
| 659 | **File:** `/Users/[REDACTED]/Projects/esp32gpt/data/names.txt` |
| 660 | - 32,032 names (one per line) |
| 661 | - ~223 KB uncompressed |
| 662 | - Embedded in binary via `include_str!("../data/names.txt")` |
| 663 | - Lives in Flash memory, not SRAM (critical for ESP32 memory budget) |
| 664 | |
| 665 | **Sample names:** |
| 666 | ``` |
| 667 | emma |
| 668 | olivia |
| 669 | ava |
| 670 | isabella |
| 671 | sophia |
| 672 | charlotte |
| 673 | ... |
| 674 | ``` |
| 675 | |
| 676 | --- |
| 677 | |
| 678 | ### 6. MEMORY BUDGET ANALYSIS |
| 679 | |
| 680 | **Total SRAM Available:** ~520 KB (with WiFi disabled: ~300 KB usable) |
| 681 | |
| 682 | **Actual Usage:** |
| 683 | ``` |
| 684 | Flat parameter buffer: 17 KB (4,192 × 4 bytes) |
| 685 | Gradient buffer: 17 KB (4,192 × 4 bytes) |
| 686 | Adam m state: 17 KB (4,192 × 4 bytes) |
| 687 | Adam v state: 17 KB (4,192 × 4 bytes) |
| 688 | ForwardCache activations: ~25 KB (16×16 + 16×64 + 16×27 + attn probs) |
| 689 | Temporary vectors: ~10 KB (d_logits, d_probs, etc.) |
| 690 | Dataset (flash, not RAM): 0 KB (embedded via include_str!) |
| 691 | ───────────────────────────────── |
| 692 | TOTAL: ~100 KB (leaving ~200 KB buffer) |
| 693 | ``` |
| 694 | |
| 695 | **ForwardCache Breakdown:** |
| 696 | ``` |
| 697 | embeddings: 16 × 16 × 4 = 1,024 bytes |
| 698 | queries: 16 × 16 × 4 = 1,024 bytes |
| 699 | k_cache: 16 × 16 × 4 = 1,024 bytes |
| 700 | v_cache: 16 × 16 × 4 = 1,024 bytes |
| 701 | att_outs: 16 × 16 × 4 = 1,024 bytes |
| 702 | proj_outs: 16 × 16 × 4 = 1,024 bytes |
| 703 | residual1: 16 × 16 × 4 = 1,024 bytes |
| 704 | ffn_hidden: 16 × 64 × 4 = 4,096 bytes |
| 705 | ffn_relu: 16 × 64 × 4 = 4,096 bytes |
| 706 | ffn_out: 16 × 16 × 4 = 1,024 bytes |
| 707 | residual2: 16 × 16 × 4 = 1,024 bytes |
| 708 | logits: 16 × 27 × 4 = 1,728 bytes |
| 709 | att_probs: 16 × 4 × 16 × 4 = 4,096 bytes |
| 710 | ───────────────────────────────── |
| 711 | TOTAL: ~25 KB |
| 712 | ``` |
| 713 | |
| 714 | **Key Optimization:** The cache is pre-allocated once and reused via `cache.clear()` rather than deallocating/reallocating each iteration, avoiding heap fragmentation on embedded systems. |
| 715 | |
| 716 | --- |
| 717 | |
| 718 | ### 7. TESTING INFRASTRUCTURE |
| 719 | |
| 720 | **Host Tests (14 total, run via `make test`):** |
| 721 | |
| 722 | **tensor.rs tests:** |
| 723 | - `test_vec_mat_mul`: Verify matrix multiplication |
| 724 | - `test_backward_x`: Numerical gradient check for x gradients |
| 725 | - `test_softmax`: Output sums to 1, respects ordering |
| 726 | - `test_relu`: Zeros negatives, preserves positives |
| 727 | |
| 728 | **model.rs tests:** |
| 729 | - `test_param_count`: Assert NUM_PARAMS == 4192 |
| 730 | - `test_forward_smoke`: Check finite logits and valid probabilities |
| 731 | - `test_generate_smoke`: Generate names with random weights (basic sanity) |
| 732 | |
| 733 | **backward.rs tests:** |
| 734 | - `test_gradient_check`: CRITICAL numerical gradient verification |
| 735 | - Checks every 17th parameter (256 params total) against finite differences |
| 736 | - eps = 0.01 (not 1e-3 due to f32 precision limits at opt-level=2) |
| 737 | - Passes if: `abs_err < 1e-3 OR rel_err < 0.1` |
| 738 | - Reports max relative error |
| 739 | |
| 740 | **tokenizer.rs tests:** |
| 741 | - `test_encode_decode`: "emma" -> [0,5,13,13,1,0] and back |
| 742 | - `test_roundtrip`: Full encode/decode for multiple names |
| 743 | |
| 744 | **rng.rs tests:** |
| 745 | - `test_uniform_range`: All samples in [0, 1) |
| 746 | - `test_gaussian_mean`: 10K samples, mean close to 0 |
| 747 | - `test_sample_from_probs`: Categorical sampling correctness |
| 748 | |
| 749 | **optimizer.rs tests:** |
| 750 | - `test_adam_converges`: Minimize f(x)=x^2 from x=5, verify x≈0 after 500 steps |
| 751 | |
| 752 | --- |
| 753 | |
| 754 | ### 8. KEY DESIGN DECISIONS & INSIGHTS |
| 755 | |
| 756 | **1. Why Manual Backward Instead of PyTorch Autograd?** |
| 757 | - PyTorch's scalar-level autograd creates 30-50K graph nodes per forward pass |
| 758 | - That's 1-2 MB of memory — more than ESP32's entire SRAM |
| 759 | - Manual backward stores only activations (25 KB) needed for backprop |
| 760 | - Explicitly derived and numerically verified |
| 761 | |
| 762 | **2. KV Cache for Autoregressive Generation** |
| 763 | - In inference, we build up the full K and V history as we generate |
| 764 | - At position t, position t's query attends to all positions 0..=t |
| 765 | - Storing K and V avoids recomputing them for every new position |
| 766 | - Same KV cache is reused across training and inference (via `cache.clear()`) |
| 767 | |
| 768 | **3. Reverse-Order Backward Pass** |
| 769 | - Position t's attention reads from positions 0..=t |
| 770 | - So d_k[i] and d_v[i] accumulate contributions from ALL positions t >= i |
| 771 | - Processing backward ensures d_k[t] and d_v[t] are complete before use |
| 772 | - Critical for correctness of attention backward |
| 773 | |
| 774 | **4. Linear Learning Rate Decay** |
| 775 | - Helps convergence: aggressive early learning, gentle fine-tuning |
| 776 | - Formula: `lr = 0.01 * (1 - step / 1000)` |
| 777 | - At step 0: lr = 0.01; at step 1000: lr = 0 |
| 778 | |
| 779 | **5. Numerical Stability Choices** |
| 780 | - Softmax: subtract max before exp() to prevent overflow |
| 781 | - Gradient check epsilon: eps=0.01 (not 1e-3) because opt-level=2 loses precision |
| 782 | - Loss clipping: `ln(max(probs[target], 1e-10))` prevents log(0) |
| 783 | - Xorshift32: fast 32-bit PRNG, adequate for weight initialization (not cryptographic) |
| 784 | |
| 785 | **6. Row-Major Flat Memory Layout** |
| 786 | - All 2D matrices stored flat: `M[i][j] = flat[i * cols + j]` |
| 787 | - Simplifies indexing, improves cache locality |
| 788 | - Consistent throughout codebase |
| 789 | |
| 790 | **7. CPU1 Pinning + IDLE1 Watchdog Disabled** |
| 791 | - Training loop monopolizes CPU1 |
| 792 | - CPU0 free for FreeRTOS, WiFi management (though WiFi is disabled) |
| 793 | - IDLE1 watchdog disabled because CPU1 never yields (infinite training loop) |
| 794 | - This is safe because CPU0's watchdog still monitors system health |
| 795 | |
| 796 | **8. Dual-Target Compilation** |
| 797 | - All ML logic is platform-independent |
| 798 | - Only `main.rs` has `#[cfg(target_arch = "xtensa")]` blocks |
| 799 | - Host compilation with `--target aarch64-apple-darwin` for testing |
| 800 | - Same tests run on both targets |
| 801 | |
| 802 | --- |
| 803 | |
| 804 | ### 9. COMPLETE TRAINING FLOW SUMMARY |
| 805 | |
| 806 | ``` |
| 807 | 1. Load 32K names from embedded data |
| 808 | 2. Initialize random weights: N(0, 0.01) |
| 809 | 3. For 1000 steps: |
| 810 | a. Sample random name from dataset |
| 811 | b. Skip if too long (>17 tokens) |
| 812 | c. Encode: [BOS, chars..., BOS] |
| 813 | d. Input tokens: all but last; Targets: all but first |
| 814 | e. Forward pass: compute embeddings, attention, FFN, logits |
| 815 | f. Backward pass (reverse order): |
| 816 | - Compute loss gradient |
| 817 | - Backprop through output head |
| 818 | - Backprop through FFN |
| 819 | - Backprop through attention (accumulate KV gradients) |
| 820 | - Backprop through embeddings |
| 821 | g. Adam step with decaying LR |
| 822 | h. Log loss every 100 steps |
| 823 | 4. Infinite inference loop: |
| 824 | - Generate names autoregressive sampling |
| 825 | - Temperature-scaled softmax |
| 826 | - Stop at BOS or block size |
| 827 | ``` |
| 828 | |
| 829 | --- |
| 830 | |
| 831 | ### 10. SOURCE FILE DEPENDENCIES & MODULE STRUCTURE |
| 832 | |
| 833 | ``` |
| 834 | main.rs (entry point) |
| 835 | ├── model.rs (forward pass, parameter layout) |
| 836 | │ ├── tensor.rs (vec_mat_mul, softmax, relu) |
| 837 | │ └── rng.rs (Gaussian init) |
| 838 | ├── backward.rs (gradient computation) |
| 839 | │ ├── tensor.rs (vec_mat_mul_backward, vec_add, vec_dot, softmax) |
| 840 | │ └── model.rs (parameters, cache, constants) |
| 841 | ├── optimizer.rs (Adam updates) |
| 842 | ├── tokenizer.rs (encode/decode) |
| 843 | ├── rng.rs (sampling, random_index) |
| 844 | └── (all use log crate for logging) |
| 845 | ``` |
| 846 | |
| 847 | --- |
| 848 | |
| 849 | ### 11. ABSOLUTE FILE PATHS |
| 850 | |
| 851 | ``` |
| 852 | /Users/[REDACTED]/Projects/esp32gpt/src/main.rs |
| 853 | /Users/[REDACTED]/Projects/esp32gpt/src/model.rs |
| 854 | /Users/[REDACTED]/Projects/esp32gpt/src/backward.rs |
| 855 | /Users/[REDACTED]/Projects/esp32gpt/src/tensor.rs |
| 856 | /Users/[REDACTED]/Projects/esp32gpt/src/optimizer.rs |
| 857 | /Users/[REDACTED]/Projects/esp32gpt/src/tokenizer.rs |
| 858 | /Users/[REDACTED]/Projects/esp32gpt/src/rng.rs |
| 859 | /Users/[REDACTED]/Projects/esp32gpt/Cargo.toml |
| 860 | /Users/[REDACTED]/Projects/esp32gpt/Cargo.lock |
| 861 | /Users/[REDACTED]/Projects/esp32gpt/build.rs |
| 862 | /Users/[REDACTED]/Projects/esp32gpt/Makefile |
| 863 | /Users/[REDACTED]/Projects/esp32gpt/sdkconfig.defaults |
| 864 | /Users/[REDACTED]/Projects/esp32gpt/partitions.csv |
| 865 | /Users/[REDACTED]/Projects/esp32gpt/data/names.txt |
| 866 | /Users/[REDACTED]/Projects/esp32gpt/README.md |
| 867 | /Users/[REDACTED]/Projects/esp32gpt/CLAUDE.md |
| 868 | /Users/[REDACTED]/Projects/esp32gpt/LICENSE |
| 869 | ``` |
| 870 | |
| 871 | --- |
| 872 | |
| 873 | This comprehensive exploration covers every aspect of the esp32gpt implementation. All code has been read completely, all architecture decisions explained, and all memory/performance trade-offs justified. The implementation is remarkably elegant for a constrained embedded system — a full 1-layer transformer with 4,192 parameters training from scratch on a microcontroller with only 300 KB of usable SRAM. |
| 874 | agentId: a0e6de1ae6b863b74 (for resuming to continue this agent's work if needed) |
| 875 | <usage>total_tokens: 52243 |
| 876 | tool_uses: 20 |
| 877 | duration_ms: 95615</usage> |
| 1 | /// GPT model: 1-layer transformer with multi-head self-attention. |
| 2 | /// |
| 3 | /// All parameters live in a single flat f32 buffer. The forward pass processes |
| 4 | /// tokens sequentially, building a KV cache, and stores all intermediate |
| 5 | /// activations needed by the backward pass. |
| 6 | |
| 7 | use crate::rng::Rng; |
| 8 | use crate::tensor::{softmax, vec_add, vec_mat_mul, relu}; |
| 9 | use crate::tokenizer::{VOCAB_SIZE, BOS}; |
| 10 | |
| 11 | // Architecture constants (matching Karpathy's microgpt). |
| 12 | pub const EMBED_DIM: usize = 16; |
| 13 | pub const NUM_HEADS: usize = 4; |
| 14 | pub const HEAD_DIM: usize = EMBED_DIM / NUM_HEADS; // 4 |
| 15 | pub const FFN_DIM: usize = EMBED_DIM * 4; // 64 |
| 16 | pub const BLOCK_SIZE: usize = 16; |
| 17 | |
| 18 | // Parameter layout in the flat buffer. |
| 19 | pub const WTE_OFFSET: usize = 0; |
| 20 | pub const WTE_SIZE: usize = VOCAB_SIZE * EMBED_DIM; // 27 × 16 = 432 |
| 21 | |
| 22 | pub const WPE_OFFSET: usize = WTE_OFFSET + WTE_SIZE; // 432 |
| 23 | pub const WPE_SIZE: usize = BLOCK_SIZE * EMBED_DIM; // 16 × 16 = 256 |
| 24 | |
| 25 | pub const WQ_OFFSET: usize = WPE_OFFSET + WPE_SIZE; // 688 |
| 26 | pub const WQ_SIZE: usize = EMBED_DIM * EMBED_DIM; // 256 |
| 27 | |
| 28 | pub const WK_OFFSET: usize = WQ_OFFSET + WQ_SIZE; // 944 |
| 29 | pub const WK_SIZE: usize = EMBED_DIM * EMBED_DIM; // 256 |
| 30 | |
| 31 | pub const WV_OFFSET: usize = WK_OFFSET + WK_SIZE; // 1200 |
| 32 | pub const WV_SIZE: usize = EMBED_DIM * EMBED_DIM; // 256 |
| 33 | |
| 34 | pub const WO_OFFSET: usize = WV_OFFSET + WV_SIZE; // 1456 |
| 35 | pub const WO_SIZE: usize = EMBED_DIM * EMBED_DIM; // 256 |
| 36 | |
| 37 | pub const W1_OFFSET: usize = WO_OFFSET + WO_SIZE; // 1712 |
| 38 | pub const W1_SIZE: usize = EMBED_DIM * FFN_DIM; // 1024 |
| 39 | |
| 40 | pub const W2_OFFSET: usize = W1_OFFSET + W1_SIZE; // 2736 |
| 41 | pub const W2_SIZE: usize = FFN_DIM * EMBED_DIM; // 1024 |
| 42 | |
| 43 | pub const WOUT_OFFSET: usize = W2_OFFSET + W2_SIZE; // 3760 |
| 44 | pub const WOUT_SIZE: usize = EMBED_DIM * VOCAB_SIZE; // 432 |
| 45 | |
| 46 | pub const NUM_PARAMS: usize = WOUT_OFFSET + WOUT_SIZE; // 4192 |
| 47 | |
| 48 | // Slice accessors for parameter groups. |
| 49 | #[allow(dead_code)] |
| 50 | pub fn wte(p: &[f32]) -> &[f32] { &p[WTE_OFFSET..WTE_OFFSET + WTE_SIZE] } |
| 51 | #[allow(dead_code)] |
| 52 | pub fn wpe(p: &[f32]) -> &[f32] { &p[WPE_OFFSET..WPE_OFFSET + WPE_SIZE] } |
| 53 | pub fn wq(p: &[f32]) -> &[f32] { &p[WQ_OFFSET..WQ_OFFSET + WQ_SIZE] } |
| 54 | pub fn wk(p: &[f32]) -> &[f32] { &p[WK_OFFSET..WK_OFFSET + WK_SIZE] } |
| 55 | pub fn wv(p: &[f32]) -> &[f32] { &p[WV_OFFSET..WV_OFFSET + WV_SIZE] } |
| 56 | pub fn wo(p: &[f32]) -> &[f32] { &p[WO_OFFSET..WO_OFFSET + WO_SIZE] } |
| 57 | pub fn w1(p: &[f32]) -> &[f32] { &p[W1_OFFSET..W1_OFFSET + W1_SIZE] } |
| 58 | pub fn w2(p: &[f32]) -> &[f32] { &p[W2_OFFSET..W2_OFFSET + W2_SIZE] } |
| 59 | pub fn wout(p: &[f32]) -> &[f32] { &p[WOUT_OFFSET..WOUT_OFFSET + WOUT_SIZE] } |
| 60 | |
| 61 | /// Embedding row for a given token. |
| 62 | pub fn wte_row(p: &[f32], token: usize) -> &[f32] { |
| 63 | let start = WTE_OFFSET + token * EMBED_DIM; |
| 64 | &p[start..start + EMBED_DIM] |
| 65 | } |
| 66 | |
| 67 | /// Position embedding row. |
| 68 | pub fn wpe_row(p: &[f32], pos: usize) -> &[f32] { |
| 69 | let start = WPE_OFFSET + pos * EMBED_DIM; |
| 70 | &p[start..start + EMBED_DIM] |
| 71 | } |
| 72 | |
| 73 | /// Initialize all weights with small Gaussian noise. |
| 74 | pub fn init_weights(params: &mut [f32], rng: &mut Rng) { |
| 75 | for p in params.iter_mut() { |
| 76 | *p = rng.next_gaussian() * 0.1; |
| 77 | } |
| 78 | } |
| 79 | |
| 80 | /// Cached activations from the forward pass, needed by backward. |
| 81 | /// |
| 82 | /// All 2D data is stored flat with row-major indexing: `[position * dim + i]`. |
| 83 | /// Attention probs are stored as `[position * NUM_HEADS * BLOCK_SIZE + head * BLOCK_SIZE + i]`. |
| 84 | pub struct ForwardCache { |
| 85 | pub seq_len: usize, |
| 86 | pub tokens: Vec<usize>, |
| 87 | |
| 88 | // Per-position activations, flat: [pos * dim + i] |
| 89 | pub embeddings: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 90 | pub queries: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 91 | pub k_cache: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 92 | pub v_cache: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 93 | pub att_outs: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 94 | pub proj_outs: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 95 | pub residual1: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 96 | pub ffn_hidden: Vec<f32>, // [BLOCK_SIZE × FFN_DIM] |
| 97 | pub ffn_relu: Vec<f32>, // [BLOCK_SIZE × FFN_DIM] |
| 98 | pub ffn_out: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 99 | pub residual2: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 100 | pub logits: Vec<f32>, // [BLOCK_SIZE × VOCAB_SIZE] |
| 101 | |
| 102 | // Attention probs: [pos * NUM_HEADS * BLOCK_SIZE + head * BLOCK_SIZE + i] |
| 103 | pub att_probs: Vec<f32>, // [BLOCK_SIZE × NUM_HEADS × BLOCK_SIZE] |
| 104 | } |
| 105 | |
| 106 | impl ForwardCache { |
| 107 | pub fn new() -> Self { |
| 108 | Self { |
| 109 | seq_len: 0, |
| 110 | tokens: Vec::new(), |
| 111 | embeddings: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 112 | queries: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 113 | k_cache: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 114 | v_cache: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 115 | att_outs: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 116 | proj_outs: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 117 | residual1: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 118 | ffn_hidden: vec![0.0; BLOCK_SIZE * FFN_DIM], |
| 119 | ffn_relu: vec![0.0; BLOCK_SIZE * FFN_DIM], |
| 120 | ffn_out: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 121 | residual2: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 122 | logits: vec![0.0; BLOCK_SIZE * VOCAB_SIZE], |
| 123 | att_probs: vec![0.0; BLOCK_SIZE * NUM_HEADS * BLOCK_SIZE], |
| 124 | } |
| 125 | } |
| 126 | |
| 127 | /// Reset for a new sequence without reallocating. |
| 128 | pub fn clear(&mut self) { |
| 129 | self.seq_len = 0; |
| 130 | self.tokens.clear(); |
| 131 | // Zero out would be ideal but forward pass overwrites what it uses. |
| 132 | } |
| 133 | |
| 134 | // Accessors for a specific position's data. |
| 135 | |
| 136 | pub fn emb_at(&self, pos: usize) -> &[f32] { |
| 137 | &self.embeddings[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 138 | } |
| 139 | pub fn query_at(&self, pos: usize) -> &[f32] { |
| 140 | &self.queries[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 141 | } |
| 142 | #[allow(dead_code)] |
| 143 | pub fn k_at(&self, pos: usize) -> &[f32] { |
| 144 | &self.k_cache[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 145 | } |
| 146 | #[allow(dead_code)] |
| 147 | pub fn v_at(&self, pos: usize) -> &[f32] { |
| 148 | &self.v_cache[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 149 | } |
| 150 | pub fn att_out_at(&self, pos: usize) -> &[f32] { |
| 151 | &self.att_outs[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 152 | } |
| 153 | pub fn res1_at(&self, pos: usize) -> &[f32] { |
| 154 | &self.residual1[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 155 | } |
| 156 | pub fn ffn_hidden_at(&self, pos: usize) -> &[f32] { |
| 157 | &self.ffn_hidden[pos * FFN_DIM..(pos + 1) * FFN_DIM] |
| 158 | } |
| 159 | pub fn ffn_relu_at(&self, pos: usize) -> &[f32] { |
| 160 | &self.ffn_relu[pos * FFN_DIM..(pos + 1) * FFN_DIM] |
| 161 | } |
| 162 | pub fn res2_at(&self, pos: usize) -> &[f32] { |
| 163 | &self.residual2[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 164 | } |
| 165 | pub fn logits_at(&self, pos: usize) -> &[f32] { |
| 166 | &self.logits[pos * VOCAB_SIZE..(pos + 1) * VOCAB_SIZE] |
| 167 | } |
| 168 | |
| 169 | /// Attention prob for position `pos`, head `h`, attending to position `i`. |
| 170 | pub fn att_prob(&self, pos: usize, h: usize, i: usize) -> f32 { |
| 171 | self.att_probs[pos * NUM_HEADS * BLOCK_SIZE + h * BLOCK_SIZE + i] |
| 172 | } |
| 173 | } |
| 174 | |
| 175 | /// Run the forward pass for a full sequence of input tokens. |
| 176 | /// |
| 177 | /// Input tokens are the tokens to process (excluding the final target). |
| 178 | /// For a name like "emma", the encoded form is [BOS, e, m, m, a, BOS], |
| 179 | /// and the input tokens would be [BOS, e, m, m, a] (first 5 tokens). |
| 180 | /// The targets are [e, m, m, a, BOS] (last 5 tokens). |
| 181 | pub fn forward(params: &[f32], tokens: &[usize], cache: &mut ForwardCache) { |
| 182 | let seq_len = tokens.len(); |
| 183 | assert!(seq_len <= BLOCK_SIZE); |
| 184 | cache.seq_len = seq_len; |
| 185 | cache.tokens = tokens.to_vec(); |
| 186 | |
| 187 | let scale = 1.0 / (HEAD_DIM as f32).sqrt(); |
| 188 | |
| 189 | for t in 0..seq_len { |
| 190 | let token = tokens[t]; |
| 191 | |
| 192 | // 1. Token + position embeddings |
| 193 | let tok_emb = wte_row(params, token); |
| 194 | let pos_emb = wpe_row(params, t); |
| 195 | let emb = vec_add(tok_emb, pos_emb); |
| 196 | cache.embeddings[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&emb); |
| 197 | |
| 198 | // 2. Q, K, V projections |
| 199 | let q = vec_mat_mul(&emb, wq(params), EMBED_DIM, EMBED_DIM); |
| 200 | cache.queries[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&q); |
| 201 | |
| 202 | let k = vec_mat_mul(&emb, wk(params), EMBED_DIM, EMBED_DIM); |
| 203 | cache.k_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&k); |
| 204 | |
| 205 | let v = vec_mat_mul(&emb, wv(params), EMBED_DIM, EMBED_DIM); |
| 206 | cache.v_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&v); |
| 207 | |
| 208 | // 3. Multi-head causal self-attention |
| 209 | let mut att_out = vec![0.0f32; EMBED_DIM]; |
| 210 | for h in 0..NUM_HEADS { |
| 211 | let q_h = &q[h * HEAD_DIM..(h + 1) * HEAD_DIM]; |
| 212 | |
| 213 | // Compute attention scores for positions 0..t+1 |
| 214 | let mut scores = vec![0.0f32; t + 1]; |
| 215 | for i in 0..=t { |
| 216 | let k_i = &cache.k_cache[i * EMBED_DIM + h * HEAD_DIM..i * EMBED_DIM + (h + 1) * HEAD_DIM]; |
| 217 | let mut dot = 0.0; |
| 218 | for d in 0..HEAD_DIM { |
| 219 | dot += q_h[d] * k_i[d]; |
| 220 | } |
| 221 | scores[i] = dot * scale; |
| 222 | } |
| 223 | |
| 224 | // Softmax |
| 225 | let probs = softmax(&scores); |
| 226 | for i in 0..=t { |
| 227 | cache.att_probs[t * NUM_HEADS * BLOCK_SIZE + h * BLOCK_SIZE + i] = probs[i]; |
| 228 | } |
| 229 | |
| 230 | // Weighted sum of values |
| 231 | for i in 0..=t { |
| 232 | let v_i = &cache.v_cache[i * EMBED_DIM + h * HEAD_DIM..i * EMBED_DIM + (h + 1) * HEAD_DIM]; |
| 233 | for d in 0..HEAD_DIM { |
| 234 | att_out[h * HEAD_DIM + d] += probs[i] * v_i[d]; |
| 235 | } |
| 236 | } |
| 237 | } |
| 238 | cache.att_outs[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&att_out); |
| 239 | |
| 240 | // 4. Output projection + residual |
| 241 | let proj = vec_mat_mul(&att_out, wo(params), EMBED_DIM, EMBED_DIM); |
| 242 | cache.proj_outs[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&proj); |
| 243 | let res1 = vec_add(&emb, &proj); |
| 244 | cache.residual1[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&res1); |
| 245 | |
| 246 | // 5. FFN: ReLU(x @ W1) @ W2 |
| 247 | let hidden = vec_mat_mul(&res1, w1(params), EMBED_DIM, FFN_DIM); |
| 248 | cache.ffn_hidden[t * FFN_DIM..(t + 1) * FFN_DIM].copy_from_slice(&hidden); |
| 249 | let activated = relu(&hidden); |
| 250 | cache.ffn_relu[t * FFN_DIM..(t + 1) * FFN_DIM].copy_from_slice(&activated); |
| 251 | let ffn = vec_mat_mul(&activated, w2(params), FFN_DIM, EMBED_DIM); |
| 252 | cache.ffn_out[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&ffn); |
| 253 | |
| 254 | // 6. FFN residual |
| 255 | let res2 = vec_add(&res1, &ffn); |
| 256 | cache.residual2[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&res2); |
| 257 | |
| 258 | // 7. Output logits |
| 259 | let logits = vec_mat_mul(&res2, wout(params), EMBED_DIM, VOCAB_SIZE); |
| 260 | cache.logits[t * VOCAB_SIZE..(t + 1) * VOCAB_SIZE].copy_from_slice(&logits); |
| 261 | } |
| 262 | } |
| 263 | |
| 264 | /// Generate a name by autoregressive sampling. |
| 265 | pub fn generate(params: &[f32], cache: &mut ForwardCache, rng: &mut Rng, temperature: f32) -> String { |
| 266 | let mut tokens = vec![BOS]; |
| 267 | |
| 268 | for _ in 0..BLOCK_SIZE - 1 { |
| 269 | cache.clear(); |
| 270 | forward(params, &tokens, cache); |
| 271 | |
| 272 | let logits = cache.logits_at(tokens.len() - 1); |
| 273 | let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect(); |
| 274 | let probs = softmax(&scaled); |
| 275 | let next = rng.sample_from_probs(&probs); |
| 276 | |
| 277 | if next == BOS { |
| 278 | break; |
| 279 | } |
| 280 | tokens.push(next); |
| 281 | } |
| 282 | |
| 283 | use crate::tokenizer::decode; |
| 284 | tokens[1..].iter().map(|&t| decode(t)).collect() |
| 285 | } |
| 286 | |
| 287 | #[cfg(test)] |
| 288 | mod tests { |
| 289 | use super::*; |
| 290 | |
| 291 | #[test] |
| 292 | fn test_param_count() { |
| 293 | assert_eq!(NUM_PARAMS, 4192); |
| 294 | } |
| 295 | |
| 296 | #[test] |
| 297 | fn test_forward_smoke() { |
| 298 | let mut rng = Rng::new(42); |
| 299 | let mut params = vec![0.0f32; NUM_PARAMS]; |
| 300 | init_weights(&mut params, &mut rng); |
| 301 | |
| 302 | let tokens = vec![0, 5, 13]; // BOS, e, m |
| 303 | let mut cache = ForwardCache::new(); |
| 304 | forward(¶ms, &tokens, &mut cache); |
| 305 | |
| 306 | // Logits should be finite |
| 307 | for t in 0..3 { |
| 308 | let logits = cache.logits_at(t); |
| 309 | for &l in logits { |
| 310 | assert!(l.is_finite(), "non-finite logit"); |
| 311 | } |
| 312 | // Softmax of logits should sum to 1 |
| 313 | let probs = softmax(logits); |
| 314 | let sum: f32 = probs.iter().sum(); |
| 315 | assert!((sum - 1.0).abs() < 1e-5); |
| 316 | } |
| 317 | } |
| 318 | |
| 319 | #[test] |
| 320 | fn test_generate_smoke() { |
| 321 | let mut rng = Rng::new(42); |
| 322 | let mut params = vec![0.0f32; NUM_PARAMS]; |
| 323 | init_weights(&mut params, &mut rng); |
| 324 | |
| 325 | let mut cache = ForwardCache::new(); |
| 326 | // With random weights, some names may be empty (BOS sampled first). |
| 327 | // Try multiple times — at least one should produce a non-empty name. |
| 328 | let mut got_nonempty = false; |
| 329 | for _ in 0..20 { |
| 330 | let name = generate(¶ms, &mut cache, &mut rng, 1.0); |
| 331 | assert!(name.chars().all(|c| c >= 'a' && c <= 'z')); |
| 332 | if !name.is_empty() { |
| 333 | got_nonempty = true; |
| 334 | } |
| 335 | } |
| 336 | assert!(got_nonempty, "all 20 generated names were empty"); |
| 337 | } |
| 338 | } |
| 339 | |
| 1 | /// Manual backward pass for the GPT model. |
| 2 | /// |
| 3 | /// Computes parameter gradients by backpropagating through each position in |
| 4 | /// reverse order. The tricky part is attention: position t's query attends to |
| 5 | /// all keys/values at positions 0..t, so key/value gradients accumulate |
| 6 | /// contributions from multiple future positions. |
| 7 | |
| 8 | use crate::model::*; |
| 9 | use crate::tensor::*; |
| 10 | use crate::tokenizer::VOCAB_SIZE; |
| 11 | |
| 12 | /// Backpropagate through the full sequence, accumulating gradients. |
| 13 | /// |
| 14 | /// Returns the average cross-entropy loss over the sequence. |
| 15 | /// `targets[t]` is the target token for position t (i.e., `tokens[t+1]` in the |
| 16 | /// encoded name). `grads` must be pre-zeroed. |
| 17 | pub fn backward( |
| 18 | params: &[f32], |
| 19 | cache: &ForwardCache, |
| 20 | targets: &[usize], |
| 21 | grads: &mut [f32], |
| 22 | ) -> f32 { |
| 23 | let seq_len = cache.seq_len; |
| 24 | assert_eq!(targets.len(), seq_len); |
| 25 | let scale = 1.0 / (HEAD_DIM as f32).sqrt(); |
| 26 | let inv_seq_len = 1.0 / seq_len as f32; |
| 27 | |
| 28 | // Accumulated key/value gradients: d_k_cache[pos] and d_v_cache[pos] collect |
| 29 | // contributions from all positions that attend to them. |
| 30 | let mut d_k_cache = vec![0.0f32; seq_len * EMBED_DIM]; |
| 31 | let mut d_v_cache = vec![0.0f32; seq_len * EMBED_DIM]; |
| 32 | |
| 33 | let mut total_loss = 0.0f32; |
| 34 | |
| 35 | // Process positions in reverse so that when we reach position t, |
| 36 | // d_k_cache[t] and d_v_cache[t] are complete. |
| 37 | for t in (0..seq_len).rev() { |
| 38 | let target = targets[t]; |
| 39 | |
| 40 | // --- Cross-entropy loss gradient --- |
| 41 | let logits = cache.logits_at(t); |
| 42 | let probs = softmax(logits); |
| 43 | total_loss += -probs[target].max(1e-10).ln(); |
| 44 | |
| 45 | // d_logits = (probs - one_hot(target)) / seq_len |
| 46 | let mut d_logits = probs; |
| 47 | d_logits[target] -= 1.0; |
| 48 | for v in d_logits.iter_mut() { |
| 49 | *v *= inv_seq_len; |
| 50 | } |
| 51 | |
| 52 | // --- Output projection backward: logits = res2 @ Wout --- |
| 53 | let res2 = cache.res2_at(t); |
| 54 | let d_res2 = vec_mat_mul_backward_x(&d_logits, wout(params), EMBED_DIM, VOCAB_SIZE); |
| 55 | vec_mat_mul_backward_w(res2, &d_logits, &mut grads[WOUT_OFFSET..WOUT_OFFSET + WOUT_SIZE], EMBED_DIM, VOCAB_SIZE); |
| 56 | |
| 57 | // --- FFN residual backward: res2 = res1 + ffn_out --- |
| 58 | // d_res1 and d_ffn_out both receive d_res2 |
| 59 | let mut d_res1 = d_res2.clone(); |
| 60 | let d_ffn_out = d_res2; |
| 61 | |
| 62 | // --- FFN down-projection backward: ffn_out = ffn_relu @ W2 --- |
| 63 | let ffn_relu = cache.ffn_relu_at(t); |
| 64 | let d_ffn_relu = vec_mat_mul_backward_x(&d_ffn_out, w2(params), FFN_DIM, EMBED_DIM); |
| 65 | vec_mat_mul_backward_w(ffn_relu, &d_ffn_out, &mut grads[W2_OFFSET..W2_OFFSET + W2_SIZE], FFN_DIM, EMBED_DIM); |
| 66 | |
| 67 | // --- ReLU backward --- |
| 68 | let ffn_hidden = cache.ffn_hidden_at(t); |
| 69 | let d_ffn_hidden: Vec<f32> = d_ffn_relu.iter().zip(ffn_hidden.iter()) |
| 70 | .map(|(&dg, &h)| if h > 0.0 { dg } else { 0.0 }) |
| 71 | .collect(); |
| 72 | |
| 73 | // --- FFN up-projection backward: ffn_hidden = res1 @ W1 --- |
| 74 | let res1 = cache.res1_at(t); |
| 75 | let d_res1_from_ffn = vec_mat_mul_backward_x(&d_ffn_hidden, w1(params), EMBED_DIM, FFN_DIM); |
| 76 | vec_mat_mul_backward_w(res1, &d_ffn_hidden, &mut grads[W1_OFFSET..W1_OFFSET + W1_SIZE], EMBED_DIM, FFN_DIM); |
| 77 | vec_add_inplace(&mut d_res1, &d_res1_from_ffn); |
| 78 | |
| 79 | // --- Attention residual backward: res1 = emb + proj --- |
| 80 | let d_proj = d_res1.clone(); |
| 81 | let mut d_emb = d_res1; // emb receives same gradient from residual |
| 82 | |
| 83 | // --- Output projection backward: proj = att_out @ Wo --- |
| 84 | let att_out = cache.att_out_at(t); |
| 85 | let d_att_out = vec_mat_mul_backward_x(&d_proj, wo(params), EMBED_DIM, EMBED_DIM); |
| 86 | vec_mat_mul_backward_w(att_out, &d_proj, &mut grads[WO_OFFSET..WO_OFFSET + WO_SIZE], EMBED_DIM, EMBED_DIM); |
| 87 | |
| 88 | // --- Multi-head attention backward --- |
| 89 | let mut d_q = vec![0.0f32; EMBED_DIM]; |
| 90 | |
| 91 | for h in 0..NUM_HEADS { |
| 92 | let ho = h * HEAD_DIM; |
| 93 | let d_out_h = &d_att_out[ho..ho + HEAD_DIM]; |
| 94 | let q_h = &cache.query_at(t)[ho..ho + HEAD_DIM]; |
| 95 | |
| 96 | // out_h = sum_i probs[i] * v_cache_h[i] |
| 97 | for i in 0..=t { |
| 98 | let v_h = &cache.v_cache[i * EMBED_DIM + ho..i * EMBED_DIM + ho + HEAD_DIM]; |
| 99 | let prob = cache.att_prob(t, h, i); |
| 100 | |
| 101 | // d_probs[i] = dot(d_out_h, v_h) |
| 102 | let d_prob_i = vec_dot(d_out_h, v_h); |
| 103 | |
| 104 | // d_v_cache_h[i] += prob * d_out_h |
| 105 | for d in 0..HEAD_DIM { |
| 106 | d_v_cache[i * EMBED_DIM + ho + d] += prob * d_out_h[d]; |
| 107 | } |
| 108 | |
| 109 | // We'll collect d_probs and do softmax backward below. |
| 110 | // For now, just store d_prob_i temporarily by recomputing in the next loop. |
| 111 | let _ = d_prob_i; // used below |
| 112 | } |
| 113 | |
| 114 | // Softmax backward: d_scores = probs * (d_probs - sum(probs * d_probs)) |
| 115 | // Recompute d_probs (small loop, avoids allocation). |
| 116 | let mut d_probs = vec![0.0f32; t + 1]; |
| 117 | for i in 0..=t { |
| 118 | let v_h = &cache.v_cache[i * EMBED_DIM + ho..i * EMBED_DIM + ho + HEAD_DIM]; |
| 119 | d_probs[i] = vec_dot(d_out_h, v_h); |
| 120 | } |
| 121 | |
| 122 | let mut dot_sum = 0.0f32; |
| 123 | for i in 0..=t { |
| 124 | dot_sum += cache.att_prob(t, h, i) * d_probs[i]; |
| 125 | } |
| 126 | |
| 127 | for i in 0..=t { |
| 128 | let prob = cache.att_prob(t, h, i); |
| 129 | let d_score = prob * (d_probs[i] - dot_sum) * scale; |
| 130 | |
| 131 | // scores[i] = dot(q_h, k_h[i]) * scale |
| 132 | // d_q_h += d_score * k_h[i] (scale already applied above) |
| 133 | let k_h = &cache.k_cache[i * EMBED_DIM + ho..i * EMBED_DIM + ho + HEAD_DIM]; |
| 134 | for d in 0..HEAD_DIM { |
| 135 | d_q[ho + d] += d_score * k_h[d]; |
| 136 | d_k_cache[i * EMBED_DIM + ho + d] += d_score * q_h[d]; |
| 137 | } |
| 138 | } |
| 139 | } |
| 140 | |
| 141 | // --- Q projection backward: q = emb @ Wq --- |
| 142 | let emb = cache.emb_at(t); |
| 143 | let d_emb_from_q = vec_mat_mul_backward_x(&d_q, wq(params), EMBED_DIM, EMBED_DIM); |
| 144 | vec_mat_mul_backward_w(emb, &d_q, &mut grads[WQ_OFFSET..WQ_OFFSET + WQ_SIZE], EMBED_DIM, EMBED_DIM); |
| 145 | vec_add_inplace(&mut d_emb, &d_emb_from_q); |
| 146 | |
| 147 | // --- K, V projection backward (d_k_cache[t] and d_v_cache[t] are now complete) --- |
| 148 | let d_k_t = &d_k_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM]; |
| 149 | let d_emb_from_k = vec_mat_mul_backward_x(d_k_t, wk(params), EMBED_DIM, EMBED_DIM); |
| 150 | vec_mat_mul_backward_w(emb, d_k_t, &mut grads[WK_OFFSET..WK_OFFSET + WK_SIZE], EMBED_DIM, EMBED_DIM); |
| 151 | vec_add_inplace(&mut d_emb, &d_emb_from_k); |
| 152 | |
| 153 | let d_v_t = &d_v_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM]; |
| 154 | let d_emb_from_v = vec_mat_mul_backward_x(d_v_t, wv(params), EMBED_DIM, EMBED_DIM); |
| 155 | vec_mat_mul_backward_w(emb, d_v_t, &mut grads[WV_OFFSET..WV_OFFSET + WV_SIZE], EMBED_DIM, EMBED_DIM); |
| 156 | vec_add_inplace(&mut d_emb, &d_emb_from_v); |
| 157 | |
| 158 | // --- Embedding backward: emb = wte[token] + wpe[pos] --- |
| 159 | let token = cache.tokens[t]; |
| 160 | let wte_start = WTE_OFFSET + token * EMBED_DIM; |
| 161 | for d in 0..EMBED_DIM { |
| 162 | grads[wte_start + d] += d_emb[d]; |
| 163 | } |
| 164 | let wpe_start = WPE_OFFSET + t * EMBED_DIM; |
| 165 | for d in 0..EMBED_DIM { |
| 166 | grads[wpe_start + d] += d_emb[d]; |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | total_loss * inv_seq_len |
| 171 | } |
| 172 | |
| 173 | #[cfg(test)] |
| 174 | mod tests { |
| 175 | use super::*; |
| 176 | use crate::rng::Rng; |
| 177 | |
| 178 | /// Numerical gradient check: compare analytic gradients from backward() |
| 179 | /// against finite-difference approximations for every parameter. |
| 180 | #[test] |
| 181 | fn test_gradient_check() { |
| 182 | let mut rng = Rng::new(42); |
| 183 | let mut params = vec![0.0f32; NUM_PARAMS]; |
| 184 | init_weights(&mut params, &mut rng); |
| 185 | |
| 186 | // Short sequence for faster checking |
| 187 | let input_tokens = vec![0, 5, 13]; // BOS, e, m |
| 188 | let targets = vec![5, 13, 0]; // e, m, BOS |
| 189 | |
| 190 | // Analytic gradients |
| 191 | let mut cache = ForwardCache::new(); |
| 192 | forward(¶ms, &input_tokens, &mut cache); |
| 193 | let mut grads = vec![0.0f32; NUM_PARAMS]; |
| 194 | let loss = backward(¶ms, &cache, &targets, &mut grads); |
| 195 | assert!(loss.is_finite()); |
| 196 | |
| 197 | // Numerical gradients for a random subset of parameters. |
| 198 | // Use a larger epsilon (0.01) for better f32 numerical resolution: |
| 199 | // with eps=1e-3, the loss delta can be ~1e-6 which is near f32 limits. |
| 200 | let eps = 0.01; |
| 201 | let mut max_rel = 0.0f32; |
| 202 | let mut checked = 0; |
| 203 | let mut failed = 0; |
| 204 | let check_indices: Vec<usize> = (0..NUM_PARAMS).step_by(17).collect(); |
| 205 | for &i in &check_indices { |
| 206 | let mut p_plus = params.clone(); |
| 207 | p_plus[i] += eps; |
| 208 | let mut c_plus = ForwardCache::new(); |
| 209 | forward(&p_plus, &input_tokens, &mut c_plus); |
| 210 | let mut g_dummy = vec![0.0f32; NUM_PARAMS]; |
| 211 | let loss_plus = backward(&p_plus, &c_plus, &targets, &mut g_dummy); |
| 212 | |
| 213 | let mut p_minus = params.clone(); |
| 214 | p_minus[i] -= eps; |
| 215 | let mut c_minus = ForwardCache::new(); |
| 216 | forward(&p_minus, &input_tokens, &mut c_minus); |
| 217 | let mut g_dummy2 = vec![0.0f32; NUM_PARAMS]; |
| 218 | let loss_minus = backward(&p_minus, &c_minus, &targets, &mut g_dummy2); |
| 219 | |
| 220 | let numerical = (loss_plus - loss_minus) / (2.0 * eps); |
| 221 | let analytic = grads[i]; |
| 222 | let abs_err = (analytic - numerical).abs(); |
| 223 | |
| 224 | // Standard combined check: pass if EITHER absolute or relative error |
| 225 | // is small. This handles near-zero gradients (absolute) and larger |
| 226 | // gradients (relative) correctly. |
| 227 | let rel_err = abs_err / (analytic.abs().max(numerical.abs()) + 1e-7); |
| 228 | if rel_err > max_rel { |
| 229 | max_rel = rel_err; |
| 230 | } |
| 231 | if abs_err > 1e-3 && rel_err > 0.1 { |
| 232 | eprintln!( |
| 233 | "FAIL param {}: analytic={:.6e}, numerical={:.6e}, rel={:.4}, abs={:.6e}", |
| 234 | i, analytic, numerical, rel_err, abs_err |
| 235 | ); |
| 236 | failed += 1; |
| 237 | } |
| 238 | checked += 1; |
| 239 | } |
| 240 | assert!(checked > 100, "checked too few params: {}", checked); |
| 241 | assert!(failed == 0, "{} of {} gradient checks failed", failed, checked); |
| 242 | eprintln!("Gradient check passed: {} params checked, max relative error = {:.4}", checked, max_rel); |
| 243 | } |
| 244 | } |
| 245 | |
| 1 | diff --git a/CLAUDE.md b/CLAUDE.md |
| 2 | index 5322f56..e6ff70f 100644 |
| 3 | --- a/CLAUDE.md |
| 4 | +++ b/CLAUDE.md |
| 5 | @@ -24,7 +24,7 @@ The codebase compiles for both ESP32 and desktop. All ML logic (tensor, model, b |
| 6 | ### Module responsibilities |
| 7 | - `model.rs` — Parameter layout (flat `Vec<f32>`, 4192 elements), forward pass with KV cache, generation |
| 8 | - `backward.rs` — Manual backward pass. Processes positions in reverse; key/value gradients accumulate across positions |
| 9 | -- `tensor.rs` — `vec_mat_mul`, `vec_mat_mul_backward_x/w`, `softmax`, `relu`, `vec_add` |
| 10 | +- `tensor.rs` — `vec_mat_mul`, `vec_mat_mul_backward_x/w`, `softmax`, `relu`, `rmsnorm`, `vec_add` |
| 11 | - `optimizer.rs` — Adam with bias correction |
| 12 | - `tokenizer.rs` — 27-token vocab: BOS=0, a=1..z=26 |
| 13 | - `rng.rs` — Xorshift32 + Box-Muller for Gaussian init |
| 14 | @@ -44,13 +44,14 @@ WOUT [3760..4192) 16×27 output head |
| 15 | |
| 16 | ### Forward pass flow (per position t) |
| 17 | 1. `emb = wte[token] + wpe[t]` |
| 18 | -2. `q = emb @ Wq`, `k = emb @ Wk` (append to cache), `v = emb @ Wv` (append to cache) |
| 19 | -3. Multi-head attention: split into 4 heads, causal softmax, weighted value sum |
| 20 | -4. `proj = att_out @ Wo`, `res1 = emb + proj` |
| 21 | -5. `ffn = relu(res1 @ W1) @ W2`, `res2 = res1 + ffn` |
| 22 | -6. `logits = res2 @ Wout` |
| 23 | - |
| 24 | -No layer norm, no bias terms. |
| 25 | +2. `normed = rmsnorm(emb)` (RMSNorm, no learnable parameters) |
| 26 | +3. `q = normed @ Wq`, `k = normed @ Wk` (append to cache), `v = normed @ Wv` (append to cache) |
| 27 | +4. Multi-head attention: split into 4 heads, causal softmax, weighted value sum |
| 28 | +5. `proj = att_out @ Wo`, `res1 = emb + proj` (residual uses raw emb) |
| 29 | +6. `ffn = relu(res1 @ W1) @ W2`, `res2 = res1 + ffn` |
| 30 | +7. `logits = res2 @ Wout` |
| 31 | + |
| 32 | +No bias terms. |
| 33 | |
| 34 | ### Memory budget (~300KB SRAM available with WiFi disabled) |
| 35 | - Params + grads + Adam state: ~68 KB |
| 36 | diff --git a/README.md b/README.md |
| 37 | index a0c3de6..7b36f58 100644 |
| 38 | --- a/README.md |
| 39 | +++ b/README.md |
| 40 | @@ -9,7 +9,7 @@ step 0/1000: loss = 3.3071 |
| 41 | step 100/1000: loss = 2.4193 |
| 42 | step 500/1000: loss = 1.9888 |
| 43 | step 999/1000: loss = 2.0980 |
| 44 | ---- Generated names (temperature=0.8) --- |
| 45 | +--- Generated names (temperature=0.5) --- |
| 46 | arona, raeli, cealin, malie, sunaya, arishel, mosile ... |
| 47 | ``` |
| 48 | |
| 49 | @@ -25,6 +25,7 @@ A 1-layer GPT transformer matching the original Python implementation: |
| 50 | | Layers | 1 | |
| 51 | | Block size | 16 | |
| 52 | | Vocab | 27 tokens (a-z + BOS) | |
| 53 | +| Normalization | RMSNorm (before attention, no learnable params) | |
| 54 | | Optimizer | Adam (lr=0.01, beta1=0.85, beta2=0.99) | |
| 55 | | Training | 1,000 steps on 32K names | |
| 56 | |
| 57 | @@ -53,7 +54,7 @@ src/ |
| 58 | model.rs GPT forward pass, parameter layout, KV cache |
| 59 | backward.rs Manual backward pass with gradient accumulation |
| 60 | optimizer.rs Adam optimizer |
| 61 | - tensor.rs Vector-matrix math primitives |
| 62 | + tensor.rs Vector-matrix math, RMSNorm |
| 63 | tokenizer.rs Character-level encode/decode (a-z + BOS) |
| 64 | rng.rs Xorshift32 PRNG + Box-Muller for Gaussian init |
| 65 | data/ |
| 66 | diff --git a/src/backward.rs b/src/backward.rs |
| 67 | index e1b987d..ac6f09f 100644 |
| 68 | --- a/src/backward.rs |
| 69 | +++ b/src/backward.rs |
| 70 | @@ -138,22 +138,25 @@ pub fn backward( |
| 71 | } |
| 72 | } |
| 73 | |
| 74 | - // --- Q projection backward: q = emb @ Wq --- |
| 75 | - let emb = cache.emb_at(t); |
| 76 | - let d_emb_from_q = vec_mat_mul_backward_x(&d_q, wq(params), EMBED_DIM, EMBED_DIM); |
| 77 | - vec_mat_mul_backward_w(emb, &d_q, &mut grads[WQ_OFFSET..WQ_OFFSET + WQ_SIZE], EMBED_DIM, EMBED_DIM); |
| 78 | - vec_add_inplace(&mut d_emb, &d_emb_from_q); |
| 79 | + // --- Q projection backward: q = normed @ Wq --- |
| 80 | + let normed = cache.normed_emb_at(t); |
| 81 | + let mut d_normed = vec_mat_mul_backward_x(&d_q, wq(params), EMBED_DIM, EMBED_DIM); |
| 82 | + vec_mat_mul_backward_w(normed, &d_q, &mut grads[WQ_OFFSET..WQ_OFFSET + WQ_SIZE], EMBED_DIM, EMBED_DIM); |
| 83 | |
| 84 | // --- K, V projection backward (d_k_cache[t] and d_v_cache[t] are now complete) --- |
| 85 | let d_k_t = &d_k_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM]; |
| 86 | - let d_emb_from_k = vec_mat_mul_backward_x(d_k_t, wk(params), EMBED_DIM, EMBED_DIM); |
| 87 | - vec_mat_mul_backward_w(emb, d_k_t, &mut grads[WK_OFFSET..WK_OFFSET + WK_SIZE], EMBED_DIM, EMBED_DIM); |
| 88 | - vec_add_inplace(&mut d_emb, &d_emb_from_k); |
| 89 | + let d_normed_from_k = vec_mat_mul_backward_x(d_k_t, wk(params), EMBED_DIM, EMBED_DIM); |
| 90 | + vec_mat_mul_backward_w(normed, d_k_t, &mut grads[WK_OFFSET..WK_OFFSET + WK_SIZE], EMBED_DIM, EMBED_DIM); |
| 91 | + vec_add_inplace(&mut d_normed, &d_normed_from_k); |
| 92 | |
| 93 | let d_v_t = &d_v_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM]; |
| 94 | - let d_emb_from_v = vec_mat_mul_backward_x(d_v_t, wv(params), EMBED_DIM, EMBED_DIM); |
| 95 | - vec_mat_mul_backward_w(emb, d_v_t, &mut grads[WV_OFFSET..WV_OFFSET + WV_SIZE], EMBED_DIM, EMBED_DIM); |
| 96 | - vec_add_inplace(&mut d_emb, &d_emb_from_v); |
| 97 | + let d_normed_from_v = vec_mat_mul_backward_x(d_v_t, wv(params), EMBED_DIM, EMBED_DIM); |
| 98 | + vec_mat_mul_backward_w(normed, d_v_t, &mut grads[WV_OFFSET..WV_OFFSET + WV_SIZE], EMBED_DIM, EMBED_DIM); |
| 99 | + vec_add_inplace(&mut d_normed, &d_normed_from_v); |
| 100 | + |
| 101 | + // --- RMSNorm backward: normed = rmsnorm(emb) --- |
| 102 | + let d_emb_from_norm = rmsnorm_backward(&d_normed, normed, cache.rms_scales[t]); |
| 103 | + vec_add_inplace(&mut d_emb, &d_emb_from_norm); |
| 104 | |
| 105 | // --- Embedding backward: emb = wte[token] + wpe[pos] --- |
| 106 | let token = cache.tokens[t]; |
| 107 | diff --git a/src/main.rs b/src/main.rs |
| 108 | index 8aecbad..c8e8394 100644 |
| 109 | --- a/src/main.rs |
| 110 | +++ b/src/main.rs |
| 111 | @@ -11,7 +11,7 @@ use rng::Rng; |
| 112 | |
| 113 | const TRAINING_STEPS: usize = 1000; |
| 114 | const LEARNING_RATE: f32 = 0.01; |
| 115 | -const TEMPERATURE: f32 = 0.8; |
| 116 | +const TEMPERATURE: f32 = 0.5; |
| 117 | |
| 118 | fn main() { |
| 119 | // ESP-IDF boilerplate: link patches and initialize logging. |
| 120 | diff --git a/src/model.rs b/src/model.rs |
| 121 | index a4d70f6..db75b37 100644 |
| 122 | --- a/src/model.rs |
| 123 | +++ b/src/model.rs |
| 124 | @@ -5,7 +5,7 @@ |
| 125 | /// activations needed by the backward pass. |
| 126 | |
| 127 | use crate::rng::Rng; |
| 128 | -use crate::tensor::{softmax, vec_add, vec_mat_mul, relu}; |
| 129 | +use crate::tensor::{softmax, vec_add, vec_mat_mul, relu, rmsnorm}; |
| 130 | use crate::tokenizer::{VOCAB_SIZE, BOS}; |
| 131 | |
| 132 | // Architecture constants (matching Karpathy's microgpt). |
| 133 | @@ -86,8 +86,10 @@ pub struct ForwardCache { |
| 134 | pub tokens: Vec<usize>, |
| 135 | |
| 136 | // Per-position activations, flat: [pos * dim + i] |
| 137 | - pub embeddings: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 138 | - pub queries: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 139 | + pub embeddings: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] (raw, pre-norm) |
| 140 | + pub normed_emb: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] (after RMSNorm) |
| 141 | + pub rms_scales: Vec<f32>, // [BLOCK_SIZE] (RMSNorm scale per position) |
| 142 | + pub queries: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 143 | pub k_cache: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 144 | pub v_cache: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 145 | pub att_outs: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 146 | @@ -109,6 +111,8 @@ impl ForwardCache { |
| 147 | seq_len: 0, |
| 148 | tokens: Vec::new(), |
| 149 | embeddings: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 150 | + normed_emb: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 151 | + rms_scales: vec![0.0; BLOCK_SIZE], |
| 152 | queries: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 153 | k_cache: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 154 | v_cache: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 155 | @@ -133,9 +137,13 @@ impl ForwardCache { |
| 156 | |
| 157 | // Accessors for a specific position's data. |
| 158 | |
| 159 | + #[allow(dead_code)] |
| 160 | pub fn emb_at(&self, pos: usize) -> &[f32] { |
| 161 | &self.embeddings[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 162 | } |
| 163 | + pub fn normed_emb_at(&self, pos: usize) -> &[f32] { |
| 164 | + &self.normed_emb[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 165 | + } |
| 166 | pub fn query_at(&self, pos: usize) -> &[f32] { |
| 167 | &self.queries[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 168 | } |
| 169 | @@ -195,14 +203,19 @@ pub fn forward(params: &[f32], tokens: &[usize], cache: &mut ForwardCache) { |
| 170 | let emb = vec_add(tok_emb, pos_emb); |
| 171 | cache.embeddings[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&emb); |
| 172 | |
| 173 | - // 2. Q, K, V projections |
| 174 | - let q = vec_mat_mul(&emb, wq(params), EMBED_DIM, EMBED_DIM); |
| 175 | + // 1b. RMSNorm |
| 176 | + let (normed, rms_scale) = rmsnorm(&emb); |
| 177 | + cache.normed_emb[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&normed); |
| 178 | + cache.rms_scales[t] = rms_scale; |
| 179 | + |
| 180 | + // 2. Q, K, V projections (from normed embedding) |
| 181 | + let q = vec_mat_mul(&normed, wq(params), EMBED_DIM, EMBED_DIM); |
| 182 | cache.queries[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&q); |
| 183 | |
| 184 | - let k = vec_mat_mul(&emb, wk(params), EMBED_DIM, EMBED_DIM); |
| 185 | + let k = vec_mat_mul(&normed, wk(params), EMBED_DIM, EMBED_DIM); |
| 186 | cache.k_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&k); |
| 187 | |
| 188 | - let v = vec_mat_mul(&emb, wv(params), EMBED_DIM, EMBED_DIM); |
| 189 | + let v = vec_mat_mul(&normed, wv(params), EMBED_DIM, EMBED_DIM); |
| 190 | cache.v_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&v); |
| 191 | |
| 192 | // 3. Multi-head causal self-attention |
| 193 | diff --git a/src/tensor.rs b/src/tensor.rs |
| 194 | index e56664e..9690611 100644 |
| 195 | --- a/src/tensor.rs |
| 196 | +++ b/src/tensor.rs |
| 197 | @@ -81,6 +81,26 @@ pub fn relu(v: &[f32]) -> Vec<f32> { |
| 198 | v.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect() |
| 199 | } |
| 200 | |
| 201 | +/// RMSNorm: rescale so values have unit root-mean-square. |
| 202 | +/// Returns (normalized_vector, scale_factor) where scale = 1/sqrt(mean(x^2) + eps). |
| 203 | +pub fn rmsnorm(x: &[f32]) -> (Vec<f32>, f32) { |
| 204 | + let n = x.len() as f32; |
| 205 | + let ms: f32 = x.iter().map(|&v| v * v).sum::<f32>() / n; |
| 206 | + let scale = 1.0 / (ms + 1e-5_f32).sqrt(); |
| 207 | + let y: Vec<f32> = x.iter().map(|&v| v * scale).collect(); |
| 208 | + (y, scale) |
| 209 | +} |
| 210 | + |
| 211 | +/// Backward of rmsnorm: d_x given d_y, the normed output y, and the scale factor. |
| 212 | +pub fn rmsnorm_backward(d_y: &[f32], y: &[f32], scale: f32) -> Vec<f32> { |
| 213 | + let n = y.len() as f32; |
| 214 | + let dot: f32 = d_y.iter().zip(y.iter()).map(|(&dy, &yi)| dy * yi).sum(); |
| 215 | + let mean_dot = dot / n; |
| 216 | + d_y.iter().zip(y.iter()) |
| 217 | + .map(|(&dy, &yi)| scale * (dy - yi * mean_dot)) |
| 218 | + .collect() |
| 219 | +} |
| 220 | + |
| 221 | /// Scale each element by s. |
| 222 | #[allow(dead_code)] |
| 223 | pub fn vec_scale(v: &[f32], s: f32) -> Vec<f32> { |
| 1 | /// GPT model: 1-layer transformer with multi-head self-attention. |
| 2 | /// |
| 3 | /// All parameters live in a single flat f32 buffer. The forward pass processes |
| 4 | /// tokens sequentially, building a KV cache, and stores all intermediate |
| 5 | /// activations needed by the backward pass. |
| 6 | |
| 7 | use crate::rng::Rng; |
| 8 | use crate::tensor::{softmax, vec_add, vec_mat_mul, relu}; |
| 9 | use crate::tokenizer::{VOCAB_SIZE, BOS}; |
| 10 | |
| 11 | // Architecture constants (matching Karpathy's microgpt). |
| 12 | pub const EMBED_DIM: usize = 16; |
| 13 | pub const NUM_HEADS: usize = 4; |
| 14 | pub const HEAD_DIM: usize = EMBED_DIM / NUM_HEADS; // 4 |
| 15 | pub const FFN_DIM: usize = EMBED_DIM * 4; // 64 |
| 16 | pub const BLOCK_SIZE: usize = 16; |
| 17 | |
| 18 | // Parameter layout in the flat buffer. |
| 19 | pub const WTE_OFFSET: usize = 0; |
| 20 | pub const WTE_SIZE: usize = VOCAB_SIZE * EMBED_DIM; // 27 × 16 = 432 |
| 21 | |
| 22 | pub const WPE_OFFSET: usize = WTE_OFFSET + WTE_SIZE; // 432 |
| 23 | pub const WPE_SIZE: usize = BLOCK_SIZE * EMBED_DIM; // 16 × 16 = 256 |
| 24 | |
| 25 | pub const WQ_OFFSET: usize = WPE_OFFSET + WPE_SIZE; // 688 |
| 26 | pub const WQ_SIZE: usize = EMBED_DIM * EMBED_DIM; // 256 |
| 27 | |
| 28 | pub const WK_OFFSET: usize = WQ_OFFSET + WQ_SIZE; // 944 |
| 29 | pub const WK_SIZE: usize = EMBED_DIM * EMBED_DIM; // 256 |
| 30 | |
| 31 | pub const WV_OFFSET: usize = WK_OFFSET + WK_SIZE; // 1200 |
| 32 | pub const WV_SIZE: usize = EMBED_DIM * EMBED_DIM; // 256 |
| 33 | |
| 34 | pub const WO_OFFSET: usize = WV_OFFSET + WV_SIZE; // 1456 |
| 35 | pub const WO_SIZE: usize = EMBED_DIM * EMBED_DIM; // 256 |
| 36 | |
| 37 | pub const W1_OFFSET: usize = WO_OFFSET + WO_SIZE; // 1712 |
| 38 | pub const W1_SIZE: usize = EMBED_DIM * FFN_DIM; // 1024 |
| 39 | |
| 40 | pub const W2_OFFSET: usize = W1_OFFSET + W1_SIZE; // 2736 |
| 41 | pub const W2_SIZE: usize = FFN_DIM * EMBED_DIM; // 1024 |
| 42 | |
| 43 | pub const WOUT_OFFSET: usize = W2_OFFSET + W2_SIZE; // 3760 |
| 44 | pub const WOUT_SIZE: usize = EMBED_DIM * VOCAB_SIZE; // 432 |
| 45 | |
| 46 | pub const NUM_PARAMS: usize = WOUT_OFFSET + WOUT_SIZE; // 4192 |
| 47 | |
| 48 | // Slice accessors for parameter groups. |
| 49 | #[allow(dead_code)] |
| 50 | pub fn wte(p: &[f32]) -> &[f32] { &p[WTE_OFFSET..WTE_OFFSET + WTE_SIZE] } |
| 51 | #[allow(dead_code)] |
| 52 | pub fn wpe(p: &[f32]) -> &[f32] { &p[WPE_OFFSET..WPE_OFFSET + WPE_SIZE] } |
| 53 | pub fn wq(p: &[f32]) -> &[f32] { &p[WQ_OFFSET..WQ_OFFSET + WQ_SIZE] } |
| 54 | pub fn wk(p: &[f32]) -> &[f32] { &p[WK_OFFSET..WK_OFFSET + WK_SIZE] } |
| 55 | pub fn wv(p: &[f32]) -> &[f32] { &p[WV_OFFSET..WV_OFFSET + WV_SIZE] } |
| 56 | pub fn wo(p: &[f32]) -> &[f32] { &p[WO_OFFSET..WO_OFFSET + WO_SIZE] } |
| 57 | pub fn w1(p: &[f32]) -> &[f32] { &p[W1_OFFSET..W1_OFFSET + W1_SIZE] } |
| 58 | pub fn w2(p: &[f32]) -> &[f32] { &p[W2_OFFSET..W2_OFFSET + W2_SIZE] } |
| 59 | pub fn wout(p: &[f32]) -> &[f32] { &p[WOUT_OFFSET..WOUT_OFFSET + WOUT_SIZE] } |
| 60 | |
| 61 | /// Embedding row for a given token. |
| 62 | pub fn wte_row(p: &[f32], token: usize) -> &[f32] { |
| 63 | let start = WTE_OFFSET + token * EMBED_DIM; |
| 64 | &p[start..start + EMBED_DIM] |
| 65 | } |
| 66 | |
| 67 | /// Position embedding row. |
| 68 | pub fn wpe_row(p: &[f32], pos: usize) -> &[f32] { |
| 69 | let start = WPE_OFFSET + pos * EMBED_DIM; |
| 70 | &p[start..start + EMBED_DIM] |
| 71 | } |
| 72 | |
| 73 | /// Initialize all weights with small Gaussian noise. |
| 74 | pub fn init_weights(params: &mut [f32], rng: &mut Rng) { |
| 75 | for p in params.iter_mut() { |
| 76 | *p = rng.next_gaussian() * 0.1; |
| 77 | } |
| 78 | } |
| 79 | |
| 80 | /// Cached activations from the forward pass, needed by backward. |
| 81 | /// |
| 82 | /// All 2D data is stored flat with row-major indexing: `[position * dim + i]`. |
| 83 | /// Attention probs are stored as `[position * NUM_HEADS * BLOCK_SIZE + head * BLOCK_SIZE + i]`. |
| 84 | pub struct ForwardCache { |
| 85 | pub seq_len: usize, |
| 86 | pub tokens: Vec<usize>, |
| 87 | |
| 88 | // Per-position activations, flat: [pos * dim + i] |
| 89 | pub embeddings: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 90 | pub queries: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 91 | pub k_cache: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 92 | pub v_cache: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 93 | pub att_outs: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 94 | pub proj_outs: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 95 | pub residual1: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 96 | pub ffn_hidden: Vec<f32>, // [BLOCK_SIZE × FFN_DIM] |
| 97 | pub ffn_relu: Vec<f32>, // [BLOCK_SIZE × FFN_DIM] |
| 98 | pub ffn_out: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 99 | pub residual2: Vec<f32>, // [BLOCK_SIZE × EMBED_DIM] |
| 100 | pub logits: Vec<f32>, // [BLOCK_SIZE × VOCAB_SIZE] |
| 101 | |
| 102 | // Attention probs: [pos * NUM_HEADS * BLOCK_SIZE + head * BLOCK_SIZE + i] |
| 103 | pub att_probs: Vec<f32>, // [BLOCK_SIZE × NUM_HEADS × BLOCK_SIZE] |
| 104 | } |
| 105 | |
| 106 | impl ForwardCache { |
| 107 | pub fn new() -> Self { |
| 108 | Self { |
| 109 | seq_len: 0, |
| 110 | tokens: Vec::new(), |
| 111 | embeddings: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 112 | queries: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 113 | k_cache: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 114 | v_cache: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 115 | att_outs: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 116 | proj_outs: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 117 | residual1: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 118 | ffn_hidden: vec![0.0; BLOCK_SIZE * FFN_DIM], |
| 119 | ffn_relu: vec![0.0; BLOCK_SIZE * FFN_DIM], |
| 120 | ffn_out: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 121 | residual2: vec![0.0; BLOCK_SIZE * EMBED_DIM], |
| 122 | logits: vec![0.0; BLOCK_SIZE * VOCAB_SIZE], |
| 123 | att_probs: vec![0.0; BLOCK_SIZE * NUM_HEADS * BLOCK_SIZE], |
| 124 | } |
| 125 | } |
| 126 | |
| 127 | /// Reset for a new sequence without reallocating. |
| 128 | pub fn clear(&mut self) { |
| 129 | self.seq_len = 0; |
| 130 | self.tokens.clear(); |
| 131 | // Zero out would be ideal but forward pass overwrites what it uses. |
| 132 | } |
| 133 | |
| 134 | // Accessors for a specific position's data. |
| 135 | |
| 136 | pub fn emb_at(&self, pos: usize) -> &[f32] { |
| 137 | &self.embeddings[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 138 | } |
| 139 | pub fn query_at(&self, pos: usize) -> &[f32] { |
| 140 | &self.queries[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 141 | } |
| 142 | #[allow(dead_code)] |
| 143 | pub fn k_at(&self, pos: usize) -> &[f32] { |
| 144 | &self.k_cache[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 145 | } |
| 146 | #[allow(dead_code)] |
| 147 | pub fn v_at(&self, pos: usize) -> &[f32] { |
| 148 | &self.v_cache[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 149 | } |
| 150 | pub fn att_out_at(&self, pos: usize) -> &[f32] { |
| 151 | &self.att_outs[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 152 | } |
| 153 | pub fn res1_at(&self, pos: usize) -> &[f32] { |
| 154 | &self.residual1[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 155 | } |
| 156 | pub fn ffn_hidden_at(&self, pos: usize) -> &[f32] { |
| 157 | &self.ffn_hidden[pos * FFN_DIM..(pos + 1) * FFN_DIM] |
| 158 | } |
| 159 | pub fn ffn_relu_at(&self, pos: usize) -> &[f32] { |
| 160 | &self.ffn_relu[pos * FFN_DIM..(pos + 1) * FFN_DIM] |
| 161 | } |
| 162 | pub fn res2_at(&self, pos: usize) -> &[f32] { |
| 163 | &self.residual2[pos * EMBED_DIM..(pos + 1) * EMBED_DIM] |
| 164 | } |
| 165 | pub fn logits_at(&self, pos: usize) -> &[f32] { |
| 166 | &self.logits[pos * VOCAB_SIZE..(pos + 1) * VOCAB_SIZE] |
| 167 | } |
| 168 | |
| 169 | /// Attention prob for position `pos`, head `h`, attending to position `i`. |
| 170 | pub fn att_prob(&self, pos: usize, h: usize, i: usize) -> f32 { |
| 171 | self.att_probs[pos * NUM_HEADS * BLOCK_SIZE + h * BLOCK_SIZE + i] |
| 172 | } |
| 173 | } |
| 174 | |
| 175 | /// Run the forward pass for a full sequence of input tokens. |
| 176 | /// |
| 177 | /// Input tokens are the tokens to process (excluding the final target). |
| 178 | /// For a name like "emma", the encoded form is [BOS, e, m, m, a, BOS], |
| 179 | /// and the input tokens would be [BOS, e, m, m, a] (first 5 tokens). |
| 180 | /// The targets are [e, m, m, a, BOS] (last 5 tokens). |
| 181 | pub fn forward(params: &[f32], tokens: &[usize], cache: &mut ForwardCache) { |
| 182 | let seq_len = tokens.len(); |
| 183 | assert!(seq_len <= BLOCK_SIZE); |
| 184 | cache.seq_len = seq_len; |
| 185 | cache.tokens = tokens.to_vec(); |
| 186 | |
| 187 | let scale = 1.0 / (HEAD_DIM as f32).sqrt(); |
| 188 | |
| 189 | for t in 0..seq_len { |
| 190 | let token = tokens[t]; |
| 191 | |
| 192 | // 1. Token + position embeddings |
| 193 | let tok_emb = wte_row(params, token); |
| 194 | let pos_emb = wpe_row(params, t); |
| 195 | let emb = vec_add(tok_emb, pos_emb); |
| 196 | cache.embeddings[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&emb); |
| 197 | |
| 198 | // 2. Q, K, V projections |
| 199 | let q = vec_mat_mul(&emb, wq(params), EMBED_DIM, EMBED_DIM); |
| 200 | cache.queries[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&q); |
| 201 | |
| 202 | let k = vec_mat_mul(&emb, wk(params), EMBED_DIM, EMBED_DIM); |
| 203 | cache.k_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&k); |
| 204 | |
| 205 | let v = vec_mat_mul(&emb, wv(params), EMBED_DIM, EMBED_DIM); |
| 206 | cache.v_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&v); |
| 207 | |
| 208 | // 3. Multi-head causal self-attention |
| 209 | let mut att_out = vec![0.0f32; EMBED_DIM]; |
| 210 | for h in 0..NUM_HEADS { |
| 211 | let q_h = &q[h * HEAD_DIM..(h + 1) * HEAD_DIM]; |
| 212 | |
| 213 | // Compute attention scores for positions 0..t+1 |
| 214 | let mut scores = vec![0.0f32; t + 1]; |
| 215 | for i in 0..=t { |
| 216 | let k_i = &cache.k_cache[i * EMBED_DIM + h * HEAD_DIM..i * EMBED_DIM + (h + 1) * HEAD_DIM]; |
| 217 | let mut dot = 0.0; |
| 218 | for d in 0..HEAD_DIM { |
| 219 | dot += q_h[d] * k_i[d]; |
| 220 | } |
| 221 | scores[i] = dot * scale; |
| 222 | } |
| 223 | |
| 224 | // Softmax |
| 225 | let probs = softmax(&scores); |
| 226 | for i in 0..=t { |
| 227 | cache.att_probs[t * NUM_HEADS * BLOCK_SIZE + h * BLOCK_SIZE + i] = probs[i]; |
| 228 | } |
| 229 | |
| 230 | // Weighted sum of values |
| 231 | for i in 0..=t { |
| 232 | let v_i = &cache.v_cache[i * EMBED_DIM + h * HEAD_DIM..i * EMBED_DIM + (h + 1) * HEAD_DIM]; |
| 233 | for d in 0..HEAD_DIM { |
| 234 | att_out[h * HEAD_DIM + d] += probs[i] * v_i[d]; |
| 235 | } |
| 236 | } |
| 237 | } |
| 238 | cache.att_outs[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&att_out); |
| 239 | |
| 240 | // 4. Output projection + residual |
| 241 | let proj = vec_mat_mul(&att_out, wo(params), EMBED_DIM, EMBED_DIM); |
| 242 | cache.proj_outs[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&proj); |
| 243 | let res1 = vec_add(&emb, &proj); |
| 244 | cache.residual1[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&res1); |
| 245 | |
| 246 | // 5. FFN: ReLU(x @ W1) @ W2 |
| 247 | let hidden = vec_mat_mul(&res1, w1(params), EMBED_DIM, FFN_DIM); |
| 248 | cache.ffn_hidden[t * FFN_DIM..(t + 1) * FFN_DIM].copy_from_slice(&hidden); |
| 249 | let activated = relu(&hidden); |
| 250 | cache.ffn_relu[t * FFN_DIM..(t + 1) * FFN_DIM].copy_from_slice(&activated); |
| 251 | let ffn = vec_mat_mul(&activated, w2(params), FFN_DIM, EMBED_DIM); |
| 252 | cache.ffn_out[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&ffn); |
| 253 | |
| 254 | // 6. FFN residual |
| 255 | let res2 = vec_add(&res1, &ffn); |
| 256 | cache.residual2[t * EMBED_DIM..(t + 1) * EMBED_DIM].copy_from_slice(&res2); |
| 257 | |
| 258 | // 7. Output logits |
| 259 | let logits = vec_mat_mul(&res2, wout(params), EMBED_DIM, VOCAB_SIZE); |
| 260 | cache.logits[t * VOCAB_SIZE..(t + 1) * VOCAB_SIZE].copy_from_slice(&logits); |
| 261 | } |
| 262 | } |
| 263 | |
| 264 | /// Generate a name by autoregressive sampling. |
| 265 | pub fn generate(params: &[f32], cache: &mut ForwardCache, rng: &mut Rng, temperature: f32) -> String { |
| 266 | let mut tokens = vec![BOS]; |
| 267 | |
| 268 | for _ in 0..BLOCK_SIZE - 1 { |
| 269 | cache.clear(); |
| 270 | forward(params, &tokens, cache); |
| 271 | |
| 272 | let logits = cache.logits_at(tokens.len() - 1); |
| 273 | let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect(); |
| 274 | let probs = softmax(&scaled); |
| 275 | let next = rng.sample_from_probs(&probs); |
| 276 | |
| 277 | if next == BOS { |
| 278 | break; |
| 279 | } |
| 280 | tokens.push(next); |
| 281 | } |
| 282 | |
| 283 | use crate::tokenizer::decode; |
| 284 | tokens[1..].iter().map(|&t| decode(t)).collect() |
| 285 | } |
| 286 | |
| 287 | #[cfg(test)] |
| 288 | mod tests { |
| 289 | use super::*; |
| 290 | |
| 291 | #[test] |
| 292 | fn test_param_count() { |
| 293 | assert_eq!(NUM_PARAMS, 4192); |
| 294 | } |
| 295 | |
| 296 | #[test] |
| 297 | fn test_forward_smoke() { |
| 298 | let mut rng = Rng::new(42); |
| 299 | let mut params = vec![0.0f32; NUM_PARAMS]; |
| 300 | init_weights(&mut params, &mut rng); |
| 301 | |
| 302 | let tokens = vec![0, 5, 13]; // BOS, e, m |
| 303 | let mut cache = ForwardCache::new(); |
| 304 | forward(¶ms, &tokens, &mut cache); |
| 305 | |
| 306 | // Logits should be finite |
| 307 | for t in 0..3 { |
| 308 | let logits = cache.logits_at(t); |
| 309 | for &l in logits { |
| 310 | assert!(l.is_finite(), "non-finite logit"); |
| 311 | } |
| 312 | // Softmax of logits should sum to 1 |
| 313 | let probs = softmax(logits); |
| 314 | let sum: f32 = probs.iter().sum(); |
| 315 | assert!((sum - 1.0).abs() < 1e-5); |
| 316 | } |
| 317 | } |
| 318 | |
| 319 | #[test] |
| 320 | fn test_generate_smoke() { |
| 321 | let mut rng = Rng::new(42); |
| 322 | let mut params = vec![0.0f32; NUM_PARAMS]; |
| 323 | init_weights(&mut params, &mut rng); |
| 324 | |
| 325 | let mut cache = ForwardCache::new(); |
| 326 | // With random weights, some names may be empty (BOS sampled first). |
| 327 | // Try multiple times — at least one should produce a non-empty name. |
| 328 | let mut got_nonempty = false; |
| 329 | for _ in 0..20 { |
| 330 | let name = generate(¶ms, &mut cache, &mut rng, 1.0); |
| 331 | assert!(name.chars().all(|c| c >= 'a' && c <= 'z')); |
| 332 | if !name.is_empty() { |
| 333 | got_nonempty = true; |
| 334 | } |
| 335 | } |
| 336 | assert!(got_nonempty, "all 20 generated names were empty"); |
| 337 | } |
| 338 | } |
| 339 | |
| 1 | /// Manual backward pass for the GPT model. |
| 2 | /// |
| 3 | /// Computes parameter gradients by backpropagating through each position in |
| 4 | /// reverse order. The tricky part is attention: position t's query attends to |
| 5 | /// all keys/values at positions 0..t, so key/value gradients accumulate |
| 6 | /// contributions from multiple future positions. |
| 7 | |
| 8 | use crate::model::*; |
| 9 | use crate::tensor::*; |
| 10 | use crate::tokenizer::VOCAB_SIZE; |
| 11 | |
| 12 | /// Backpropagate through the full sequence, accumulating gradients. |
| 13 | /// |
| 14 | /// Returns the average cross-entropy loss over the sequence. |
| 15 | /// `targets[t]` is the target token for position t (i.e., `tokens[t+1]` in the |
| 16 | /// encoded name). `grads` must be pre-zeroed. |
| 17 | pub fn backward( |
| 18 | params: &[f32], |
| 19 | cache: &ForwardCache, |
| 20 | targets: &[usize], |
| 21 | grads: &mut [f32], |
| 22 | ) -> f32 { |
| 23 | let seq_len = cache.seq_len; |
| 24 | assert_eq!(targets.len(), seq_len); |
| 25 | let scale = 1.0 / (HEAD_DIM as f32).sqrt(); |
| 26 | let inv_seq_len = 1.0 / seq_len as f32; |
| 27 | |
| 28 | // Accumulated key/value gradients: d_k_cache[pos] and d_v_cache[pos] collect |
| 29 | // contributions from all positions that attend to them. |
| 30 | let mut d_k_cache = vec![0.0f32; seq_len * EMBED_DIM]; |
| 31 | let mut d_v_cache = vec![0.0f32; seq_len * EMBED_DIM]; |
| 32 | |
| 33 | let mut total_loss = 0.0f32; |
| 34 | |
| 35 | // Process positions in reverse so that when we reach position t, |
| 36 | // d_k_cache[t] and d_v_cache[t] are complete. |
| 37 | for t in (0..seq_len).rev() { |
| 38 | let target = targets[t]; |
| 39 | |
| 40 | // --- Cross-entropy loss gradient --- |
| 41 | let logits = cache.logits_at(t); |
| 42 | let probs = softmax(logits); |
| 43 | total_loss += -probs[target].max(1e-10).ln(); |
| 44 | |
| 45 | // d_logits = (probs - one_hot(target)) / seq_len |
| 46 | let mut d_logits = probs; |
| 47 | d_logits[target] -= 1.0; |
| 48 | for v in d_logits.iter_mut() { |
| 49 | *v *= inv_seq_len; |
| 50 | } |
| 51 | |
| 52 | // --- Output projection backward: logits = res2 @ Wout --- |
| 53 | let res2 = cache.res2_at(t); |
| 54 | let d_res2 = vec_mat_mul_backward_x(&d_logits, wout(params), EMBED_DIM, VOCAB_SIZE); |
| 55 | vec_mat_mul_backward_w(res2, &d_logits, &mut grads[WOUT_OFFSET..WOUT_OFFSET + WOUT_SIZE], EMBED_DIM, VOCAB_SIZE); |
| 56 | |
| 57 | // --- FFN residual backward: res2 = res1 + ffn_out --- |
| 58 | // d_res1 and d_ffn_out both receive d_res2 |
| 59 | let mut d_res1 = d_res2.clone(); |
| 60 | let d_ffn_out = d_res2; |
| 61 | |
| 62 | // --- FFN down-projection backward: ffn_out = ffn_relu @ W2 --- |
| 63 | let ffn_relu = cache.ffn_relu_at(t); |
| 64 | let d_ffn_relu = vec_mat_mul_backward_x(&d_ffn_out, w2(params), FFN_DIM, EMBED_DIM); |
| 65 | vec_mat_mul_backward_w(ffn_relu, &d_ffn_out, &mut grads[W2_OFFSET..W2_OFFSET + W2_SIZE], FFN_DIM, EMBED_DIM); |
| 66 | |
| 67 | // --- ReLU backward --- |
| 68 | let ffn_hidden = cache.ffn_hidden_at(t); |
| 69 | let d_ffn_hidden: Vec<f32> = d_ffn_relu.iter().zip(ffn_hidden.iter()) |
| 70 | .map(|(&dg, &h)| if h > 0.0 { dg } else { 0.0 }) |
| 71 | .collect(); |
| 72 | |
| 73 | // --- FFN up-projection backward: ffn_hidden = res1 @ W1 --- |
| 74 | let res1 = cache.res1_at(t); |
| 75 | let d_res1_from_ffn = vec_mat_mul_backward_x(&d_ffn_hidden, w1(params), EMBED_DIM, FFN_DIM); |
| 76 | vec_mat_mul_backward_w(res1, &d_ffn_hidden, &mut grads[W1_OFFSET..W1_OFFSET + W1_SIZE], EMBED_DIM, FFN_DIM); |
| 77 | vec_add_inplace(&mut d_res1, &d_res1_from_ffn); |
| 78 | |
| 79 | // --- Attention residual backward: res1 = emb + proj --- |
| 80 | let d_proj = d_res1.clone(); |
| 81 | let mut d_emb = d_res1; // emb receives same gradient from residual |
| 82 | |
| 83 | // --- Output projection backward: proj = att_out @ Wo --- |
| 84 | let att_out = cache.att_out_at(t); |
| 85 | let d_att_out = vec_mat_mul_backward_x(&d_proj, wo(params), EMBED_DIM, EMBED_DIM); |
| 86 | vec_mat_mul_backward_w(att_out, &d_proj, &mut grads[WO_OFFSET..WO_OFFSET + WO_SIZE], EMBED_DIM, EMBED_DIM); |
| 87 | |
| 88 | // --- Multi-head attention backward --- |
| 89 | let mut d_q = vec![0.0f32; EMBED_DIM]; |
| 90 | |
| 91 | for h in 0..NUM_HEADS { |
| 92 | let ho = h * HEAD_DIM; |
| 93 | let d_out_h = &d_att_out[ho..ho + HEAD_DIM]; |
| 94 | let q_h = &cache.query_at(t)[ho..ho + HEAD_DIM]; |
| 95 | |
| 96 | // out_h = sum_i probs[i] * v_cache_h[i] |
| 97 | for i in 0..=t { |
| 98 | let v_h = &cache.v_cache[i * EMBED_DIM + ho..i * EMBED_DIM + ho + HEAD_DIM]; |
| 99 | let prob = cache.att_prob(t, h, i); |
| 100 | |
| 101 | // d_probs[i] = dot(d_out_h, v_h) |
| 102 | let d_prob_i = vec_dot(d_out_h, v_h); |
| 103 | |
| 104 | // d_v_cache_h[i] += prob * d_out_h |
| 105 | for d in 0..HEAD_DIM { |
| 106 | d_v_cache[i * EMBED_DIM + ho + d] += prob * d_out_h[d]; |
| 107 | } |
| 108 | |
| 109 | // We'll collect d_probs and do softmax backward below. |
| 110 | // For now, just store d_prob_i temporarily by recomputing in the next loop. |
| 111 | let _ = d_prob_i; // used below |
| 112 | } |
| 113 | |
| 114 | // Softmax backward: d_scores = probs * (d_probs - sum(probs * d_probs)) |
| 115 | // Recompute d_probs (small loop, avoids allocation). |
| 116 | let mut d_probs = vec![0.0f32; t + 1]; |
| 117 | for i in 0..=t { |
| 118 | let v_h = &cache.v_cache[i * EMBED_DIM + ho..i * EMBED_DIM + ho + HEAD_DIM]; |
| 119 | d_probs[i] = vec_dot(d_out_h, v_h); |
| 120 | } |
| 121 | |
| 122 | let mut dot_sum = 0.0f32; |
| 123 | for i in 0..=t { |
| 124 | dot_sum += cache.att_prob(t, h, i) * d_probs[i]; |
| 125 | } |
| 126 | |
| 127 | for i in 0..=t { |
| 128 | let prob = cache.att_prob(t, h, i); |
| 129 | let d_score = prob * (d_probs[i] - dot_sum) * scale; |
| 130 | |
| 131 | // scores[i] = dot(q_h, k_h[i]) * scale |
| 132 | // d_q_h += d_score * k_h[i] (scale already applied above) |
| 133 | let k_h = &cache.k_cache[i * EMBED_DIM + ho..i * EMBED_DIM + ho + HEAD_DIM]; |
| 134 | for d in 0..HEAD_DIM { |
| 135 | d_q[ho + d] += d_score * k_h[d]; |
| 136 | d_k_cache[i * EMBED_DIM + ho + d] += d_score * q_h[d]; |
| 137 | } |
| 138 | } |
| 139 | } |
| 140 | |
| 141 | // --- Q projection backward: q = emb @ Wq --- |
| 142 | let emb = cache.emb_at(t); |
| 143 | let d_emb_from_q = vec_mat_mul_backward_x(&d_q, wq(params), EMBED_DIM, EMBED_DIM); |
| 144 | vec_mat_mul_backward_w(emb, &d_q, &mut grads[WQ_OFFSET..WQ_OFFSET + WQ_SIZE], EMBED_DIM, EMBED_DIM); |
| 145 | vec_add_inplace(&mut d_emb, &d_emb_from_q); |
| 146 | |
| 147 | // --- K, V projection backward (d_k_cache[t] and d_v_cache[t] are now complete) --- |
| 148 | let d_k_t = &d_k_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM]; |
| 149 | let d_emb_from_k = vec_mat_mul_backward_x(d_k_t, wk(params), EMBED_DIM, EMBED_DIM); |
| 150 | vec_mat_mul_backward_w(emb, d_k_t, &mut grads[WK_OFFSET..WK_OFFSET + WK_SIZE], EMBED_DIM, EMBED_DIM); |
| 151 | vec_add_inplace(&mut d_emb, &d_emb_from_k); |
| 152 | |
| 153 | let d_v_t = &d_v_cache[t * EMBED_DIM..(t + 1) * EMBED_DIM]; |
| 154 | let d_emb_from_v = vec_mat_mul_backward_x(d_v_t, wv(params), EMBED_DIM, EMBED_DIM); |
| 155 | vec_mat_mul_backward_w(emb, d_v_t, &mut grads[WV_OFFSET..WV_OFFSET + WV_SIZE], EMBED_DIM, EMBED_DIM); |
| 156 | vec_add_inplace(&mut d_emb, &d_emb_from_v); |
| 157 | |
| 158 | // --- Embedding backward: emb = wte[token] + wpe[pos] --- |
| 159 | let token = cache.tokens[t]; |
| 160 | let wte_start = WTE_OFFSET + token * EMBED_DIM; |
| 161 | for d in 0..EMBED_DIM { |
| 162 | grads[wte_start + d] += d_emb[d]; |
| 163 | } |
| 164 | let wpe_start = WPE_OFFSET + t * EMBED_DIM; |
| 165 | for d in 0..EMBED_DIM { |
| 166 | grads[wpe_start + d] += d_emb[d]; |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | total_loss * inv_seq_len |
| 171 | } |
| 172 | |
| 173 | #[cfg(test)] |
| 174 | mod tests { |
| 175 | use super::*; |
| 176 | use crate::rng::Rng; |
| 177 | |
| 178 | /// Numerical gradient check: compare analytic gradients from backward() |
| 179 | /// against finite-difference approximations for every parameter. |
| 180 | #[test] |
| 181 | fn test_gradient_check() { |
| 182 | let mut rng = Rng::new(42); |
| 183 | let mut params = vec![0.0f32; NUM_PARAMS]; |
| 184 | init_weights(&mut params, &mut rng); |
| 185 | |
| 186 | // Short sequence for faster checking |
| 187 | let input_tokens = vec![0, 5, 13]; // BOS, e, m |
| 188 | let targets = vec![5, 13, 0]; // e, m, BOS |
| 189 | |
| 190 | // Analytic gradients |
| 191 | let mut cache = ForwardCache::new(); |
| 192 | forward(¶ms, &input_tokens, &mut cache); |
| 193 | let mut grads = vec![0.0f32; NUM_PARAMS]; |
| 194 | let loss = backward(¶ms, &cache, &targets, &mut grads); |
| 195 | assert!(loss.is_finite()); |
| 196 | |
| 197 | // Numerical gradients for a random subset of parameters. |
| 198 | // Use a larger epsilon (0.01) for better f32 numerical resolution: |
| 199 | // with eps=1e-3, the loss delta can be ~1e-6 which is near f32 limits. |
| 200 | let eps = 0.01; |
| 201 | let mut max_rel = 0.0f32; |
| 202 | let mut checked = 0; |
| 203 | let mut failed = 0; |
| 204 | let check_indices: Vec<usize> = (0..NUM_PARAMS).step_by(17).collect(); |
| 205 | for &i in &check_indices { |
| 206 | let mut p_plus = params.clone(); |
| 207 | p_plus[i] += eps; |
| 208 | let mut c_plus = ForwardCache::new(); |
| 209 | forward(&p_plus, &input_tokens, &mut c_plus); |
| 210 | let mut g_dummy = vec![0.0f32; NUM_PARAMS]; |
| 211 | let loss_plus = backward(&p_plus, &c_plus, &targets, &mut g_dummy); |
| 212 | |
| 213 | let mut p_minus = params.clone(); |
| 214 | p_minus[i] -= eps; |
| 215 | let mut c_minus = ForwardCache::new(); |
| 216 | forward(&p_minus, &input_tokens, &mut c_minus); |
| 217 | let mut g_dummy2 = vec![0.0f32; NUM_PARAMS]; |
| 218 | let loss_minus = backward(&p_minus, &c_minus, &targets, &mut g_dummy2); |
| 219 | |
| 220 | let numerical = (loss_plus - loss_minus) / (2.0 * eps); |
| 221 | let analytic = grads[i]; |
| 222 | let abs_err = (analytic - numerical).abs(); |
| 223 | |
| 224 | // Standard combined check: pass if EITHER absolute or relative error |
| 225 | // is small. This handles near-zero gradients (absolute) and larger |
| 226 | // gradients (relative) correctly. |
| 227 | let rel_err = abs_err / (analytic.abs().max(numerical.abs()) + 1e-7); |
| 228 | if rel_err > max_rel { |
| 229 | max_rel = rel_err; |
| 230 | } |
| 231 | if abs_err > 1e-3 && rel_err > 0.1 { |
| 232 | eprintln!( |
| 233 | "FAIL param {}: analytic={:.6e}, numerical={:.6e}, rel={:.4}, abs={:.6e}", |
| 234 | i, analytic, numerical, rel_err, abs_err |
| 235 | ); |
| 236 | failed += 1; |
| 237 | } |
| 238 | checked += 1; |
| 239 | } |
| 240 | assert!(checked > 100, "checked too few params: {}", checked); |
| 241 | assert!(failed == 0, "{} of {} gradient checks failed", failed, checked); |
| 242 | eprintln!("Gradient check passed: {} params checked, max relative error = {:.4}", checked, max_rel); |
| 243 | } |
| 244 | } |
| 245 | |
| 1 | # Review: esp32gpt vs Karpathy's microgpt |
| 2 | |
| 3 | ## Verdict |
| 4 | |
| 5 | The implementation is a high-quality reproduction of microgpt with **one significant omission: RMSNorm**. Everything else — architecture, forward pass, manual backward pass, optimizer, tokenizer, training loop — is correct and faithful to the reference. The manual backward pass in particular is impressive: it correctly handles multi-head causal attention gradient accumulation and passes numerical gradient verification. |
| 6 | |
| 7 | --- |
| 8 | |
| 9 | ## Missing: RMSNorm (the one real gap) |
| 10 | |
| 11 | The blog post applies RMSNorm after embedding addition, before the attention projections: |
| 12 | |
| 13 | ``` |
| 14 | emb = wte[token] + wpe[pos] |
| 15 | normed = rmsnorm(emb) ← missing in esp32gpt |
| 16 | q = normed @ Wq |
| 17 | ``` |
| 18 | |
| 19 | The reference defines: |
| 20 | ```python |
| 21 | def rmsnorm(x): |
| 22 | ms = sum(xi * xi for xi in x) / len(x) |
| 23 | scale = (ms + 1e-5) ** -0.5 |
| 24 | return [xi * scale for xi in x] |
| 25 | ``` |
| 26 | |
| 27 | esp32gpt skips this step entirely — the raw embedding sum goes directly into Q/K/V projections. RMSNorm has no learnable parameters, so the 4,192 param count is unaffected, but it changes the forward computation and therefore the gradient landscape. |
| 28 | |
| 29 | **Impact**: For this tiny model, training probably still converges (the weights can partially compensate), but: |
| 30 | - The trained model is solving a slightly different optimization problem than microgpt |
| 31 | - Loss curves won't match the blog's reported ~3.3 → ~2.37 trajectory |
| 32 | - Generation quality may differ |
| 33 | |
| 34 | **To fix**: Add `rmsnorm` to the forward pass (after `vec_add(tok_emb, pos_emb)`, before Q/K/V projections), add its backward (`d_x = scale * (d_y - y * mean(d_y * y))`), and cache the pre-norm embedding + scale factor. |
| 35 | |
| 36 | --- |
| 37 | |
| 38 | ## What's correct (everything else) |
| 39 | |
| 40 | ### Architecture — matches exactly |
| 41 | | Parameter | Blog | esp32gpt | | |
| 42 | |-----------|------|----------|-| |
| 43 | | Embed dim | 16 | 16 | match | |
| 44 | | Heads | 4 | 4 | match | |
| 45 | | Head dim | 4 | 4 | match | |
| 46 | | FFN dim | 64 | 64 | match | |
| 47 | | Block size | 16 | 16 | match | |
| 48 | | Vocab size | 27 | 27 | match | |
| 49 | | Layers | 1 | 1 | match | |
| 50 | | Total params | 4,192 | 4,192 | match | |
| 51 | | Activation | ReLU | ReLU | match | |
| 52 | | Normalization | RMSNorm | none | **MISSING** | |
| 53 | |
| 54 | ### Forward pass — correct |
| 55 | The forward pass at each position computes: |
| 56 | 1. `emb = wte[token] + wpe[pos]` — verified in [model.rs:193-196](src/model.rs#L193-L196) |
| 57 | 2. `q = emb @ Wq`, `k = emb @ Wk`, `v = emb @ Wv` — verified in [model.rs:199-206](src/model.rs#L199-L206) |
| 58 | 3. Multi-head causal attention with scaled dot-product (`/ sqrt(head_dim)`) — verified in [model.rs:210-237](src/model.rs#L210-L237) |
| 59 | 4. `proj = att_out @ Wo`, `res1 = emb + proj` — verified in [model.rs:241-244](src/model.rs#L241-L244) |
| 60 | 5. `ffn = relu(res1 @ W1) @ W2`, `res2 = res1 + ffn` — verified in [model.rs:247-256](src/model.rs#L247-L256) |
| 61 | 6. `logits = res2 @ Wout` — verified in [model.rs:259](src/model.rs#L259) |
| 62 | |
| 63 | Weight matrix convention: blog stores `W[output×input]` and computes `y[j] = sum_i W[j][i] * x[i]`; esp32gpt stores `W[input×output]` and computes `y[j] = sum_i x[i] * W[i*m+j]`. Same linear algebra, different storage — **correct**. |
| 64 | |
| 65 | ### Manual backward pass — correct (the hard part) |
| 66 | |
| 67 | Traced through every gradient computation in [backward.rs](src/backward.rs): |
| 68 | |
| 69 | - **Cross-entropy gradient**: `d_logits = (softmax(logits) - one_hot(target)) / seq_len` — standard, correct |
| 70 | - **Output head**: `d_res2 = d_logits @ Wout^T`, `d_Wout += res2 ⊗ d_logits` — correct |
| 71 | - **FFN backward**: ReLU gate, up/down projection gradients — correct |
| 72 | - **Residual connections**: gradient flows through addition — correct |
| 73 | - **Attention output projection**: `d_att_out = d_proj @ Wo^T` — correct |
| 74 | - **Multi-head attention backward** (the trickiest part): |
| 75 | - Value gradient: `d_v[i] += prob[i] * d_out_h` — correct |
| 76 | - Softmax Jacobian: `d_score = prob * (d_prob - dot_sum)` — correct |
| 77 | - Scale factor propagation: `* scale` applied after softmax backward — correct |
| 78 | - Query/key gradients from dot product: `d_q += d_score * k`, `d_k += d_score * q` — correct |
| 79 | - **KV cache gradient accumulation**: Reverse-order processing ensures `d_k_cache[t]` and `d_v_cache[t]` are complete (have received contributions from all positions `>= t`) before being used for the K/V projection backward at position `t` — correct |
| 80 | - **Embedding backward**: `d_wte[token] += d_emb`, `d_wpe[pos] += d_emb` — correct |
| 81 | |
| 82 | The numerical gradient check in [backward.rs:181-243](src/backward.rs#L181-L243) tests every 17th parameter (247 params) with `eps=0.01` and confirms `abs_err < 1e-3 OR rel_err < 0.1`. This is the strongest evidence that the backward pass is correct. |
| 83 | |
| 84 | ### Optimizer — correct |
| 85 | Adam with bias correction in [optimizer.rs](src/optimizer.rs). Hyperparameters match: `β₁=0.85, β₂=0.99, ε=1e-8`. Linear LR decay `lr = 0.01 * (1 - step/1000)` in [main.rs:83](src/main.rs#L83). |
| 86 | |
| 87 | ### Training loop — correct |
| 88 | One name per step, encode as `[BOS, chars..., BOS]`, input = all-but-last, targets = all-but-first. Matches the blog's training procedure. |
| 89 | |
| 90 | --- |
| 91 | |
| 92 | ## Minor differences (not bugs) |
| 93 | |
| 94 | | Item | Blog | esp32gpt | Impact | |
| 95 | |------|------|----------|--------| |
| 96 | | BOS token ID | 26 | 0 | None — functionally equivalent, embeddings learn the same thing | |
| 97 | | Temperature | 0.5 | 0.8 | Generation only — higher T = more diverse names | |
| 98 | |
| 99 | --- |
| 100 | |
| 101 | ## Code quality observations |
| 102 | |
| 103 | **Strengths:** |
| 104 | - Clean separation of concerns across modules |
| 105 | - Pre-allocated `ForwardCache` avoids heap fragmentation on ESP32 |
| 106 | - `include_str!` keeps dataset in flash, not SRAM |
| 107 | - Dual-target design (ESP32 + host) with minimal `#[cfg]` blocks |
| 108 | - Comprehensive test suite (14 tests) including numerical gradient verification |
| 109 | - Memory budget is well within ESP32's ~300KB SRAM constraint (~100KB used) |
| 110 | |
| 111 | **Nits** (not blocking): |
| 112 | - Name sampling is O(n) per step — `dataset.lines().filter().nth(idx)` re-scans up to 32K lines each time. Pre-splitting into a `Vec<&str>` would make it O(1) but costs ~250KB of pointers |
| 113 | - `ForwardCache::clear()` doesn't zero activation arrays (relying on forward pass overwriting). Correct but could mask bugs if someone reads stale positions |
| 114 | - The `d_probs` vector in the attention backward is computed, discarded (`let _ = d_prob_i`), then recomputed in a second loop to avoid allocation — functional but slightly confusing to read |
| 115 | |
| 116 | --- |
| 117 | |
| 118 | ## Recommendation |
| 119 | |
| 120 | Add RMSNorm to make this a faithful reproduction. The changes are localized: |
| 121 | |
| 122 | 1. **Forward** ([model.rs](src/model.rs)): After `emb = vec_add(tok_emb, pos_emb)`, apply `rmsnorm(emb)`. Cache the scale factor. |
| 123 | 2. **Backward** ([backward.rs](src/backward.rs)): Before the embedding backward, apply `d_x = scale * (d_y - y * mean(d_y * y))` where `y` is the normed embedding. |
| 124 | 3. **Cache**: Add `normed_embeddings` and `rms_scales` fields to `ForwardCache`. |
| 125 | 4. **Tests**: The numerical gradient check will automatically verify the new backward. |
| 126 | |
| 127 | No parameter changes, no architecture changes, no new dependencies. Roughly +20 lines of code. |