Skip to content

Better checkpoints#6

Merged
jlamypoirier merged 10 commits into
mainfrom
better_checkpoints
Oct 21, 2024
Merged

Better checkpoints#6
jlamypoirier merged 10 commits into
mainfrom
better_checkpoints

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented Oct 17, 2024

  • Add a FieldUpdate tool to simplify field overrides.
  • Standardize checkpoint saving and loading configs, using mixins so training checkpoint/export configs can use just the right fields (ex. no path since it's provided by the trainer)
  • Adjust some fields so they make sense for both saving and loading, and merge the ones for config selection into an enum (see below).
  • Move checkpoint save/load from Run to Trainer and simplify.
  • Make IntervalConfig into a proper class (use FieldUpdate foe doc instead)
  • Make export checkpoint fully configurable, save separately instead of using a symlink.
  • Add keep_every for checkpoints to get the old behaviour, and optional callback script.

RENAME[BW COMPATIBLE] pretrained.imported_type ->pretrained.model_type
RENAME[BW COMPATIBLE] pretrained.load_weights -> pretrained.model_weights
RENAME[BW COMPATIBLE] pretrained.load_optimizer -> pretrained.optimizer_state
MERGE (pretrained.override_architecture, pretrained.load_full_base_model_config, pretrained.load_full_fast_llm_config) ->pretrained.load_config:enum

@jlamypoirier jlamypoirier marked this pull request as ready for review October 18, 2024 14:11
@jlamypoirier jlamypoirier marked this pull request as draft October 18, 2024 15:42
@jlamypoirier jlamypoirier marked this pull request as ready for review October 21, 2024 16:58
@jlamypoirier jlamypoirier requested a review from tscholak October 21, 2024 16:59
Copy link
Copy Markdown
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great changes, thanks @jlamypoirier.
I have a few suggestions and questions for clarification. Can you please answer them?

# Distributed checkpoint for fast checkpointing and resuming.
distributed = "distributed"
# Model state dict, for safe long-term storage in Fast-LLM format.
state_dict = "state_dict"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we call this long_term instead?

Copy link
Copy Markdown
Collaborator Author

@jlamypoirier jlamypoirier Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with renaming (in a future PR), but I would prefer a name more related to Fast-LLM, since it's meant to be the "standard" Fast-LLM checkpoint format. (Maybe just "fast_llm"?)

Comment on lines +79 to +81
if default.get("format", None) == "huggingface":
warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.")
default["format"] = CheckpointFormat.external.value
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will there be other external checkpoint formats or why this change?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, in theory the conversion mechanism can be used for any kind of checkpoint format, I had renamed anything else already but not this one because of backward compatibility.

I'm thinking of getting rid of it altogether in the next PR though, and just use the "model_type" as the format for external formats.

@@ -0,0 +1,147 @@
# TODO: Use packaging.version? (Safer but extra requirement)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure where to ask this question:
where will we be configuring the layers-per-step option?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a hack for the single process converter, it doesn't apply here because we already have the whole model loaded. I'm not entirely sure about memory usage but it should be OK because we reconstruct one layer at the time.


class CheckpointFormat(str, enum.Enum):
# Distributed checkpoint for fast checkpointing and resuming.
distributed = "distributed"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we call this fast or short_term?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also an option, let's discuss it later.


# TODO: Simplify branching.
if checkpoint_config.format == CheckpointFormat.external:
# TODO: Support optimizer?
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you confirm that Fast-LLM will set the checkpoint_config.optimizer_state to False somewhere (or reject the config at launch) before we end up here? Would be frustrating to crash thousands of steps into training because of this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now there is no safety check, but that would be easy to add. Looking at it, right now the bigger problem is that the optimizer isn't saved by default so the bigger risk is not saving it when we want to...

# Intervals are a common pattern, so we standardize them with this base class.
interval: int | None = Field(
default=None,
desc="The number of training iterations between each interval. Setting to None will disable.",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the wrong place to do this, but can we make it so that there is at least a warning that checkpoint saving is disabled? I don't think people will appreciate training for hours (or days) only to find out that they forgot to set a saving interval ;)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an option, but would the warning actually be useful? A runtime warning is unlikely to be seen, and a validation one is only useful if the config is validated before launch.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I see. Do you have a better idea?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really....

interval = FieldUpdate(
desc="The number of training iterations between each Wandb status post (alert)."
" Setting to None will disable iteration-based wandb alerts."
" Must be a sub-interval of the logging interval."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean that wandb posting can only happen at logging times? That would be super-interval, no? can you clarify what you mean here? I'm confused.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. By sub-interval I meant posting iterations are a subset of logging ones, but that's probably not the right term.

class CheckpointBaseConfig(IntervalConfig):
_abstract = True
save_name: typing.ClassVar[str] = "save"
directory_name: typing.ClassVar[str] = "save"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for what kind of checkpoint will we have save as the name of the output directory?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None, that's a placeholder (I could just remove?)

save_name: typing.ClassVar[str] = "export"
directory_name = "export"
interval = FieldUpdate(
desc="The number of training iterations between each export." " Setting to None will disable exports."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this interval be incompatible with the checkpointing interval (i.e., not an exact multiple or divisor of the checkpointing interval)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it doesn't matter anymore because checkpoints and exports are completely independent.

interval = FieldUpdate(
desc="The number of training iterations between each automated shutdown."
" Setting to None will disable automated shutdowns."
" Must be a sub-interval of the checkpoint interval."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's an automated shutdown? I am not familiar with this feature? is this new?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's always been there but never used.

Copy link
Copy Markdown
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jlamypoirier for answering my comments. Since most changes I suggested will be future work, this can be merged as is.

@jlamypoirier jlamypoirier merged commit 6350dec into main Oct 21, 2024
@jlamypoirier jlamypoirier deleted the better_checkpoints branch October 21, 2024 19:34
@jlamypoirier jlamypoirier mentioned this pull request Oct 25, 2024
@tscholak tscholak added this to the 0.2.0 milestone Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants