From f49063655a6e989f8352c50e36e9a014afec68cf Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Wed, 10 Jun 2026 23:01:20 +0000 Subject: [PATCH 1/3] Add get_multi_weight_workspace method for optimized weight handling This new method allows for efficient workspace management of multiple weights, enabling fused operations for delayed-scaling FP8. It enhances performance by reducing the number of quantization calls and supports caching of workspaces. The grouped_linear module has been updated to utilize this method. --- transformer_engine/pytorch/module/base.py | 84 +++++++++++++++++++ .../pytorch/module/grouped_linear.py | 28 ++++--- 2 files changed, 99 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 933060aff..71802b1b9 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1593,6 +1593,90 @@ def get_weight_workspace( return out + def get_multi_weight_workspace( + self, + *, + tensors: List[torch.Tensor], + quantizers: List[Quantizer], + cache_names: Optional[List[str]] = None, + update_workspace: bool = True, + skip_update_flag: Optional[torch.Tensor] = None, + workspace_dtype: Optional[torch.dtype] = None, + ) -> List[QuantizedTensor]: + """Get workspace buffers for a group of weights and maybe update their values. + + Analogous to `get_weight_workspace`, but operates on a whole group of + weights. For delayed-scaling FP8 the group is cast and transposed with a + single fused multi_cast_transpose kernel instead of one quantize call per + tensor. When fusion is not applicable (other recipes, rowwise-only usage, + already-quantized weights, or CUDA-graph weight caching) the call falls + back to `get_weight_workspace` for each tensor, matching the per-tensor path. + + Parameters + ---------- + tensors : list of torch.Tensor + Values to copy into the workspaces. + quantizers : list of Quantizer + Quantizers used to cast the weights. + cache_names : list of str, optional + Keys for caching. If None, the workspaces are not cached. + update_workspace : bool, default = True + Update workspaces with values from `tensors`. + skip_update_flag : torch.Tensor, optional + GPU flag to skip updating the workspaces. + workspace_dtype : torch.dtype, optional + High-precision workspace dtype (used for debug quantization). + """ + num_tensors = len(tensors) + + # Fused path: delayed-scaling FP8 with transpose, high-precision (not + # already quantized) weights, and no CUDA-graph skip flag (the fused kernel + # has no device-side noop and would not preserve cached buffer pointers). + can_fuse = ( + num_tensors > 0 + and skip_update_flag is None + and all( + isinstance(quantizer, Float8Quantizer) and quantizer.columnwise_usage + for quantizer in quantizers + ) + and not any(isinstance(tensor, QuantizedTensorStorage) for tensor in tensors) + ) + + if can_fuse: + caching = cache_names is not None + cache_valid = caching and all( + self._fp8_workspaces.get(name) is not None for name in cache_names + ) + if update_workspace or not cache_valid: + # Force internal=False so cached workspaces survive prepare_for_saving. + saved_internal = [quantizer.internal for quantizer in quantizers] + if caching: + for quantizer in quantizers: + quantizer.internal = False + workspaces = tex.multi_tensor_quantize(list(tensors), quantizers) + if caching: + for quantizer, internal in zip(quantizers, saved_internal): + quantizer.internal = internal + for name, workspace in zip(cache_names, workspaces): + self._fp8_workspaces[name] = workspace + return workspaces + return [self._fp8_workspaces[name] for name in cache_names] + + # Fallback: quantize each weight individually. + workspaces = [] + for i in range(num_tensors): + workspaces.append( + self.get_weight_workspace( + tensor=tensors[i], + quantizer=quantizers[i], + cache_name=(None if cache_names is None else cache_names[i]), + update_workspace=update_workspace, + skip_update_flag=skip_update_flag, + workspace_dtype=workspace_dtype, + ) + ) + return workspaces + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a47912b4c..ac51250e7 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -180,20 +180,22 @@ def forward( # Initialize weights weights_fp8: list if fp8 or debug: - # FP8 cast to workspace buffer - weights_fp8 = [] + # FP8 cast to workspace buffer. For delayed-scaling FP8 the whole group is + # cast and transposed with a single fused multi_cast_transpose kernel; other + # cases fall back to a per-weight quantize inside get_multi_weight_workspace. update_workspace = is_first_microbatch is None or is_first_microbatch - for i in range(num_gemms): - weight_fp8 = module.get_weight_workspace( - tensor=weights[i], - quantizer=weight_quantizers[i], - cache_name=(None if is_first_microbatch is None else f"weight{i}"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - workspace_dtype=activation_dtype, - ) - weights_fp8.append(weight_fp8) - + weights_fp8 = module.get_multi_weight_workspace( + tensors=weights, + quantizers=weight_quantizers, + cache_names=( + None + if is_first_microbatch is None + else [f"weight{i}" for i in range(num_gemms)] + ), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + workspace_dtype=activation_dtype, + ) else: weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] From 57d82b2c21501c8d08b15ca0fad41b76e44dbcea Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Thu, 11 Jun 2026 16:33:42 +0000 Subject: [PATCH 2/3] Refactor quantization logic in TransformerEngineBaseModule for improved workspace management This update enhances the quantization process by optimizing workspace allocation and handling cache misses more effectively. It introduces a streamlined approach for both cache hits and misses, ensuring efficient in-place quantization and reducing unnecessary memory reallocations. The changes aim to improve performance during FP8 operations while maintaining compatibility with existing functionality. --- transformer_engine/pytorch/module/base.py | 43 +++++++++++++++-------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 71802b1b9..00f6a69be 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1643,24 +1643,37 @@ def get_multi_weight_workspace( ) if can_fuse: - caching = cache_names is not None - cache_valid = caching and all( - self._fp8_workspaces.get(name) is not None for name in cache_names - ) - if update_workspace or not cache_valid: + # No caching: allocate fresh and cast + transpose the whole group at once. + if cache_names is None: + return tex.multi_tensor_quantize(list(tensors), quantizers) + + workspaces = [self._fp8_workspaces.get(name) for name in cache_names] + if any(workspace is None for workspace in workspaces): + # Cache miss: allocate the workspaces once with a single fused kernel. # Force internal=False so cached workspaces survive prepare_for_saving. saved_internal = [quantizer.internal for quantizer in quantizers] - if caching: - for quantizer in quantizers: - quantizer.internal = False + for quantizer in quantizers: + quantizer.internal = False workspaces = tex.multi_tensor_quantize(list(tensors), quantizers) - if caching: - for quantizer, internal in zip(quantizers, saved_internal): - quantizer.internal = internal - for name, workspace in zip(cache_names, workspaces): - self._fp8_workspaces[name] = workspace - return workspaces - return [self._fp8_workspaces[name] for name in cache_names] + for quantizer, internal in zip(quantizers, saved_internal): + quantizer.internal = internal + for name, workspace in zip(cache_names, workspaces): + self._fp8_workspaces[name] = workspace + elif update_workspace: + # Cache hit: quantize in-place into the existing buffers to avoid + # reallocating FP8 storage (and rebuilding tensor objects) every step. + for tensor, quantizer, workspace in zip(tensors, quantizers, workspaces): + if hasattr(workspace, "quantize_"): + workspace.quantize_(tensor, noop_flag=None) + elif IS_HIP_EXTENSION: + use_cast_transpose_triton = bool( + int(os.environ.get("NVTE_USE_CAST_TRANSPOSE_TRITON", "0")) + ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + quantize_func(tensor, quantizer, workspace, None) + else: + tex.quantize(tensor, quantizer, workspace, None) + return workspaces # Fallback: quantize each weight individually. workspaces = [] From 1ed060798cc7682877ef2476233fca117f4a8b88 Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Thu, 11 Jun 2026 18:30:46 +0000 Subject: [PATCH 3/3] Enhance multi_tensor_quantize function to support output tensors This update modifies the multi_tensor_quantize function to accept an optional outputs parameter, allowing for in-place quantization when cached workspaces are provided. The changes improve memory efficiency and performance during FP8 operations by reducing unnecessary tensor reallocations. --- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../pytorch/csrc/extensions/cast.cpp | 27 ++++++++-- .../pytorch/csrc/extensions/pybind.cpp | 3 +- transformer_engine/pytorch/module/base.py | 52 ++++++++++++------- 4 files changed, 59 insertions(+), 26 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index fd0013f3e..98b4f27e8 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -312,7 +312,8 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const std::optional first_dims); std::vector multi_tensor_quantize(const std::vector &tensor_list, - std::vector quantizer_list); + std::vector quantizer_list, + const py::object &outputs = py::none()); std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9ed5502b7..bbfd3fd1a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -314,12 +314,23 @@ void multi_tensor_quantize_impl(const std::vector &input_list, } // namespace std::vector multi_tensor_quantize(const std::vector &tensor_list, - std::vector quantizer_list) { + std::vector quantizer_list, + const py::object &outputs) { // Check number of tensors const size_t num_tensors = tensor_list.size(); NVTE_CHECK(quantizer_list.size() == num_tensors, "Expected ", num_tensors, " quantizers, but got ", quantizer_list.size()); + const bool use_provided_outputs = !outputs.is_none(); + py::sequence outputs_seq; + if (use_provided_outputs) { + NVTE_CHECK(py::isinstance(outputs) || py::isinstance(outputs), + "multi_tensor_quantize: outputs must be None, a list, or a tuple."); + outputs_seq = py::reinterpret_borrow(outputs); + NVTE_CHECK(static_cast(outputs_seq.size()) == num_tensors, "multi_tensor_quantize: ", + "len(outputs) is ", outputs_seq.size(), " but expected ", num_tensors, "."); + } + // Convert quantizers to C++ objects std::vector> quantizer_cpp_list; for (size_t i = 0; i < num_tensors; i++) { @@ -339,9 +350,17 @@ std::vector multi_tensor_quantize(const std::vector &ten const auto input_shape = input_cpp.shape(); const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); - // Construct output tensor - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + TensorWrapper output_cpp; + py::object output_py; + if (use_provided_outputs) { + py::object output_obj = outputs_seq[static_cast(i)]; + std::tie(output_cpp, output_py) = + quantizer_cpp_list[i]->convert_and_update_tensor(std::move(output_obj)); + } else { + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + std::tie(output_cpp, output_py) = + quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + } output_cpp_list.emplace_back(std::move(output_cpp)); output_py_list.emplace_back(std::move(output_py)); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 0ea760f89..afd6be1b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -297,7 +297,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add, "Fused backward of RMSNorm + add"); m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize, - "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); + "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"), + py::arg("outputs") = py::none()); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 00f6a69be..3bbb88e4c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1608,7 +1608,11 @@ def get_multi_weight_workspace( Analogous to `get_weight_workspace`, but operates on a whole group of weights. For delayed-scaling FP8 the group is cast and transposed with a single fused multi_cast_transpose kernel instead of one quantize call per - tensor. When fusion is not applicable (other recipes, rowwise-only usage, + tensor. With caching, both the initial allocation and later updates use that + fused `multi_tensor_quantize` path when `update_workspace` is True; cached + workspaces are passed as `outputs` so the fused kernel writes in place without + reallocating buffers. When + fusion is not applicable (other recipes, rowwise-only usage, already-quantized weights, or CUDA-graph weight caching) the call falls back to `get_weight_workspace` for each tensor, matching the per-tensor path. @@ -1647,34 +1651,42 @@ def get_multi_weight_workspace( if cache_names is None: return tex.multi_tensor_quantize(list(tensors), quantizers) - workspaces = [self._fp8_workspaces.get(name) for name in cache_names] - if any(workspace is None for workspace in workspaces): - # Cache miss: allocate the workspaces once with a single fused kernel. + workspaces = [] + cache_miss = False + for name, quantizer in zip(cache_names, quantizers): + out = self._fp8_workspaces.get(name) + if out is not None and quantizer is not None: + reset_cache = False + if isinstance(out, Float8TensorStorage): + if ( + not is_non_tn_fp8_gemm_supported() + and quantizer.columnwise_usage + and out._transpose is None + ): + reset_cache = True + if reset_cache: + del self._fp8_workspaces[name] + out = None + if out is None: + cache_miss = True + workspaces.append(out) + if cache_miss or update_workspace: + # Single fused kernel for initial allocation and for refreshes. # Force internal=False so cached workspaces survive prepare_for_saving. saved_internal = [quantizer.internal for quantizer in quantizers] for quantizer in quantizers: quantizer.internal = False - workspaces = tex.multi_tensor_quantize(list(tensors), quantizers) + if cache_miss: + workspaces = tex.multi_tensor_quantize(list(tensors), quantizers) + else: + workspaces = tex.multi_tensor_quantize( + list(tensors), quantizers, outputs=workspaces + ) for quantizer, internal in zip(quantizers, saved_internal): quantizer.internal = internal for name, workspace in zip(cache_names, workspaces): self._fp8_workspaces[name] = workspace - elif update_workspace: - # Cache hit: quantize in-place into the existing buffers to avoid - # reallocating FP8 storage (and rebuilding tensor objects) every step. - for tensor, quantizer, workspace in zip(tensors, quantizers, workspaces): - if hasattr(workspace, "quantize_"): - workspace.quantize_(tensor, noop_flag=None) - elif IS_HIP_EXTENSION: - use_cast_transpose_triton = bool( - int(os.environ.get("NVTE_USE_CAST_TRANSPOSE_TRITON", "0")) - ) - quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize - quantize_func(tensor, quantizer, workspace, None) - else: - tex.quantize(tensor, quantizer, workspace, None) return workspaces - # Fallback: quantize each weight individually. workspaces = [] for i in range(num_tensors):