This repository contains the optimization pipeline for the BB84 QKD protocol, investigating the use of neural networks (NNs) to accelerate parameter optimization. It supports both a "Legacy" NumPy-based baseline (Tier 1), a modern JAX-accelerated suite (Tier 3), and a Neural Network surrogate model.
Experience the power of real-time AI optimization through our interactive dashboard.
Launch Live App on Streamlit Cloud
Optimizing parameters is crucial for maximizing the performance of Quantum Key Distribution (QKD) systems. This study investigates the efficacy of neural networks (NNs) as a high-speed alternative to Dual Annealing (DA) for determining optimal operational parameters (signal/decoy intensities μk, probabilities Pμk, basis choice Px) for the finite-key decoy-state BB84 protocol. We demonstrate that a trained NN can predict near-optimal parameters with high accuracy, achieving a >6,000-fold practical speedup over traditional legacy solvers.
- 📉 Rate Improvement: ~47x - 60x higher key rate at long distances (180 km+).
- 📏 Range Extension: +5.0 km extended secure communication distance.
- ⚡ Implementation Speed: >6,000-fold practical speedup (End-to-End Latency).
- 🎯 High Accuracy: The NN predictions closely match the numerically optimized ground truth.
Optimization/: Main project folder for numerical optimization.run_scripts/: Executable scripts for different optimization tiers.tier1_legacy_baseline.py: Baseline (Tier 1). Pure NumPy/SciPydual_annealing. Slow but accurate.tier3_global_jax.py: Sequential Adaptive (Tier 3). JAX-accelerated. The "Gold Standard" for smooth parameters.tier3_jax_segmented.py: Segmented Adaptive. Faster parallel execution with L-splitting.tier2_local.py: Fast Local Search. For quick verification.plot_results.py: Tool to generate plots from JSON results.
outputs/: Centralized storage for all run results.tier1_legacy_baseline/: Logs and JSONs from Tier 1 runs.global_jax/: Logs and JSONs from Tier 3 runs.
notebooks/: Jupyter notebooks (Archived/Research).tools/: Helper utilities (Timing reports, etc.).
src/: Shared physics library (qkd.model,qkd.model_numpy).NeuralNetwork/: Neural network architecture, training notebooks, and pre-trained models.Analysis/: Validation notebooks for the analytical model.
-
Create Environment:
conda create -n qkd-opt python=3.10 conda activate qkd-opt
-
Install Dependencies:
pip install -r requirements.txt
-
Install JAX (OS-Specific):
- Mac (Apple Silicon) / Linux / Windows (CPU):
pip install jax[cpu]
- Linux (NVIDIA GPU):
pip install jax[cuda12]
- Mac (Apple Silicon) / Linux / Windows (CPU):
For production-grade data generation and benchmarking, use the Python scripts in Optimization/run_scripts/.
Pure NumPy implementation using dual_annealing. Useful for verifying results without JAX dependencies.
python Optimization/run_scripts/tier1_legacy_baseline.py- Output:
Optimization/outputs/tier1_legacy_baseline/
The state-of-the-art Adaptive JAX optimizer. Fastest and most robust.
python Optimization/run_scripts/tier3_global_jax.py- Output:
Optimization/outputs/global_jax/
Follow these steps to reproduce the full research study, including Model Verification and Neural Network Training.
Notebook: Analysis/BB84_Parameters_2014_Analysis_Jax.ipynb
- Verifies the core QKD simulation.
- Calculates Secret Key Rate (SKR) using fixed parameters to ensure correct exponential decay.
Notebook: NeuralNetwork/neural_network_updated.ipynb
- Trains the PyTorch model using the dataset generated by the Optimization steps.
- Evaluates accuracy and generates relative error plots.
- Saves the trained model to
NeuralNetwork/models/.
We benchmarked the Neural Network against traditional solvers.
| Solver Type | Time per Point | Speedup Factor |
|---|---|---|
| Legacy (Dual Annealing) |
|
1x (Baseline) |
| Modern (JAX/SciPy) |
|
|
| Neural Network |
|
Comparison of the background training data generation process:
| Feature | Old Method (Baseline) | New Method (JAX Adaptive) | Improvement |
|---|---|---|---|
| Runtime (6000 pts) | ~1 hr 30 min | ~2 min | ~45x Faster ⚡ |
| Algorithm | scipy.dual_annealing |
jax.grad + Adaptive L-BFGS-B |
Gradient-based |
| Quality | Noisy (Staircase artifacts) | Smooth (Physically realistic) | High Stability |
- Prediction Accuracy Near Physical Limits: <5% relative error in useful ranges, but higher near zero-key rate (negligible impact).
- Static Channel Assumption: Trained for static fiber conditions.
- Physics-Informed Neural Networks (PINNs): Incorporate rate equations into the loss function.
- Real-Time Adaptation: Deploy on FPGA/Edge devices for dynamic scenarios.
- Transfer Learning: Extend to MDI-QKD or Twin-Field QKD.
If you use this work in your research, please cite the original project:
@mastersthesis{leung2024mlqkd,
author = {Leung, Shek Lun},
title = {Machine Learning for Quantum Key Distribution Network Optimization},
school = {KTH Royal Institute of Technology},
year = {2024},
supervisor = {Svanberg, Erik and Foletto, Giulio and Adya, Vaishali},
examiner = {Gallo, Katia}
}This work is based on the analytical model presented in: Lim, C. C. W., et al. (2014). "Concise security bounds for practical decoy-state quantum key distribution". Physical Review A, 89(2), 022307.