Fix ConvertMmToBmmPass for quantized (int8/int16) mm ops (#18974)#18974
Fix ConvertMmToBmmPass for quantized (int8/int16) mm ops (#18974)#18974apullin wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18974
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 Awaiting Approval, 1 New Failure, 2 Cancelled Jobs, 3 Unrelated FailuresAs of commit a28c6cc with merge base bf64fa1 ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
digantdesai
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
This PR needs a
|
316e474 to
7802809
Compare
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
b4a1625 to
5439a12
Compare
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
|
|
Summary: This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD. The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace the graph on FakeTensors for shape propagation, but aten.bmm rejects int8/int16 FakeTensors, causing failures for any quantized mm ops. Since mm→bmm is a pure shape transformation (adding a batch dim of 1), we can set the output metadata directly: unsqueeze the mm's FakeTensor for the bmm node, and use the original for the squeeze. No need to re-execute the op. Reviewed By: digantdesai Differential Revision: D99857137
5439a12 to
a28c6cc
Compare
Summary:
This diff is experimental, but appears to address incomplete support for INT pathways for BMM. TBD.
The pass converts rank-2 mm to rank-3 bmm (required by TOSA spec) via
unsqueeze/bmm/squeeze. Previously it called super().call() to re-trace
the graph on FakeTensors for shape propagation, but aten.bmm rejects
int8/int16 FakeTensors, causing failures for any quantized mm ops.
Since mm→bmm is a pure shape transformation (adding a batch dim of 1),
we can set the output metadata directly: unsqueeze the mm's FakeTensor
for the bmm node, and use the original for the squeeze. No need to
re-execute the op.
Reviewed By: digantdesai
Differential Revision: D99857137