From 818cd062d618932c300162c82991c8f586fb4137 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Mon, 1 Jul 2024 16:07:42 +0200 Subject: [PATCH] Support partial overrides for logical_axis_rules. --- MaxText/configs/llama2_70b_gpu.yml | 37 +------------ MaxText/configs/llama2_7b_gpu.yml | 38 +------------ MaxText/pyconfig.py | 26 ++++++++- MaxText/tests/pyconfig_test.py | 88 ++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 74 deletions(-) create mode 100644 MaxText/tests/pyconfig_test.py diff --git a/MaxText/configs/llama2_70b_gpu.yml b/MaxText/configs/llama2_70b_gpu.yml index a19ef339cb..c70f50b25f 100644 --- a/MaxText/configs/llama2_70b_gpu.yml +++ b/MaxText/configs/llama2_70b_gpu.yml @@ -17,39 +17,4 @@ logits_dot_in_fp32: False per_device_batch_size: 6 max_target_length: 4096 -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'] -logical_axis_rules: [ - ['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]], - # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages. - # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape. - # The "stage" needs to be listed first since the microbatch dimension is first before the reshape. - ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']], - ['activation_heads', ['tensor','sequence']], - ['activation_kv_heads', ['tensor','sequence']], - ['activation_length', 'sequence'], - ['activation_embed', 'tensor'], - ['activation_mlp', 'tensor'], - ['activation_kv', 'tensor'], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]], - ['activation_kv_head_dim', 'tensor'], - ['activation_vocab', ['tensor', 'sequence']], - ['activation_vocab', 'tensor'], - ['activation_vocab', 'sequence'], - ['activation_stage','stage'], - ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']], - ['vocab', ['tensor', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence']], - ['embed', ['fsdp', 'sequence']], - ['norm', 'fsdp'], - ['heads', ['tensor', 'autoregressive']], - ['layers', 'stage'], - ['kv', []], - ['kv_heads', ['tensor', 'autoregressive']], - ['kv_head_dim', []], - ['cache_batch', []], - ['cache_heads', ['autoregressive', 'tensor']], - ['cache_kv', []], - ['cache_sequence', []], - ] -# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']] \ No newline at end of file +logical_axis_rules: [['norm', 'fsdp']] diff --git a/MaxText/configs/llama2_7b_gpu.yml b/MaxText/configs/llama2_7b_gpu.yml index fcff05e9f8..571652bb39 100644 --- a/MaxText/configs/llama2_7b_gpu.yml +++ b/MaxText/configs/llama2_7b_gpu.yml @@ -18,40 +18,4 @@ logits_dot_in_fp32: False per_device_batch_size: 4 max_target_length: 4096 -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'] -logical_axis_rules: [ - ['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]], - # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages. - # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape. - # The "stage" needs to be listed first since the microbatch dimension is first before the reshape. - ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']], - ['activation_heads', ['tensor','sequence']], - ['activation_kv_heads', ['tensor','sequence']], - ['activation_length', 'sequence'], - ['activation_embed', 'tensor'], - ['activation_mlp', 'tensor'], - ['activation_kv', 'tensor'], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]], - ['activation_kv_head_dim', 'tensor'], - ['activation_vocab', ['tensor', 'sequence']], - ['activation_vocab', 'tensor'], - ['activation_vocab', 'sequence'], - ['activation_stage','stage'], - ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']], - ['vocab', ['tensor', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence']], - ['embed', ['fsdp', 'sequence']], - ['norm', 'fsdp'], - ['heads', ['tensor', 'autoregressive']], - ['layers', 'stage'], - ['kv', []], - ['kv_heads', ['tensor', 'autoregressive']], - ['kv_head_dim', []], - ['cache_batch', []], - ['cache_heads', ['autoregressive', 'tensor']], - ['cache_kv', []], - ['cache_sequence', []], - ] - -# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']] \ No newline at end of file +logical_axis_rules: [['norm', 'fsdp']] diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 15b50e1281..ec463c171f 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -381,6 +381,30 @@ def validate_megablox_parallelism(raw_keys): using_pipeline_parallelism(raw_keys)): raise ValueError("Currently we only support Megablox with data parallelism.") +def create_new_logical_axis_rules(old_logical_axis_rules, new_logical_axis_rules): + new_logical_axis = set() + replacements = [] + for logical_axis, mesh_axes in new_logical_axis_rules: + logical_axis_exists = any(rule for rule in old_logical_axis_rules if rule[0] == logical_axis) + if not logical_axis_exists: + continue + replacements.append((logical_axis, mesh_axes)) + new_logical_axis.add(logical_axis) + old_logical_rules_filtered = [(old_logical_axis, old_mesh_axes) for old_logical_axis, old_mesh_axes + in old_logical_axis_rules if old_logical_axis not in new_logical_axis] + return old_logical_rules_filtered + replacements + + +def update_model_keys(raw_keys, model_keys, key): + """Update `key` value in `raw_keys` from the value in `model_keys`. """ + assert key in model_keys and key in raw_keys + if key == 'logical_axis_rules': + raw_keys[key] = create_new_logical_axis_rules( + old_logical_axis_rules=raw_keys[key], + new_logical_axis_rules=model_keys[key]) + return + raw_keys[key] = model_keys[key] + def validate_and_update_keys(raw_keys, model_keys, config_name: str): """Validate and update model specific config keys""" max_logging.log("Updating following parameters in config\n") @@ -395,7 +419,7 @@ def validate_and_update_keys(raw_keys, model_keys, config_name: str): elif not isinstance(raw_keys[k], type(model_keys[k])): raise ValueError(f"Type of key:{k} does not match with {type(model_keys[k])}") else: - raw_keys[k] = model_keys[k] + update_model_keys(raw_keys, model_keys, k) return raw_keys diff --git a/MaxText/tests/pyconfig_test.py b/MaxText/tests/pyconfig_test.py new file mode 100644 index 0000000000..212a7f6e59 --- /dev/null +++ b/MaxText/tests/pyconfig_test.py @@ -0,0 +1,88 @@ +""" +Copyright 2024 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +import unittest +import pyconfig + +class PyconfigTest(unittest.TestCase): + """Tests for pyconfig.py""" + + def test_basic_override(self): + raw_keys = { + 'megablox': None, + 'foo': ['bar', 'baz'] + } + model_keys = { + 'foo': ['x', 'y'] + } + + pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name='config') + + self.assertEqual(raw_keys, { + 'megablox': None, + 'foo': ['x', 'y'] + }) + + def test_logical_axis_override(self): + raw_keys = { + 'megablox': None, + 'foo': ['bar', 'baz'], + 'logical_axis_rules': [ + ['activation', ['data', 'fsdp']], + ['norm', 'tensor'] + ] + } + model_keys = { + 'logical_axis_rules': [ + ['activation', ['data', 'fsdp_transpose']], + ['norm', 'fsdp'] + ] + } + + pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name='config') + + self.assertEqual(raw_keys, { + 'megablox': None, + 'foo': ['bar', 'baz'], + 'logical_axis_rules': [ + ('activation', ['data', 'fsdp_transpose']), + ('norm', 'fsdp') + ] + }) + + def test_logical_axis_partial_override(self): + raw_keys = { + 'megablox': None, + 'foo': ['bar', 'baz'], + 'logical_axis_rules': [ + ['activation', ['data', 'fsdp']], + ['norm', 'tensor'] + ] + } + model_keys = { + 'logical_axis_rules': [ + ['norm', 'fsdp'] + ] + } + + pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name='config') + + self.assertEqual(raw_keys, { + 'megablox': None, + 'foo': ['bar', 'baz'], + 'logical_axis_rules': [ + ('activation', ['data', 'fsdp']), + ('norm', 'fsdp') + ] + })