Skip to content

Framing model.forward more like ML libraries? #188

@iancze

Description

@iancze

A common idiom used in many machine learning libraries is something like

model.fit(X, y)

where X is the feature vector and y is the target vector. Then, the model is "trained" internally. This shows up very commonly in scikit-learn, for example.

In PyTorch, I think I usually see something (for example) like

outputs = model(inputs)

and then a loss function is computed on a separate line, rather than internal to the model (though in theory I think it could be, like in a cross-val runner). Anyway, this makes it relatively easy to divide a dataset up into several smaller batches of inputs and outputs and then train in batches.

On the other hand, in MPoL we usually have something like (for example)

rml = precomposed.SimpleNet(coords=coords, nchan=dset.nchan)
modelVisibilityCube = rml()

and then have a loss function to compare modelVisibilityCube to the data. This obviously works, but it has always bugged me a little bit that we don't take in some X or inputs and therefore this prevents us from training in batches. I think the problem is that there isn't an obvious 1-to-1 relationship between number of input points and number of output points because of a) the Fourier nature of the problem and b) how data averaging ("gridding") affects the number of visibilities.

The Fourier nature of the problem means that you'll always need to populate a full image and then do the full FFT, even if you are comparing to a single data point. This means there isn't much time-saving for training on a smaller "batch" compared to the full batch. This is especially true for the GriddedDataset, and probably still applies in some fashion to the NuFFT.

If we made the NuFFT layer the default FourierTransformer, then we could take in u,v coordinates, such that we'd have

modelVisibilities = model(us, vs)

and then these could be used in a loss function.

But this doesn't make a ton of sense for the SimpleNet and the FourierLayer, which returns a modelVisibilityCube that needs to be indexed by the GriddedDataset.

Is this just a quirk of the nature of our problem, and I shouldn't lose sleep over not taking in an X? Or are we framing the network in some sub-optimal way? The GriddedDataset approach is accurate enough that I think there will be few applications where we want the training loop to function using the NuFFT directly. Rather, it's much more useful for predicting loose visibilities for visualization applications.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions