diff --git a/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/README.md b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/README.md new file mode 100644 index 000000000..834902d3f --- /dev/null +++ b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/README.md @@ -0,0 +1,48 @@ +# Folding GroupNorm-via-InstanceNorm + +## Introduction + +PyTorch versions prior to 2.5 export `torch.nn.GroupNorm` as a sequence of +five operators when targeting ONNX opset versions below 18. The pattern is + +``` + x -> Reshape -> InstanceNormalization -> Reshape -> Mul -> Add -> y +``` + +where the `InstanceNormalization` runs with constant scale `1` and bias `0`, +and the trailing `Mul`/`Add` apply the learned per-channel `gamma`/`beta`. The +second `Reshape` reads its target shape from a `Shape` op fed by `x`. + +This example demonstrates how to detect that pattern and rewrite it as a +single native `GroupNormalization` node (opset 21+). Doing so avoids the +fused-norm code path inside TensorRT, which has been observed to drift from +ONNX Runtime when reduction extents are very large or when `num_groups` +equals `num_channels`. + +## Running The Example + +1. Generate a small model containing the legacy pattern. + + ```bash + python3 generate.py + ``` + +2. Fold the pattern into a `GroupNormalization` op. + + ```bash + python3 fold.py model.onnx folded.onnx + ``` + +3. Inspect the resulting graph in [Netron](https://netron.app) to confirm the + five-op subgraph collapsed to a single `GroupNormalization`. + +## How It Works + +`fold.py` walks every `InstanceNormalization` node in the graph and verifies +that its surrounding nodes match the legacy template. When the match is good +it pulls `num_groups` from the upstream `Reshape` constant, lifts the +per-channel weights out of the trailing `Mul`/`Add`, flattens them to 1D, and +wires a single `GroupNormalization` node that consumes the original input +tensor and produces the original output tensor. `graph.cleanup()` then +removes the now-orphaned reshapes, the `Shape` op, and the dangling +constants. diff --git a/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py new file mode 100644 index 000000000..34f102ab7 --- /dev/null +++ b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +import numpy as np +import onnx +import onnx_graphsurgeon as gs + + +def to_1d(arr): + return np.asarray(arr).reshape(-1) + + +def make_legacy_groupnorm_pattern(): + pat = gs.GraphPattern() + + x = pat.variable() + target_shape = pat.variable() + inst_scale = pat.variable() + inst_bias = pat.variable() + gamma = pat.variable() + beta = pat.variable() + + reshape_in = pat.add("pre", "Reshape", inputs=[x, target_shape]) + inst_out = pat.add("inst", "InstanceNormalization", inputs=[reshape_in, inst_scale, inst_bias]) + shape_out = pat.add("shape", "Shape", inputs=[x]) + reshape_back = pat.add("post", "Reshape", inputs=[inst_out, shape_out]) + mul_out = pat.add("mul", "Mul", inputs=[reshape_back, gamma]) + add_out = pat.add("add", "Add", inputs=[mul_out, beta]) + + pat.set_output_tensors([add_out]) + return pat + + +def fold_groupnorm(graph): + pattern = make_legacy_groupnorm_pattern() + matches = pattern.match_all(graph) + + folded = 0 + for match in matches: + pre = match.get("pre") + inst = match.get("inst") + mul = match.get("mul") + addn = match.get("add") + + target = pre.inputs[1] + if not isinstance(target, gs.Constant): + continue + target_shape = target.values.tolist() + if len(target_shape) < 2 or int(target_shape[1]) <= 0: + continue + num_groups = int(target_shape[1]) + + gamma_t = mul.inputs[1] + beta_t = addn.inputs[1] + if not isinstance(gamma_t, gs.Constant) or not isinstance(beta_t, gs.Constant): + continue + gamma = to_1d(gamma_t.values).astype(np.float32) + beta = to_1d(beta_t.values).astype(np.float32) + if gamma.shape != beta.shape: + continue + + epsilon = inst.attrs.get("epsilon", 1e-5) + + x_tensor = pre.inputs[0] + gn_out = addn.outputs[0] + gn_out.inputs.clear() + + graph.nodes.append( + gs.Node( + op="GroupNormalization", + name=inst.name + "_folded", + attrs={"num_groups": num_groups, "epsilon": float(epsilon)}, + inputs=[ + x_tensor, + gs.Constant(name=inst.name + "_gn_scale", values=gamma), + gs.Constant(name=inst.name + "_gn_bias", values=beta), + ], + outputs=[gn_out], + ) + ) + folded += 1 + + if folded: + graph.cleanup().toposort() + return folded + + +def bump_opset(model, min_version=21): + for op in model.opset_import: + if op.domain in ("", "ai.onnx") and op.version < min_version: + op.version = min_version + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("input", nargs="?", default="model.onnx") + ap.add_argument("output", nargs="?", default="folded.onnx") + args = ap.parse_args() + + graph = gs.import_onnx(onnx.load(args.input)) + n = fold_groupnorm(graph) + print(f"folded {n} pattern(s)") + + model = gs.export_onnx(graph) + bump_opset(model, 21) + onnx.save(model, args.output) + print(f"wrote {args.output}") + + +if __name__ == "__main__": + main() diff --git a/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/generate.py b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/generate.py new file mode 100644 index 000000000..242da4cd2 --- /dev/null +++ b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/generate.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import onnx +import onnx_graphsurgeon as gs + + +N, C, H, W = 1, 8, 4, 4 +NUM_GROUPS = 4 + +graph = gs.Graph(ir_version=10, opset=17) + +x = gs.Variable("input", dtype=np.float32, shape=(N, C, H, W)) +graph.inputs = [x] + +target_shape = gs.Constant("reshape_target", values=np.array([N, NUM_GROUPS, -1], dtype=np.int64)) +reshape_in = gs.Variable("reshape_in_out", dtype=np.float32) +graph.nodes.append(gs.Node(op="Reshape", inputs=[x, target_shape], outputs=[reshape_in])) + +inst_scale = gs.Constant("inst_scale", values=np.ones((NUM_GROUPS,), dtype=np.float32)) +inst_bias = gs.Constant("inst_bias", values=np.zeros((NUM_GROUPS,), dtype=np.float32)) +inst_out = gs.Variable("inst_out", dtype=np.float32) +graph.nodes.append( + gs.Node( + op="InstanceNormalization", + attrs={"epsilon": 1e-5}, + inputs=[reshape_in, inst_scale, inst_bias], + outputs=[inst_out], + ) +) + +shape_out = gs.Variable("shape_out", dtype=np.int64) +graph.nodes.append(gs.Node(op="Shape", inputs=[x], outputs=[shape_out])) + +reshape_back_out = gs.Variable("reshape_back_out", dtype=np.float32) +graph.nodes.append(gs.Node(op="Reshape", inputs=[inst_out, shape_out], outputs=[reshape_back_out])) + +gamma = gs.Constant("gamma", values=np.random.rand(C).astype(np.float32).reshape(1, C, 1, 1)) +beta = gs.Constant("beta", values=np.random.rand(C).astype(np.float32).reshape(1, C, 1, 1)) +mul_out = gs.Variable("mul_out", dtype=np.float32) +add_out = gs.Variable("output", dtype=np.float32, shape=(N, C, H, W)) +graph.nodes.append(gs.Node(op="Mul", inputs=[reshape_back_out, gamma], outputs=[mul_out])) +graph.nodes.append(gs.Node(op="Add", inputs=[mul_out, beta], outputs=[add_out])) + +graph.outputs = [add_out] +onnx.save(gs.export_onnx(graph), "model.onnx")