Skip to content

Commit dd740e6

Browse files
authored
feat: train embedding layers only (borisdayma#291)
1 parent 80171af commit dd740e6

File tree

2 files changed

+167
-89
lines changed

2 files changed

+167
-89
lines changed

src/dalle_mini/model/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGeneration):
13951395
def num_params(self, params=None):
13961396
if params is None:
13971397
params = self.params
1398-
num_params = jax.tree_map(
1398+
num_params = jax.tree_util.tree_map(
13991399
lambda param: param.size, flatten_dict(unfreeze(params))
14001400
).values()
14011401
return sum(list(num_params))

0 commit comments

Comments
 (0)