forked from stanford-cs336/lectures
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlecture_08_utils.py
More file actions
57 lines (46 loc) · 2 KB
/
lecture_08_utils.py
File metadata and controls
57 lines (46 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch.distributed as dist
from inspect import isfunction
from typing import Callable
import sys
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import math
from torch_util import get_device
class DisableDistributed:
"""Context manager that temporarily disables distributed functions (replaces with no-ops)"""
def __enter__(self):
self.old_functions = {}
for name in dir(dist):
value = getattr(dist, name, None)
if isfunction(value):
self.old_functions[name] = value
setattr(dist, name, lambda *args, **kwargs: None)
def __exit__(self, exc_type, exc_value, traceback):
for name in self.old_functions:
setattr(dist, name, self.old_functions[name])
def spawn(func: Callable, world_size: int, *args, **kwargs):
# Note: assume kwargs are in the same order as what main needs
if sys.gettrace():
# If we're being traced, run the function directly, since we can't trace through mp.spawn
with DisableDistributed():
args = (0, world_size,) + args + tuple(kwargs.values())
func(*args)
else:
args = (world_size,) + args + tuple(kwargs.values())
mp.spawn(func, args=args, nprocs=world_size, join=True)
def int_divide(a: int, b: int):
"""Return a / b and throw an error if there's a remainder."""
assert a % b == 0
return a // b
def summarize_tensor(tensor: torch.Tensor) -> str:
return "x".join(map(str, tensor.shape)) + "[" + str(round(tensor.view(-1)[0].item(), 4)) + "...]"
def get_init_params(num_inputs: int, num_outputs: int, rank: int) -> nn.Parameter:
torch.random.manual_seed(0) # For reproducibility
return nn.Parameter(torch.randn(num_inputs, num_outputs, device=get_device(rank)) / math.sqrt(num_outputs))
def render_duration(duration: float) -> str:
if duration < 1e-3:
return f"{duration * 1e6:.2f}us"
if duration < 1:
return f"{duration * 1e3:.2f}ms"
return f"{duration:.2f}s"