A TensorFlow/Keras optimizer implementing concepts from Google Research's paper "Nested Learning: The Illusion of Deep Learning Architectures" (Behrouz et al., 2024), extended with multi-timescale associative memory, attention-based gradient retrieval, and depth-aware update scheduling.
In practice, this optimizer has shown significant improvements in performance, convergence speed, and training stability compared to Adam, AdamW, NAdam, and SGD across various architectures.
This implementation extends the theoretical insights from the Nested Learning paper with:
- Continuum Memory System - Three-tier gradient memory (short/medium/long-term) instead of single momentum
- Attention-Based Retrieval - Dynamic memory weighting using scaled dot-product attention
- Depth-Based Update Scheduling - Automatic layer-wise update frequencies based on network depth
- Gradient Accumulation - No information loss for slow-updating layers
- Per-Epoch LR Schedule - Built-in warmup → hold → decay learning rate scheduling
The Nested Learning paper reveals that standard optimizers like Adam and SGD can be viewed as associative memory systems that compress gradient information over time. This optimizer extends that insight by implementing:
- Explicit Multi-Timescale Memory: Instead of a single momentum term, maintains three gradient memories with different decay rates (short/medium/long-term)
- Attention-Based Retrieval: Dynamically weights memory contributions based on current gradient similarity
- Depth-Aware Update Scheduling: Different network depths update at different frequencies
| Feature | Standard Optimizers | This Implementation |
|---|---|---|
| Gradient Memory | Single EMA (momentum) | Three-tier continuum memory |
| Memory Access | Fixed weighted sum | Attention-based retrieval |
| Update Frequency | Uniform across layers | Depth-based scheduling |
| Gradient Handling | Discard between updates | Accumulation for slow layers |
Short-term (decay=0.94) → Fast adaptation, recent gradients
Medium-term (decay=0.994) → Intermediate stability
Long-term (decay=0.9999) → Stable, historical direction
The optimizer blends these memories using one of three modes:
- fixed: Static weights (default: 0.1/0.3/0.6)
- adaptive: Learned weights that adjust during training
- attention: Query-based retrieval using current gradient as query
Earlier layers in deep networks tend to learn more general features that stabilize early. This optimizer leverages that by:
- Assigning update intervals based on layer depth
- Supporting multiple schedule curves:
linear,exponential,cosine - Accumulating gradients for slow-updating layers (no information loss)
pip install nested-learning-optimizerOr from source:
git clone https://github.com/Kareemfarid/nested-learning-optimizer.git
cd nested-learning-optimizer
pip install -e .from nested_learning_optimizer import NestedLearningOptimizer
optimizer = NestedLearningOptimizer(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='mse')
model.fit(x_train, y_train, epochs=10)optimizer = NestedLearningOptimizer(
learning_rate=0.001,
auto_schedule=True,
max_interval=6,
schedule_curve="cosine",
)
# Analyze model structure before compiling
optimizer.compute_depths_from_model(model)
model.compile(optimizer=optimizer, loss='categorical_crossentropy')steps_per_epoch = len(train_dataset)
optimizer = NestedLearningOptimizer(
learning_rate=0.001,
steps_per_epoch=steps_per_epoch,
lr_warmup_fraction=0.25, # 25% warmup
lr_hold_fraction=0.25, # 25% hold at peak
lr_decay_fraction=0.50, # 50% decay
)optimizer = NestedLearningOptimizer(auto_schedule=True)
# Register specific layers with custom depths
optimizer.register_layer_depth(model.layers[0], depth=0) # Slowest updates
optimizer.register_layer_depth(model.layers[-1], depth=10) # Fastest updates
model.compile(optimizer=optimizer, loss='mse')optimizer = NestedLearningOptimizer(
auto_schedule=False,
layer_update_intervals={
"embedding": 8, # Very slow updates
"encoder": 4, # Moderate
"decoder": 2, # Faster
"output": 1, # Every step
},
default_interval=1,
)| Parameter | Type | Default | Description |
|---|---|---|---|
short_term_decay |
float | 0.94 | Decay rate for short-term memory |
medium_term_decay |
float | 0.994 | Decay rate for medium-term memory |
long_term_decay |
float | 0.9999 | Decay rate for long-term memory |
memory_blend_mode |
str | "attention" | How to combine memories: "fixed", "adaptive", "attention" |
short_term_weight |
float | 0.1 | Fixed weight for short-term (when mode="fixed") |
medium_term_weight |
float | 0.3 | Fixed weight for medium-term |
long_term_weight |
float | 0.6 | Fixed weight for long-term |
| Parameter | Type | Default | Description |
|---|---|---|---|
auto_schedule |
bool | True | Auto-compute intervals from layer depth |
max_interval |
int | 6 | Maximum update interval for slowest layers |
schedule_curve |
str | "cosine" | Interval curve: "linear", "exponential", "cosine" |
reverse_depth_order |
bool | True | If True, deeper layers update slower |
warmup_steps |
int | 0 | Steps where all layers update every step |
accumulate_gradients |
bool | True | Accumulate gradients for slow layers |
| Parameter | Type | Default | Description |
|---|---|---|---|
use_gradient_compression |
bool | False | Enable low-rank gradient approximation |
compression_rank |
int | None | Rank for SVD compression |
| Parameter | Type | Default | Description |
|---|---|---|---|
beta_1 |
float | 0.9 | First moment decay |
beta_2 |
float | 0.999 | Second moment decay |
epsilon |
float | 1e-7 | Numerical stability |
amsgrad |
bool | False | Use AMSGrad variant |
| Parameter | Type | Default | Description |
|---|---|---|---|
steps_per_epoch |
int | None | Steps per epoch (enables schedule) |
lr_warmup_fraction |
float | 0.25 | Fraction of epoch for warmup |
lr_hold_fraction |
float | 0.25 | Fraction to hold at peak |
lr_decay_fraction |
float | 0.50 | Fraction for decay |
lr_min_fraction |
float | 0.01 | Min LR as fraction of base |
Analyze model structure to assign depths automatically. Call before model.compile().
depths = optimizer.compute_depths_from_model(model)Manually set depth for a layer's variables.
optimizer.register_layer_depth(model.get_layer("encoder"), depth=2)Manually set depth for a specific variable.
optimizer.register_variable_depth(my_variable, depth=5)Configure LR schedule after initialization.
optimizer.set_steps_per_epoch(1000)Print debug information about optimizer state.
optimizer.print_memory_stats()iterations: Current global stepcurrent_learning_rate: Current scheduled LR
See the examples/ directory for complete examples:
basic_usage.py: Simple training exampledepth_scheduling.py: Using automatic depth-based schedulingcustom_memory.py: Configuring memory blend modes
short_mem = 0.94 * short_mem + 0.06 * gradient
medium_mem = 0.994 * medium_mem + 0.006 * gradient
long_mem = 0.9999 * long_mem + 0.0001 * gradient
scores = [dot(gradient, mem) / sqrt(dim) for mem in memories]
weights = softmax(scores)
combined = sum(w * mem for w, mem in zip(weights, memories))
effective_grad = 0.5 * gradient + 0.5 * combinedv = beta_2 * v + (1 - beta_2) * effective_grad²
update = lr * effective_grad / (sqrt(v / bias_correction) + epsilon)
weights -= updateIf you use this optimizer in your research, please cite:
@article{behrouz2024nested,
title={Nested Learning: The Illusion of Deep Learning Architectures},
author={Behrouz, Ali and others},
journal={arXiv preprint arXiv:2402.09747},
year={2024}
}Kareem Farid
- GitHub: @Kareemfarid
- Twitter: @kareemfarid
MIT License - see LICENSE for details.
Contributions welcome! Please see CONTRIBUTING.md for guidelines.