Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tools/onnx-graphsurgeon/examples/13_folding_groupnorm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Folding GroupNorm-via-InstanceNorm

## Introduction

PyTorch versions prior to 2.5 export `torch.nn.GroupNorm` as a sequence of
five operators when targeting ONNX opset versions below 18. The pattern is

```
x -> Reshape -> InstanceNormalization -> Reshape -> Mul -> Add -> y
```

where the `InstanceNormalization` runs with constant scale `1` and bias `0`,
and the trailing `Mul`/`Add` apply the learned per-channel `gamma`/`beta`. The
second `Reshape` reads its target shape from a `Shape` op fed by `x`.

This example demonstrates how to detect that pattern and rewrite it as a
single native `GroupNormalization` node (opset 21+). Doing so avoids the
fused-norm code path inside TensorRT, which has been observed to drift from
ONNX Runtime when reduction extents are very large or when `num_groups`
equals `num_channels`.

## Running The Example

1. Generate a small model containing the legacy pattern.

```bash
python3 generate.py
```

2. Fold the pattern into a `GroupNormalization` op.

```bash
python3 fold.py model.onnx folded.onnx
```

3. Inspect the resulting graph in [Netron](https://netron.app) to confirm the
five-op subgraph collapsed to a single `GroupNormalization`.

## How It Works

`fold.py` walks every `InstanceNormalization` node in the graph and verifies
that its surrounding nodes match the legacy template. When the match is good
it pulls `num_groups` from the upstream `Reshape` constant, lifts the
per-channel weights out of the trailing `Mul`/`Add`, flattens them to 1D, and
wires a single `GroupNormalization` node that consumes the original input
tensor and produces the original output tensor. `graph.cleanup()` then
removes the now-orphaned reshapes, the `Shape` op, and the dangling
constants.
128 changes: 128 additions & 0 deletions tools/onnx-graphsurgeon/examples/13_folding_groupnorm/fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse

import numpy as np
import onnx
import onnx_graphsurgeon as gs


def to_1d(arr):
return np.asarray(arr).reshape(-1)


def make_legacy_groupnorm_pattern():
pat = gs.GraphPattern()

x = pat.variable()
target_shape = pat.variable()
inst_scale = pat.variable()
inst_bias = pat.variable()
gamma = pat.variable()
beta = pat.variable()

reshape_in = pat.add("pre", "Reshape", inputs=[x, target_shape])
inst_out = pat.add("inst", "InstanceNormalization", inputs=[reshape_in, inst_scale, inst_bias])
shape_out = pat.add("shape", "Shape", inputs=[x])
reshape_back = pat.add("post", "Reshape", inputs=[inst_out, shape_out])
mul_out = pat.add("mul", "Mul", inputs=[reshape_back, gamma])
add_out = pat.add("add", "Add", inputs=[mul_out, beta])

pat.set_output_tensors([add_out])
return pat


def fold_groupnorm(graph):
pattern = make_legacy_groupnorm_pattern()
matches = pattern.match_all(graph)

folded = 0
for match in matches:
pre = match.get("pre")
inst = match.get("inst")
mul = match.get("mul")
addn = match.get("add")

target = pre.inputs[1]
if not isinstance(target, gs.Constant):
continue
target_shape = target.values.tolist()
if len(target_shape) < 2 or int(target_shape[1]) <= 0:
continue
num_groups = int(target_shape[1])

gamma_t = mul.inputs[1]
beta_t = addn.inputs[1]
if not isinstance(gamma_t, gs.Constant) or not isinstance(beta_t, gs.Constant):
continue
gamma = to_1d(gamma_t.values).astype(np.float32)
beta = to_1d(beta_t.values).astype(np.float32)
if gamma.shape != beta.shape:
continue

epsilon = inst.attrs.get("epsilon", 1e-5)

x_tensor = pre.inputs[0]
gn_out = addn.outputs[0]
gn_out.inputs.clear()

graph.nodes.append(
gs.Node(
op="GroupNormalization",
name=inst.name + "_folded",
attrs={"num_groups": num_groups, "epsilon": float(epsilon)},
inputs=[
x_tensor,
gs.Constant(name=inst.name + "_gn_scale", values=gamma),
gs.Constant(name=inst.name + "_gn_bias", values=beta),
],
outputs=[gn_out],
)
)
folded += 1

if folded:
graph.cleanup().toposort()
return folded


def bump_opset(model, min_version=21):
for op in model.opset_import:
if op.domain in ("", "ai.onnx") and op.version < min_version:
op.version = min_version


def main():
ap = argparse.ArgumentParser()
ap.add_argument("input", nargs="?", default="model.onnx")
ap.add_argument("output", nargs="?", default="folded.onnx")
args = ap.parse_args()

graph = gs.import_onnx(onnx.load(args.input))
n = fold_groupnorm(graph)
print(f"folded {n} pattern(s)")

model = gs.export_onnx(graph)
bump_opset(model, 21)
onnx.save(model, args.output)
print(f"wrote {args.output}")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import onnx
import onnx_graphsurgeon as gs


N, C, H, W = 1, 8, 4, 4
NUM_GROUPS = 4

graph = gs.Graph(ir_version=10, opset=17)

x = gs.Variable("input", dtype=np.float32, shape=(N, C, H, W))
graph.inputs = [x]

target_shape = gs.Constant("reshape_target", values=np.array([N, NUM_GROUPS, -1], dtype=np.int64))
reshape_in = gs.Variable("reshape_in_out", dtype=np.float32)
graph.nodes.append(gs.Node(op="Reshape", inputs=[x, target_shape], outputs=[reshape_in]))

inst_scale = gs.Constant("inst_scale", values=np.ones((NUM_GROUPS,), dtype=np.float32))
inst_bias = gs.Constant("inst_bias", values=np.zeros((NUM_GROUPS,), dtype=np.float32))
inst_out = gs.Variable("inst_out", dtype=np.float32)
graph.nodes.append(
gs.Node(
op="InstanceNormalization",
attrs={"epsilon": 1e-5},
inputs=[reshape_in, inst_scale, inst_bias],
outputs=[inst_out],
)
)

shape_out = gs.Variable("shape_out", dtype=np.int64)
graph.nodes.append(gs.Node(op="Shape", inputs=[x], outputs=[shape_out]))

reshape_back_out = gs.Variable("reshape_back_out", dtype=np.float32)
graph.nodes.append(gs.Node(op="Reshape", inputs=[inst_out, shape_out], outputs=[reshape_back_out]))

gamma = gs.Constant("gamma", values=np.random.rand(C).astype(np.float32).reshape(1, C, 1, 1))
beta = gs.Constant("beta", values=np.random.rand(C).astype(np.float32).reshape(1, C, 1, 1))
mul_out = gs.Variable("mul_out", dtype=np.float32)
add_out = gs.Variable("output", dtype=np.float32, shape=(N, C, H, W))
graph.nodes.append(gs.Node(op="Mul", inputs=[reshape_back_out, gamma], outputs=[mul_out]))
graph.nodes.append(gs.Node(op="Add", inputs=[mul_out, beta], outputs=[add_out]))

graph.outputs = [add_out]
onnx.save(gs.export_onnx(graph), "model.onnx")