[mlir][math] Add constant folding for math.fpowi#193761
[mlir][math] Add constant folding for math.fpowi#193761
math.fpowi#193761Conversation
Adds a constant folder for math.fpowi when both operands are constant and the integer exponent is exactly representable in the floating-point type of the base.
|
@llvm/pr-subscribers-mlir-math Author: Longsheng Mou (CoTinker) ChangesAdds a constant folder for Full diff: https://github.com/llvm/llvm-project/pull/193761.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 113765157946d..736b16ed25d44 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -36,13 +36,14 @@ class PoisonAttr;
/// Uses `resultType` for the type of the returned attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
- class CalculationT = function_ref<
- std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+ class CalculationT = function_ref<std::optional<ResultElementValueT>(
+ LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type resultType,
CalculationT &&calculate) {
@@ -62,11 +63,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!resultType || !operands[0] || !operands[1])
return {};
- if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
- auto lhs = cast<AttrElementT>(operands[0]);
- auto rhs = cast<AttrElementT>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if (isa<LAttrElementT>(operands[0]) && isa<RAttrElementT>(operands[1])) {
+ auto lhs = cast<LAttrElementT>(operands[0]);
+ auto rhs = cast<RAttrElementT>(operands[1]);
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
auto calRes = calculate(lhs.getValue(), rhs.getValue());
@@ -82,11 +84,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
// just fold based on the splat value.
auto lhs = cast<SplatElementsAttr>(operands[0]);
auto rhs = cast<SplatElementsAttr>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
- auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
- rhs.getSplatValue<ElementValueT>());
+ auto elementResult = calculate(lhs.getSplatValue<LElementValueT>(),
+ rhs.getSplatValue<RElementValueT>());
if (!elementResult)
return {};
@@ -98,11 +101,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
// expanding the values.
auto lhs = cast<ElementsAttr>(operands[0]);
auto rhs = cast<ElementsAttr>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
- auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
- auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
+ auto maybeLhsIt = lhs.try_value_begin<LElementValueT>();
+ auto maybeRhsIt = rhs.try_value_begin<RElementValueT>();
if (!maybeLhsIt || !maybeRhsIt)
return {};
auto lhsIt = *maybeLhsIt;
@@ -127,13 +131,14 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
/// attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
- class CalculationT = function_ref<
- std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+ class CalculationT = function_ref<std::optional<ResultElementValueT>(
+ LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
@@ -159,44 +164,49 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type rhsType = getAttrType(operands[1]);
if (!lhsType || !rhsType)
return {};
- if (lhsType != rhsType)
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhsType != rhsType)
+ return {};
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT, ResultElementValueT,
- CalculationT>(
+ return constFoldBinaryOpConditional<
+ LAttrElementT, RAttrElementT, LElementValueT, RElementValueT, PoisonAttr,
+ ResultAttrElementT, ResultElementValueT, CalculationT>(
operands, lhsType, std::forward<CalculationT>(calculate));
}
-template <class AttrElementT,
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = void, //
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT>(
+ return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+ LElementValueT, RElementValueT,
+ PoisonAttr, ResultAttrElementT>(
operands, resultType,
- [&](ElementValueT a, ElementValueT b)
+ [&](LElementValueT a, RElementValueT b)
-> std::optional<ResultElementValueT> { return calculate(a, b); });
}
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT>(
+ return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+ LElementValueT, RElementValueT,
+ PoisonAttr, ResultAttrElementT>(
operands,
- [&](ElementValueT a, ElementValueT b)
+ [&](LElementValueT a, RElementValueT b)
-> std::optional<ResultElementValueT> { return calculate(a, b); });
}
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1265bfb18aaa2..90f3f121a16d9 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -1148,7 +1148,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
The operation is elementwise for non-scalars, e.g.:
```mlir
- %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32
+ %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32>
```
The result is a vector of:
@@ -1172,9 +1172,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs) `,` type($rhs) }];
- // TODO: add a constant folder using pow[f] for cases, when
- // the power argument is exactly representable in floating
- // point type of the base.
+ let hasFolder = 1;
}
#endif // MATH_OPS
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bb552bd253b5f 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -776,6 +776,34 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// FPowIOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
+ return constFoldBinaryOpConditional<FloatAttr, IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APFloat &base, const APInt &exp) -> std::optional<APFloat> {
+ const llvm::fltSemantics &sem = base.getSemantics();
+ // Fold when the exponent is exactly representable in the
+ // floating-point type of the base.
+ APFloat fExp(sem);
+ if (fExp.convertFromAPInt(exp, /*isSigned=*/true,
+ APFloat::rmNearestTiesToEven) !=
+ APFloat::opOK)
+ return {};
+
+ switch (APFloat::getSizeInBits(sem)) {
+ case 64:
+ return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
+ case 32:
+ return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
/// Materialize an integer or floating point constant.
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 67235c38e9cdf..228faa31781c4 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -614,3 +614,36 @@ func.func @ipowi_i1_const_neg_exp() -> i1 {
%r = math.ipowi %b, %e : i1
return %r : i1
}
+
+// CHECK-LABEL: @fpowi_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f64
+// CHECK: %[[cst0:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]], %[[cst0]] : f64, f32
+func.func @fpowi_fold() -> (f64, f32) {
+ %cst = arith.constant 2.000000e+00 : f64
+ %cst_0 = arith.constant 2.000000e+00 : f32
+ %c2_i32 = arith.constant 2 : i32
+ %0 = math.fpowi %cst, %c2_i32 : f64, i32
+ %1 = math.fpowi %cst_0, %c2_i32 : f32, i32
+ return %0, %1 : f64, f32
+}
+
+// CHECK-LABEL: @fpowi_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 1.600000e+01, 9.000000e+00, 1.600000e+01]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @fpowi_fold_vec() -> vector<4xf32> {
+ %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %cst_0 = arith.constant dense<[2, 4, 2, 2]> : vector<4xi32>
+ %0 = math.fpowi %cst, %cst_0 : vector<4xf32>, vector<4xi32>
+ return %0 : vector<4xf32>
+}
+
+// 16777217 is not exactly representable in f32.
+// CHECK-LABEL: @fpowi_fold_failed
+// CHECK: math.fpowi
+func.func @fpowi_fold_failed() -> f32 {
+ %cst = arith.constant 2.000000e+00 : f32
+ %c16777217_i32 = arith.constant 16777217 : i32
+ %0 = math.fpowi %cst, %c16777217_i32 : f32, i32
+ return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c8be4bf3f0f8d..55e72b57cfd1b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -265,7 +265,8 @@ struct FoldLessThanOpF32ToI1 : public OpRewritePattern<test::LessThanOp> {
Attribute operandAttrs[2] = {lhsAttr, rhsAttr};
TypedAttr res = cast_or_null<TypedAttr>(
- constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
+ constFoldBinaryOp<FloatAttr, FloatAttr, FloatAttr::ValueType,
+ FloatAttr::ValueType, void, IntegerAttr>(
operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt {
return APInt(1, lhs < rhs);
}));
|
|
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesAdds a constant folder for Full diff: https://github.com/llvm/llvm-project/pull/193761.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 113765157946d..736b16ed25d44 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -36,13 +36,14 @@ class PoisonAttr;
/// Uses `resultType` for the type of the returned attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
- class CalculationT = function_ref<
- std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+ class CalculationT = function_ref<std::optional<ResultElementValueT>(
+ LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type resultType,
CalculationT &&calculate) {
@@ -62,11 +63,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!resultType || !operands[0] || !operands[1])
return {};
- if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
- auto lhs = cast<AttrElementT>(operands[0]);
- auto rhs = cast<AttrElementT>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if (isa<LAttrElementT>(operands[0]) && isa<RAttrElementT>(operands[1])) {
+ auto lhs = cast<LAttrElementT>(operands[0]);
+ auto rhs = cast<RAttrElementT>(operands[1]);
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
auto calRes = calculate(lhs.getValue(), rhs.getValue());
@@ -82,11 +84,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
// just fold based on the splat value.
auto lhs = cast<SplatElementsAttr>(operands[0]);
auto rhs = cast<SplatElementsAttr>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
- auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
- rhs.getSplatValue<ElementValueT>());
+ auto elementResult = calculate(lhs.getSplatValue<LElementValueT>(),
+ rhs.getSplatValue<RElementValueT>());
if (!elementResult)
return {};
@@ -98,11 +101,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
// expanding the values.
auto lhs = cast<ElementsAttr>(operands[0]);
auto rhs = cast<ElementsAttr>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
- auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
- auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
+ auto maybeLhsIt = lhs.try_value_begin<LElementValueT>();
+ auto maybeRhsIt = rhs.try_value_begin<RElementValueT>();
if (!maybeLhsIt || !maybeRhsIt)
return {};
auto lhsIt = *maybeLhsIt;
@@ -127,13 +131,14 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
/// attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
- class CalculationT = function_ref<
- std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+ class CalculationT = function_ref<std::optional<ResultElementValueT>(
+ LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
@@ -159,44 +164,49 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type rhsType = getAttrType(operands[1]);
if (!lhsType || !rhsType)
return {};
- if (lhsType != rhsType)
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhsType != rhsType)
+ return {};
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT, ResultElementValueT,
- CalculationT>(
+ return constFoldBinaryOpConditional<
+ LAttrElementT, RAttrElementT, LElementValueT, RElementValueT, PoisonAttr,
+ ResultAttrElementT, ResultElementValueT, CalculationT>(
operands, lhsType, std::forward<CalculationT>(calculate));
}
-template <class AttrElementT,
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = void, //
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT>(
+ return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+ LElementValueT, RElementValueT,
+ PoisonAttr, ResultAttrElementT>(
operands, resultType,
- [&](ElementValueT a, ElementValueT b)
+ [&](LElementValueT a, RElementValueT b)
-> std::optional<ResultElementValueT> { return calculate(a, b); });
}
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT>(
+ return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+ LElementValueT, RElementValueT,
+ PoisonAttr, ResultAttrElementT>(
operands,
- [&](ElementValueT a, ElementValueT b)
+ [&](LElementValueT a, RElementValueT b)
-> std::optional<ResultElementValueT> { return calculate(a, b); });
}
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1265bfb18aaa2..90f3f121a16d9 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -1148,7 +1148,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
The operation is elementwise for non-scalars, e.g.:
```mlir
- %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32
+ %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32>
```
The result is a vector of:
@@ -1172,9 +1172,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs) `,` type($rhs) }];
- // TODO: add a constant folder using pow[f] for cases, when
- // the power argument is exactly representable in floating
- // point type of the base.
+ let hasFolder = 1;
}
#endif // MATH_OPS
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bb552bd253b5f 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -776,6 +776,34 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// FPowIOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
+ return constFoldBinaryOpConditional<FloatAttr, IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APFloat &base, const APInt &exp) -> std::optional<APFloat> {
+ const llvm::fltSemantics &sem = base.getSemantics();
+ // Fold when the exponent is exactly representable in the
+ // floating-point type of the base.
+ APFloat fExp(sem);
+ if (fExp.convertFromAPInt(exp, /*isSigned=*/true,
+ APFloat::rmNearestTiesToEven) !=
+ APFloat::opOK)
+ return {};
+
+ switch (APFloat::getSizeInBits(sem)) {
+ case 64:
+ return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
+ case 32:
+ return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
/// Materialize an integer or floating point constant.
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 67235c38e9cdf..228faa31781c4 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -614,3 +614,36 @@ func.func @ipowi_i1_const_neg_exp() -> i1 {
%r = math.ipowi %b, %e : i1
return %r : i1
}
+
+// CHECK-LABEL: @fpowi_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f64
+// CHECK: %[[cst0:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]], %[[cst0]] : f64, f32
+func.func @fpowi_fold() -> (f64, f32) {
+ %cst = arith.constant 2.000000e+00 : f64
+ %cst_0 = arith.constant 2.000000e+00 : f32
+ %c2_i32 = arith.constant 2 : i32
+ %0 = math.fpowi %cst, %c2_i32 : f64, i32
+ %1 = math.fpowi %cst_0, %c2_i32 : f32, i32
+ return %0, %1 : f64, f32
+}
+
+// CHECK-LABEL: @fpowi_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 1.600000e+01, 9.000000e+00, 1.600000e+01]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @fpowi_fold_vec() -> vector<4xf32> {
+ %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %cst_0 = arith.constant dense<[2, 4, 2, 2]> : vector<4xi32>
+ %0 = math.fpowi %cst, %cst_0 : vector<4xf32>, vector<4xi32>
+ return %0 : vector<4xf32>
+}
+
+// 16777217 is not exactly representable in f32.
+// CHECK-LABEL: @fpowi_fold_failed
+// CHECK: math.fpowi
+func.func @fpowi_fold_failed() -> f32 {
+ %cst = arith.constant 2.000000e+00 : f32
+ %c16777217_i32 = arith.constant 16777217 : i32
+ %0 = math.fpowi %cst, %c16777217_i32 : f32, i32
+ return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c8be4bf3f0f8d..55e72b57cfd1b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -265,7 +265,8 @@ struct FoldLessThanOpF32ToI1 : public OpRewritePattern<test::LessThanOp> {
Attribute operandAttrs[2] = {lhsAttr, rhsAttr};
TypedAttr res = cast_or_null<TypedAttr>(
- constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
+ constFoldBinaryOp<FloatAttr, FloatAttr, FloatAttr::ValueType,
+ FloatAttr::ValueType, void, IntegerAttr>(
operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt {
return APInt(1, lhs < rhs);
}));
|
| switch (APFloat::getSizeInBits(sem)) { | ||
| case 64: | ||
| return APFloat(pow(base.convertToDouble(), fExp.convertToDouble())); | ||
| case 32: |
There was a problem hiding this comment.
I would check for the exact type here. E.g., if (getLhs().getType().isF32()), instead of checking only the bitwidth.
There was a problem hiding this comment.
I referred to the code style in the Math folder, such as math::TruncOp::fold.
There was a problem hiding this comment.
The problem here is that you could have multiple FP types with the same bitwidth. (Although, admittedly that's not the case for 32 and 64 bit today.)
There was a problem hiding this comment.
I see, but the lambda function requires an APFloat type argument. Can we obtain a plain float type from APFloat?
| class PoisonAttr = ub::PoisonAttr, | ||
| class ResultAttrElementT = AttrElementT, | ||
| class ResultAttrElementT = LAttrElementT, |
There was a problem hiding this comment.
Can you summarize the changes here?
There was a problem hiding this comment.
I support binary operations where the left and right operands have different attribute element types, for example, left is float and right is integer.
Adds a constant folder for
math.fpowiwhen both operands are constant and the integer exponent is exactly representable in the floating-point type of the base.