diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index 0dba0359f88b..280897e3617a 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -17,6 +17,7 @@ import gc from deepspeed.accelerator import get_accelerator import re +from .utils import transpose def load_model_with_checkpoint(r_module, @@ -42,14 +43,6 @@ def prefix_check(): skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix - def transpose(data): - with torch.no_grad(): - data = data.contiguous() - data1 = data.transpose(-1, -2).reshape(-1) - data.reshape(-1).copy_(data1) - data1 = None - return data.reshape(data.shape[-1], data.shape[-2]) - def load(module, prefix): args = (sd[0], prefix, {}, True, [], [], error_msgs) diff --git a/deepspeed/module_inject/policy.py b/deepspeed/module_inject/policy.py index 41df2b85dc0c..dff12b6c64b3 100644 --- a/deepspeed/module_inject/policy.py +++ b/deepspeed/module_inject/policy.py @@ -7,6 +7,7 @@ from deepspeed.utils.types import ActivationFuncType, NormType import torch from deepspeed.accelerator import get_accelerator +from .utils import transpose transformer_param_names = ( 'attn_qkvw', \ @@ -109,16 +110,6 @@ def layernorm(self): raise NotImplementedError -# TODO (lekurile): This function exists in base container as well, consolidate as some point -def transpose(data): - with torch.no_grad(): - data = data.contiguous() - data1 = data.transpose(-1, -2).reshape(-1) - data.reshape(-1).copy_(data1) - data1 = None - return data.reshape(data.shape[-1], data.shape[-2]) - - # TODO (lekurile): This function exists in megatron feature container as well, consolidate as some point def _transpose(x, heads=1, mp_replace=None): heads = heads // mp_replace.mp_size # type: ignore diff --git a/deepspeed/module_inject/utils.py b/deepspeed/module_inject/utils.py index 42822128f9e1..1837063ec63f 100644 --- a/deepspeed/module_inject/utils.py +++ b/deepspeed/module_inject/utils.py @@ -3,9 +3,19 @@ # DeepSpeed Team +import torch from deepspeed.utils import log_dist +def transpose(data): + with torch.no_grad(): + data = data.contiguous() + data1 = data.transpose(-1, -2).reshape(-1) + data.reshape(-1).copy_(data1) + data1 = None + return data.reshape(data.shape[-1], data.shape[-2]) + + # helper function to map between DS policies and DS containers def policy_to_ds_container(**kwargs): from .containers import HFGPT2LayerPolicy, DS_GPT2Container