88
99@triton .jit
1010def act_quant_kernel (x_ptr , y_ptr , s_ptr , BLOCK_SIZE : tl .constexpr ):
11+ """
12+ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
13+
14+ Args:
15+ x_ptr (triton.Pointer): Pointer to the input tensor.
16+ y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
17+ s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
18+ BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
19+
20+ Returns:
21+ None
22+ """
1123 pid = tl .program_id (axis = 0 )
1224 offs = pid * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
1325 x = tl .load (x_ptr + offs ).to (tl .float32 )
@@ -19,6 +31,18 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
1931
2032
2133def act_quant (x : torch .Tensor , block_size : int = 128 ) -> Tuple [torch .Tensor , torch .Tensor ]:
34+ """
35+ Quantizes the input tensor `x` using block-wise quantization.
36+
37+ Args:
38+ x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
39+ block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
40+
41+ Returns:
42+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
43+ - The quantized tensor with dtype `torch.float8_e4m3fn`.
44+ - A tensor of scaling factors with dtype `torch.float32`.
45+ """
2246 assert x .is_contiguous ()
2347 assert x .size (- 1 ) % block_size == 0
2448 y = torch .empty_like (x , dtype = torch .float8_e4m3fn )
@@ -30,6 +54,20 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
3054
3155@triton .jit
3256def weight_dequant_kernel (x_ptr , s_ptr , y_ptr , M , N , BLOCK_SIZE : tl .constexpr ):
57+ """
58+ Dequantizes weights using the provided scaling factors and stores the result.
59+
60+ Args:
61+ x_ptr (tl.pointer): Pointer to the quantized weights.
62+ s_ptr (tl.pointer): Pointer to the scaling factors.
63+ y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
64+ M (int): Number of rows in the weight matrix.
65+ N (int): Number of columns in the weight matrix.
66+ BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
67+
68+ Returns:
69+ None
70+ """
3371 pid_m = tl .program_id (axis = 0 )
3472 pid_n = tl .program_id (axis = 1 )
3573 n = tl .cdiv (N , BLOCK_SIZE )
@@ -44,6 +82,20 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
4482
4583
4684def weight_dequant (x : torch .Tensor , s : torch .Tensor , block_size : int = 128 ) -> torch .Tensor :
85+ """
86+ Dequantizes the given weight tensor using the provided scale tensor.
87+
88+ Args:
89+ x (torch.Tensor): The quantized weight tensor of shape (M, N).
90+ s (torch.Tensor): The scale tensor of shape (M, N).
91+ block_size (int, optional): The block size to use for dequantization. Defaults to 128.
92+
93+ Returns:
94+ torch.Tensor: The dequantized weight tensor of the same shape as `x`.
95+
96+ Raises:
97+ AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
98+ """
4799 assert x .is_contiguous () and s .is_contiguous ()
48100 assert x .dim () == 2 and s .dim () == 2
49101 M , N = x .size ()
@@ -66,6 +118,25 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
66118 BLOCK_SIZE_M : tl .constexpr ,
67119 BLOCK_SIZE_N : tl .constexpr ,
68120 BLOCK_SIZE_K : tl .constexpr ):
121+ """
122+ Performs a matrix multiplication operation on FP8 matrices with scaling factors.
123+
124+ Args:
125+ a_ptr (tl.tensor): Pointer to the first input matrix A.
126+ b_ptr (tl.tensor): Pointer to the second input matrix B.
127+ c_ptr (tl.tensor): Pointer to the output matrix C.
128+ a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
129+ b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
130+ M (int): Number of rows in matrix A and C.
131+ N (tl.constexpr): Number of columns in matrix B and C.
132+ K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
133+ BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
134+ BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
135+ BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
136+
137+ Returns:
138+ None
139+ """
69140 pid_m = tl .program_id (axis = 0 )
70141 pid_n = tl .program_id (axis = 1 )
71142 k = tl .cdiv (K , BLOCK_SIZE_K )
@@ -97,6 +168,18 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
97168
98169
99170def fp8_gemm (a : torch .Tensor , a_s : torch .Tensor , b : torch .Tensor , b_s : torch .Tensor ):
171+ """
172+ Perform a matrix multiplication using FP8 precision.
173+
174+ Args:
175+ a (torch.Tensor): The first input matrix, must be contiguous.
176+ a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
177+ b (torch.Tensor): The second input matrix, must be contiguous.
178+ b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
179+
180+ Returns:
181+ torch.Tensor: The result of the matrix multiplication.
182+ """
100183 assert a .is_contiguous () and b .is_contiguous ()
101184 assert a_s .is_contiguous () and b_s .is_contiguous ()
102185 K = a .size (- 1 )
0 commit comments