Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions deepspeed/module_inject/load_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import gc
from deepspeed.accelerator import get_accelerator
import re
from .utils import transpose


def load_model_with_checkpoint(r_module,
Expand All @@ -42,14 +43,6 @@ def prefix_check():

skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix

def transpose(data):
with torch.no_grad():
data = data.contiguous()
data1 = data.transpose(-1, -2).reshape(-1)
data.reshape(-1).copy_(data1)
data1 = None
return data.reshape(data.shape[-1], data.shape[-2])

def load(module, prefix):
args = (sd[0], prefix, {}, True, [], [], error_msgs)

Expand Down
11 changes: 1 addition & 10 deletions deepspeed/module_inject/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from deepspeed.utils.types import ActivationFuncType, NormType
import torch
from deepspeed.accelerator import get_accelerator
from .utils import transpose

transformer_param_names = (
'attn_qkvw', \
Expand Down Expand Up @@ -109,16 +110,6 @@ def layernorm(self):
raise NotImplementedError


# TODO (lekurile): This function exists in base container as well, consolidate as some point
def transpose(data):
with torch.no_grad():
data = data.contiguous()
data1 = data.transpose(-1, -2).reshape(-1)
data.reshape(-1).copy_(data1)
data1 = None
return data.reshape(data.shape[-1], data.shape[-2])


# TODO (lekurile): This function exists in megatron feature container as well, consolidate as some point
def _transpose(x, heads=1, mp_replace=None):
heads = heads // mp_replace.mp_size # type: ignore
Expand Down
10 changes: 10 additions & 0 deletions deepspeed/module_inject/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@

# DeepSpeed Team

import torch
from deepspeed.utils import log_dist


def transpose(data):
with torch.no_grad():
data = data.contiguous()
data1 = data.transpose(-1, -2).reshape(-1)
data.reshape(-1).copy_(data1)
data1 = None
return data.reshape(data.shape[-1], data.shape[-2])


# helper function to map between DS policies and DS containers
def policy_to_ds_container(**kwargs):
from .containers import HFGPT2LayerPolicy, DS_GPT2Container
Expand Down
Loading