diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3d6fe704e1..10baae0d9a 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -4,7 +4,7 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Sequence import functools import io import math @@ -200,6 +200,18 @@ def make_reference_and_test_tensors( return ref, test +def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Convert to an FP64 CPU tensor""" + if tensor is None: + return None + out = tensor.detach() + if isinstance(out, QuantizedTensor): + out = out.dequantize() + out = out.to(dtype=torch.float64, device="cpu") + out = out.requires_grad_(requires_grad=tensor.requires_grad) + return out + + class MegatronTrainingHelper: """Test-side stand-in for the Megatron-Core DDP / MegatronFSDP wrapper. Megatron's DDP wrapper (and MegatronFSDP) owns the per-parameter @@ -3368,25 +3380,17 @@ def test_layernorm_mlp( y_test = forward(x_test) y_test.backward(dy_test) - def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - """Convert to FP64 CPU tensor""" - if tensor is None: - return None - out = tensor.detach().to(dtype=torch.float64, device="cpu") - out = out.requires_grad_(requires_grad=tensor.requires_grad) - return out - # Check values tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking - torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) - torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) - torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols) - torch.testing.assert_close(to_cpu(norm.bias.grad), norm_b_ref.grad, **tols) - torch.testing.assert_close(to_cpu(ffn2.weight.grad), w2_ref.grad, **tols) - torch.testing.assert_close(to_cpu(ffn1.weight.grad), w1_ref.grad, **tols) + assert_close(y_test, y_ref, **tols) + assert_close(x_test.grad, x_ref.grad, **tols) + assert_close_grads(norm.weight, norm_w_ref, **tols) + assert_close_grads(norm.bias, norm_b_ref, **tols) + assert_close_grads(ffn2.weight, w2_ref, **tols) + assert_close_grads(ffn1.weight, w1_ref, **tols) if bias: - torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) - torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) + assert_close_grads(ffn1.bias, b1_ref, **tols) + assert_close_grads(ffn2.bias, b2_ref, **tols) @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @@ -4740,6 +4744,232 @@ def fuse_ops( torch.testing.assert_close(dw_test, w_ref.grad, **tols) +class TestTrainingLoops: + + def _linear_train_stage( + self, + module: te.ops.Linear, + *, + steps: int = 3, + in_shape: Sequence[int], + out_shape: Sequence[int], + dtype: torch.type, + device: torch.device, + quantization: Optional[str], + recipe: Optional[transformer_engine.common.recipe.Recipe], + ) -> None: + """Perform training steps with linear op""" + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantization is not None: + tols = quantization_tols(quantization) + + for _ in range(steps): + # Update parameters with random values to simulate + # optimizer step or FSDP param all-gather + with torch.no_grad(): + module.weight.copy_(torch.empty_like(module.weight).uniform_()) + module.bias.copy_(torch.empty_like(module.bias).uniform_()) + for param in module.parameters(): + param.grad = None + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w_ref = to_cpu(module.weight) + b_ref = to_cpu(module.bias) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) + y_ref.backward(dy_ref) + + # Implementation with linear op + with te.autocast(enabled=quantization is not None, recipe=recipe): + y_test = module(x_test) + y_test.backward(dy_test) + + # Check results + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(module.weight, w_ref, **tols) + assert_close_grads(module.bias, b_ref, **tols) + + @torch.inference_mode + def _linear_infer_stage( + self, + module: te.ops.Linear, + *, + steps: int = 3, + in_shape: Sequence[int], + dtype: torch.type, + device: torch.device, + quantization: Optional[str], + recipe: Optional[transformer_engine.common.recipe.Recipe], + ) -> None: + """Perform inference steps with linear op""" + + # Parameter reference values + w_ref = to_cpu(module.weight) + b_ref = to_cpu(module.bias) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantization is not None: + tols = quantization_tols(quantization) + + for _ in range(steps): + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) + + # Implementation with linear op + with te.autocast(enabled=quantization is not None, recipe=recipe): + y_test = module(x_test) + + # Check results + assert_close(y_test, y_ref, **tols) + + @pytest.mark.parametrize("stages", (["train", "infer"] * 2, ["infer", "train"] * 2)) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear_training_loop( + self, + *, + stages: Sequence[str], + weight_shape: tuple[int, int] = (32, 32), + in_shape: Sequence[int] = (32, -1), + dtype: Optional[torch.dtype] = None, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Training loops with linear op""" + if dtype is None: + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and quantized_weight: + pytest.skip("Quantization scheme is not specified") + + # Construct module with random weights + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + module = te.ops.Linear( + in_features, + out_features, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + for param in module.parameters(): + param.copy_(torch.empty_like(param).uniform_()) + + # Training loop stages + for stage in stages: + if stage == "train": + self._linear_train_stage( + module, + in_shape=in_shape, + out_shape=out_shape, + dtype=dtype, + device=device, + quantization=quantization, + recipe=recipe, + ) + elif stage == "infer": + self._linear_infer_stage( + module, + in_shape=in_shape, + dtype=dtype, + device=device, + quantization=quantization, + recipe=recipe, + ) + else: + raise ValueError(f"Unrecognized stage ({stage})") + + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear_inference_loop( + self, + *, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Sequence[int] = (32, -1), + dtype: Optional[torch.dtype] = None, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Inference loop with linear op""" + if dtype is None: + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and quantized_weight: + pytest.skip("Quantization scheme is not specified") + + # Construct module with random weights + recipe = make_recipe(quantization) + with ( + torch.inference_mode(), + te.quantized_model_init(enabled=quantized_weight, recipe=recipe), + ): + module = te.ops.Linear( + in_features, + out_features, + device=device, + dtype=dtype, + ) + for param in module.parameters(): + param.copy_(torch.empty_like(param).uniform_()) + + # Inference loop + self._linear_infer_stage( + module, + in_shape=in_shape, + dtype=dtype, + device=device, + quantization=quantization, + recipe=recipe, + ) + + def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: if not mxfp8_available: pytest.skip(reason_for_no_mxfp8) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 17594726cc..19fcf62ced 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -329,8 +329,6 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: super().pre_fuser_forward(requires_grad=requires_grad) if FP8GlobalStateManager.is_fp8_enabled(): # Configure quantizer usages - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad columnwise_usage = weight_requires_grad if FP8GlobalStateManager.get_fp8_recipe().backward_override is not None: @@ -339,13 +337,13 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - weight_quantizer.set_usage(rowwise=True, columnwise=False) + weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad) grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) - # Configure input/grad output tensor + # Configure input/grad output quantizers # Note: These tensors are only used internally. If there is no # tensor-parallel communication, they are only used for GEMM. input_quantizer = self.get_quantizer("forward", 0) @@ -370,21 +368,15 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: # Configure weight quantizer # Note: This function may be called in base class constructor, - # before any basic linear attrs have been set. + # before basic linear attrs have been set. weight_quantizer = self.get_quantizer("forward", 1) - if weight_quantizer is None: - pass - elif is_quantized_tensor(getattr(self, "weight", None)): - # Make sure weight param has correct quantizer - weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) - weight_quantizer.internal = False - self.weight.update_quantizer(weight_quantizer.copy()) - else: - # Use internal tensors if quantized weights will not be - # exposed externally - weight_quantizer.internal = ( - not FP8GlobalStateManager.with_fp8_parameters() - and not getattr(self, "_with_quantized_weight", False) + weight = getattr(self, "weight", None) + if weight_quantizer is not None: + # Determine if quantized weight is exposed as parameter + weight_quantizer.internal = not ( + FP8GlobalStateManager.with_fp8_parameters() + or getattr(self, "_with_quantized_weight", False) + or is_quantized_tensor(weight) ) # Recipe-specific configuration @@ -416,6 +408,18 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: grad_output_quantizer.with_amax_reduction = True grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + # Update quantizer in quantized weight tensor + if weight_quantizer is not None and is_quantized_tensor(weight): + if weight._quantizer is not None: + # Preserve existing usages in weight tensor. Even if a + # usage is currently unnecessary, the weight tensor + # may be used elsewhere. + weight_quantizer.set_usage( + rowwise=weight._quantizer.rowwise_usage, + columnwise=weight._quantizer.columnwise_usage, + ) + weight.update_quantizer(weight_quantizer.copy()) + @staticmethod def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index fe5997a71e..b503cb186b 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -619,14 +619,12 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: weight_requires_grad = requires_grad and weight_requires_grad # Configure quantizer usages - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. for group_idx in range(self.num_groups): input_quantizer = self.get_quantizer("forward", 2 * group_idx) weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) grad_output_quantizer = self.get_quantizer("backward", group_idx) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=False) + weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: @@ -641,32 +639,29 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: if grad_output_quantizer is not None: grad_output_quantizer.internal = True - # Handle weight quantizer + # Get weight tensor # Note: This function may be called in base class constructor, - # before any basic linear attrs have been set. - weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) - if weight_quantizer is None: - pass - elif is_quantized_tensor(getattr(self, f"weight{group_idx}", None)): - # Make sure weight param has correct quantizer - weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) - weight_quantizer.internal = False - if self.single_grouped_weight: - self.weight.quantizer = weight_quantizer.copy() - else: - getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) + # before any grouped linear attrs have been set. + weight = None + weight_is_quantized = False + if getattr(self, "single_grouped_weight", False): + weight = getattr(self, "weight", None) + weight_is_quantized = weight is not None and weight.quantizer is not None else: - # Use internal tensors if quantized weights will not be - # exposed externally - weight_quantizer.internal = ( - not FP8GlobalStateManager.with_fp8_parameters() - and not getattr(self, "_with_quantized_weight", False) - and not self.single_grouped_weight + weight = getattr(self, f"weight{group_idx}", None) + weight_is_quantized = is_quantized_tensor(weight) + + # Configure weight quantizer + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + if weight_quantizer is not None: + # Determine if quantized weight is exposed as parameter + weight_quantizer.internal = not ( + FP8GlobalStateManager.with_fp8_parameters() + or getattr(self, "_with_quantized_weight", False) + or weight_is_quantized ) # Recipe-specific configuration - # Note: This function may be called in base class constructor, - # before any basic linear attrs have been set. if recipe is not None: if recipe.float8_current_scaling(): input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale @@ -680,6 +675,29 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: recipe.fp8_quant_bwd_grad.amax_epsilon ) + # Update quantizer in quantized weight tensor + if weight_quantizer is not None and weight_is_quantized: + # Get quantizer from weight tensor + weight_tensor_quantizer = ( + weight.quantizer if self.single_grouped_weight else weight._quantizer + ) + + # Preserve existing usages in weight tensor. Even if a + # usage is currently unnecessary, the weight tensor + # may be used elsewhere. + if weight_tensor_quantizer is not None: + weight_quantizer.set_usage( + rowwise=weight_tensor_quantizer.rowwise_usage, + columnwise=weight_tensor_quantizer.columnwise_usage, + ) + + # Update weight tensor + if self.single_grouped_weight: + if group_idx == 0: + weight.quantizer = weight_quantizer.copy() + else: + weight.update_quantizer(weight_quantizer.copy()) + def op_forward(self, *args, **kwargs): raise RuntimeError( f"{self.__class__.__name__} operation has "