This repository trains Crosslayer Transcoders and variants with PyTorch/Lightning on multi‑GPU via tensor parallelism. It implements Anthropic’s crosslayer transcoders and related architectures (per‑layer transcoders, MOLTs, SAEs, Matryoshka CLTs) and supports losses such as ReLU, JumpReLU, TopK, and BatchTopK, for learning human‑interpretable features from LLM activations and building replacement models for circuit tracing.
We want to understand the “brain” of LLMs: what their representations encode and what algorithms emerged from circuits. To start, we can learn feature dictionaries with Sparse Autoencoders to break activations into human‑interpretable features. They tell us what features representations contain but not how they interact to make circuits and algorithms. For that, we need to sparsify the entire model (we call this a sparse replacement model), not just representations of a single layer. One approach are Transcoders, which learn features that approximate MLP components, which lets us swap in a replacement model and trace circuits end to end. Crosslayer transcoders allow features to affect all subsequent layers, essentially letting features live across layers. This yields smaller and more interpretable circuits and enables circuit tracing and studies of LLM biology.
⚠️ Early Development Disclaimer
This repository is still in very early development and under active development. It's not yet a stable, production-ready package. There will likely be many breaking changes in the future as the codebase evolves. Use at your own risk and expect API changes between commits.
- ✅ Per-Layer Transcoder (PLT)
- ✅ Crosslayer Transcoder (CLT)
- ✅ Sparse Mixture of Linear Transforms (MOLT)
- ⏳ Matryoshka CLTs
- ⏳ SAEs (by tweaking the activation data extractor)
- ✅ ReLU and JumpReLU (via straight-through estimators)
- ✅ TopK
- ✅ BatchTopK (per layer and across layers)
- ✅ Pre-activation loss (reduces dead features for (Jump-)ReLU CLTs)
- ✅ TopK auxiliary loss (reduces dead features for (Batch-)TopK CLTs)
- ✅ Tanh sparsity penalty
- ✅ Activation standardization
- ✅ On-demand activation extraction and streaming using a shared-memory activation buffer
⚠️ Tensor parallelism using PyTorch DTensor API (requires PyTorch 2.8; comms optimization in progress)- ⏳ Sparse Kernels
- ✅ Distributed data parallelism
- ✅ WandB logging
- ✅ Mixed precision (float16 + gradient scaler or bfloat16), gradient accumulation, checkpointing, profiling
- ✅ Replacement Model Accuracy and KL divergence
- ✅ Dead Features
- ✅ Feature activation frequency and other statistics
- ✅ L0
- ⏳ Replacement Model Score
Recommended: use the setup script (it installs uv if needed and creates the venv).
# Clone the repository and enter the directory
git clone https://github.com/Goreg12345/crosslayer-transcoder.git
cd crosslayer-transcoder
# Run the setup script (installs uv if needed and creates .venv)
./setup.shNotes:
- This will create
.venv/and install frompyproject.toml, usinguv.lockfor reproducibility. - For GPU installs, ensure you have a compatible PyTorch build for your CUDA setup. If needed, follow the official PyTorch instructions to select the right wheel for your CUDA version.
You can customize almost everything: datasets and activation extraction, model architecture, loss functions, and all training hyperparameters. This works by using PyTorch Lightning’s CLI to read a YAML config that defines which classes to use and how to compose them. By editing a single config.yaml, you control the entire run and keep every parameter in one place; you can still override any field from the command line for quick experiments.
-
Why this is great
- Single source of truth for all settings → easy to reproduce and share
- Composable: swap architectures, losses, data pipelines by changing class entries in YAML
- Discoverable and explicit: every knob is visible in one file, with sane defaults
- Fast iteration: override any value from the CLI without touching code
-
How it works
- The CLI loads the YAML, instantiates the classes declared there, and wires them together (data → model → trainer)
- You can keep multiple configs (e.g., under
config/) for different experiments and hardware setups - Any setting can be overridden via dot notation flags on the command line
# Train with a config file
python main.py fit --config ./config/default.yaml
# Override anything at runtime (examples)
python main.py fit --config ./config/default.yaml \
--trainer.max_epochs=10 \
--data.num_workers=8 \
--model.nonlinearity=topk \
--model.num_features=32768You can also switch components by editing the YAML to point to different classes and init args. A minimal schematic (names are illustrative):
model:
class_path: crosslayer_transcoder.model.CrossLayerTranscoderModule
init_args:
d_features: 32768
nonlinearity_theta: 0.03
lambda_sparsity: 0.0004
data:
class_path: crosslayer_transcoder.data.ActivationDataModule
init_args:
dataset_name: "Skylion007/openwebtext"
buffer_size: 1000000
batch_size: 4000
trainer:
max_epochs: 10
accelerator: gpu
devices: 4The config folder contains example configuration files for different architectures. These serve as good starting points and can be easily customized for your needs:
default.yamlandjumprelu-clt.yamltrains a small JumpReLU Crosslayer Transcoder with 120k features on gpt2-small using the openwebtext dataset. It uses a small pre-activation loss to reduce the number of dead features.topk-clt.yamltrains a small TopK Crosslayer Transcoder with 240k features and 16 active features per token and layer. It uses auxilliary loss to reduce the number of dead features.jumprelu-plt.yamltrains an ordinary Transcoder ("Per-Layer Transcoder").
- Implement your new architecture/loss/data class following the same interface as existing components
- Reference it in the YAML via its import path and init args
- It plugs into the same training loop - multi‑GPU (DDP) works out of the box
- Tensor parallelism works automatically because PyTorch Lightning handles the distributed setup and PyTorch's Distributed Tensor API shards your model across GPUs without requiring changes to your component code
Run the test suite to ensure everything is working correctly:
# Run all tests
uv run pytest
# Run tests excluding GPU-dependent ones
uv run pytest -m "not gpu"
# Run only fast tests (exclude slow ones)
uv run pytest -m "not slow"