From 458e65feb7bb60ed4598c188c83394715675cadd Mon Sep 17 00:00:00 2001 From: eisene Date: Wed, 12 Apr 2023 14:28:37 -0400 Subject: [PATCH 01/10] fix whitespace --- csrc/transformer/general_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu index a4193da94702..a987eec5ef0b 100644 --- a/csrc/transformer/general_kernels.cu +++ b/csrc/transformer/general_kernels.cu @@ -162,7 +162,7 @@ void launch_fused_add2(float* out, int total_count = batch_size * seq_length * hidden_dim / 4; dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); fused_add2_kernel<<>>(total_count, out, inp1, inp2); } @@ -179,7 +179,7 @@ void launch_fused_add2<__half>(__half* out, int total_count = batch_size * seq_length * hidden_dim / 4; dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); fused_add2_kernel<<>>(total_count, out, inp1, inp2); } From d25d202563aff81912113f6cc0b1b3ba53ac6cb5 Mon Sep 17 00:00:00 2001 From: eisene Date: Wed, 12 Apr 2023 16:54:08 -0400 Subject: [PATCH 02/10] make deepspeed.zero.Init() idempotent (microsoft#3202) --- deepspeed/runtime/zero/partition_parameters.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 84e628ef487c..ac827de03773 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -34,7 +34,7 @@ param_count = 0 partitioned_param_data_shape = [0] -zero_init_enabled = False +zero_init_enabled = 0 class NoGatherHandle: @@ -295,7 +295,9 @@ def __enter__(self): global zero_init_enabled if not self.enabled: return - zero_init_enabled = True + zero_init_enabled += 1 + if zero_init_enabled > 1: + return def apply_with_gather(orig_module_apply_fn: Callable) -> Callable: """many models make use of child modules like Linear or Embedding which @@ -397,6 +399,7 @@ def _enable_class(cls): cls.__init__ = partition_after(cls.__init__) def _init_subclass(cls, **kwargs): + cls._old_init = cls.__init__ cls.__init__ = partition_after(cls.__init__) # Replace .__init__() for all existing subclasses of torch.nn.Module recursively @@ -461,7 +464,8 @@ def _set_dtype(self, ds_config, dtype): def shutdown_init_context(): global zero_init_enabled - if not zero_init_enabled: + zero_init_enabled -= 1 + if not zero_init_enabled == 0: return def _disable_class(cls): From fa547c4d3b23de333a1c9634019e00ef23263776 Mon Sep 17 00:00:00 2001 From: eisene Date: Thu, 13 Apr 2023 17:01:39 -0400 Subject: [PATCH 03/10] do not print finished init message multiple times --- deepspeed/runtime/zero/partition_parameters.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index ac827de03773..050cb28ce5b7 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -433,9 +433,7 @@ def __exit__(self, exc_type, exc_value, traceback): if not self.enabled: return - shutdown_init_context() - - if dist.get_rank() == 0: + if shutdown_init_context() and dist.get_rank() == 0: logger.info("finished initializing model with %.2fB parameters", param_count / 1e9) # Now that we cleaned up the metaclass injection, raise the exception. @@ -466,7 +464,7 @@ def shutdown_init_context(): zero_init_enabled -= 1 if not zero_init_enabled == 0: - return + return False def _disable_class(cls): cls.__init__ = cls._old_init @@ -491,8 +489,7 @@ def _disable_class(cls): # if self.mem_efficient_linear: # torch.nn.functional.linear = self.linear_bk - zero_init_enabled = False - + return True class AllGatherHandle: From d3f1ff725c66a728060d2176f01cade7b91d5acc Mon Sep 17 00:00:00 2001 From: eisene Date: Thu, 13 Apr 2023 18:25:01 -0400 Subject: [PATCH 04/10] add tests for nested init case --- .../runtime/zero/test_zero_dynamic_class.py | 50 +++++++++++++++++++ .../runtime/zero/test_zero_nesting_init.py | 24 +++++++++ 2 files changed, 74 insertions(+) create mode 100644 tests/unit/runtime/zero/test_zero_dynamic_class.py create mode 100644 tests/unit/runtime/zero/test_zero_nesting_init.py diff --git a/tests/unit/runtime/zero/test_zero_dynamic_class.py b/tests/unit/runtime/zero/test_zero_dynamic_class.py new file mode 100644 index 000000000000..de7e3220f0a5 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_dynamic_class.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from unit.common import DistributedTest + +import deepspeed + + +class TestNewClassDeclaredInsideNestedInit(DistributedTest): + world_size = 1 + + def test_new_class_declared_inside_nested_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + + class MyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model = MyModel() + + # ensure that zero3 processed the parameter + assert hasattr(model.fc.weight, "ds_id") + + +class TestNewClassDeclaredInsideInit(DistributedTest): + world_size = 1 + + def test_new_class_declared_inside_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + + class MyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(1, 1) + + model = MyModel() + # ensure that zero3 processed the parameter + assert hasattr(model.fc.weight, "ds_id") diff --git a/tests/unit/runtime/zero/test_zero_nesting_init.py b/tests/unit/runtime/zero/test_zero_nesting_init.py new file mode 100644 index 000000000000..20389913c996 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_nesting_init.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from unit.common import DistributedTest + +import deepspeed + + +class TestNestingInit(DistributedTest): + world_size = 1 + + def test_nesting_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model = torch.nn.Linear(4, 4) + + # ensure that zero3 processed the parameter + assert hasattr(model.weight, "ds_id") From 01e7eb2b49fae58ba60eebec8edcd14814c6b500 Mon Sep 17 00:00:00 2001 From: eisene Date: Thu, 13 Apr 2023 19:23:08 -0400 Subject: [PATCH 05/10] fix formatting --- deepspeed/runtime/zero/partition_parameters.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 050cb28ce5b7..d2ef4009e829 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -491,6 +491,7 @@ def _disable_class(cls): return True + class AllGatherHandle: def __init__(self, handle, param: Parameter) -> None: From 8e2fcc0ad46ca2de2a6e50c37d2995af2bc843bd Mon Sep 17 00:00:00 2001 From: eisene Date: Fri, 14 Apr 2023 11:14:42 -0400 Subject: [PATCH 06/10] make shutdown_init_context also idempotent --- deepspeed/runtime/zero/partition_parameters.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index d2ef4009e829..cc7136aaa8e8 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -433,7 +433,10 @@ def __exit__(self, exc_type, exc_value, traceback): if not self.enabled: return - if shutdown_init_context() and dist.get_rank() == 0: + if not shutdown_init_context(): + return + + if dist.get_rank() == 0: logger.info("finished initializing model with %.2fB parameters", param_count / 1e9) # Now that we cleaned up the metaclass injection, raise the exception. @@ -463,6 +466,12 @@ def shutdown_init_context(): global zero_init_enabled zero_init_enabled -= 1 + if zero_init_enabled < 0: + # This can happen because deepspeed.initialize calls shutdown_init_context outside an Init() context. If the + # deepspeed.initialize call is wrapped in an Init() context to begin with, then when that context exits this + # method will be called again. This happens in the HF accelerate tests, for example. + zero_init_enabled = 0 + return False if not zero_init_enabled == 0: return False From 5ba5e0855a10f0ca81aea1167c8b536fca7bdbc5 Mon Sep 17 00:00:00 2001 From: eisene Date: Fri, 14 Apr 2023 11:25:05 -0400 Subject: [PATCH 07/10] fix formatting --- csrc/transformer/general_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu index a987eec5ef0b..a4193da94702 100644 --- a/csrc/transformer/general_kernels.cu +++ b/csrc/transformer/general_kernels.cu @@ -162,7 +162,7 @@ void launch_fused_add2(float* out, int total_count = batch_size * seq_length * hidden_dim / 4; dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); fused_add2_kernel<<>>(total_count, out, inp1, inp2); } @@ -179,7 +179,7 @@ void launch_fused_add2<__half>(__half* out, int total_count = batch_size * seq_length * hidden_dim / 4; dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); fused_add2_kernel<<>>(total_count, out, inp1, inp2); } From d0b6952c92e8f00251712631fdc3b0aedd19bdd9 Mon Sep 17 00:00:00 2001 From: eisene Date: Fri, 14 Apr 2023 11:34:44 -0400 Subject: [PATCH 08/10] add test for initialize call inside Init context --- tests/unit/runtime/zero/test_zero_nesting_init.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/unit/runtime/zero/test_zero_nesting_init.py b/tests/unit/runtime/zero/test_zero_nesting_init.py index 20389913c996..d8429087dc81 100644 --- a/tests/unit/runtime/zero/test_zero_nesting_init.py +++ b/tests/unit/runtime/zero/test_zero_nesting_init.py @@ -22,3 +22,14 @@ def test_nesting_init(self): # ensure that zero3 processed the parameter assert hasattr(model.weight, "ds_id") + + def test_initialize_inside_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + assert(deepspeed.zero.partition_parameters.zero_init_enabled == 1) + model = torch.nn.Linear(4, 4) + _, *_ = deepspeed.initialize(model=model, config_params=ds_config) + assert(deepspeed.zero.partition_parameters.zero_init_enabled == 0) + + assert(deepspeed.zero.partition_parameters.zero_init_enabled == 0) From f2fb40eb9941f63f2152ec13cf107149914b3203 Mon Sep 17 00:00:00 2001 From: eisene Date: Wed, 19 Apr 2023 14:45:59 -0400 Subject: [PATCH 09/10] fix formatting --- tests/unit/runtime/zero/test_zero_nesting_init.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/runtime/zero/test_zero_nesting_init.py b/tests/unit/runtime/zero/test_zero_nesting_init.py index d8429087dc81..f0b5494f1145 100644 --- a/tests/unit/runtime/zero/test_zero_nesting_init.py +++ b/tests/unit/runtime/zero/test_zero_nesting_init.py @@ -27,9 +27,9 @@ def test_initialize_inside_init(self): ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) with deepspeed.zero.Init(config_dict_or_path=ds_config): - assert(deepspeed.zero.partition_parameters.zero_init_enabled == 1) + assert (deepspeed.zero.partition_parameters.zero_init_enabled == 1) model = torch.nn.Linear(4, 4) _, *_ = deepspeed.initialize(model=model, config_params=ds_config) - assert(deepspeed.zero.partition_parameters.zero_init_enabled == 0) + assert (deepspeed.zero.partition_parameters.zero_init_enabled == 0) - assert(deepspeed.zero.partition_parameters.zero_init_enabled == 0) + assert (deepspeed.zero.partition_parameters.zero_init_enabled == 0) From 69546e013dc9f9ef6ddad5efb6ea7bc92b357df3 Mon Sep 17 00:00:00 2001 From: eisene Date: Thu, 4 May 2023 14:58:48 -0400 Subject: [PATCH 10/10] Do not partition DeepSpeedEngine with a decorator --- deepspeed/__init__.py | 3 -- deepspeed/runtime/engine.py | 4 +- .../runtime/zero/partition_parameters.py | 43 +++++++++++++------ .../runtime/zero/test_zero_nesting_init.py | 20 +++++++-- 4 files changed, 49 insertions(+), 21 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 255dacdccf6e..f4eb1625aab9 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -118,9 +118,6 @@ def initialize(args=None, __git_branch__), ranks=[0]) - # Disable zero.Init context if it's currently enabled - zero.partition_parameters.shutdown_init_context() - assert model is not None, "deepspeed.initialize requires a model" global dist diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 29223423d2f4..1b02a513b843 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -25,7 +25,7 @@ from deepspeed.runtime.utils import see_memory_usage, DummyOptim from .zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, NoPartitioningDecorator from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION @@ -177,7 +177,7 @@ def __init__(self, enable_micro_timers, enable_global_timers): STEP_GLOBAL_TIMER ] - +@NoPartitioningDecorator() class DeepSpeedEngine(Module): r"""DeepSpeed engine for training.""" diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 69def4133065..09313564a05d 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -34,7 +34,19 @@ param_count = 0 partitioned_param_data_shape = [0] -zero_init_enabled = 0 +_zero_init_nesting_depth = 0 +_cls_excluded_from_partitioning = set() + + +class NoPartitioningDecorator: + def __call__(self, cls): + global _cls_excluded_from_partitioning + _cls_excluded_from_partitioning.add(cls) + return cls + + +def is_zero_init_enabled(): + return _zero_init_nesting_depth > 0 class NoGatherHandle: @@ -292,11 +304,11 @@ def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtyp ], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" def __enter__(self): - global zero_init_enabled + global _zero_init_nesting_depth if not self.enabled: return - zero_init_enabled += 1 - if zero_init_enabled > 1: + _zero_init_nesting_depth += 1 + if _zero_init_nesting_depth > 1: return def apply_with_gather(orig_module_apply_fn: Callable) -> Callable: @@ -395,10 +407,14 @@ def wrapper(module, *args, **kwargs): return wrapper def _enable_class(cls): + if cls in _cls_excluded_from_partitioning: + return cls._old_init = cls.__init__ cls.__init__ = partition_after(cls.__init__) def _init_subclass(cls, **kwargs): + if cls in _cls_excluded_from_partitioning: + return cls._old_init = cls.__init__ cls.__init__ = partition_after(cls.__init__) @@ -463,19 +479,20 @@ def _set_dtype(self, ds_config, dtype): def shutdown_init_context(): - global zero_init_enabled - - zero_init_enabled -= 1 - if zero_init_enabled < 0: - # This can happen because deepspeed.initialize calls shutdown_init_context outside an Init() context. If the - # deepspeed.initialize call is wrapped in an Init() context to begin with, then when that context exits this - # method will be called again. This happens in the HF accelerate tests, for example. - zero_init_enabled = 0 + global _zero_init_nesting_depth + + _zero_init_nesting_depth -= 1 + if _zero_init_nesting_depth < 0: + # This can happen if someone calls shutdown_init_context() explicitly, and not as an exit from an Init() + # context. This used to happen in deepspeed.initialize() before NoPartitioningDecorator was implemented. + _zero_init_nesting_depth = 0 return False - if not zero_init_enabled == 0: + if not _zero_init_nesting_depth == 0: return False def _disable_class(cls): + if cls in _cls_excluded_from_partitioning: + return cls.__init__ = cls._old_init # Replace .__init__() for all existing subclasses of torch.nn.Module diff --git a/tests/unit/runtime/zero/test_zero_nesting_init.py b/tests/unit/runtime/zero/test_zero_nesting_init.py index f0b5494f1145..499bf1faeab1 100644 --- a/tests/unit/runtime/zero/test_zero_nesting_init.py +++ b/tests/unit/runtime/zero/test_zero_nesting_init.py @@ -27,9 +27,23 @@ def test_initialize_inside_init(self): ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) with deepspeed.zero.Init(config_dict_or_path=ds_config): - assert (deepspeed.zero.partition_parameters.zero_init_enabled == 1) + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 1) model = torch.nn.Linear(4, 4) _, *_ = deepspeed.initialize(model=model, config_params=ds_config) - assert (deepspeed.zero.partition_parameters.zero_init_enabled == 0) + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 1) - assert (deepspeed.zero.partition_parameters.zero_init_enabled == 0) + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 0) + + def test_initialize_inside_nested_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model = torch.nn.Linear(4, 4) + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 2) + _, *_ = deepspeed.initialize(model=model, config_params=ds_config) + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 2) + + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 1) + + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 0)