TritonTorch is a fully differentiable, efficient, and modular open-source library of PyTorch neural network modules and operations implemented in Triton. It provides GPU-accelerated primitives that leverage Triton's low-level control and parallelism, enabling seamless integration of deep learning building blocks into your workflows.
- Linux operating system (WSL for Windows users)
- CUDA-capable GPU
- Python 3.8+
- PyTorch 2.0+
- Triton (installed via pip or from source)
git clone https://github.com/ayoussf/TritonTorch.git
cd TritonTorch
pip install .pip install -e '.[dev]'import torch
from TritonTorch.Normalization import LayerNorm
from TritonTorch.Activations import ReLU
# Configuration
batch, length, dim = 2, 100, 128
device = "cuda"
dtype = torch.float16 # Or torch.float32
# Initialize input tensor
x = torch.randn(batch, length, dim, device=device, dtype=dtype)
# Create modules
layernorm = LayerNorm(dim, eps=1e-6, elementwise_affine=True,
bias=True, device=device, dtype=dtype)
relu = ReLU()
# Forward pass
x = layernorm(x)
x = relu(x)All modules support both forward and backward passes for full differentiability.
Note
Some modules (e.g., Conv1d, Conv2d) are not fully optimized yet. It is still to be determined whether this is due to kernel implementation or autotuning configuration.
| Function | Description |
|---|---|
| GeLU | Gaussian Error Linear Unit (with/without tanh approximation) |
| ReLU | Rectified Linear Unit |
| LeakyReLU | Leaky Rectified Linear Unit |
| ReLU6 | ReLU clamped at 6 |
| Sigmoid | Sigmoid activation |
| Tanh | Hyperbolic tangent |
| Mish | Mish activation function |
| SiLU | Sigmoid Linear Unit (Swish) |
| Softmax | Softmax normalization |
| LogSoftmax | Log-Softmax |
| Softmin | Softmin normalization |
| Softplus | Smooth approximation of ReLU |
| Threshold | Thresholded activation |
| Layer | Description |
|---|---|
| LayerNorm | Layer normalization |
| RMSNorm | Root Mean Square normalization |
| BatchNorm | Batch normalization (In Progress) |
| Layer | Description |
|---|---|
| Linear | Fully connected layer |
| Dropout | Dropout regularization |
| MLP | Multi-Layer Perceptron (Gated-MLP / FFN) |
| Multi-Head Attention | Scaled dot-product attention |
| Conv1d | 1D convolution |
| Conv2d | 2D convolution |
| Operation | Description |
|---|---|
| BMM | Batched matrix multiplication (supports unbatched inputs) |
| Normalize | L1, L2, and p-norm tensor normalization |
| Norm | Matrix/vector L1, L2, and p-norms |
| Pairwise Cosine Similarity | Distance computation between vectors |
See tests/README.md for full documentation on running tests, CLI options, and test structure.
Contributions are welcome! To contribute:
- Fork the repository
- Create a feature branch (
git checkout -b feature/your-feature) - Add unit tests under the
tests/directory - Ensure compatibility with PyTorch and Triton
- Submit a pull request
Found a bug or have a suggestion? Please Open an issue or submit a Pull Request.
TritonTorch is released under the MIT License. You are free to use, modify, and distribute it.
Special thanks to the authors of Mamba. Their work has been a valuable reference for parts of this repository.
