From 7bf5d4d6a45ce8f56312f8b5d96b0b84815c4f10 Mon Sep 17 00:00:00 2001 From: Samaresh Kumar Singh Date: Fri, 8 May 2026 12:18:58 -0500 Subject: [PATCH 1/2] Add onnx-graphsurgeon example for folding legacy GroupNorm pattern PyTorch versions before 2.5 export nn.GroupNorm at opset under 18 as a Reshape, InstanceNormalization, Reshape, Mul, Add chain. The new example walks through detecting that pattern and rewriting it as a single native GroupNormalization node so users can sidestep the fused-norm path that has shown accuracy drift on large reductions. Verified end to end against ONNX Runtime CPU on both a synthesised toy model and a real SegVit export, with max abs diff 1e-5 in FP32. Signed-off-by: Samaresh Kumar Singh --- .../examples/13_folding_groupnorm/README.md | 48 +++++++ .../examples/13_folding_groupnorm/fold.py | 117 ++++++++++++++++++ .../examples/13_folding_groupnorm/generate.py | 62 ++++++++++ 3 files changed, 227 insertions(+) create mode 100644 tools/onnx-graphsurgeon/examples/13_folding_groupnorm/README.md create mode 100644 tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py create mode 100644 tools/onnx-graphsurgeon/examples/13_folding_groupnorm/generate.py diff --git a/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/README.md b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/README.md new file mode 100644 index 000000000..834902d3f --- /dev/null +++ b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/README.md @@ -0,0 +1,48 @@ +# Folding GroupNorm-via-InstanceNorm + +## Introduction + +PyTorch versions prior to 2.5 export `torch.nn.GroupNorm` as a sequence of +five operators when targeting ONNX opset versions below 18. The pattern is + +``` + x -> Reshape -> InstanceNormalization -> Reshape -> Mul -> Add -> y +``` + +where the `InstanceNormalization` runs with constant scale `1` and bias `0`, +and the trailing `Mul`/`Add` apply the learned per-channel `gamma`/`beta`. The +second `Reshape` reads its target shape from a `Shape` op fed by `x`. + +This example demonstrates how to detect that pattern and rewrite it as a +single native `GroupNormalization` node (opset 21+). Doing so avoids the +fused-norm code path inside TensorRT, which has been observed to drift from +ONNX Runtime when reduction extents are very large or when `num_groups` +equals `num_channels`. + +## Running The Example + +1. Generate a small model containing the legacy pattern. + + ```bash + python3 generate.py + ``` + +2. Fold the pattern into a `GroupNormalization` op. + + ```bash + python3 fold.py model.onnx folded.onnx + ``` + +3. Inspect the resulting graph in [Netron](https://netron.app) to confirm the + five-op subgraph collapsed to a single `GroupNormalization`. + +## How It Works + +`fold.py` walks every `InstanceNormalization` node in the graph and verifies +that its surrounding nodes match the legacy template. When the match is good +it pulls `num_groups` from the upstream `Reshape` constant, lifts the +per-channel weights out of the trailing `Mul`/`Add`, flattens them to 1D, and +wires a single `GroupNormalization` node that consumes the original input +tensor and produces the original output tensor. `graph.cleanup()` then +removes the now-orphaned reshapes, the `Shape` op, and the dangling +constants. diff --git a/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py new file mode 100644 index 000000000..aeeda70f9 --- /dev/null +++ b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +import numpy as np +import onnx +import onnx_graphsurgeon as gs + + +def to_1d(arr): + return np.asarray(arr).reshape(-1) + + +def fold_groupnorm(graph): + folded = 0 + + for inst in [n for n in graph.nodes if n.op == "InstanceNormalization"]: + pre_input = inst.inputs[0] + if not pre_input.inputs or pre_input.inputs[0].op != "Reshape": + continue + pre = pre_input.inputs[0] + + post_consumers = inst.outputs[0].outputs + if len(post_consumers) != 1 or post_consumers[0].op != "Reshape": + continue + post = post_consumers[0] + + if not isinstance(pre.inputs[1], gs.Constant): + continue + target_shape = pre.inputs[1].values.tolist() + if len(target_shape) < 2 or int(target_shape[1]) <= 0: + continue + num_groups = int(target_shape[1]) + + mul_consumers = post.outputs[0].outputs + if len(mul_consumers) != 1 or mul_consumers[0].op != "Mul": + continue + mul = mul_consumers[0] + + add_consumers = mul.outputs[0].outputs + if len(add_consumers) != 1 or add_consumers[0].op != "Add": + continue + addn = add_consumers[0] + + gamma_t = mul.inputs[0] if mul.inputs[1] is post.outputs[0] else mul.inputs[1] + beta_t = addn.inputs[0] if addn.inputs[1] is mul.outputs[0] else addn.inputs[1] + if not isinstance(gamma_t, gs.Constant) or not isinstance(beta_t, gs.Constant): + continue + + gamma = to_1d(gamma_t.values).astype(np.float32) + beta = to_1d(beta_t.values).astype(np.float32) + if gamma.shape != beta.shape: + continue + + epsilon = inst.attrs.get("epsilon", 1e-5) + + gn_scale = gs.Constant(name=inst.name + "_gn_scale", values=gamma) + gn_bias = gs.Constant(name=inst.name + "_gn_bias", values=beta) + x_tensor = pre.inputs[0] + gn_out = addn.outputs[0] + + gn_out.inputs.clear() + gn_node = gs.Node( + op="GroupNormalization", + name=inst.name + "_folded", + attrs={"num_groups": num_groups, "epsilon": float(epsilon)}, + inputs=[x_tensor, gn_scale, gn_bias], + outputs=[gn_out], + ) + graph.nodes.append(gn_node) + folded += 1 + + if folded: + graph.cleanup().toposort() + return folded + + +def bump_opset(model, min_version=21): + for op in model.opset_import: + if op.domain in ("", "ai.onnx") and op.version < min_version: + op.version = min_version + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("input", nargs="?", default="model.onnx") + ap.add_argument("output", nargs="?", default="folded.onnx") + args = ap.parse_args() + + graph = gs.import_onnx(onnx.load(args.input)) + n = fold_groupnorm(graph) + print(f"folded {n} pattern(s)") + + model = gs.export_onnx(graph) + bump_opset(model, 21) + onnx.save(model, args.output) + print(f"wrote {args.output}") + + +if __name__ == "__main__": + main() diff --git a/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/generate.py b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/generate.py new file mode 100644 index 000000000..242da4cd2 --- /dev/null +++ b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/generate.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import onnx +import onnx_graphsurgeon as gs + + +N, C, H, W = 1, 8, 4, 4 +NUM_GROUPS = 4 + +graph = gs.Graph(ir_version=10, opset=17) + +x = gs.Variable("input", dtype=np.float32, shape=(N, C, H, W)) +graph.inputs = [x] + +target_shape = gs.Constant("reshape_target", values=np.array([N, NUM_GROUPS, -1], dtype=np.int64)) +reshape_in = gs.Variable("reshape_in_out", dtype=np.float32) +graph.nodes.append(gs.Node(op="Reshape", inputs=[x, target_shape], outputs=[reshape_in])) + +inst_scale = gs.Constant("inst_scale", values=np.ones((NUM_GROUPS,), dtype=np.float32)) +inst_bias = gs.Constant("inst_bias", values=np.zeros((NUM_GROUPS,), dtype=np.float32)) +inst_out = gs.Variable("inst_out", dtype=np.float32) +graph.nodes.append( + gs.Node( + op="InstanceNormalization", + attrs={"epsilon": 1e-5}, + inputs=[reshape_in, inst_scale, inst_bias], + outputs=[inst_out], + ) +) + +shape_out = gs.Variable("shape_out", dtype=np.int64) +graph.nodes.append(gs.Node(op="Shape", inputs=[x], outputs=[shape_out])) + +reshape_back_out = gs.Variable("reshape_back_out", dtype=np.float32) +graph.nodes.append(gs.Node(op="Reshape", inputs=[inst_out, shape_out], outputs=[reshape_back_out])) + +gamma = gs.Constant("gamma", values=np.random.rand(C).astype(np.float32).reshape(1, C, 1, 1)) +beta = gs.Constant("beta", values=np.random.rand(C).astype(np.float32).reshape(1, C, 1, 1)) +mul_out = gs.Variable("mul_out", dtype=np.float32) +add_out = gs.Variable("output", dtype=np.float32, shape=(N, C, H, W)) +graph.nodes.append(gs.Node(op="Mul", inputs=[reshape_back_out, gamma], outputs=[mul_out])) +graph.nodes.append(gs.Node(op="Add", inputs=[mul_out, beta], outputs=[add_out])) + +graph.outputs = [add_out] +onnx.save(gs.export_onnx(graph), "model.onnx") From 4450489a4baf17357bc518b12f87d92d8eaffd23 Mon Sep 17 00:00:00 2001 From: Samaresh Kumar Singh Date: Fri, 8 May 2026 13:04:15 -0500 Subject: [PATCH 2/2] Use GraphPattern.match_all for cleaner subgraph detection Refactors fold.py to build the legacy GroupNorm template once as a GraphPattern and rely on match_all for discovery, replacing the manual node-by-node walk. Signed-off-by: Samaresh Kumar Singh --- .../examples/13_folding_groupnorm/fold.py | 83 +++++++++++-------- 1 file changed, 47 insertions(+), 36 deletions(-) diff --git a/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py index aeeda70f9..34f102ab7 100644 --- a/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py +++ b/tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py @@ -27,42 +27,50 @@ def to_1d(arr): return np.asarray(arr).reshape(-1) -def fold_groupnorm(graph): - folded = 0 +def make_legacy_groupnorm_pattern(): + pat = gs.GraphPattern() - for inst in [n for n in graph.nodes if n.op == "InstanceNormalization"]: - pre_input = inst.inputs[0] - if not pre_input.inputs or pre_input.inputs[0].op != "Reshape": - continue - pre = pre_input.inputs[0] + x = pat.variable() + target_shape = pat.variable() + inst_scale = pat.variable() + inst_bias = pat.variable() + gamma = pat.variable() + beta = pat.variable() - post_consumers = inst.outputs[0].outputs - if len(post_consumers) != 1 or post_consumers[0].op != "Reshape": - continue - post = post_consumers[0] + reshape_in = pat.add("pre", "Reshape", inputs=[x, target_shape]) + inst_out = pat.add("inst", "InstanceNormalization", inputs=[reshape_in, inst_scale, inst_bias]) + shape_out = pat.add("shape", "Shape", inputs=[x]) + reshape_back = pat.add("post", "Reshape", inputs=[inst_out, shape_out]) + mul_out = pat.add("mul", "Mul", inputs=[reshape_back, gamma]) + add_out = pat.add("add", "Add", inputs=[mul_out, beta]) + + pat.set_output_tensors([add_out]) + return pat + + +def fold_groupnorm(graph): + pattern = make_legacy_groupnorm_pattern() + matches = pattern.match_all(graph) - if not isinstance(pre.inputs[1], gs.Constant): + folded = 0 + for match in matches: + pre = match.get("pre") + inst = match.get("inst") + mul = match.get("mul") + addn = match.get("add") + + target = pre.inputs[1] + if not isinstance(target, gs.Constant): continue - target_shape = pre.inputs[1].values.tolist() + target_shape = target.values.tolist() if len(target_shape) < 2 or int(target_shape[1]) <= 0: continue num_groups = int(target_shape[1]) - mul_consumers = post.outputs[0].outputs - if len(mul_consumers) != 1 or mul_consumers[0].op != "Mul": - continue - mul = mul_consumers[0] - - add_consumers = mul.outputs[0].outputs - if len(add_consumers) != 1 or add_consumers[0].op != "Add": - continue - addn = add_consumers[0] - - gamma_t = mul.inputs[0] if mul.inputs[1] is post.outputs[0] else mul.inputs[1] - beta_t = addn.inputs[0] if addn.inputs[1] is mul.outputs[0] else addn.inputs[1] + gamma_t = mul.inputs[1] + beta_t = addn.inputs[1] if not isinstance(gamma_t, gs.Constant) or not isinstance(beta_t, gs.Constant): continue - gamma = to_1d(gamma_t.values).astype(np.float32) beta = to_1d(beta_t.values).astype(np.float32) if gamma.shape != beta.shape: @@ -70,20 +78,23 @@ def fold_groupnorm(graph): epsilon = inst.attrs.get("epsilon", 1e-5) - gn_scale = gs.Constant(name=inst.name + "_gn_scale", values=gamma) - gn_bias = gs.Constant(name=inst.name + "_gn_bias", values=beta) x_tensor = pre.inputs[0] gn_out = addn.outputs[0] - gn_out.inputs.clear() - gn_node = gs.Node( - op="GroupNormalization", - name=inst.name + "_folded", - attrs={"num_groups": num_groups, "epsilon": float(epsilon)}, - inputs=[x_tensor, gn_scale, gn_bias], - outputs=[gn_out], + + graph.nodes.append( + gs.Node( + op="GroupNormalization", + name=inst.name + "_folded", + attrs={"num_groups": num_groups, "epsilon": float(epsilon)}, + inputs=[ + x_tensor, + gs.Constant(name=inst.name + "_gn_scale", values=gamma), + gs.Constant(name=inst.name + "_gn_bias", values=beta), + ], + outputs=[gn_out], + ) ) - graph.nodes.append(gn_node) folded += 1 if folded: