diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index 3a9f71be..9e37bbe1 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -1,5 +1,6 @@ # RUN: %PYTHON %s --sizes 512 1024 128 --dump-kernel=xegpu-wg | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# RUN: %PYTHON %s --dump-kernel=xegpu-wg --ab-type bf16 | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --transpose-a | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --transpose-b | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --bias | FileCheck %s @@ -78,6 +79,10 @@ class XeGPUMatMul: If `has_bias` is True, adds a bias term to the result. If `has_relu` is True, applies ReLU activation to the result (after bias if any). If `truncate_c` is True, truncates the C to A/B data type after accumulation. + + `ab_type` specifies the data type for A and B matrices (f16 or bf16). + `c_type` specifies the data type for the result C matrix (f32 by default, or ab_type if truncate_c is True). + `acc_type` specifies the data type for accumulation (f32 by default). """ payload_function_name: ClassVar[str] = "payload" @@ -112,8 +117,8 @@ def __post_init__(self): self.acc_type = ir.F32Type.get() if self.c_type is None: self.c_type = self.ab_type if self.truncate_c else self.acc_type - assert isinstance(self.ab_type, ir.F16Type), ( - "Only f16 type is supported for A and B" + assert isinstance(self.ab_type, (ir.F16Type, ir.BF16Type)), ( + "Only f16 and bf16 types are supported for A and B" ) assert isinstance(self.acc_type, ir.F32Type), "Only f32 type is supported for C" if self.truncate_c: @@ -247,6 +252,8 @@ def check_results( D_ref += bias.astype(f32) if mmul.has_relu: D_ref = np.maximum(D_ref, 0) + if mmul.truncate_c: + D_ref = D_ref.astype(mmul.ab_dtype) D_host = host_solution.astype(np.float32) if verbose > 1: @@ -261,6 +268,8 @@ def check_results( print("PASSED") else: print("FAILED Result mismatch!") + diff = np.abs(D_host - D_ref) + print(f"Max absolute difference: {np.max(diff)}") return success @@ -281,6 +290,13 @@ def cli_parser(description): default=[4096, 4096, 4096], help="M,N,K matrix sizes (A=MxK, B=KxN, C=MxN).", ) + parser.add_argument( + "--ab-type", + type=str, + choices=["f16", "bf16"], + default="f16", + help="Data type for A and B matrices.", + ) parser.add_argument( "--transpose-a", action="store_true", @@ -499,6 +515,7 @@ def parse_cli_args(description): M=params["m"], N=params["n"], K=params["k"], + ab_type=args.ab_type, transpose_a=params["transpose_a"], transpose_b=params["transpose_b"], has_bias=args.bias, diff --git a/examples/xegpu/mlp.py b/examples/xegpu/mlp.py index 2b7c7f25..3f908b38 100644 --- a/examples/xegpu/mlp.py +++ b/examples/xegpu/mlp.py @@ -1,5 +1,6 @@ # RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 | FileCheck %s +# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --ab-type bf16 | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --transpose-a | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --transpose-b | FileCheck %s # RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --relu | FileCheck %s @@ -133,8 +134,8 @@ def __post_init__(self): self.ab_type = ir.F16Type.get() if self.acc_type is None: self.acc_type = ir.F32Type.get() - assert isinstance(self.ab_type, ir.F16Type), ( - "Only f16 type is supported for A and B" + assert isinstance(self.ab_type, (ir.F16Type, ir.BF16Type)), ( + "Only f16 and bf16 types are supported for A and B" ) assert isinstance(self.acc_type, ir.F32Type), ( "Only f32 type is supported for accumulator" @@ -299,6 +300,13 @@ def parse_cli(): nargs="+", help="Number of features in each hidden layers.", ) + parser.add_argument( + "--ab-type", + type=str, + choices=["f16", "bf16"], + default="f16", + help="Data type for A and B matrices.", + ) parser.add_argument( "--nruns", type=int, @@ -397,6 +405,7 @@ def parse_cli(): input_size=args.input_size, output_size=args.output_size, hidden_layer_sizes=args.hidden_sizes, + ab_type=args.ab_type, has_bias=args.bias, has_relu=args.relu, transpose_a=tr_a,