Skip to content

faranbutt/Cassava-Disease-Prediction

Repository files navigation

๐ŸŒฟ Cassava Leaf Disease Classification

A high-performance, production-ready deep learning system for cassava leaf disease classification, inspired by 3rd-place Kaggle competition solutions. This project leverages Vision Transformers (ViT), image patch division, attention-based weighting, and ensemble learning, wrapped inside a full MLOps pipeline from training to Triton-based deployment.


๐Ÿง  Architecture Overview

View Architecture Diagram


๐Ÿ“Œ Problem Statement

Cassava is a staple crop for over 800 million people worldwide, yet it is highly vulnerable to leaf diseases that severely impact yield. Accurate and early detection is critical for disease management and food security.

This repository provides an end-to-end solution for classifying cassava leaf diseases from images using state-of-the-art deep learning techniques.


๐Ÿ“ Project Structure

Cassava-Disease-Prediction/
โ”œโ”€โ”€ configs/
โ”œโ”€โ”€ deployment/
โ”‚   โ””โ”€โ”€ quantization/
โ”‚   โ””โ”€โ”€ triton/
โ”œโ”€โ”€ src/cassava_classifier/
โ”‚   โ”œโ”€โ”€ data/
โ”‚   โ”œโ”€โ”€ models/
โ”‚   โ”œโ”€โ”€ pipelines/
โ”‚   โ””โ”€โ”€ utils/
โ”œโ”€โ”€ images/
โ”œโ”€โ”€ artifacts/
โ”œโ”€โ”€ plots/
โ”œโ”€โ”€ models/
โ”œโ”€โ”€ outputs/
โ”‚   โ”œโ”€โ”€ model1/
|   |   โ”œโ”€โ”€ model1_best.ckpt
|   |   โ”œโ”€โ”€ model1_best.onnx
|   |   โ”œโ”€โ”€ model1_best.trt
โ”‚   โ”œโ”€โ”€ model3/
|   |   โ”œโ”€โ”€ model2_best.ckpt
|   |   โ”œโ”€โ”€ model2_best.onnx
|   |   โ”œโ”€โ”€ model2_best.trt
โ”‚   โ”œโ”€โ”€ model3/
|   |   โ”œโ”€โ”€ model3_best.ckpt
|   |   โ”œโ”€โ”€ model3_best.onnx
|   |   โ”œโ”€โ”€ model3_best.trt
โ”œโ”€โ”€ data.dvc
โ”œโ”€โ”€ pyproject.toml
โ””โ”€โ”€ README.md

๐Ÿš€ Key Features

  • Ensemble of 3 Vision Transformer models
    • ViT-384 (global context)
    • ViT-448 ร—2 (patch-based fine-grained analysis)
  • Image Division Strategy
    • 448ร—448 images split into four 224ร—224 patches
  • Attention-Based Feature Weighting
    • Learns importance of spatial regions
  • Multi-Dropout Regularization
    • Improves robustness and generalization
  • Label Smoothing
    • Handles noisy labels
  • End-to-End MLOps
    • DVC, Hydra, PyTorch Lightning, MLflow
  • Production Deployment
    • ONNX โ†’ TensorRT โ†’ Triton Inference Server

๐Ÿง  Classes

  • Cassava Bacterial Blight (CBB)
  • Cassava Brown Streak Disease (CBSD)
  • Cassava Green Mottle (CGM)
  • Cassava Mosaic Disease (CMD)
  • Healthy

๐Ÿ› ๏ธ Tech Stack

Component Tool
Training PyTorch Lightning
Models Vision Transformers (timm)
Config Hydra
Data DVC
Tracking MLflow
Inference ONNX Runtime / TensorRT
Serving Triton Inference Server
Env uv

โš™๏ธ Installation

Prerequisites

  • Python 3.9+
  • Git
  • CUDA GPU (optional)
git clone https://github.com/faranbutt/Cassava-Disease-Prediction.git
cd Cassava-Disease-Prediction

uv venv
source .venv/bin/activate

uv pip install -e .
uv run pre-commit install

๐Ÿš‚ Training Pipeline

Training is orchestrated using Hydra and PyTorch Lightning, with datasets versioned and managed via DVC.


1๏ธโƒฃ Download Dataset

dvc pull

This command fetches the following files: train.csv train_images/

Drive Link:

Download Dataset from Google Drive

2๏ธโƒฃ Run Full Pipeline

MPLBACKEND=Agg python src/cassava_classifier/commands.py run_full=true

This command automatically: Trains 3 ViT models Runs K-Fold validation Logs metrics to MLflow Saves checkpoints Exports models to ONNX Converts to TensorRT Builds Triton ensemble

โš™๏ธ Configuration (Hydra)

All parameters are configurable via YAML:

configs/
โ”œโ”€โ”€ model/     # ViT variants, image size, attention
โ”œโ”€โ”€ train/     # epochs, batch size, LR, folds
โ”œโ”€โ”€ data/      # dataset paths
โ””โ”€โ”€ config.yaml

โœ… No hardcoded values โ€” everything is configurable.

๐Ÿ–ฅ๏ธ Streamlit Interface

Streamlit UI

๐ŸŒ Hugging Face Deployment

You can interact with the live model via Hugging Face Spaces, deployed automatically using GitHub Actions:

Try the Cassava Leaf Disease Detector

โšก Note: The deployment is linked to this repository and automatically updates via GitHub Actions whenever changes are pushed to the main branch.

๐Ÿ“ฆ Production Artifacts

Generated automatically during training: model_best.ckpt โ€“ PyTorch checkpoints model.onnx โ€“ Framework-agnostic inference model.trt โ€“ Optimized TensorRT engine Triton config.pbtxt

These artifacts are excluded from Git and managed locally or via deployment targets.

โšก Model Quantization

To accelerate CPU-based inference, models are quantized to FP16 using ONNX Runtime:

  • FP16 reduces memory usage and speeds up inference on CPU/GPU with minimal accuracy loss.
  • This was done to save space on the Huggingface repo as it only give 1GB of space

Quantized models are saved as:

๐Ÿ’พ Saved Models

Saved Models

๐Ÿ”ฎ Inference Options

๐Ÿงช Local (PyTorch / ONNX)

python src/cassava_classifier/commands.py \
  predict=true \
  +predict.use_ensemble=true \
  +predict.image_path="data/test.jpg"

๐Ÿš€ Production (Triton Server)

docker run --rm -p 8000:8000 \
  -v $(pwd)/triton:/models \
  nvcr.io/nvidia/tritonserver:25.11-py3 \
  tritonserver --model-repository=/models

๐Ÿš€ Triton Model Serving

Triton Inference

Supports:

GPU (TensorRT) CPU-only (ONNX Runtime backend)

๐Ÿ“Š Experiment Tracking

MLflow logs:

Hyperparameters

  • Metrics: loss, accuracy, F1
  • Per-fold results
  • Training curves

Launch MLflow:

mlflow ui --backend-store-uri sqlite:///mlflow.db --port 8080

๐Ÿ“Š MLflow Metrics

MLflow Metrics

๐Ÿ“ˆ Training Plots

Training Graphs

๐Ÿงช Code Quality

black, isort, flake8 Pre-commit enforced Clean modular package structure CLI-based execution (Hydra)

Pre-commit Hooks

๐Ÿ”— References

Releases

No releases published

Packages

No packages published