A small GPT-style language model training prototype in Rust, using a hand-written matrix backend (column-major storage, no external tensor framework).
This project is based on:
- Andrej Karpathy's
microgptconcept and reference implementation.
- Uses an explicit
Matrixmodule (src/algebra.rs) instead of a tape/autodiff node graph. - Keeps computations in matrix form for forward and backward passes (
src/model.rs). - Uses column-major matrix storage for performance.
- Preserves reference training/inference behavior and sampling outputs as closely as possible while using matrix code paths.
- Loads a text dataset from
input.txt(line-based documents). - Builds a character-level vocabulary from the dataset.
- Trains a tiny decoder-only transformer on next-token prediction.
- Prints periodic training loss.
- Runs autoregressive sampling after training to generate text.
- Data preparation (
src/main.rs)
- Reads non-empty lines from
input.txt. - Shuffles documents with a Python-compatible MT19937 RNG.
- Builds charset and adds a special BOS token id.
- Model definition (
src/model.rs)
- Token embedding (
wte) + positional embedding (wpe). - One attention block:
- RMSNorm
- Multi-head causal self-attention (
wq,wk,wv,wo) - Residual connection
- One MLP block:
- RMSNorm
fc1 -> ReLU -> fc2- Residual connection
- Final projection to vocabulary logits (
lm_head).
- Math primitives (
src/algebra.rs)
- Column-major
Matrixtype. - Core ops: matmul, transpose, ReLU, softmax-by-column, scaling.
- RMSNorm forward/backward utilities.
- Training loop (
src/main.rs)
- For each step, creates a token sequence:
[BOS] + encoded_document + [BOS]- truncated to the model block size before the final BOS.
- Forward pass -> logits.
- Cross-entropy loss and logits gradient.
- Backward pass to compute gradients for all parameters.
- Parameter update with Adam-like optimizer (
step_all).
- Inference (
src/main.rs)
- Starts from BOS.
- Repeats:
- forward pass on current prefix
- temperature scaling
- softmax + weighted sampling
- stop on BOS or max length.
cargo run --release