Skip to content

ayoussf/TritonTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

47 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TritonTorch

License: MIT Python 3.8+ PyTorch 2.0+ Triton

Overview

TritonTorch is a fully differentiable, efficient, and modular open-source library of PyTorch neural network modules and operations implemented in Triton. It provides GPU-accelerated primitives that leverage Triton's low-level control and parallelism, enabling seamless integration of deep learning building blocks into your workflows.

Table of Contents

Installation

Prerequisites

  • Linux operating system (WSL for Windows users)
  • CUDA-capable GPU
  • Python 3.8+
  • PyTorch 2.0+
  • Triton (installed via pip or from source)

From Source

git clone https://github.com/ayoussf/TritonTorch.git
cd TritonTorch
pip install .

Development Installation

pip install -e '.[dev]'

Quick Start

import torch
from TritonTorch.Normalization import LayerNorm
from TritonTorch.Activations import ReLU

# Configuration
batch, length, dim = 2, 100, 128
device = "cuda"
dtype = torch.float16  # Or torch.float32

# Initialize input tensor
x = torch.randn(batch, length, dim, device=device, dtype=dtype)

# Create modules
layernorm = LayerNorm(dim, eps=1e-6, elementwise_affine=True, 
                      bias=True, device=device, dtype=dtype)
relu = ReLU()

# Forward pass
x = layernorm(x)
x = relu(x)

Supported Modules

All modules support both forward and backward passes for full differentiability.

Note

Some modules (e.g., Conv1d, Conv2d) are not fully optimized yet. It is still to be determined whether this is due to kernel implementation or autotuning configuration.

Activation Functions

Function Description
GeLU Gaussian Error Linear Unit (with/without tanh approximation)
ReLU Rectified Linear Unit
LeakyReLU Leaky Rectified Linear Unit
ReLU6 ReLU clamped at 6
Sigmoid Sigmoid activation
Tanh Hyperbolic tangent
Mish Mish activation function
SiLU Sigmoid Linear Unit (Swish)
Softmax Softmax normalization
LogSoftmax Log-Softmax
Softmin Softmin normalization
Softplus Smooth approximation of ReLU
Threshold Thresholded activation

Normalization Layers

Layer Description
LayerNorm Layer normalization
RMSNorm Root Mean Square normalization
BatchNorm Batch normalization (In Progress)

Neural Network Layers

Layer Description
Linear Fully connected layer
Dropout Dropout regularization
MLP Multi-Layer Perceptron (Gated-MLP / FFN)
Multi-Head Attention Scaled dot-product attention
Conv1d 1D convolution
Conv2d 2D convolution

Operations

Operation Description
BMM Batched matrix multiplication (supports unbatched inputs)
Normalize L1, L2, and p-norm tensor normalization
Norm Matrix/vector L1, L2, and p-norms
Pairwise Cosine Similarity Distance computation between vectors

Testing

See tests/README.md for full documentation on running tests, CLI options, and test structure.

Contributing

Contributions are welcome! To contribute:

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/your-feature)
  3. Add unit tests under the tests/ directory
  4. Ensure compatibility with PyTorch and Triton
  5. Submit a pull request

Found a bug or have a suggestion? Please Open an issue or submit a Pull Request.

License

TritonTorch is released under the MIT License. You are free to use, modify, and distribute it.

Acknowledgments

Special thanks to the authors of Mamba. Their work has been a valuable reference for parts of this repository.

About

A container of various PyTorch neural network modules written in Triton.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages