Skip to content

Commit f1f4e46

Browse files
authored
bug fix for sharding
2 parents d9e36d5 + caf6b5c commit f1f4e46

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def loss_func(loss, outputs):
9090
"pipeline_parallel": kwargs["args"].pipeline_parallel_degree > 1,
9191
"data_sharding_parallel": kwargs["args"].dataset_world_size > 1,
9292
"sharding": kwargs["args"].sharding,
93+
"sharding_mesh_dim": kwargs["args"].sharding_parallel_mesh_dimension,
9394
}
9495
auto_dist_config = model._generate_auto_dist_config(auto_dist_degree)
9596
self.auto_dist_config = auto_dist_config
@@ -164,7 +165,6 @@ def _wrap_for_auto(self, model, train_dataloader):
164165
if self.args.use_intermediate_api:
165166
assert self.auto_dist_config is not None
166167
self.optimizer = parallelize_optimizer(
167-
model,
168168
self.optimizer,
169169
dp_config=self.auto_dist_config["dp_config"],
170170
mp_config=self.auto_dist_config["mp_config"],

paddlenlp/transformers/model_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2826,7 +2826,10 @@ def _generate_auto_dist_config(self, auto_dist_degree):
28262826
level = 2
28272827
if ShardingOption.FULL_SHARD in sharding:
28282828
level = 3
2829-
final_config["dp_config"] = {"level": level}
2829+
final_config["dp_config"] = {
2830+
"sharding_level": level,
2831+
"sharding_mesh_dim": auto_dist_degree.get("sharding_mesh_dim", None),
2832+
}
28302833

28312834
return final_config
28322835

0 commit comments

Comments
 (0)