High-performance machine learning library for Node.js with GPU acceleration
Quick Start · Training · Training TUI · Architecture · API Reference · Documentation
MLX-Node brings Apple's MLX framework to JavaScript/TypeScript, enabling efficient on-device ML inference and training on Apple Silicon and CUDA devices. Built with a Rust compute layer and TypeScript orchestration, it delivers production-ready GRPO training with 100% feature parity with HuggingFace's TRL library.
| Feature | Description | |
|---|---|---|
| ⚡ | Metal GPU Acceleration | Native Apple Silicon performance via MLX with lazy evaluation and operation fusion |
| 🎯 | GRPO Training | Complete reinforcement learning pipeline with 4 loss variants (GRPO, DAPO, Dr.GRPO, BNPO) |
| 🤖 | Qwen Models | Support for 0.6B, 1.7B, 4B, 8B, 14B, 32B parameter models with advanced sampling |
| 🔄 | Automatic Differentiation | Compute gradients through entire models via functional forward pass |
| 🚫 | Zero Python Dependency | Pure Rust/TypeScript implementation — no Python runtime required |
| 📊 | TypedArray-First API | Zero-copy operations using native JavaScript typed arrays |
|
|
|
|
|
|
- macOS with Apple Silicon (M1/M2/M3/M4/M5) with Metal (Linux/Windows with CUDA is coming soon)
- Node.js 18+
- Rust 1.90
git clone https://github.com/mlx-node/mlx-node.git
cd mlx-node
git submodule update --init --recursive
yarn install
yarn buildyarn mlx download model
yarn mlx convert --input .cache/models/qwen3-0.6b -d bf16 --output .cache/models/qwen3-0.6b-mlx-bf16yarn oxnode ./examples/lm.tsimport { Qwen3Model } from '@mlx-node/lm';
const model = await Qwen3Model.load('.cache/models/qwen3-0.6b-mlx-bf16');
const result = await model.generate([{ role: 'user', content: 'Write a haiku about TypeScript.' }], {
maxNewTokens: 50,
temperature: 0.8,
});
console.log(result.text);Train language models using Group Relative Policy Optimization:
import { GRPOTrainer, loadLocalGsm8kDataset } from '@mlx-node/trl';
const trainer = await GRPOTrainer.create({
modelPath: '.cache/models/qwen3-0.6b-mlx-bf16',
outputDir: 'outputs/my-training',
// Training hyperparameters
learningRate: 5e-6,
batchSize: 4,
groupSize: 4,
numEpochs: 3,
// Generation
maxNewTokens: 256,
temperature: 0.8,
repetitionPenalty: 1.1,
// GRPO parameters
clipEpsilon: 0.2,
klCoef: 0.1,
lossType: 'grpo', // or 'dapo', 'dr_grpo', 'bnpo'
// Custom reward function
rewardFunction: async (prompts, completions, answers) => {
return completions.map((completion, i) => {
const expected = answers[i];
if (!expected) return 0;
return completion.includes(expected) ? 1.0 : 0.0;
});
},
});
const dataset = await loadLocalGsm8kDataset('.cache/gsm8k', 100);
await trainer.train(dataset);// Register multiple reward functions
trainer.registerBuiltinReward({
rewardType: 'ToolUse',
allowedTools: ['search', 'calculate'],
weight: 1.0,
});
trainer.registerBuiltinReward({
rewardType: 'XmlFormat',
requiredTags: ['thinking', 'answer'],
weight: 0.5,
});
trainer.registerBuiltinReward({
rewardType: 'Length',
minLength: 100,
maxLength: 500,
});| Variant | Description |
|---|---|
grpo |
Standard Group Relative Policy Optimization |
dapo |
Dynamic Advantage Policy Optimization — adaptive clipping |
dr_grpo |
Dropout-Regularized GRPO — improved stability |
bnpo |
Batch-Normalized Policy Optimization — normalized advantages |
# training with complex reward function
yarn oxnode examples/grpo/train-github-tool.tsMLX-Node includes a terminal user interface (TUI) built with Ratatui for real-time training visualization and control.
# Basic usage
cargo run -p mlx-tui -- --import '@oxc-node/core/register' --script ./examples/grpo/train-tool-use.tsThe TUI wraps your Node.js training script and communicates via stdout (JSONL messages) and stdin (control commands).
| Panel | Description |
|---|---|
| Header | Model name, epoch/step progress, training status |
| Metrics | Loss, reward, and advantage with sparkline history |
| Progress | Epoch and step progress bars with percentages |
| Stats | Token count, elapsed time, step speed breakdown |
| Logs | Real-time training logs (scrollable) |
| Samples | Generated samples with rewards (best/worst/latest modes) |
| Config | Current training configuration |
| Key | Action |
|---|---|
p |
Pause training |
r |
Resume training |
s |
Save checkpoint |
Tab |
Switch between tabs (Logs/Samples/Config) |
↑ ↓ |
Scroll within current tab |
m |
Cycle sample display mode (Best → Worst → Latest) |
? |
Toggle help overlay |
q |
Quit TUI |
To make your training script compatible with the TUI, enable tuiMode in the trainer:
import { GRPOTrainer } from '@mlx-node/trl';
const trainer = await GRPOTrainer.create({
modelPath: '.cache/models/qwen3-0.6b-mlx-bf16',
outputDir: 'outputs/training',
tuiMode: true, // Enable TUI-compatible output
// ... other options
});When tuiMode is enabled:
- All logging output uses JSONL format for TUI parsing
- The trainer listens for stdin commands (pause, resume, save)
- Progress updates are sent as structured messages
The TUI communicates with training scripts via a simple protocol:
Training → TUI (stdout, JSONL):
{"type": "step", "epoch": 1, "step": 10, "loss": 0.5, "reward": 4.2}
{"type": "log", "level": "info", "message": "Starting epoch 2"}
{"type": "sample", "prompt": "...", "completion": "...", "reward": 5.0}TUI → Training (stdin, line commands):
pause
resume
save
MLX-Node uses a clean two-layer architecture: Rust for compute, TypeScript for orchestration.
┌──────────────────────────────────────────────────────────────────────────┐
│ TypeScript Orchestration Layer │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ @mlx-node/lm │ │ @mlx-node/trl │ │ @mlx-node/core │ │
│ │ Model loading │ │ GRPO Trainer │ │ (internal) │ │
│ │ Generation │ │ Rewards │ │ NAPI bindings │ │
│ │ Configs │ │ Datasets │ │ Type exports │ │
│ └─────────────────┘ └─────────────────┘ └─────────────────┘ │
├──────────────────────────────────────────────────────────────────────────┤
│ NAPI-RS Bridge │
├──────────────────────────────────────────────────────────────────────────┤
│ Rust Compute Layer │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ array/ │ │ transformer/ │ │ grpo/ │ │ optimizers/ │ │
│ │ 90+ ops │ │ Attention │ │ Loss │ │ Adam(W) │ │
│ │ Masking │ │ KVCache │ │ Advantages │ │ SGD │ │
│ │ Padding │ │ MLP │ │ Autograd │ │ RMSprop │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ nn/ │ │ models │ │ sampling │ │ tokenizer │ │
│ │ Linear │ │ Forward │ │ Top-k/p │ │ HuggingFace │ │
│ │ RMSNorm │ │ Generation │ │ Min-p │ │ Chat │ │
│ │ Embedding │ │ Persistence │ │ Rep. Pen. │ │ Templates │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ │
├──────────────────────────────────────────────────────────────────────────┤
│ mlx-sys FFI + C++ Bridge │
├──────────────────────────────────────────────────────────────────────────┤
│ MLX Library + Metal GPU Backend │
│ Lazy evaluation · Operation fusion · GPU kernels │
└──────────────────────────────────────────────────────────────────────────┘
| Package | Purpose | Use For |
|---|---|---|
@mlx-node/lm |
Model loading & inference | Loading models, generating text, model configs |
@mlx-node/trl |
Training & optimization | GRPO training, custom rewards, optimizers |
@mlx-node/core |
Native bindings (internal) | Low-level operations (usually import via lm/trl) |
@mlx-node/cli |
CLI | Download models, quantize weights |
@mlx-node/vlm |
Vision-language models | PaddleOCR-VL, document processing |
- Zero-copy TypedArray operations — Direct memory access without serialization
- Lazy evaluation — Operations are traced and fused before execution
- Fused kernels — Combined attention, MLP, and transformer blocks in C++
- Rust training loop — Fast!
- Thread-safe handles —
Arc<MxHandle>for safe multi-threaded access - Memory-efficient caching — Standard, Batch, and Rotating KV cache options
# Build
yarn build # Release build (native + TypeScript)
yarn build:debug # Debug build
yarn build:native # Native addon only
yarn build:ts # TypeScript packages only
# Test
vp test # All tests
vp test run <path> # Specific test file
# Quality
vp check # Linting & formatting & TypecheckingMIT License — see LICENSE for details.
- Apple MLX Team for the ML framework
- HuggingFace TRL for GRPO reference implementation
- unsloth for dynamic quantization
- NAPI-RS for seamless Node.js bindings
- Qwen Team for the model architecture
