@@ -393,6 +393,47 @@ def test_matmul_fp8_multiply_offload():
393393 tvm .testing .assert_allclose (out , ref , rtol = 1e-3 , atol = 1e-3 )
394394
395395
396+ @pytest .mark .skipif (ml_dtypes is None , reason = "requires ml_dtypes to be installed" )
397+ @pytest .mark .parametrize (
398+ "x_shape, y_shape, transpose_y, out_dtype" ,
399+ [
400+ ((10 , 32 ), (64 , 32 ), True , "float32" ),
401+ ((32 , 16 ), (32 , 16 ), True , "float32" ),
402+ ((2 , 10 , 32 ), (2 , 64 , 32 ), True , "float32" ),
403+ ],
404+ )
405+ def test_matmul_bfloat16_offload (
406+ x_shape ,
407+ y_shape ,
408+ transpose_y ,
409+ out_dtype ,
410+ ):
411+ in_dtype = "bfloat16"
412+ mod = get_relax_matmul_module (
413+ x_shape ,
414+ y_shape ,
415+ in_dtype ,
416+ out_dtype ,
417+ bias_shape = None ,
418+ transposed_y = transpose_y ,
419+ activation = None ,
420+ )
421+ # Generate input data in float32 and then convert to bfloat16 using ml_dtypes.
422+ x_float32 = np .random .uniform (low = 0 , high = 5 , size = x_shape ).astype ("float32" )
423+ y_float32 = np .random .uniform (low = 0 , high = 5 , size = y_shape ).astype ("float32" )
424+ x_bf16 = ml_dtypes .bfloat16 (x_float32 )
425+ y_bf16 = ml_dtypes .bfloat16 (y_float32 )
426+
427+ # For the reference result, adjust y (if needed) in float32.
428+ z = np .swapaxes (y_float32 , - 2 , - 1 ) if transpose_y else y_float32
429+ args = (x_bf16 , y_bf16 )
430+
431+ out = get_result_with_relax_cublas_offload (mod , args )
432+ ref_out = np .matmul (x_float32 , z ).astype (out_dtype )
433+
434+ tvm .testing .assert_allclose (out , ref_out , rtol = 1e-2 , atol = 1e-2 )
435+
436+
396437@pytest .mark .parametrize (
397438 "M, N, K, out_dtype, transposed_y, partition_done" ,
398439 [
0 commit comments