Skip to content

Commit fdbd5be

Browse files
authored
Merge pull request deepseek-ai#193 from enochkan/main
Add docstrings to functions in inference modules for better clarity
2 parents fd011c1 + bc77f22 commit fdbd5be

File tree

6 files changed

+563
-1
lines changed

6 files changed

+563
-1
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,8 @@ cython_debug/
165165
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166166
# and can be added to the global gitignore or merged into this file. For a more nuclear
167167
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
168-
#.idea/
168+
#.idea/
169+
170+
.vscode/*
171+
172+
.DS_Store

inference/convert.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@
3131

3232

3333
def main(hf_ckpt_path, save_path, n_experts, mp):
34+
"""
35+
Converts and saves model checkpoint files into a specified format.
36+
37+
Args:
38+
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
39+
save_path (str): Path to the directory where the converted checkpoint files will be saved.
40+
n_experts (int): Total number of experts in the model.
41+
mp (int): Model parallelism factor.
42+
43+
Returns:
44+
None
45+
"""
3446
torch.set_num_threads(8)
3547
n_local_experts = n_experts // mp
3648
state_dicts = [{} for _ in range(mp)]

inference/fp8_cast_bf16.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,25 @@
1010
from kernel import weight_dequant
1111

1212
def main(fp8_path, bf16_path):
13+
"""
14+
Converts FP8 weights to BF16 and saves the converted weights.
15+
16+
This function reads FP8 weights from the specified directory, converts them to BF16,
17+
and saves the converted weights to another specified directory. It also updates the
18+
model index file to reflect the changes.
19+
20+
Args:
21+
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
22+
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
23+
24+
Raises:
25+
KeyError: If a required scale_inv tensor is missing for a weight.
26+
27+
Notes:
28+
- The function assumes that the FP8 weights are stored in safetensor files.
29+
- The function caches loaded safetensor files to optimize memory usage.
30+
- The function updates the model index file to remove references to scale_inv tensors.
31+
"""
1332
torch.set_default_dtype(torch.bfloat16)
1433
os.makedirs(bf16_path, exist_ok=True)
1534
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
@@ -23,6 +42,18 @@ def main(fp8_path, bf16_path):
2342

2443
# Helper function to get tensor from the correct file
2544
def get_tensor(tensor_name):
45+
"""
46+
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
47+
48+
Args:
49+
tensor_name (str): The name of the tensor to retrieve.
50+
51+
Returns:
52+
torch.Tensor: The retrieved tensor.
53+
54+
Raises:
55+
KeyError: If the tensor does not exist in the safetensor file.
56+
"""
2657
file_name = weight_map[tensor_name]
2758
if file_name not in loaded_files:
2859
file_path = os.path.join(fp8_path, file_name)

inference/generate.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212

1313

1414
def sample(logits, temperature: float = 1.0):
15+
"""
16+
Samples a token from the logits using temperature scaling.
17+
18+
Args:
19+
logits (torch.Tensor): The logits tensor for token predictions.
20+
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
21+
22+
Returns:
23+
torch.Tensor: The sampled token.
24+
"""
1525
logits = logits / max(temperature, 1e-5)
1626
probs = torch.softmax(logits, dim=-1)
1727
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@@ -25,6 +35,19 @@ def generate(
2535
eos_id: int,
2636
temperature: float = 1.0
2737
) -> List[List[int]]:
38+
"""
39+
Generates new tokens based on the given prompt tokens using the specified model.
40+
41+
Args:
42+
model (Transformer): The transformer model used for token generation.
43+
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
44+
max_new_tokens (int): The maximum number of new tokens to generate.
45+
eos_id (int): The end-of-sequence token ID.
46+
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
47+
48+
Returns:
49+
List[List[int]]: A list of lists containing the generated tokens for each sequence.
50+
"""
2851
prompt_lens = [len(t) for t in prompt_tokens]
2952
assert max(prompt_lens) <= model.max_seq_len
3053
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
@@ -63,6 +86,17 @@ def main(
6386
max_new_tokens: int = 100,
6487
temperature: float = 1.0,
6588
) -> None:
89+
"""
90+
Main function to load the model and perform interactive or batch text generation.
91+
92+
Args:
93+
ckpt_path (str): Path to the model checkpoint directory.
94+
config (str): Path to the model configuration file.
95+
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
96+
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
97+
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
98+
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
99+
"""
66100
world_size = int(os.getenv("WORLD_SIZE", "1"))
67101
rank = int(os.getenv("RANK", "0"))
68102
local_rank = int(os.getenv("LOCAL_RANK", "0"))
@@ -125,6 +159,20 @@ def main(
125159

126160

127161
if __name__ == "__main__":
162+
"""
163+
Command-line interface for distributed text generation.
164+
165+
Arguments:
166+
--ckpt-path (str): Path to the model checkpoint directory.
167+
--config (str): Path to the model configuration file.
168+
--input-file (str, optional): File containing prompts for batch processing.
169+
--interactive (bool, optional): Enable interactive mode for generating text.
170+
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
171+
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
172+
173+
Raises:
174+
AssertionError: If neither input-file nor interactive mode is specified.
175+
"""
128176
parser = ArgumentParser()
129177
parser.add_argument("--ckpt-path", type=str, required=True)
130178
parser.add_argument("--config", type=str, required=True)

inference/kernel.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88

99
@triton.jit
1010
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
11+
"""
12+
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
13+
14+
Args:
15+
x_ptr (triton.Pointer): Pointer to the input tensor.
16+
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
17+
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
18+
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
19+
20+
Returns:
21+
None
22+
"""
1123
pid = tl.program_id(axis=0)
1224
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1325
x = tl.load(x_ptr + offs).to(tl.float32)
@@ -19,6 +31,18 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
1931

2032

2133
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
34+
"""
35+
Quantizes the input tensor `x` using block-wise quantization.
36+
37+
Args:
38+
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
39+
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
40+
41+
Returns:
42+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
43+
- The quantized tensor with dtype `torch.float8_e4m3fn`.
44+
- A tensor of scaling factors with dtype `torch.float32`.
45+
"""
2246
assert x.is_contiguous()
2347
assert x.size(-1) % block_size == 0
2448
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
@@ -30,6 +54,20 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
3054

3155
@triton.jit
3256
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
57+
"""
58+
Dequantizes weights using the provided scaling factors and stores the result.
59+
60+
Args:
61+
x_ptr (tl.pointer): Pointer to the quantized weights.
62+
s_ptr (tl.pointer): Pointer to the scaling factors.
63+
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
64+
M (int): Number of rows in the weight matrix.
65+
N (int): Number of columns in the weight matrix.
66+
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
67+
68+
Returns:
69+
None
70+
"""
3371
pid_m = tl.program_id(axis=0)
3472
pid_n = tl.program_id(axis=1)
3573
n = tl.cdiv(N, BLOCK_SIZE)
@@ -44,6 +82,20 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
4482

4583

4684
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
85+
"""
86+
Dequantizes the given weight tensor using the provided scale tensor.
87+
88+
Args:
89+
x (torch.Tensor): The quantized weight tensor of shape (M, N).
90+
s (torch.Tensor): The scale tensor of shape (M, N).
91+
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
92+
93+
Returns:
94+
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
95+
96+
Raises:
97+
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
98+
"""
4799
assert x.is_contiguous() and s.is_contiguous()
48100
assert x.dim() == 2 and s.dim() == 2
49101
M, N = x.size()
@@ -66,6 +118,25 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
66118
BLOCK_SIZE_M: tl.constexpr,
67119
BLOCK_SIZE_N: tl.constexpr,
68120
BLOCK_SIZE_K: tl.constexpr):
121+
"""
122+
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
123+
124+
Args:
125+
a_ptr (tl.tensor): Pointer to the first input matrix A.
126+
b_ptr (tl.tensor): Pointer to the second input matrix B.
127+
c_ptr (tl.tensor): Pointer to the output matrix C.
128+
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
129+
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
130+
M (int): Number of rows in matrix A and C.
131+
N (tl.constexpr): Number of columns in matrix B and C.
132+
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
133+
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
134+
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
135+
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
136+
137+
Returns:
138+
None
139+
"""
69140
pid_m = tl.program_id(axis=0)
70141
pid_n = tl.program_id(axis=1)
71142
k = tl.cdiv(K, BLOCK_SIZE_K)
@@ -97,6 +168,18 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
97168

98169

99170
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
171+
"""
172+
Perform a matrix multiplication using FP8 precision.
173+
174+
Args:
175+
a (torch.Tensor): The first input matrix, must be contiguous.
176+
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
177+
b (torch.Tensor): The second input matrix, must be contiguous.
178+
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
179+
180+
Returns:
181+
torch.Tensor: The result of the matrix multiplication.
182+
"""
100183
assert a.is_contiguous() and b.is_contiguous()
101184
assert a_s.is_contiguous() and b_s.is_contiguous()
102185
K = a.size(-1)

0 commit comments

Comments
 (0)