-
Notifications
You must be signed in to change notification settings - Fork 719
[PyTorch] Fusible ops preserve usages in quantized weight tensors #2929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
525a749
59d417d
4798bdc
3b9a8a1
aaf8588
63328be
aaf0bd6
838ea73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -329,8 +329,6 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||
| super().pre_fuser_forward(requires_grad=requires_grad) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if FP8GlobalStateManager.is_fp8_enabled(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Configure quantizer usages | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Note: We cache the quantized input for backward pass, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # but discard the quantized weights. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_requires_grad = requires_grad and self.weight.requires_grad | ||||||||||||||||||||||||||||||||||||||||||||||||||
| columnwise_usage = weight_requires_grad | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if FP8GlobalStateManager.get_fp8_recipe().backward_override is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -339,13 +337,13 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer = self.get_quantizer("forward", 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_output_quantizer = self.get_quantizer("backward", 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer.set_usage(rowwise=True, columnwise=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| super().reset_recipe_state(recipe=recipe) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Configure input/grad output tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Configure input/grad output quantizers | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Note: These tensors are only used internally. If there is no | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # tensor-parallel communication, they are only used for GEMM. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| input_quantizer = self.get_quantizer("forward", 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -370,21 +368,15 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Configure weight quantizer | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Note: This function may be called in base class constructor, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # before any basic linear attrs have been set. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # before basic linear attrs have been set. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer = self.get_quantizer("forward", 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if weight_quantizer is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||||||||||||||||
| elif is_quantized_tensor(getattr(self, "weight", None)): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Make sure weight param has correct quantizer | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer.internal = False | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.weight.update_quantizer(weight_quantizer.copy()) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use internal tensors if quantized weights will not be | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # exposed externally | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer.internal = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| not FP8GlobalStateManager.with_fp8_parameters() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| and not getattr(self, "_with_quantized_weight", False) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight = getattr(self, "weight", None) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if weight_quantizer is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Determine if quantized weight is exposed as parameter | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer.internal = not ( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| FP8GlobalStateManager.with_fp8_parameters() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| or getattr(self, "_with_quantized_weight", False) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| or is_quantized_tensor(weight) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Recipe-specific configuration | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -416,6 +408,18 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||
| grad_output_quantizer.with_amax_reduction = True | ||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Update quantizer in quantized weight tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if weight_quantizer is not None and is_quantized_tensor(weight): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if weight._quantizer is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Preserve existing usages in weight tensor. Even if a | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # usage is currently unnecessary, the weight tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # may be used elsewhere. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_quantizer.set_usage( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| rowwise=weight._quantizer.rowwise_usage, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| columnwise=weight._quantizer.columnwise_usage, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| weight.update_quantizer(weight_quantizer.copy()) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+411
to
+421
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you move this to its own section after the other quantizers? Logically, it should follow the rest of the setup for the weight quantizer.
Suggested change
(the assumption here is that the rowwise would always be there for the weights).
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, no. I see that the context of this function is the resetting of the recipe, so sure, we need to create the new quantizer and pass it to the tensor. This makes sense. I don't see the requantization happening in this case though - if we have weight as already quantized tensor then where is the code to actually apply this change in the _quantizer to the data held by it?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The big problem with the existing code is that we update the weight param quantizer, and then we do more weight quantizer configuration afterwards. We should make sure the weight quantizer is fully configured, and only then update the quantized param. The logic for updating the quantized param was also a little convoluted, so I attempted to make it a bit more clear.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think requantization really makes sense: #2929 (comment) I can see the argument that we should blindly preserve usages in the weight param quantizer. This logic may be trying too hard to clean up after a user doing something weird.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't agree with that (if I change the recipe to a different type then trying to run without this requantization would just fail since e.g. we would try to multiply the MXFP8 tensor with FP8 CS tensor) but I think this should be addressed in its own PR rather than here. |
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||
| def _functional_forward( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| input: torch.Tensor, # pylint: disable=redefined-builtin | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment was made incorrect in #1817.