🐞 Describe the Bug
Whenever the lr scale of any component is set to 0, e.g. model.base_model.transformer.mlp_lr_scale=0, test_checkpoint is failing with:
FAILED tests/test_checkpoint.py::test_load_pretrained_distributed_checkpoint - AssertionError: torch.Size([0]) != torch.Size([786432])
I wonder how critical is this for loading/saving checkpoints that were trained with lr scaling?
Maybe related to #256.
🔄 Steps to Reproduce
Steps to reproduce the behavior:
add e.g. model.base_model.transformer.mlp_lr_scale=0 here and run test_checkpoint.
Same is the case when lr is set to zero using per layer lr scale from #243 and #258 (yet in this case more than just one test in test_checkpoint fail.
Importantly, if the line self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) here is replaced with simple self.requires_grad = requires_grad the test passes.
🎯 Expected Behavior
Test passes.
📜 Environment Information
📝 Additional Context
🐞 Describe the Bug
Whenever the lr scale of any component is set to 0, e.g.
model.base_model.transformer.mlp_lr_scale=0,test_checkpointis failing with:FAILED tests/test_checkpoint.py::test_load_pretrained_distributed_checkpoint - AssertionError: torch.Size([0]) != torch.Size([786432])I wonder how critical is this for loading/saving checkpoints that were trained with lr scaling?
Maybe related to #256.
🔄 Steps to Reproduce
Steps to reproduce the behavior:
add e.g.
model.base_model.transformer.mlp_lr_scale=0here and runtest_checkpoint.Same is the case when lr is set to zero using per layer lr scale from #243 and #258 (yet in this case more than just one test in
test_checkpointfail.Importantly, if the line
self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale)here is replaced with simpleself.requires_grad = requires_gradthe test passes.🎯 Expected Behavior
Test passes.
📜 Environment Information
📝 Additional Context