Better checkpoints#6
Conversation
tscholak
left a comment
There was a problem hiding this comment.
Great changes, thanks @jlamypoirier.
I have a few suggestions and questions for clarification. Can you please answer them?
| # Distributed checkpoint for fast checkpointing and resuming. | ||
| distributed = "distributed" | ||
| # Model state dict, for safe long-term storage in Fast-LLM format. | ||
| state_dict = "state_dict" |
There was a problem hiding this comment.
how about we call this long_term instead?
There was a problem hiding this comment.
I'm ok with renaming (in a future PR), but I would prefer a name more related to Fast-LLM, since it's meant to be the "standard" Fast-LLM checkpoint format. (Maybe just "fast_llm"?)
| if default.get("format", None) == "huggingface": | ||
| warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.") | ||
| default["format"] = CheckpointFormat.external.value |
There was a problem hiding this comment.
will there be other external checkpoint formats or why this change?
There was a problem hiding this comment.
Maybe, in theory the conversion mechanism can be used for any kind of checkpoint format, I had renamed anything else already but not this one because of backward compatibility.
I'm thinking of getting rid of it altogether in the next PR though, and just use the "model_type" as the format for external formats.
| @@ -0,0 +1,147 @@ | |||
| # TODO: Use packaging.version? (Safer but extra requirement) | |||
There was a problem hiding this comment.
not sure where to ask this question:
where will we be configuring the layers-per-step option?
There was a problem hiding this comment.
That's a hack for the single process converter, it doesn't apply here because we already have the whole model loaded. I'm not entirely sure about memory usage but it should be OK because we reconstruct one layer at the time.
|
|
||
| class CheckpointFormat(str, enum.Enum): | ||
| # Distributed checkpoint for fast checkpointing and resuming. | ||
| distributed = "distributed" |
There was a problem hiding this comment.
how about we call this fast or short_term?
There was a problem hiding this comment.
Also an option, let's discuss it later.
|
|
||
| # TODO: Simplify branching. | ||
| if checkpoint_config.format == CheckpointFormat.external: | ||
| # TODO: Support optimizer? |
There was a problem hiding this comment.
can you confirm that Fast-LLM will set the checkpoint_config.optimizer_state to False somewhere (or reject the config at launch) before we end up here? Would be frustrating to crash thousands of steps into training because of this.
There was a problem hiding this comment.
Right now there is no safety check, but that would be easy to add. Looking at it, right now the bigger problem is that the optimizer isn't saved by default so the bigger risk is not saving it when we want to...
| # Intervals are a common pattern, so we standardize them with this base class. | ||
| interval: int | None = Field( | ||
| default=None, | ||
| desc="The number of training iterations between each interval. Setting to None will disable.", |
There was a problem hiding this comment.
maybe the wrong place to do this, but can we make it so that there is at least a warning that checkpoint saving is disabled? I don't think people will appreciate training for hours (or days) only to find out that they forgot to set a saving interval ;)
There was a problem hiding this comment.
That's an option, but would the warning actually be useful? A runtime warning is unlikely to be seen, and a validation one is only useful if the config is validated before launch.
There was a problem hiding this comment.
Hm, I see. Do you have a better idea?
There was a problem hiding this comment.
Not really....
| interval = FieldUpdate( | ||
| desc="The number of training iterations between each Wandb status post (alert)." | ||
| " Setting to None will disable iteration-based wandb alerts." | ||
| " Must be a sub-interval of the logging interval." |
There was a problem hiding this comment.
you mean that wandb posting can only happen at logging times? That would be super-interval, no? can you clarify what you mean here? I'm confused.
There was a problem hiding this comment.
Yes. By sub-interval I meant posting iterations are a subset of logging ones, but that's probably not the right term.
| class CheckpointBaseConfig(IntervalConfig): | ||
| _abstract = True | ||
| save_name: typing.ClassVar[str] = "save" | ||
| directory_name: typing.ClassVar[str] = "save" |
There was a problem hiding this comment.
for what kind of checkpoint will we have save as the name of the output directory?
There was a problem hiding this comment.
None, that's a placeholder (I could just remove?)
| save_name: typing.ClassVar[str] = "export" | ||
| directory_name = "export" | ||
| interval = FieldUpdate( | ||
| desc="The number of training iterations between each export." " Setting to None will disable exports." |
There was a problem hiding this comment.
Can this interval be incompatible with the checkpointing interval (i.e., not an exact multiple or divisor of the checkpointing interval)?
There was a problem hiding this comment.
Yes it doesn't matter anymore because checkpoints and exports are completely independent.
| interval = FieldUpdate( | ||
| desc="The number of training iterations between each automated shutdown." | ||
| " Setting to None will disable automated shutdowns." | ||
| " Must be a sub-interval of the checkpoint interval." |
There was a problem hiding this comment.
what's an automated shutdown? I am not familiar with this feature? is this new?
There was a problem hiding this comment.
It's always been there but never used.
tscholak
left a comment
There was a problem hiding this comment.
Thanks @jlamypoirier for answering my comments. Since most changes I suggested will be future work, this can be merged as is.
RENAME[BW COMPATIBLE] pretrained.imported_type ->pretrained.model_type
RENAME[BW COMPATIBLE] pretrained.load_weights -> pretrained.model_weights
RENAME[BW COMPATIBLE] pretrained.load_optimizer -> pretrained.optimizer_state
MERGE (pretrained.override_architecture, pretrained.load_full_base_model_config, pretrained.load_full_fast_llm_config) ->pretrained.load_config:enum