Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,14 @@
)
},
)
expert_parallel_degree: int = field(
default=-1,
metadata={"help": ("The paddle expert data parallel strategy.")},
)
expert_tensor_parallel_degree: int = field(
default=-1,
metadata={"help": ("The paddle expert tensor parallel strategy. Currently is not supported. DO NOT SET.")},
)
data_parallel_config: str = field(
default="",
metadata={
Expand Down Expand Up @@ -1123,6 +1131,13 @@
sep_parallel_degree = max(self.sep_parallel_degree, 1)
context_parallel_degree = max(self.context_parallel_degree, 1)
pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1)
expert_parallel_degree = max(self.expert_parallel_degree, 1)
expert_tensor_parallel_degree = max(self.expert_tensor_parallel_degree, 1)

Check warning on line 1135 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1134-L1135

Added lines #L1134 - L1135 were not covered by tests

# TODO(@gexiao): support expert_tensor_parallel_degree > 1 in the future
assert (

Check warning on line 1138 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1138

Added line #L1138 was not covered by tests
expert_tensor_parallel_degree == 1
), f"Currently only support expert_tensor_parallel_degree=1, but got expert_tensor_parallel_degree of {expert_tensor_parallel_degree}"

assert (
world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0
Expand All @@ -1146,6 +1161,11 @@
logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!")
self.sharding = []

if sharding_parallel_degree > 1:
assert (

Check warning on line 1165 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1164-L1165

Added lines #L1164 - L1165 were not covered by tests
sharding_parallel_degree % expert_parallel_degree == 0
), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}."

self.data_parallel_degree = world_size // (
sharding_parallel_degree
* tensor_parallel_degree
Expand All @@ -1154,19 +1174,27 @@
* pipeline_parallel_degree
)

assert not (

Check warning on line 1177 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1177

Added line #L1177 was not covered by tests
self.data_parallel_degree > 1 and expert_parallel_degree > 1
), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is {self.data_parallel_degree}."

if (
sharding_parallel_degree > 1
or tensor_parallel_degree > 1
or pipeline_parallel_degree > 1
or self.sep_parallel_degree > 1
or self.context_parallel_degree > 1
or expert_parallel_degree > 1
or expert_tensor_parallel_degree > 1
):
self.use_hybrid_parallel = True
self.sharding_parallel_degree = sharding_parallel_degree
self.tensor_parallel_degree = tensor_parallel_degree
self.pipeline_parallel_degree = pipeline_parallel_degree
self.sep_parallel_degree = sep_parallel_degree
self.context_parallel_degree = context_parallel_degree
self.expert_parallel_degree = expert_parallel_degree
self.expert_tensor_parallel_degree = expert_tensor_parallel_degree

Check warning on line 1197 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1196-L1197

Added lines #L1196 - L1197 were not covered by tests

if not self.use_hybrid_parallel:
self.sharding = []
Expand All @@ -1175,6 +1203,8 @@
self.pipeline_parallel_degree = -1
self.sep_parallel_degree = -1
self.context_parallel_degree = -1
self.expert_parallel_degree = -1
self.expert_tensor_parallel_degree = -1

Check warning on line 1207 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1206-L1207

Added lines #L1206 - L1207 were not covered by tests

if self.hybrid_parallel_topo_order is None:
self.hybrid_parallel_topo_order = "pp_first"
Expand Down Expand Up @@ -1530,6 +1560,9 @@
fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)

if self.expert_parallel_degree > 1:
self.add_moe_comm_group()

Check warning on line 1564 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1563-L1564

Added lines #L1563 - L1564 were not covered by tests

elif self.enable_auto_parallel:
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
self.sep_parallel_degree = max(self.sep_parallel_degree, 1)
Expand Down Expand Up @@ -1738,6 +1771,10 @@
order = ["pp", "dp", "sharding", "sep", "mp"]
elif self.hybrid_parallel_topo_order == "sharding_first":
order = ["dp", "sharding", "pp", "sep", "mp"]
if self.expert_parallel_degree > 1:
logger.warning(

Check warning on line 1775 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1774-L1775

Added lines #L1774 - L1775 were not covered by tests
"Currently using sharding_first topo order, but pp_first is recommended when using experts parallel for performance."
)

strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
Expand Down Expand Up @@ -1877,6 +1914,37 @@
)
self.pdc_download_ckpt = False

def add_moe_comm_group(self):
hcg = fleet.get_hybrid_communicate_group()
topo = hcg._topo
sharding_parallel_groups = topo.get_comm_list("sharding")
experts_replicas = self.sharding_parallel_degree // self.expert_parallel_degree

Check warning on line 1921 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1918-L1921

Added lines #L1918 - L1921 were not covered by tests

# init experts groups inside all sharding groups
for ranks_in_current_sharding_group in sharding_parallel_groups:

Check warning on line 1924 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1924

Added line #L1924 was not covered by tests
# init experts parallel groups (dispatch & combine)
for i in range(experts_replicas):
rank_indices = list(range(i * self.expert_parallel_degree, (i + 1) * self.expert_parallel_degree))
ranks = [ranks_in_current_sharding_group[i] for i in rank_indices]
group = dist.new_group(ranks=ranks)
if dist.get_rank() in ranks:
assert not hasattr(hcg, "expert_parallel_group"), "expert_parallel_group can not be set repeate"
setattr(hcg, "expert_parallel_group", group)

Check warning on line 1932 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1926-L1932

Added lines #L1926 - L1932 were not covered by tests

# init experts gradients comm groups
for i in range(self.expert_parallel_degree):
rank_indices = list(range(i, self.sharding_parallel_degree, self.expert_parallel_degree))
ranks = [ranks_in_current_sharding_group[i] for i in rank_indices]
group = dist.new_group(ranks=ranks)
if dist.get_rank() in ranks:
assert not hasattr(hcg, "expert_grad_comm_group"), "expert_grad_comm_group can not be set repeate"
setattr(hcg, "expert_grad_comm_group", group)

Check warning on line 1941 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1935-L1941

Added lines #L1935 - L1941 were not covered by tests

assert hasattr(hcg, "expert_parallel_group") and hasattr(hcg, "expert_grad_comm_group")
logger.info(

Check warning on line 1944 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1943-L1944

Added lines #L1943 - L1944 were not covered by tests
f"experts groups are created, expert_parallel_group: {hcg.expert_parallel_group}, expert_grad_comm_group: {hcg.expert_grad_comm_group}"
)

def __str__(self):
self_as_dict = asdict(self)
self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()}
Expand Down
Loading