A high-performance, production-ready deep learning system for cassava leaf disease classification, inspired by 3rd-place Kaggle competition solutions. This project leverages Vision Transformers (ViT), image patch division, attention-based weighting, and ensemble learning, wrapped inside a full MLOps pipeline from training to Triton-based deployment.
Cassava is a staple crop for over 800 million people worldwide, yet it is highly vulnerable to leaf diseases that severely impact yield. Accurate and early detection is critical for disease management and food security.
This repository provides an end-to-end solution for classifying cassava leaf diseases from images using state-of-the-art deep learning techniques.
Cassava-Disease-Prediction/
โโโ configs/
โโโ deployment/
โ โโโ quantization/
โ โโโ triton/
โโโ src/cassava_classifier/
โ โโโ data/
โ โโโ models/
โ โโโ pipelines/
โ โโโ utils/
โโโ images/
โโโ artifacts/
โโโ plots/
โโโ models/
โโโ outputs/
โ โโโ model1/
| | โโโ model1_best.ckpt
| | โโโ model1_best.onnx
| | โโโ model1_best.trt
โ โโโ model3/
| | โโโ model2_best.ckpt
| | โโโ model2_best.onnx
| | โโโ model2_best.trt
โ โโโ model3/
| | โโโ model3_best.ckpt
| | โโโ model3_best.onnx
| | โโโ model3_best.trt
โโโ data.dvc
โโโ pyproject.toml
โโโ README.md- Ensemble of 3 Vision Transformer models
- ViT-384 (global context)
- ViT-448 ร2 (patch-based fine-grained analysis)
- Image Division Strategy
- 448ร448 images split into four 224ร224 patches
- Attention-Based Feature Weighting
- Learns importance of spatial regions
- Multi-Dropout Regularization
- Improves robustness and generalization
- Label Smoothing
- Handles noisy labels
- End-to-End MLOps
- DVC, Hydra, PyTorch Lightning, MLflow
- Production Deployment
- ONNX โ TensorRT โ Triton Inference Server
- Cassava Bacterial Blight (CBB)
- Cassava Brown Streak Disease (CBSD)
- Cassava Green Mottle (CGM)
- Cassava Mosaic Disease (CMD)
- Healthy
| Component | Tool |
|---|---|
| Training | PyTorch Lightning |
| Models | Vision Transformers (timm) |
| Config | Hydra |
| Data | DVC |
| Tracking | MLflow |
| Inference | ONNX Runtime / TensorRT |
| Serving | Triton Inference Server |
| Env | uv |
- Python 3.9+
- Git
- CUDA GPU (optional)
git clone https://github.com/faranbutt/Cassava-Disease-Prediction.git
cd Cassava-Disease-Prediction
uv venv
source .venv/bin/activate
uv pip install -e .
uv run pre-commit installTraining is orchestrated using Hydra and PyTorch Lightning, with datasets versioned and managed via DVC.
dvc pullThis command fetches the following files: train.csv train_images/
Download Dataset from Google Drive
MPLBACKEND=Agg python src/cassava_classifier/commands.py run_full=trueThis command automatically: Trains 3 ViT models Runs K-Fold validation Logs metrics to MLflow Saves checkpoints Exports models to ONNX Converts to TensorRT Builds Triton ensemble
All parameters are configurable via YAML:
configs/
โโโ model/ # ViT variants, image size, attention
โโโ train/ # epochs, batch size, LR, folds
โโโ data/ # dataset paths
โโโ config.yamlโ No hardcoded values โ everything is configurable.
You can interact with the live model via Hugging Face Spaces, deployed automatically using GitHub Actions:
Try the Cassava Leaf Disease Detector
โก Note: The deployment is linked to this repository and automatically updates via GitHub Actions whenever changes are pushed to the
mainbranch.
Generated automatically during training: model_best.ckpt โ PyTorch checkpoints model.onnx โ Framework-agnostic inference model.trt โ Optimized TensorRT engine Triton config.pbtxt
These artifacts are excluded from Git and managed locally or via deployment targets.
To accelerate CPU-based inference, models are quantized to FP16 using ONNX Runtime:
- FP16 reduces memory usage and speeds up inference on CPU/GPU with minimal accuracy loss.
- This was done to save space on the Huggingface repo as it only give 1GB of space
Quantized models are saved as:
๐งช Local (PyTorch / ONNX)
python src/cassava_classifier/commands.py \
predict=true \
+predict.use_ensemble=true \
+predict.image_path="data/test.jpg"๐ Production (Triton Server)
docker run --rm -p 8000:8000 \
-v $(pwd)/triton:/models \
nvcr.io/nvidia/tritonserver:25.11-py3 \
tritonserver --model-repository=/modelsGPU (TensorRT) CPU-only (ONNX Runtime backend)
Hyperparameters
- Metrics: loss, accuracy, F1
- Per-fold results
- Training curves
mlflow ui --backend-store-uri sqlite:///mlflow.db --port 8080black, isort, flake8 Pre-commit enforced Clean modular package structure CLI-based execution (Hydra)






