Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 1 addition & 36 deletions MaxText/configs/llama2_70b_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,4 @@ logits_dot_in_fp32: False
per_device_batch_size: 6
max_target_length: 4096

mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
# Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
# The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
['activation_heads', ['tensor','sequence']],
['activation_kv_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
['activation_mlp', 'tensor'],
['activation_kv', 'tensor'],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]],
['activation_kv_head_dim', 'tensor'],
['activation_vocab', ['tensor', 'sequence']],
['activation_vocab', 'tensor'],
['activation_vocab', 'sequence'],
['activation_stage','stage'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['norm', 'fsdp'],
['heads', ['tensor', 'autoregressive']],
['layers', 'stage'],
['kv', []],
['kv_heads', ['tensor', 'autoregressive']],
['kv_head_dim', []],
['cache_batch', []],
['cache_heads', ['autoregressive', 'tensor']],
['cache_kv', []],
['cache_sequence', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]
logical_axis_rules: [['norm', 'fsdp']]
38 changes: 1 addition & 37 deletions MaxText/configs/llama2_7b_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,4 @@ logits_dot_in_fp32: False
per_device_batch_size: 4
max_target_length: 4096

mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
# For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
# Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
# The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
['activation_heads', ['tensor','sequence']],
['activation_kv_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
['activation_mlp', 'tensor'],
['activation_kv', 'tensor'],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]],
['activation_kv_head_dim', 'tensor'],
['activation_vocab', ['tensor', 'sequence']],
['activation_vocab', 'tensor'],
['activation_vocab', 'sequence'],
['activation_stage','stage'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['norm', 'fsdp'],
['heads', ['tensor', 'autoregressive']],
['layers', 'stage'],
['kv', []],
['kv_heads', ['tensor', 'autoregressive']],
['kv_head_dim', []],
['cache_batch', []],
['cache_heads', ['autoregressive', 'tensor']],
['cache_kv', []],
['cache_sequence', []],
]

# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]
logical_axis_rules: [['norm', 'fsdp']]
26 changes: 25 additions & 1 deletion MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,30 @@ def validate_megablox_parallelism(raw_keys):
using_pipeline_parallelism(raw_keys)):
raise ValueError("Currently we only support Megablox with data parallelism.")

def create_new_logical_axis_rules(old_logical_axis_rules, new_logical_axis_rules):
new_logical_axis = set()
replacements = []
for logical_axis, mesh_axes in new_logical_axis_rules:
logical_axis_exists = any(rule for rule in old_logical_axis_rules if rule[0] == logical_axis)
if not logical_axis_exists:
continue
replacements.append((logical_axis, mesh_axes))
new_logical_axis.add(logical_axis)
old_logical_rules_filtered = [(old_logical_axis, old_mesh_axes) for old_logical_axis, old_mesh_axes
in old_logical_axis_rules if old_logical_axis not in new_logical_axis]
return old_logical_rules_filtered + replacements


def update_model_keys(raw_keys, model_keys, key):
"""Update `key` value in `raw_keys` from the value in `model_keys`. """
assert key in model_keys and key in raw_keys
if key == 'logical_axis_rules':
Comment thread
golechwierowicz marked this conversation as resolved.
Outdated
raw_keys[key] = create_new_logical_axis_rules(
old_logical_axis_rules=raw_keys[key],
new_logical_axis_rules=model_keys[key])
return
raw_keys[key] = model_keys[key]

def validate_and_update_keys(raw_keys, model_keys, config_name: str):
"""Validate and update model specific config keys"""
max_logging.log("Updating following parameters in config\n")
Expand All @@ -395,7 +419,7 @@ def validate_and_update_keys(raw_keys, model_keys, config_name: str):
elif not isinstance(raw_keys[k], type(model_keys[k])):
raise ValueError(f"Type of key:{k} does not match with {type(model_keys[k])}")
else:
raw_keys[k] = model_keys[k]
update_model_keys(raw_keys, model_keys, k)
return raw_keys


Expand Down
88 changes: 88 additions & 0 deletions MaxText/tests/pyconfig_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""


import unittest
import pyconfig

class PyconfigTest(unittest.TestCase):
"""Tests for pyconfig.py"""

def test_basic_override(self):
raw_keys = {
'megablox': None,
'foo': ['bar', 'baz']
}
model_keys = {
'foo': ['x', 'y']
}

pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name='config')

self.assertEqual(raw_keys, {
'megablox': None,
'foo': ['x', 'y']
})

def test_logical_axis_override(self):
raw_keys = {
'megablox': None,
'foo': ['bar', 'baz'],
'logical_axis_rules': [
['activation', ['data', 'fsdp']],
['norm', 'tensor']
]
}
model_keys = {
'logical_axis_rules': [
['activation', ['data', 'fsdp_transpose']],
['norm', 'fsdp']
]
}

pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name='config')

self.assertEqual(raw_keys, {
'megablox': None,
'foo': ['bar', 'baz'],
'logical_axis_rules': [
('activation', ['data', 'fsdp_transpose']),
('norm', 'fsdp')
]
})

def test_logical_axis_partial_override(self):
raw_keys = {
'megablox': None,
'foo': ['bar', 'baz'],
'logical_axis_rules': [
['activation', ['data', 'fsdp']],
['norm', 'tensor']
]
}
model_keys = {
'logical_axis_rules': [
['norm', 'fsdp']
]
}

pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name='config')

self.assertEqual(raw_keys, {
'megablox': None,
'foo': ['bar', 'baz'],
'logical_axis_rules': [
('activation', ['data', 'fsdp']),
('norm', 'fsdp')
]
})