From ff2fbf9da796c1c351b162f5b9f18b249390cd5d Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 26 Jun 2026 05:52:43 -0700 Subject: [PATCH] [tools] Fix KB model type conversion --- tools/kernel-bench | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tools/kernel-bench b/tools/kernel-bench index c8a5e221..bb122a70 100755 --- a/tools/kernel-bench +++ b/tools/kernel-bench @@ -4,7 +4,6 @@ import argparse import subprocess from datetime import datetime import sys -import ml_dtypes import numpy as np from pathlib import Path import torch @@ -35,7 +34,7 @@ def from_numpy(buf): """ return ( torch.from_numpy(buf.view(np.uint16)).view(torch.bfloat16) - if buf.dtype == ml_dtypes.bfloat16 + if buf.dtype == torch.bfloat16 else torch.from_numpy(buf) ) @@ -67,7 +66,7 @@ def import_torch( model_class_name: str = "Model", sample_args=None, model_init_args=None, - dtype=torch.float32, + model_datatype=torch.float32, ) -> torch.nn.Module: """ Imports a PyTorch model from a KernelBench file and returns the PyTorch module. @@ -79,7 +78,7 @@ def import_torch( model_class_name=model_class_name, sample_args=sample_args, model_init_args=model_init_args, - dtype=dtype, + model_datatype=model_datatype, ) assert isinstance(model, torch.nn.Module) return model, sample_args, sample_kwargs @@ -174,7 +173,7 @@ def parse_module_arguments( # Parse input shapes first, to create sample tensors for the inputs only buffers = [arg.arg for arg in KernelArgumentParser.parse_all(input_shapes_str)] # Build sample torch tensors from input shapes to override hard-coded sizes in get_inputs(). - if any(buf.dtype == ml_dtypes.bfloat16 for buf in buffers): + if any(buf.dtype == torch.bfloat16 for buf in buffers): print( "Warning: bfloat16 is not natively supported by torch.from_numpy(), \ so we are reinterpreting the buffers as uint16. This may cause issues if the buffers are modified after creation." @@ -302,7 +301,7 @@ def get_torch_data_type(dtype_str: str): if dtype_str == "f32": return torch.float32 elif dtype_str == "bf16": - return ml_dtypes.bfloat16 + return torch.bfloat16 else: raise ValueError(f"Unsupported data type: {dtype_str}") @@ -448,7 +447,7 @@ if __name__ == "__main__": ) # Validate data type argument - dtype = get_torch_data_type(args.dtype) + model_datatype = get_torch_data_type(args.dtype) # Initialize the device data buffers, sample_tensors, init_args = parse_module_arguments( @@ -460,7 +459,7 @@ if __name__ == "__main__": args.kernel_bench_model, sample_args=sample_tensors, model_init_args=init_args, - dtype=dtype, + model_datatype=model_datatype, ) if args.validate: out_ref = model(*sample_args, **sample_kwargs)