Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions examples/xegpu/matmul.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# RUN: %PYTHON %s --sizes 512 1024 128 --dump-kernel=xegpu-wg | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --ab-type bf16 | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --transpose-a | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --transpose-b | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --bias | FileCheck %s
Expand Down Expand Up @@ -78,6 +79,10 @@ class XeGPUMatMul:
If `has_bias` is True, adds a bias term to the result.
If `has_relu` is True, applies ReLU activation to the result (after bias if any).
If `truncate_c` is True, truncates the C to A/B data type after accumulation.

`ab_type` specifies the data type for A and B matrices (f16 or bf16).
`c_type` specifies the data type for the result C matrix (f32 by default, or ab_type if truncate_c is True).
`acc_type` specifies the data type for accumulation (f32 by default).
"""

payload_function_name: ClassVar[str] = "payload"
Expand Down Expand Up @@ -112,8 +117,8 @@ def __post_init__(self):
self.acc_type = ir.F32Type.get()
if self.c_type is None:
self.c_type = self.ab_type if self.truncate_c else self.acc_type
assert isinstance(self.ab_type, ir.F16Type), (
"Only f16 type is supported for A and B"
assert isinstance(self.ab_type, (ir.F16Type, ir.BF16Type)), (
"Only f16 and bf16 types are supported for A and B"
)
assert isinstance(self.acc_type, ir.F32Type), "Only f32 type is supported for C"
if self.truncate_c:
Expand Down Expand Up @@ -247,6 +252,8 @@ def check_results(
D_ref += bias.astype(f32)
if mmul.has_relu:
D_ref = np.maximum(D_ref, 0)
if mmul.truncate_c:
D_ref = D_ref.astype(mmul.ab_dtype)

D_host = host_solution.astype(np.float32)
if verbose > 1:
Expand All @@ -261,6 +268,8 @@ def check_results(
print("PASSED")
else:
print("FAILED Result mismatch!")
diff = np.abs(D_host - D_ref)
print(f"Max absolute difference: {np.max(diff)}")

return success

Expand All @@ -281,6 +290,13 @@ def cli_parser(description):
default=[4096, 4096, 4096],
help="M,N,K matrix sizes (A=MxK, B=KxN, C=MxN).",
)
parser.add_argument(
"--ab-type",
type=str,
choices=["f16", "bf16"],
default="f16",
help="Data type for A and B matrices.",
)
parser.add_argument(
"--transpose-a",
action="store_true",
Expand Down Expand Up @@ -499,6 +515,7 @@ def parse_cli_args(description):
M=params["m"],
N=params["n"],
K=params["k"],
ab_type=args.ab_type,
transpose_a=params["transpose_a"],
transpose_b=params["transpose_b"],
has_bias=args.bias,
Expand Down
13 changes: 11 additions & 2 deletions examples/xegpu/mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --ab-type bf16 | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --transpose-a | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --transpose-b | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --relu | FileCheck %s
Expand Down Expand Up @@ -133,8 +134,8 @@ def __post_init__(self):
self.ab_type = ir.F16Type.get()
if self.acc_type is None:
self.acc_type = ir.F32Type.get()
assert isinstance(self.ab_type, ir.F16Type), (
"Only f16 type is supported for A and B"
assert isinstance(self.ab_type, (ir.F16Type, ir.BF16Type)), (
"Only f16 and bf16 types are supported for A and B"
)
assert isinstance(self.acc_type, ir.F32Type), (
"Only f32 type is supported for accumulator"
Expand Down Expand Up @@ -299,6 +300,13 @@ def parse_cli():
nargs="+",
help="Number of features in each hidden layers.",
)
parser.add_argument(
"--ab-type",
type=str,
choices=["f16", "bf16"],
default="f16",
help="Data type for A and B matrices.",
)
parser.add_argument(
"--nruns",
type=int,
Expand Down Expand Up @@ -397,6 +405,7 @@ def parse_cli():
input_size=args.input_size,
output_size=args.output_size,
hidden_layer_sizes=args.hidden_sizes,
ab_type=args.ab_type,
has_bias=args.bias,
has_relu=args.relu,
transpose_a=tr_a,
Expand Down
Loading