Skip to content

joaopedromattos/mochi

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mochi: Aligning Pre-training and Inference for Efficient Graph Foundation Models via Meta-Learning

Mochi Model Diagram Image generated by ChatGPT 5.4.

Reference implementation for Mochi and Mochi++, a meta-learned few-shot graph foundation model that unifies node classification, link prediction, and graph classification under a single differentiable-ridge readout.

Both variants share the same encoder: multi-hop SVD projectors feed a GAMLP network whose output is fed to a closed-form ridge classifier (R2-D2 style). The difference:

Output MLP Training batching
Mochi 2-layer One episode sampled from task mix per step
Mochi++ 3-layer + skip 1 LP + 1 NC + 1 GC episode per step

Install

pip install -r requirements.txt

The codebase depends on PyTorch, PyG, torch-scatter, and OGB. Use the PyG install instructions for your CUDA version.

Quickstart (library interface)

from mochi import Mochi, default_params

model = Mochi(**default_params)            # Mochi++ — paper default
model = Mochi(model_variant="mochi")       # Mochi    — ablation variant

Load pretrained weights from Hugging Face (jrm28/mochi):

from mochi import Mochi, default_params, load_pretrained

model = Mochi(**default_params)
load_pretrained(model, seed=2)             # downloads + loads weights in place

Three seeds are available (seed=0, 1, 2) — all trained on the paper's default setup (latdim=512, gnn_layer=3, ridge_lambda=10, 12 991 steps across 15 LP + 4 NC + 3 GC datasets).

Full end-to-end training:

from mochi import MochiConfig, default_params, Mochi, build_datasets, train, evaluate

cfg = MochiConfig(**default_params)
model = Mochi(**cfg.as_dict()).to("cuda:0")

sampler, lp, nc, gc, device = build_datasets(cfg, repo_root=".")
train(model, sampler, lp, nc, gc, device, cfg)
evaluate(model, sampler, lp, nc, gc, device, cfg)

Quickstart (CLI)

# Mochi++ with the paper defaults
python train.py

# Override hyperparameters
python train.py --model_variant mochi --seed 1 --gpu 0
python train.py --dataset_setting smoke --nc_datasets cora --train_steps 100

# Evaluate a local checkpoint on held-out datasets
python train.py --eval_only --load_model checkpoints/mochi++_s2.pt \
    --dataset_setting link2 \
    --nc_datasets cora cs photo arxiv Fitness \
    --gc_datasets MUTAG PROTEINS DD ENZYMES NCI1 IMDB-BINARY COLLAB REDDIT-MULTI-5K

# Evaluate using pretrained weights from Hugging Face (no training)
python train.py --eval_only --load_pretrained --seed 2

# Multi-GPU DDP
torchrun --nproc_per_node=4 train_ddp.py --gpus 0,1,2,3

See run.sh for the three-seed reproduction commands.

Data layout

The loaders expect:

data/
├── pyg/          # PyG-managed: Planetoid / Coauthor / Amazon / TUDataset
├── ogb/          # OGB-managed (optional)
└── lp/
    ├── <dataset_name>/
    │   ├── trn_mat.pkl    # training adjacency (scipy COO / any 2-D array)
    │   └── feats.pkl      # optional node features ([N, d] numpy)
    └── ...

Override locations via --data_root, --lp_data_root, and --cstag_root (for CS-TAG CSV datasets like Photo and Fitness).

Projectors are cached under cache/projectors/ keyed by (task, dataset, latdim, gnn_layer, niter) — subsequent runs skip the SVD step.

Package layout

mochi/
├── config.py        # MochiConfig dataclass + default_params dict
├── model.py         # Mochi / Mochi++ model, GAMLP encoder, ridge readout
├── projectors.py    # SVD + multi-hop propagation
├── data.py          # NC / LP / GC dataset loaders
├── samplers.py      # Episode samplers
├── training.py      # train / evaluate / save_embeddings
├── entrypoint.py    # build_datasets (convenience wrapper)
└── pretrained.py    # load_pretrained — fetch weights from Hugging Face
train.py             # Single-GPU CLI
train_ddp.py         # Multi-GPU CLI (torchrun)

Citation

If you use this code, please cite the paper.

About

Official repo from Mochi: Aligning Pre-training and Inference for Efficient Graph Foundation Models via Meta-Learning paper.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors