Skip to content

[bug] test_checkpoint test not passing when any lr scale is set to 0 #265

@oleksost

Description

@oleksost

🐞 Describe the Bug

Whenever the lr scale of any component is set to 0, e.g. model.base_model.transformer.mlp_lr_scale=0, test_checkpoint is failing with:

FAILED tests/test_checkpoint.py::test_load_pretrained_distributed_checkpoint - AssertionError: torch.Size([0]) != torch.Size([786432])

I wonder how critical is this for loading/saving checkpoints that were trained with lr scaling?
Maybe related to #256.

🔄 Steps to Reproduce

Steps to reproduce the behavior:
add e.g. model.base_model.transformer.mlp_lr_scale=0 here and run test_checkpoint.

Same is the case when lr is set to zero using per layer lr scale from #243 and #258 (yet in this case more than just one test in test_checkpoint fail.
Importantly, if the line self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) here is replaced with simple self.requires_grad = requires_grad the test passes.

🎯 Expected Behavior

Test passes.

📜 Environment Information

📝 Additional Context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions