Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions tools/kernel-bench
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import argparse
import subprocess
from datetime import datetime
import sys
import ml_dtypes
import numpy as np
from pathlib import Path
import torch
Expand Down Expand Up @@ -35,7 +34,7 @@ def from_numpy(buf):
"""
return (
torch.from_numpy(buf.view(np.uint16)).view(torch.bfloat16)
if buf.dtype == ml_dtypes.bfloat16
if buf.dtype == torch.bfloat16
else torch.from_numpy(buf)
)

Expand Down Expand Up @@ -67,7 +66,7 @@ def import_torch(
model_class_name: str = "Model",
sample_args=None,
model_init_args=None,
dtype=torch.float32,
model_datatype=torch.float32,
) -> torch.nn.Module:
"""
Imports a PyTorch model from a KernelBench file and returns the PyTorch module.
Expand All @@ -79,7 +78,7 @@ def import_torch(
model_class_name=model_class_name,
sample_args=sample_args,
model_init_args=model_init_args,
dtype=dtype,
model_datatype=model_datatype,
)
assert isinstance(model, torch.nn.Module)
return model, sample_args, sample_kwargs
Expand Down Expand Up @@ -174,7 +173,7 @@ def parse_module_arguments(
# Parse input shapes first, to create sample tensors for the inputs only
buffers = [arg.arg for arg in KernelArgumentParser.parse_all(input_shapes_str)]
# Build sample torch tensors from input shapes to override hard-coded sizes in get_inputs().
if any(buf.dtype == ml_dtypes.bfloat16 for buf in buffers):
if any(buf.dtype == torch.bfloat16 for buf in buffers):
print(
"Warning: bfloat16 is not natively supported by torch.from_numpy(), \
so we are reinterpreting the buffers as uint16. This may cause issues if the buffers are modified after creation."
Expand Down Expand Up @@ -302,7 +301,7 @@ def get_torch_data_type(dtype_str: str):
if dtype_str == "f32":
return torch.float32
elif dtype_str == "bf16":
return ml_dtypes.bfloat16
return torch.bfloat16
else:
raise ValueError(f"Unsupported data type: {dtype_str}")

Expand Down Expand Up @@ -448,7 +447,7 @@ if __name__ == "__main__":
)

# Validate data type argument
dtype = get_torch_data_type(args.dtype)
model_datatype = get_torch_data_type(args.dtype)

# Initialize the device data
buffers, sample_tensors, init_args = parse_module_arguments(
Expand All @@ -460,7 +459,7 @@ if __name__ == "__main__":
args.kernel_bench_model,
sample_args=sample_tensors,
model_init_args=init_args,
dtype=dtype,
model_datatype=model_datatype,
)
if args.validate:
out_ref = model(*sample_args, **sample_kwargs)
Expand Down
Loading