Skip to content

[mlir][math] Add constant folding for math.fpowi#193761

Open
CoTinker wants to merge 2 commits intollvm:mainfrom
CoTinker:fpowi
Open

[mlir][math] Add constant folding for math.fpowi#193761
CoTinker wants to merge 2 commits intollvm:mainfrom
CoTinker:fpowi

Conversation

@CoTinker
Copy link
Copy Markdown
Contributor

Adds a constant folder for math.fpowi when both operands are constant and the integer exponent is exactly representable in the floating-point type of the base.

Adds a constant folder for math.fpowi when both operands are constant and the integer exponent is exactly representable in the floating-point type of the base.
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 23, 2026

@llvm/pr-subscribers-mlir-math

Author: Longsheng Mou (CoTinker)

Changes

Adds a constant folder for math.fpowi when both operands are constant and the integer exponent is exactly representable in the floating-point type of the base.


Full diff: https://github.com/llvm/llvm-project/pull/193761.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/CommonFolders.h (+52-42)
  • (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+2-4)
  • (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+28)
  • (modified) mlir/test/Dialect/Math/canonicalize.mlir (+33)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+2-1)
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 113765157946d..736b16ed25d44 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -36,13 +36,14 @@ class PoisonAttr;
 /// Uses `resultType` for the type of the returned attribute.
 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
 /// which will be directly propagated to result.
-template <class AttrElementT, //
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = ub::PoisonAttr,
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
-          class CalculationT = function_ref<
-              std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+          class CalculationT = function_ref<std::optional<ResultElementValueT>(
+              LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
                                        Type resultType,
                                        CalculationT &&calculate) {
@@ -62,11 +63,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
   if (!resultType || !operands[0] || !operands[1])
     return {};
 
-  if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
-    auto lhs = cast<AttrElementT>(operands[0]);
-    auto rhs = cast<AttrElementT>(operands[1]);
-    if (lhs.getType() != rhs.getType())
-      return {};
+  if (isa<LAttrElementT>(operands[0]) && isa<RAttrElementT>(operands[1])) {
+    auto lhs = cast<LAttrElementT>(operands[0]);
+    auto rhs = cast<RAttrElementT>(operands[1]);
+    if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+      if (lhs.getType() != rhs.getType())
+        return {};
 
     auto calRes = calculate(lhs.getValue(), rhs.getValue());
 
@@ -82,11 +84,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     // just fold based on the splat value.
     auto lhs = cast<SplatElementsAttr>(operands[0]);
     auto rhs = cast<SplatElementsAttr>(operands[1]);
-    if (lhs.getType() != rhs.getType())
-      return {};
+    if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+      if (lhs.getType() != rhs.getType())
+        return {};
 
-    auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
-                                   rhs.getSplatValue<ElementValueT>());
+    auto elementResult = calculate(lhs.getSplatValue<LElementValueT>(),
+                                   rhs.getSplatValue<RElementValueT>());
     if (!elementResult)
       return {};
 
@@ -98,11 +101,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     // expanding the values.
     auto lhs = cast<ElementsAttr>(operands[0]);
     auto rhs = cast<ElementsAttr>(operands[1]);
-    if (lhs.getType() != rhs.getType())
-      return {};
+    if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+      if (lhs.getType() != rhs.getType())
+        return {};
 
-    auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
-    auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
+    auto maybeLhsIt = lhs.try_value_begin<LElementValueT>();
+    auto maybeRhsIt = rhs.try_value_begin<RElementValueT>();
     if (!maybeLhsIt || !maybeRhsIt)
       return {};
     auto lhsIt = *maybeLhsIt;
@@ -127,13 +131,14 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
 /// attribute.
 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
 /// which will be directly propagated to result.
-template <class AttrElementT, //
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = ub::PoisonAttr,
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
-          class CalculationT = function_ref<
-              std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+          class CalculationT = function_ref<std::optional<ResultElementValueT>(
+              LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
                                        CalculationT &&calculate) {
   assert(operands.size() == 2 && "binary op takes two operands");
@@ -159,44 +164,49 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
   Type rhsType = getAttrType(operands[1]);
   if (!lhsType || !rhsType)
     return {};
-  if (lhsType != rhsType)
-    return {};
+  if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+    if (lhsType != rhsType)
+      return {};
 
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
-                                      ResultAttrElementT, ResultElementValueT,
-                                      CalculationT>(
+  return constFoldBinaryOpConditional<
+      LAttrElementT, RAttrElementT, LElementValueT, RElementValueT, PoisonAttr,
+      ResultAttrElementT, ResultElementValueT, CalculationT>(
       operands, lhsType, std::forward<CalculationT>(calculate));
 }
 
-template <class AttrElementT,
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = void, //
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
           class CalculationT =
-              function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+              function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
                             CalculationT &&calculate) {
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
-                                      ResultAttrElementT>(
+  return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+                                      LElementValueT, RElementValueT,
+                                      PoisonAttr, ResultAttrElementT>(
       operands, resultType,
-      [&](ElementValueT a, ElementValueT b)
+      [&](LElementValueT a, RElementValueT b)
           -> std::optional<ResultElementValueT> { return calculate(a, b); });
 }
 
-template <class AttrElementT, //
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = ub::PoisonAttr,
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
           class CalculationT =
-              function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+              function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
                             CalculationT &&calculate) {
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
-                                      ResultAttrElementT>(
+  return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+                                      LElementValueT, RElementValueT,
+                                      PoisonAttr, ResultAttrElementT>(
       operands,
-      [&](ElementValueT a, ElementValueT b)
+      [&](LElementValueT a, RElementValueT b)
           -> std::optional<ResultElementValueT> { return calculate(a, b); });
 }
 
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1265bfb18aaa2..90f3f121a16d9 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -1148,7 +1148,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
     The operation is elementwise for non-scalars, e.g.:
 
     ```mlir
-    %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32
+    %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32>
     ```
 
     The result is a vector of:
@@ -1172,9 +1172,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
   let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
                           attr-dict `:` type($lhs) `,` type($rhs) }];
 
-  // TODO: add a constant folder using pow[f] for cases, when
-  //       the power argument is exactly representable in floating
-  //       point type of the base.
+  let hasFolder = 1;
 }
 
 #endif // MATH_OPS
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bb552bd253b5f 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -776,6 +776,34 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// FPowIOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
+  return constFoldBinaryOpConditional<FloatAttr, IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APFloat &base, const APInt &exp) -> std::optional<APFloat> {
+        const llvm::fltSemantics &sem = base.getSemantics();
+        // Fold when the exponent is exactly representable in the
+        // floating-point type of the base.
+        APFloat fExp(sem);
+        if (fExp.convertFromAPInt(exp, /*isSigned=*/true,
+                                  APFloat::rmNearestTiesToEven) !=
+            APFloat::opOK)
+          return {};
+
+        switch (APFloat::getSizeInBits(sem)) {
+        case 64:
+          return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
+        case 32:
+          return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
+        default:
+          return {};
+        }
+      });
+}
+
 /// Materialize an integer or floating point constant.
 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 67235c38e9cdf..228faa31781c4 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -614,3 +614,36 @@ func.func @ipowi_i1_const_neg_exp() -> i1 {
   %r = math.ipowi %b, %e : i1
   return %r : i1
 }
+
+// CHECK-LABEL: @fpowi_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f64
+// CHECK: %[[cst0:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]], %[[cst0]] : f64, f32
+func.func @fpowi_fold() -> (f64, f32) {
+  %cst = arith.constant 2.000000e+00 : f64
+  %cst_0 = arith.constant 2.000000e+00 : f32
+  %c2_i32 = arith.constant 2 : i32
+  %0 = math.fpowi %cst, %c2_i32 : f64, i32
+  %1 = math.fpowi %cst_0, %c2_i32 : f32, i32
+  return %0, %1 : f64, f32
+}
+
+// CHECK-LABEL: @fpowi_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 1.600000e+01, 9.000000e+00, 1.600000e+01]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @fpowi_fold_vec() -> vector<4xf32> {
+  %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %cst_0 = arith.constant dense<[2, 4, 2, 2]> : vector<4xi32>
+  %0 = math.fpowi %cst, %cst_0 : vector<4xf32>, vector<4xi32>
+  return %0 : vector<4xf32>
+}
+
+// 16777217 is not exactly representable in f32.
+// CHECK-LABEL: @fpowi_fold_failed
+// CHECK:       math.fpowi
+func.func @fpowi_fold_failed() -> f32 {
+  %cst = arith.constant 2.000000e+00 : f32
+  %c16777217_i32 = arith.constant 16777217 : i32
+  %0 = math.fpowi %cst, %c16777217_i32 : f32, i32
+  return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c8be4bf3f0f8d..55e72b57cfd1b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -265,7 +265,8 @@ struct FoldLessThanOpF32ToI1 : public OpRewritePattern<test::LessThanOp> {
 
     Attribute operandAttrs[2] = {lhsAttr, rhsAttr};
     TypedAttr res = cast_or_null<TypedAttr>(
-        constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
+        constFoldBinaryOp<FloatAttr, FloatAttr, FloatAttr::ValueType,
+                          FloatAttr::ValueType, void, IntegerAttr>(
             operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt {
               return APInt(1, lhs < rhs);
             }));

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 23, 2026

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

Adds a constant folder for math.fpowi when both operands are constant and the integer exponent is exactly representable in the floating-point type of the base.


Full diff: https://github.com/llvm/llvm-project/pull/193761.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/CommonFolders.h (+52-42)
  • (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+2-4)
  • (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+28)
  • (modified) mlir/test/Dialect/Math/canonicalize.mlir (+33)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+2-1)
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 113765157946d..736b16ed25d44 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -36,13 +36,14 @@ class PoisonAttr;
 /// Uses `resultType` for the type of the returned attribute.
 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
 /// which will be directly propagated to result.
-template <class AttrElementT, //
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = ub::PoisonAttr,
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
-          class CalculationT = function_ref<
-              std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+          class CalculationT = function_ref<std::optional<ResultElementValueT>(
+              LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
                                        Type resultType,
                                        CalculationT &&calculate) {
@@ -62,11 +63,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
   if (!resultType || !operands[0] || !operands[1])
     return {};
 
-  if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
-    auto lhs = cast<AttrElementT>(operands[0]);
-    auto rhs = cast<AttrElementT>(operands[1]);
-    if (lhs.getType() != rhs.getType())
-      return {};
+  if (isa<LAttrElementT>(operands[0]) && isa<RAttrElementT>(operands[1])) {
+    auto lhs = cast<LAttrElementT>(operands[0]);
+    auto rhs = cast<RAttrElementT>(operands[1]);
+    if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+      if (lhs.getType() != rhs.getType())
+        return {};
 
     auto calRes = calculate(lhs.getValue(), rhs.getValue());
 
@@ -82,11 +84,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     // just fold based on the splat value.
     auto lhs = cast<SplatElementsAttr>(operands[0]);
     auto rhs = cast<SplatElementsAttr>(operands[1]);
-    if (lhs.getType() != rhs.getType())
-      return {};
+    if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+      if (lhs.getType() != rhs.getType())
+        return {};
 
-    auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
-                                   rhs.getSplatValue<ElementValueT>());
+    auto elementResult = calculate(lhs.getSplatValue<LElementValueT>(),
+                                   rhs.getSplatValue<RElementValueT>());
     if (!elementResult)
       return {};
 
@@ -98,11 +101,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     // expanding the values.
     auto lhs = cast<ElementsAttr>(operands[0]);
     auto rhs = cast<ElementsAttr>(operands[1]);
-    if (lhs.getType() != rhs.getType())
-      return {};
+    if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+      if (lhs.getType() != rhs.getType())
+        return {};
 
-    auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
-    auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
+    auto maybeLhsIt = lhs.try_value_begin<LElementValueT>();
+    auto maybeRhsIt = rhs.try_value_begin<RElementValueT>();
     if (!maybeLhsIt || !maybeRhsIt)
       return {};
     auto lhsIt = *maybeLhsIt;
@@ -127,13 +131,14 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
 /// attribute.
 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
 /// which will be directly propagated to result.
-template <class AttrElementT, //
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = ub::PoisonAttr,
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
-          class CalculationT = function_ref<
-              std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+          class CalculationT = function_ref<std::optional<ResultElementValueT>(
+              LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
                                        CalculationT &&calculate) {
   assert(operands.size() == 2 && "binary op takes two operands");
@@ -159,44 +164,49 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
   Type rhsType = getAttrType(operands[1]);
   if (!lhsType || !rhsType)
     return {};
-  if (lhsType != rhsType)
-    return {};
+  if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+    if (lhsType != rhsType)
+      return {};
 
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
-                                      ResultAttrElementT, ResultElementValueT,
-                                      CalculationT>(
+  return constFoldBinaryOpConditional<
+      LAttrElementT, RAttrElementT, LElementValueT, RElementValueT, PoisonAttr,
+      ResultAttrElementT, ResultElementValueT, CalculationT>(
       operands, lhsType, std::forward<CalculationT>(calculate));
 }
 
-template <class AttrElementT,
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = void, //
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
           class CalculationT =
-              function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+              function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
                             CalculationT &&calculate) {
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
-                                      ResultAttrElementT>(
+  return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+                                      LElementValueT, RElementValueT,
+                                      PoisonAttr, ResultAttrElementT>(
       operands, resultType,
-      [&](ElementValueT a, ElementValueT b)
+      [&](LElementValueT a, RElementValueT b)
           -> std::optional<ResultElementValueT> { return calculate(a, b); });
 }
 
-template <class AttrElementT, //
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = ub::PoisonAttr,
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
           class CalculationT =
-              function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+              function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
                             CalculationT &&calculate) {
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
-                                      ResultAttrElementT>(
+  return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+                                      LElementValueT, RElementValueT,
+                                      PoisonAttr, ResultAttrElementT>(
       operands,
-      [&](ElementValueT a, ElementValueT b)
+      [&](LElementValueT a, RElementValueT b)
           -> std::optional<ResultElementValueT> { return calculate(a, b); });
 }
 
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1265bfb18aaa2..90f3f121a16d9 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -1148,7 +1148,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
     The operation is elementwise for non-scalars, e.g.:
 
     ```mlir
-    %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32
+    %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32>
     ```
 
     The result is a vector of:
@@ -1172,9 +1172,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
   let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
                           attr-dict `:` type($lhs) `,` type($rhs) }];
 
-  // TODO: add a constant folder using pow[f] for cases, when
-  //       the power argument is exactly representable in floating
-  //       point type of the base.
+  let hasFolder = 1;
 }
 
 #endif // MATH_OPS
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bb552bd253b5f 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -776,6 +776,34 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// FPowIOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
+  return constFoldBinaryOpConditional<FloatAttr, IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APFloat &base, const APInt &exp) -> std::optional<APFloat> {
+        const llvm::fltSemantics &sem = base.getSemantics();
+        // Fold when the exponent is exactly representable in the
+        // floating-point type of the base.
+        APFloat fExp(sem);
+        if (fExp.convertFromAPInt(exp, /*isSigned=*/true,
+                                  APFloat::rmNearestTiesToEven) !=
+            APFloat::opOK)
+          return {};
+
+        switch (APFloat::getSizeInBits(sem)) {
+        case 64:
+          return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
+        case 32:
+          return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
+        default:
+          return {};
+        }
+      });
+}
+
 /// Materialize an integer or floating point constant.
 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 67235c38e9cdf..228faa31781c4 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -614,3 +614,36 @@ func.func @ipowi_i1_const_neg_exp() -> i1 {
   %r = math.ipowi %b, %e : i1
   return %r : i1
 }
+
+// CHECK-LABEL: @fpowi_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f64
+// CHECK: %[[cst0:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]], %[[cst0]] : f64, f32
+func.func @fpowi_fold() -> (f64, f32) {
+  %cst = arith.constant 2.000000e+00 : f64
+  %cst_0 = arith.constant 2.000000e+00 : f32
+  %c2_i32 = arith.constant 2 : i32
+  %0 = math.fpowi %cst, %c2_i32 : f64, i32
+  %1 = math.fpowi %cst_0, %c2_i32 : f32, i32
+  return %0, %1 : f64, f32
+}
+
+// CHECK-LABEL: @fpowi_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 1.600000e+01, 9.000000e+00, 1.600000e+01]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @fpowi_fold_vec() -> vector<4xf32> {
+  %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %cst_0 = arith.constant dense<[2, 4, 2, 2]> : vector<4xi32>
+  %0 = math.fpowi %cst, %cst_0 : vector<4xf32>, vector<4xi32>
+  return %0 : vector<4xf32>
+}
+
+// 16777217 is not exactly representable in f32.
+// CHECK-LABEL: @fpowi_fold_failed
+// CHECK:       math.fpowi
+func.func @fpowi_fold_failed() -> f32 {
+  %cst = arith.constant 2.000000e+00 : f32
+  %c16777217_i32 = arith.constant 16777217 : i32
+  %0 = math.fpowi %cst, %c16777217_i32 : f32, i32
+  return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c8be4bf3f0f8d..55e72b57cfd1b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -265,7 +265,8 @@ struct FoldLessThanOpF32ToI1 : public OpRewritePattern<test::LessThanOp> {
 
     Attribute operandAttrs[2] = {lhsAttr, rhsAttr};
     TypedAttr res = cast_or_null<TypedAttr>(
-        constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
+        constFoldBinaryOp<FloatAttr, FloatAttr, FloatAttr::ValueType,
+                          FloatAttr::ValueType, void, IntegerAttr>(
             operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt {
               return APInt(1, lhs < rhs);
             }));

Comment thread mlir/lib/Dialect/Math/IR/MathOps.cpp Outdated
switch (APFloat::getSizeInBits(sem)) {
case 64:
return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
case 32:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would check for the exact type here. E.g., if (getLhs().getType().isF32()), instead of checking only the bitwidth.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referred to the code style in the Math folder, such as math::TruncOp::fold.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem here is that you could have multiple FP types with the same bitwidth. (Although, admittedly that's not the case for 32 and 64 bit today.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but the lambda function requires an APFloat type argument. Can we obtain a plain float type from APFloat?

class PoisonAttr = ub::PoisonAttr,
class ResultAttrElementT = AttrElementT,
class ResultAttrElementT = LAttrElementT,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you summarize the changes here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I support binary operations where the left and right operands have different attribute element types, for example, left is float and right is integer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants