PyTorch, rewritten from scratch in pure Rust.
ferrotorch is a deep learning framework with reverse-mode automatic differentiation, neural network modules, optimizers, GPU acceleration, and a JIT compiler --- all in pure Rust with no C++ dependencies. It provides the same eager-mode, dynamic-graph experience that researchers know from PyTorch, backed by ferray as its NumPy-equivalent array engine.
If you have ever wanted to train a ResNet or a transformer in Rust without pulling in libtorch, wrapping C++ with FFI, or giving up autograd --- this is it.
- Pure Rust, no C/C++ FFI --- the only foreign call is cudarc for the CUDA driver API. Everything else compiles with
cargo build. - Reverse-mode autograd with 30+ differentiable operations, topological-sort backward pass, gradient accumulation, and checkpointing.
- Operator overloading --- write
&a + &b,&x * &y,-zwith natural Rust syntax. All ownership combinations supported. - 24+ neural network layers including Linear, Conv1d/2d, LSTM, MultiheadAttention, BatchNorm, LayerNorm, RMSNorm, and LLM modules (RoPE, SwiGLU, KV cache, TransformerEncoder/DecoderLayer).
- 8 optimizers --- SGD, Adam, AdamW, RMSprop, Adagrad, L-BFGS, Muon, and K-FAC (Kronecker-factored Fisher), with parameter groups and 5 LR schedulers.
- JIT compiler --- trace a forward pass into a static IR, then run constant folding, dead code elimination, operator fusion, and memory planning.
compile()API mirrorstorch.compile. - GPU acceleration --- unified device-aware tensors (
tensor.cuda(),model.to_device(Device::Cuda(0))) with auto-dispatch to CPU or GPU. NVIDIA via cudarc + cuBLAS (81.8x matmul speedup on RTX 3090), AMD/Intel/Apple via CubeCL (WGPU, ROCm, Vulkan, Metal). No separateGpuTensortype. - GPU memory safety --- pre-OOM hooks, VRAM reservation, budget enforcement, pressure watchdog, and emergency checkpointing. Never lose a training run to a Steam game again.
- ONNX export --- trace a model and emit a standard
.onnxfile loadable by onnxruntime, TensorRT, CoreML. Hand-written protobuf encoder, no external dependency. - Operation fusion --- chain elementwise ops into a single kernel with PTX codegen. 2-5x GPU speedup for fused chains.
- SafeTensors + PyTorch .pt import --- load HuggingFace models directly; pure-Rust pickle parser (28 opcodes) for PyTorch checkpoints.
- 8 vision model architectures --- ResNet, VGG, ViT, EfficientNet, ConvNeXt, Swin Transformer, U-Net, YOLO.
- INT8/INT4 quantization --- per-tensor and per-channel post-training quantization with quantized matmul.
- Distributed training --- DDP with gradient synchronization over a TCP backend, GPU-aware collectives.
- Training loop ---
Learnerabstraction with metrics (loss, accuracy, top-k), callbacks (early stopping, progress logging), and training history. #[derive(Module)]proc macro --- annotate fields with#[param],#[submodule],#[skip]and the Module trait is implemented for you.- Einops ---
rearrange("b c h w -> b (c h w)"),repeat,reducewith readable string patterns. - LoRA --- parameter-efficient fine-tuning via
LoRALinearwith trainable low-rank A/B matrices andmerge()for zero-overhead inference. - Unified device-aware tensors ---
tensor.cuda(),model.to_device(Device::Cuda(0)), ops auto-dispatch to CPU or GPU. No separateGpuTensortype. - Fixed-point derivatives --- implicit differentiation for equilibrium models, Neural CAs, and DEQ networks.
- K-FAC natural gradient --- Kronecker-factored Fisher approximation for second-order optimization.
- Zero-panic guarantee --- every public function returns
Result<T, FerrotorchError>.
use ferrotorch_core::*;
// Build a small computation graph
let a = scalar(2.0f32)?.requires_grad_(true);
let b = scalar(3.0f32)?.requires_grad_(true);
let c = (&a * &b)?;
// Reverse-mode autodiff
c.backward()?;
println!("{}", a.grad()?.unwrap()); // tensor(3.)
println!("{}", b.grad()?.unwrap()); // tensor(2.)A condensed version of ferrotorch/examples/train_mnist.rs:
use ferrotorch_core::*;
use ferrotorch_nn::*;
use ferrotorch_optim::*;
// 3-layer MLP: 784 -> 128 -> 64 -> 10
let mut model = Sequential::new(vec![
Box::new(Linear::<f32>::new(784, 128, true)?),
Box::new(ReLU::default()),
Box::new(Linear::<f32>::new(128, 64, true)?),
Box::new(ReLU::default()),
Box::new(Linear::<f32>::new(64, 10, true)?),
]);
let mut optimizer = Adam::new(
model.parameters().into_iter().cloned().collect(),
AdamConfig::default(),
);
let loss_fn = CrossEntropyLoss::new(Reduction::Mean, 0.0);
for epoch in 0..10 {
for batch in train_loader.iter(epoch) {
let batch = batch?;
let logits = model.forward(&input)?;
let loss = loss_fn.forward(&logits, &target)?;
optimizer.zero_grad()?;
loss.backward()?;
optimizer.step()?;
}
}Or use the high-level Learner API:
use ferrotorch_train::*;
let mut learner = Learner::new(model, optimizer, loss_fn)
.with_train_metric(Box::new(LossMetric::new()))
.with_callback(Box::new(EarlyStopping::new(5, 0.001)))
.with_callback(Box::new(ProgressLogger::new(100)));
let history = learner.fit(&train_loader, Some(&val_loader), 50)?;ferrotorch is a workspace of 16 crates. Use the umbrella crate for convenience, or depend on individual crates for minimal compile times.
| Crate | Description |
|---|---|
| ferrotorch | Top-level re-export crate (cargo add ferrotorch) |
| ferrotorch-core | Tensor, autograd engine, 30+ differentiable ops, quantization |
| ferrotorch-nn | Module trait, 24+ layers, losses, activations, #[derive(Module)] |
| ferrotorch-nn-derive | Proc macro for #[derive(Module)] |
| ferrotorch-optim | 8 optimizers, 5 LR schedulers, GradScaler for mixed precision |
| ferrotorch-data | Dataset, DataLoader, samplers, transforms |
| ferrotorch-train | Learner, metrics, callbacks, training history |
| ferrotorch-vision | 8 model architectures, MNIST/CIFAR datasets, image I/O |
| ferrotorch-jit | Tracing, IR graph, optimization passes, codegen backends |
| ferrotorch-serialize | SafeTensors, PyTorch .pt import, ONNX export, checkpoints |
| ferrotorch-gpu | NVIDIA CUDA backend, cuBLAS, memory guard, pre-OOM hooks |
| ferrotorch-cubecl | Portable GPU via CubeCL (NVIDIA, AMD, Intel, Apple) |
| ferrotorch-distributed | DDP, allreduce, broadcast, TCP backend |
| ferrotorch-distributions | Probability distributions (Normal, Uniform, Bernoulli, Categorical) |
| ferrotorch-hub | Pretrained model registry, download, and caching |
| ferrotorch-profiler | Operation profiling and Chrome trace export |
ferrotorch supports GPU acceleration through two backends:
Uses cudarc for safe Rust bindings to the CUDA driver API, with cuBLAS for matmul/GEMM and custom PTX kernels for elementwise ops. Includes a caching memory allocator modeled after PyTorch's CUDACachingAllocator.
let x = tensor.cuda()?; // Move to GPU 0
let y = x.matmul(&weights)?; // cuBLAS GEMM
let z = y.cpu()?; // Move backUses CubeCL to compile a single kernel definition to multiple backends:
| Feature flag | Backend | GPU vendors |
|---|---|---|
cuda |
NVIDIA CUDA via PTX | NVIDIA |
wgpu |
WGPU (Vulkan / Metal / DX12) | AMD, Intel, Apple |
rocm |
AMD HIP (native) | AMD |
use ferrotorch_cubecl::CubeRuntime;
if let Some(rt) = CubeRuntime::auto() {
println!("Using device: {:?}", rt.device());
}Unlike PyTorch's history of separate CPU/CUDA tensor types, ferrotorch uses
a single Tensor<T> that is device-aware internally. Operations auto-dispatch:
let mut model = Linear::new(784, 10, true)?;
model.to_device(Device::Cuda(0))?; // Move weights to GPU
let x = rand::<f32>(&[32, 784])?.cuda()?;
let y = model.forward(&x)?; // Auto-dispatches to GPU
y.backward()?; // Autograd on GPUUnlike PyTorch, ferrotorch provides proactive GPU memory management. No more lost training runs.
use ferrotorch_gpu::*;
let device = Arc::new(GpuDevice::new(0)?);
// Reserve 22GB upfront — other apps can't steal it
let guard = MemoryGuardBuilder::new(device.clone())
.budget_bytes(22 * 1024 * 1024 * 1024)
.reserve_bytes(22 * 1024 * 1024 * 1024)
.oom_policy(OomPolicy::WaitAndRetry { timeout_secs: 60 })
.build()?;
// Register a pre-OOM hook: "halve the batch before crashing"
guard.register_hook(MemoryHook {
name: "halve_batch".into(),
estimated_free_bytes: 2 * 1024 * 1024 * 1024,
execution_overhead_bytes: 50 * 1024 * 1024, // metadata setup cost
priority: 0,
callback: Box::new(|| { /* split batch, free old tensors */ 2_000_000_000 }),
});
// Emergency checkpoint on unrecoverable OOM
guard.on_oom(|| save_checkpoint(&model, "emergency.ckpt").unwrap());
// Background watchdog pauses training when VRAM gets tight
let watchdog = Arc::new(MemoryWatchdog::new(device, 512 * 1024 * 1024, Duration::from_secs(1)));
watchdog.clone().start();
// In training loop
for batch in loader.iter(epoch) {
watchdog.wait_if_paused(); // blocks until VRAM pressure lifts
let buf = guard.safe_alloc_with_hooks::<f32>(batch_size)?; // hooks fire before OOM
}| Layer | What it prevents |
|---|---|
| MemoryReservation | Other processes stealing VRAM mid-training |
| Budget enforcement | Allocations beyond your declared limit |
| Pre-OOM hooks | Batch splitting, cache clearing — called before failure |
| OomPolicy | Retry, wait, checkpoint-and-fail, or crash (your choice) |
| MemoryWatchdog | Pauses training when free VRAM drops below threshold |
| Emergency checkpoint | Saves model state before crash so you don't lose progress |
Export models to the ONNX standard format for deployment on any inference runtime:
use ferrotorch_serialize::export_onnx;
export_onnx(
|inputs| model.forward(&inputs[0]),
&[example_input],
"model.onnx",
OnnxExportConfig { opset_version: 17, model_name: "my_model".into() },
)?;The exported .onnx file works with onnxruntime (C++/Python), NVIDIA TensorRT, Apple CoreML, and ONNX.js (browser). No external protobuf dependency — hand-written encoder in pure Rust.
Pre-built architectures in ferrotorch-vision, ready to use or fine-tune:
| Architecture | Variants | Task |
|---|---|---|
| ResNet | 18, 34, 50 | Image classification |
| VGG | 11, 16 | Image classification |
| Vision Transformer (ViT) | B/16 | Image classification |
| EfficientNet | B0 | Efficient mobile classification |
| ConvNeXt | Tiny | Modern ConvNet |
| Swin Transformer | Tiny | Hierarchical vision transformer |
| U-Net | --- | Semantic segmentation |
| YOLO | --- | Object detection |
use ferrotorch_vision::models::{list_models, get_model};
for name in list_models() {
println!("{name}");
}
let resnet = get_model::<f32>("resnet50", 1000)?;| ferrotorch | PyTorch | Burn | tch-rs | candle | |
|---|---|---|---|---|---|
| Language | Rust | Python/C++ | Rust | Rust (C++ FFI) | Rust |
| C++ dependency | None | libtorch | None | libtorch | None |
| Autograd | Reverse-mode, dynamic graph | Reverse-mode, dynamic graph | Reverse-mode | Via libtorch | Forward ops only |
| GPU | Unified tensors, CUDA + CubeCL | CUDA + ROCm | CubeCL | Via libtorch | CUDA |
| Eager mode | Yes | Yes | Yes | Yes | Yes |
| JIT / compile | Tracing + fusion + codegen | TorchScript / torch.compile | No | Via libtorch | No |
| Distributed | DDP | DDP / FSDP / Pipeline | No | Via libtorch | Partial |
| Quantization | INT8 / INT4 | INT8 / INT4 / FP8 | INT8 | Via libtorch | GGUF |
| ONNX export | Yes (pure Rust) | Yes | No | Yes | No |
| GPU memory safety | Pre-OOM hooks, budget, watchdog | Basic caching | No | Via libtorch | No |
| Model zoo | 8 architectures | Thousands | Limited | Via libtorch | LLM-focused |
| Training loop | Learner + callbacks | Manual / Lightning | Learner | Manual | Manual |
| Proc macro | #[derive(Module)] |
No (dynamic) | #[derive(Module)] |
No | No |
| LoRA | Yes (LoRALinear + merge) |
Via libraries | No | No | Yes |
| Einops | Yes (rearrange/repeat/reduce) | Via library | No | No | No |
| Tests | 1,800+ | Extensive | Growing | Via libtorch | Growing |
Add ferrotorch to your project:
cargo add ferrotorchOr pick individual crates:
cargo add ferrotorch-core ferrotorch-nn ferrotorch-optimFor GPU support:
# NVIDIA CUDA
cargo add ferrotorch-gpu
# Portable GPU (AMD, Intel, Apple)
cargo add ferrotorch-cubecl --features wgpuRust 1.85+ (Edition 2024).
ferrotorch tracks the latest stable Rust edition. The MSRV will only be bumped in minor version releases.
Licensed under either of
at your option.
Contributions are welcome. Whether it is a bug report, a new layer, an optimizer, a model architecture, or a performance improvement --- open an issue or submit a pull request at github.com/dollspace-gay/ferrotorch.
If you are unsure where to start, look for issues labeled good first issue or check the CHANGELOG for recently shipped features that could use more tests or documentation.