diff --git a/docs/source/usage.rst b/docs/source/usage.rst index eeeef3bb..2e7124cd 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1324,3 +1324,94 @@ Below is the documentation on the available arguments. --train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split. --valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split --share-model + +Model training using the Torch API +---------------------------------- + +The scikit-learn API provides parametrization to many common use cases. +The Torch API however allows for more flexibility and customization, for e.g. +sampling, criterions, and data loaders. + +In this minimal example we show how to initialize a CEBRA model using the Torch API. +Here the :py:class:`cebra.data.single_session.DiscreteDataLoader` +gets initialized which also allows the `prior` to be directly parametrized. + +👉 For an example notebook using the Torch API check out the :doc:`demo_notebooks/Demo_Allen`. + + +.. testcode:: + + import numpy as np + import cebra.datasets + import torch + + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + neural_data = cebra.load_data(file="neural_data.npz", key="neural") + + discrete_label = cebra.load_data( + file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"], + ) + + # 1. Define a CEBRA-ready dataset + input_data = cebra.data.TensorDataset( + torch.from_numpy(neural_data).type(torch.FloatTensor), + discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor), + ).to(device) + + # 2. Define a CEBRA model + neural_model = cebra.models.init( + name="offset10-model", + num_neurons=input_data.input_dimension, + num_units=32, + num_output=2, + ).to(device) + + input_data.configure_for(neural_model) + + # 3. Define the Loss Function Criterion and Optimizer + crit = cebra.models.criterions.LearnableCosineInfoNCE( + temperature=1, + ).to(device) + + opt = torch.optim.Adam( + list(neural_model.parameters()) + list(crit.parameters()), + lr=0.001, + weight_decay=0, + ) + + # 4. Initialize the CEBRA model + solver = cebra.solver.init( + name="single-session", + model=neural_model, + criterion=crit, + optimizer=opt, + tqdm_on=True, + ).to(device) + + # 5. Define Data Loader + loader = cebra.data.single_session.DiscreteDataLoader( + dataset=input_data, num_steps=10, batch_size=200, prior="uniform" + ) + + # 6. Fit Model + solver.fit(loader=loader) + + # 7. Transform Embedding + train_batches = np.lib.stride_tricks.sliding_window_view( + neural_data, neural_model.get_offset().__len__(), axis=0 + ) + + x_train_emb = solver.transform( + torch.from_numpy(train_batches[:]).type(torch.FloatTensor).to(device) + ).to(device) + + # 8. Plot Embedding + cebra.plot_embedding( + x_train_emb, + discrete_label[neural_model.get_offset().__len__() - 1 :, 0], + markersize=10, + )