diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index fd0013f3e..98b4f27e8 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -312,7 +312,8 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const std::optional first_dims); std::vector multi_tensor_quantize(const std::vector &tensor_list, - std::vector quantizer_list); + std::vector quantizer_list, + const py::object &outputs = py::none()); std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9ed5502b7..bbfd3fd1a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -314,12 +314,23 @@ void multi_tensor_quantize_impl(const std::vector &input_list, } // namespace std::vector multi_tensor_quantize(const std::vector &tensor_list, - std::vector quantizer_list) { + std::vector quantizer_list, + const py::object &outputs) { // Check number of tensors const size_t num_tensors = tensor_list.size(); NVTE_CHECK(quantizer_list.size() == num_tensors, "Expected ", num_tensors, " quantizers, but got ", quantizer_list.size()); + const bool use_provided_outputs = !outputs.is_none(); + py::sequence outputs_seq; + if (use_provided_outputs) { + NVTE_CHECK(py::isinstance(outputs) || py::isinstance(outputs), + "multi_tensor_quantize: outputs must be None, a list, or a tuple."); + outputs_seq = py::reinterpret_borrow(outputs); + NVTE_CHECK(static_cast(outputs_seq.size()) == num_tensors, "multi_tensor_quantize: ", + "len(outputs) is ", outputs_seq.size(), " but expected ", num_tensors, "."); + } + // Convert quantizers to C++ objects std::vector> quantizer_cpp_list; for (size_t i = 0; i < num_tensors; i++) { @@ -339,9 +350,17 @@ std::vector multi_tensor_quantize(const std::vector &ten const auto input_shape = input_cpp.shape(); const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); - // Construct output tensor - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + TensorWrapper output_cpp; + py::object output_py; + if (use_provided_outputs) { + py::object output_obj = outputs_seq[static_cast(i)]; + std::tie(output_cpp, output_py) = + quantizer_cpp_list[i]->convert_and_update_tensor(std::move(output_obj)); + } else { + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + std::tie(output_cpp, output_py) = + quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + } output_cpp_list.emplace_back(std::move(output_cpp)); output_py_list.emplace_back(std::move(output_py)); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 0ea760f89..afd6be1b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -297,7 +297,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add, "Fused backward of RMSNorm + add"); m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize, - "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); + "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"), + py::arg("outputs") = py::none()); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 933060aff..3bbb88e4c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1593,6 +1593,115 @@ def get_weight_workspace( return out + def get_multi_weight_workspace( + self, + *, + tensors: List[torch.Tensor], + quantizers: List[Quantizer], + cache_names: Optional[List[str]] = None, + update_workspace: bool = True, + skip_update_flag: Optional[torch.Tensor] = None, + workspace_dtype: Optional[torch.dtype] = None, + ) -> List[QuantizedTensor]: + """Get workspace buffers for a group of weights and maybe update their values. + + Analogous to `get_weight_workspace`, but operates on a whole group of + weights. For delayed-scaling FP8 the group is cast and transposed with a + single fused multi_cast_transpose kernel instead of one quantize call per + tensor. With caching, both the initial allocation and later updates use that + fused `multi_tensor_quantize` path when `update_workspace` is True; cached + workspaces are passed as `outputs` so the fused kernel writes in place without + reallocating buffers. When + fusion is not applicable (other recipes, rowwise-only usage, + already-quantized weights, or CUDA-graph weight caching) the call falls + back to `get_weight_workspace` for each tensor, matching the per-tensor path. + + Parameters + ---------- + tensors : list of torch.Tensor + Values to copy into the workspaces. + quantizers : list of Quantizer + Quantizers used to cast the weights. + cache_names : list of str, optional + Keys for caching. If None, the workspaces are not cached. + update_workspace : bool, default = True + Update workspaces with values from `tensors`. + skip_update_flag : torch.Tensor, optional + GPU flag to skip updating the workspaces. + workspace_dtype : torch.dtype, optional + High-precision workspace dtype (used for debug quantization). + """ + num_tensors = len(tensors) + + # Fused path: delayed-scaling FP8 with transpose, high-precision (not + # already quantized) weights, and no CUDA-graph skip flag (the fused kernel + # has no device-side noop and would not preserve cached buffer pointers). + can_fuse = ( + num_tensors > 0 + and skip_update_flag is None + and all( + isinstance(quantizer, Float8Quantizer) and quantizer.columnwise_usage + for quantizer in quantizers + ) + and not any(isinstance(tensor, QuantizedTensorStorage) for tensor in tensors) + ) + + if can_fuse: + # No caching: allocate fresh and cast + transpose the whole group at once. + if cache_names is None: + return tex.multi_tensor_quantize(list(tensors), quantizers) + + workspaces = [] + cache_miss = False + for name, quantizer in zip(cache_names, quantizers): + out = self._fp8_workspaces.get(name) + if out is not None and quantizer is not None: + reset_cache = False + if isinstance(out, Float8TensorStorage): + if ( + not is_non_tn_fp8_gemm_supported() + and quantizer.columnwise_usage + and out._transpose is None + ): + reset_cache = True + if reset_cache: + del self._fp8_workspaces[name] + out = None + if out is None: + cache_miss = True + workspaces.append(out) + if cache_miss or update_workspace: + # Single fused kernel for initial allocation and for refreshes. + # Force internal=False so cached workspaces survive prepare_for_saving. + saved_internal = [quantizer.internal for quantizer in quantizers] + for quantizer in quantizers: + quantizer.internal = False + if cache_miss: + workspaces = tex.multi_tensor_quantize(list(tensors), quantizers) + else: + workspaces = tex.multi_tensor_quantize( + list(tensors), quantizers, outputs=workspaces + ) + for quantizer, internal in zip(quantizers, saved_internal): + quantizer.internal = internal + for name, workspace in zip(cache_names, workspaces): + self._fp8_workspaces[name] = workspace + return workspaces + # Fallback: quantize each weight individually. + workspaces = [] + for i in range(num_tensors): + workspaces.append( + self.get_weight_workspace( + tensor=tensors[i], + quantizer=quantizers[i], + cache_name=(None if cache_names is None else cache_names[i]), + update_workspace=update_workspace, + skip_update_flag=skip_update_flag, + workspace_dtype=workspace_dtype, + ) + ) + return workspaces + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a47912b4c..ac51250e7 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -180,20 +180,22 @@ def forward( # Initialize weights weights_fp8: list if fp8 or debug: - # FP8 cast to workspace buffer - weights_fp8 = [] + # FP8 cast to workspace buffer. For delayed-scaling FP8 the whole group is + # cast and transposed with a single fused multi_cast_transpose kernel; other + # cases fall back to a per-weight quantize inside get_multi_weight_workspace. update_workspace = is_first_microbatch is None or is_first_microbatch - for i in range(num_gemms): - weight_fp8 = module.get_weight_workspace( - tensor=weights[i], - quantizer=weight_quantizers[i], - cache_name=(None if is_first_microbatch is None else f"weight{i}"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - workspace_dtype=activation_dtype, - ) - weights_fp8.append(weight_fp8) - + weights_fp8 = module.get_multi_weight_workspace( + tensors=weights, + quantizers=weight_quantizers, + cache_names=( + None + if is_first_microbatch is None + else [f"weight{i}" for i in range(num_gemms)] + ), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + workspace_dtype=activation_dtype, + ) else: weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights]