diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py index 7d4e9d9b33..67ebdaa495 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -528,7 +528,11 @@ def test_reshape_reshape_dynamic_rule(self, input_shape, shape2, allowzero2=0): # Check inference. inputs = np.random.default_rng(7).random(input_shape, dtype="float32") - testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + # Use the reference implementation to avoid ORT incorrectly folding/rewriting + # the original two-reshape model (e.g. ignoring allowzero=1). + testing.assert_numerically_equal( + model, updated_model, (inputs,), atol=0, rtol=0, use_reference=True + ) @parameterized.parameterized.expand( [ diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4e1fcae128..4bf98f587d 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -838,7 +838,9 @@ def _where_input_wrangler( "logcumsumexp", core_ops.aten_logcumsumexp, tolerance={torch.float16: (1e-2, 1e-1)} ), TorchLibOpInfo("logdet", core_ops.aten_logdet), - TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp), + TorchLibOpInfo( + "logsumexp", core_ops.aten_logsumexp, tolerance={torch.float16: (2e-2, 1e-4)} + ), TorchLibOpInfo("lt", core_ops.aten_lt), TorchLibOpInfo("masked_fill", core_ops.aten_masked_fill).xfail( dtypes=(torch.bool,),