diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 255dacdccf6e..f4eb1625aab9 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -118,9 +118,6 @@ def initialize(args=None, __git_branch__), ranks=[0]) - # Disable zero.Init context if it's currently enabled - zero.partition_parameters.shutdown_init_context() - assert model is not None, "deepspeed.initialize requires a model" global dist diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 29223423d2f4..1b02a513b843 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -25,7 +25,7 @@ from deepspeed.runtime.utils import see_memory_usage, DummyOptim from .zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, NoPartitioningDecorator from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION @@ -177,7 +177,7 @@ def __init__(self, enable_micro_timers, enable_global_timers): STEP_GLOBAL_TIMER ] - +@NoPartitioningDecorator() class DeepSpeedEngine(Module): r"""DeepSpeed engine for training.""" diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 431454db0cc9..09313564a05d 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -34,8 +34,19 @@ param_count = 0 partitioned_param_data_shape = [0] -zero_init_context = [] -all_wrapped_classes = set() +_zero_init_nesting_depth = 0 +_cls_excluded_from_partitioning = set() + + +class NoPartitioningDecorator: + def __call__(self, cls): + global _cls_excluded_from_partitioning + _cls_excluded_from_partitioning.add(cls) + return cls + + +def is_zero_init_enabled(): + return _zero_init_nesting_depth > 0 class NoGatherHandle: @@ -291,11 +302,14 @@ def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtyp assert self.dtype in [ torch.half, torch.bfloat16, torch.float ], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" - self.wrapped_cls = set() def __enter__(self): + global _zero_init_nesting_depth if not self.enabled: return + _zero_init_nesting_depth += 1 + if _zero_init_nesting_depth > 1: + return def apply_with_gather(orig_module_apply_fn: Callable) -> Callable: """many models make use of child modules like Linear or Embedding which @@ -393,65 +407,53 @@ def wrapper(module, *args, **kwargs): return wrapper def _enable_class(cls): + if cls in _cls_excluded_from_partitioning: + return cls._old_init = cls.__init__ cls.__init__ = partition_after(cls.__init__) def _init_subclass(cls, **kwargs): + if cls in _cls_excluded_from_partitioning: + return + cls._old_init = cls.__init__ cls.__init__ = partition_after(cls.__init__) # Replace .__init__() for all existing subclasses of torch.nn.Module recursively - global zero_init_context - self.nest_level = len(zero_init_context) - - global all_wrapped_classes for subclass in get_all_subclasses(torch.nn.modules.module.Module): - # Only wrap classes that haven't been wrapped yet - if subclass not in all_wrapped_classes: - _enable_class(subclass) - self.wrapped_cls.add(subclass) - - all_wrapped_classes = all_wrapped_classes.union(self.wrapped_cls) - - # Wrap some functions only at top level call of Init - if self.nest_level == 0: - # holding onto some methods so we can put them back the way they were in __exit__ - torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ - torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply - torch.Tensor.__old_new__ = torch.Tensor.__new__ - - # Replace .__init__() for future subclasses of torch.nn.Module - torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) - torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply) - - torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype) - torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype) - torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype) - torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype) - torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype) - - if self.mem_efficient_linear: - print_rank_0( - "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.", - force=False) - self.linear_bk = torch.nn.functional.linear - torch.nn.functional.linear = zero3_linear_wrap + # print(f"subclass={subclass.__module__}.{subclass.__qualname__}") + _enable_class(subclass) + + # holding onto some methods so we can put them back the way they were in __exit__ + torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ + torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply + torch.Tensor.__old_new__ = torch.Tensor.__new__ - self.torch_func_wrapped = True + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) + torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply) - zero_init_context.append(self) + torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype) + torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype) + torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype) + torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype) + torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype) + + if self.mem_efficient_linear: + print_rank_0( + "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.", + force=False) + self.linear_bk = torch.nn.functional.linear + torch.nn.functional.linear = zero3_linear_wrap def __exit__(self, exc_type, exc_value, traceback): if not self.enabled: return - self.remove_wrappers() + if not shutdown_init_context(): + return - # Exiting the top level context - global zero_init_context - zero_init_context.pop() - if self.nest_level == 0: - if dist.get_rank() == 0: - logger.info("finished initializing model with %.2fB parameters", param_count / 1e9) + if dist.get_rank() == 0: + logger.info("finished initializing model with %.2fB parameters", param_count / 1e9) # Now that we cleaned up the metaclass injection, raise the exception. if exc_type is not None: @@ -475,51 +477,45 @@ def _set_dtype(self, ds_config, dtype): else: self.dtype = dtype or torch.half - def remove_wrappers(self): - def _disable_class(cls): - cls.__init__ = cls._old_init - - for subclass in self.wrapped_cls: - _disable_class(subclass) - self.wrapped_cls.clear() +def shutdown_init_context(): + global _zero_init_nesting_depth - # This context is the top level of nested Init - if self.nest_level == 0 and self.torch_func_wrapped: - # putting methods back the way we found them - torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass - torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply + _zero_init_nesting_depth -= 1 + if _zero_init_nesting_depth < 0: + # This can happen if someone calls shutdown_init_context() explicitly, and not as an exit from an Init() + # context. This used to happen in deepspeed.initialize() before NoPartitioningDecorator was implemented. + _zero_init_nesting_depth = 0 + return False + if not _zero_init_nesting_depth == 0: + return False - torch.Tensor.__new__ = torch.Tensor.__old_new__ - torch.empty = _orig_torch_empty - torch.zeros = _orig_torch_zeros - torch.ones = _orig_torch_ones - torch.full = _orig_torch_full + def _disable_class(cls): + if cls in _cls_excluded_from_partitioning: + return + cls.__init__ = cls._old_init - # un doing it here will undo it during training - # if self.mem_efficient_linear: - # torch.nn.functional.linear = self.linear_bk - # if self.mem_efficient_linear: - # torch.nn.functional.linear = self.linear_bk + # Replace .__init__() for all existing subclasses of torch.nn.Module + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class(subclass) - self.torch_func_wrapped = False + # putting methods back the way we found them + torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass + torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply - global all_wrapped_classes - for subclass in get_all_subclasses(torch.nn.modules.module.Module): - if subclass not in all_wrapped_classes: - msg = f"`{subclass}' was not properly set up for sharding by zero.Init(). A subclass of torch.nn.Module must be defined before zero.Init() where an instance of the class is created." - raise RuntimeError(msg) - all_wrapped_classes.clear() + torch.Tensor.__new__ = torch.Tensor.__old_new__ + torch.empty = _orig_torch_empty + torch.zeros = _orig_torch_zeros + torch.ones = _orig_torch_ones + torch.full = _orig_torch_full + # un doing it here will undo it during training + # if self.mem_efficient_linear: + # torch.nn.functional.linear = self.linear_bk + # if self.mem_efficient_linear: + # torch.nn.functional.linear = self.linear_bk -def shutdown_init_context(): - """ - This function is used to initialize deepspeed engine inside the context of Init. - We need to remove the wrappers but keep the list of contexts. - """ - global zero_init_context - for ctx in zero_init_context: - ctx.remove_wrappers() + return True class AllGatherHandle: diff --git a/tests/unit/runtime/zero/test_zero_dynamic_class.py b/tests/unit/runtime/zero/test_zero_dynamic_class.py index bb57c87f84b8..de7e3220f0a5 100644 --- a/tests/unit/runtime/zero/test_zero_dynamic_class.py +++ b/tests/unit/runtime/zero/test_zero_dynamic_class.py @@ -10,10 +10,10 @@ import deepspeed -class TestNewClassDeclaredInsideInit(DistributedTest): +class TestNewClassDeclaredInsideNestedInit(DistributedTest): world_size = 1 - def test_new_class_declared_inside_init(self): + def test_new_class_declared_inside_nested_init(self): ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) with deepspeed.zero.Init(config_dict_or_path=ds_config): @@ -27,30 +27,24 @@ def __init__(self): with deepspeed.zero.Init(config_dict_or_path=ds_config): model = MyModel() - deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config) # ensure that zero3 processed the parameter - assert hasattr(deepspeed_engine.fc.weight, "ds_id") + assert hasattr(model.fc.weight, "ds_id") -class TestNewClassDeclaredInsideInitFailure(DistributedTest): +class TestNewClassDeclaredInsideInit(DistributedTest): world_size = 1 - def test_new_class_declared_inside_init_failure(self): + def test_new_class_declared_inside_init(self): ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) - try: - with deepspeed.zero.Init(config_dict_or_path=ds_config): - - class MyModel(torch.nn.Module): + with deepspeed.zero.Init(config_dict_or_path=ds_config): - def __init__(self): - super().__init__() - self.fc = torch.nn.Linear(1, 1) + class MyModel(torch.nn.Module): - model = MyModel() + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(1, 1) - assert False, "Should have failed. A subclass of torch.nn.Module must be defined before zero.Init() where an instance of the class is created." - except RuntimeError as e: - pass - except: - assert False, "Should have failed. Runtime error is expected." + model = MyModel() + # ensure that zero3 processed the parameter + assert hasattr(model.fc.weight, "ds_id") diff --git a/tests/unit/runtime/zero/test_zero_nesting_init.py b/tests/unit/runtime/zero/test_zero_nesting_init.py index 5b796511cb9c..499bf1faeab1 100644 --- a/tests/unit/runtime/zero/test_zero_nesting_init.py +++ b/tests/unit/runtime/zero/test_zero_nesting_init.py @@ -20,6 +20,30 @@ def test_nesting_init(self): with deepspeed.zero.Init(config_dict_or_path=ds_config): model = torch.nn.Linear(4, 4) - deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config) # ensure that zero3 processed the parameter - assert hasattr(deepspeed_engine.weight, "ds_id") + assert hasattr(model.weight, "ds_id") + + def test_initialize_inside_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 1) + model = torch.nn.Linear(4, 4) + _, *_ = deepspeed.initialize(model=model, config_params=ds_config) + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 1) + + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 0) + + def test_initialize_inside_nested_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + + with deepspeed.zero.Init(config_dict_or_path=ds_config): + with deepspeed.zero.Init(config_dict_or_path=ds_config): + model = torch.nn.Linear(4, 4) + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 2) + _, *_ = deepspeed.initialize(model=model, config_params=ds_config) + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 2) + + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 1) + + assert (deepspeed.zero.partition_parameters._zero_init_nesting_depth == 0)