Conversation
ser/train.py
Outdated
| learning_rate = 0.01 | ||
|
|
||
| # ----- TO-DO: save the parameters! ----- | ||
| params = {'epochs': epochs, 'batch size': batch_size, 'learning rate': learning_rate} |
There was a problem hiding this comment.
A dataclass would be better for the hyperparams!!
ser/train.py
Outdated
| ts = get_transforms() | ||
|
|
||
| # dataloaders | ||
| training_dataloader = load_data(directory="../data", download_bool=True, train_bool=True, ts=ts, batch_size=batch_size, shuffle_bool=True, num_workers=1) |
There was a problem hiding this comment.
Better to use the same DATA_DIR variable here to ensure that both are working out of the same data directory regardless of where you run the command.
|
|
||
|
|
||
| # train | ||
| for epoch in range(epochs): |
There was a problem hiding this comment.
This is stylistic, but I think the training and validation loop would be better split out as separate functions.
It will make the overall algorithm very clear and it will remain clear if the train or validation code grows in complexity.
ser/train.py
Outdated
|
|
||
| # ----- TO-DO: save the parameters! ----- | ||
| params = {'epochs': epochs, 'batch size': batch_size, 'learning rate': learning_rate} | ||
| filename = RESULTS_DIR / "params.json" |
There was a problem hiding this comment.
This will overwrite the parameters file for every experiment run. What you really need to do is pick an identifier for each experiment, create a directory for that experiment, and then save the results into it. That way you won't lose any old results/parameters.
I'd suggest an experiment name and timestamp as good candidates for an identifier.
Finished refactoring code