Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
std::optional<at::Tensor> first_dims);

std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list);
std::vector<py::handle> quantizer_list,
const py::object &outputs = py::none());

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
Expand Down
27 changes: 23 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,23 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
} // namespace

std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list) {
std::vector<py::handle> quantizer_list,
const py::object &outputs) {
// Check number of tensors
const size_t num_tensors = tensor_list.size();
NVTE_CHECK(quantizer_list.size() == num_tensors, "Expected ", num_tensors,
" quantizers, but got ", quantizer_list.size());

const bool use_provided_outputs = !outputs.is_none();
py::sequence outputs_seq;
if (use_provided_outputs) {
NVTE_CHECK(py::isinstance<py::list>(outputs) || py::isinstance<py::tuple>(outputs),
"multi_tensor_quantize: outputs must be None, a list, or a tuple.");
outputs_seq = py::reinterpret_borrow<py::sequence>(outputs);
NVTE_CHECK(static_cast<size_t>(outputs_seq.size()) == num_tensors, "multi_tensor_quantize: ",
"len(outputs) is ", outputs_seq.size(), " but expected ", num_tensors, ".");
}

// Convert quantizers to C++ objects
std::vector<std::unique_ptr<Quantizer>> quantizer_cpp_list;
for (size_t i = 0; i < num_tensors; i++) {
Expand All @@ -339,9 +350,17 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten
const auto input_shape = input_cpp.shape();
const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type());

// Construct output tensor
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype);
TensorWrapper output_cpp;
py::object output_py;
if (use_provided_outputs) {
py::object output_obj = outputs_seq[static_cast<ssize_t>(i)];
std::tie(output_cpp, output_py) =
quantizer_cpp_list[i]->convert_and_update_tensor(std::move(output_obj));
} else {
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
std::tie(output_cpp, output_py) =
quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype);
}
output_cpp_list.emplace_back(std::move(output_cpp));
output_py_list.emplace_back(std::move(output_py));
}
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add,
"Fused backward of RMSNorm + add");
m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize,
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"),
py::arg("outputs") = py::none());
m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false);
Expand Down
109 changes: 109 additions & 0 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,6 +1593,115 @@ def get_weight_workspace(

return out

def get_multi_weight_workspace(
self,
*,
tensors: List[torch.Tensor],
quantizers: List[Quantizer],
cache_names: Optional[List[str]] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> List[QuantizedTensor]:
"""Get workspace buffers for a group of weights and maybe update their values.

Analogous to `get_weight_workspace`, but operates on a whole group of
weights. For delayed-scaling FP8 the group is cast and transposed with a
single fused multi_cast_transpose kernel instead of one quantize call per
tensor. With caching, both the initial allocation and later updates use that
fused `multi_tensor_quantize` path when `update_workspace` is True; cached
workspaces are passed as `outputs` so the fused kernel writes in place without
reallocating buffers. When
fusion is not applicable (other recipes, rowwise-only usage,
already-quantized weights, or CUDA-graph weight caching) the call falls
back to `get_weight_workspace` for each tensor, matching the per-tensor path.

Parameters
----------
tensors : list of torch.Tensor
Values to copy into the workspaces.
quantizers : list of Quantizer
Quantizers used to cast the weights.
cache_names : list of str, optional
Keys for caching. If None, the workspaces are not cached.
update_workspace : bool, default = True
Update workspaces with values from `tensors`.
skip_update_flag : torch.Tensor, optional
GPU flag to skip updating the workspaces.
workspace_dtype : torch.dtype, optional
High-precision workspace dtype (used for debug quantization).
"""
num_tensors = len(tensors)

# Fused path: delayed-scaling FP8 with transpose, high-precision (not
# already quantized) weights, and no CUDA-graph skip flag (the fused kernel
# has no device-side noop and would not preserve cached buffer pointers).
can_fuse = (
num_tensors > 0
and skip_update_flag is None
and all(
isinstance(quantizer, Float8Quantizer) and quantizer.columnwise_usage
for quantizer in quantizers
)
and not any(isinstance(tensor, QuantizedTensorStorage) for tensor in tensors)
)

if can_fuse:
# No caching: allocate fresh and cast + transpose the whole group at once.
if cache_names is None:
return tex.multi_tensor_quantize(list(tensors), quantizers)

workspaces = []
cache_miss = False
for name, quantizer in zip(cache_names, quantizers):
out = self._fp8_workspaces.get(name)
if out is not None and quantizer is not None:
reset_cache = False
if isinstance(out, Float8TensorStorage):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and out._transpose is None
):
reset_cache = True
if reset_cache:
del self._fp8_workspaces[name]
out = None
if out is None:
cache_miss = True
workspaces.append(out)
if cache_miss or update_workspace:
# Single fused kernel for initial allocation and for refreshes.
# Force internal=False so cached workspaces survive prepare_for_saving.
saved_internal = [quantizer.internal for quantizer in quantizers]
for quantizer in quantizers:
quantizer.internal = False
if cache_miss:
workspaces = tex.multi_tensor_quantize(list(tensors), quantizers)
else:
workspaces = tex.multi_tensor_quantize(
list(tensors), quantizers, outputs=workspaces
)
for quantizer, internal in zip(quantizers, saved_internal):
quantizer.internal = internal
for name, workspace in zip(cache_names, workspaces):
self._fp8_workspaces[name] = workspace
return workspaces
# Fallback: quantize each weight individually.
workspaces = []
for i in range(num_tensors):
workspaces.append(
self.get_weight_workspace(
tensor=tensors[i],
quantizer=quantizers[i],
cache_name=(None if cache_names is None else cache_names[i]),
update_workspace=update_workspace,
skip_update_flag=skip_update_flag,
workspace_dtype=workspace_dtype,
)
)
return workspaces

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
Expand Down
28 changes: 15 additions & 13 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,22 @@ def forward(
# Initialize weights
weights_fp8: list
if fp8 or debug:
# FP8 cast to workspace buffer
weights_fp8 = []
# FP8 cast to workspace buffer. For delayed-scaling FP8 the whole group is
# cast and transposed with a single fused multi_cast_transpose kernel; other
# cases fall back to a per-weight quantize inside get_multi_weight_workspace.
update_workspace = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms):
weight_fp8 = module.get_weight_workspace(
tensor=weights[i],
quantizer=weight_quantizers[i],
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
workspace_dtype=activation_dtype,
)
weights_fp8.append(weight_fp8)

weights_fp8 = module.get_multi_weight_workspace(
tensors=weights,
quantizers=weight_quantizers,
cache_names=(
None
if is_first_microbatch is None
else [f"weight{i}" for i in range(num_gemms)]
),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
workspace_dtype=activation_dtype,
)
else:
weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights]

Expand Down