| //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" |
| |
| #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" |
| #include "mlir/Conversion/ComplexCommon/DivisionConverter.h" |
| #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
| #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Complex/IR/Complex.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Pass/Pass.h" |
| |
| namespace mlir { |
| #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS |
| #include "mlir/Conversion/Passes.h.inc" |
| } // namespace mlir |
| |
| using namespace mlir; |
| using namespace mlir::LLVM; |
| using namespace mlir::arith; |
| |
| //===----------------------------------------------------------------------===// |
| // ComplexStructBuilder implementation. |
| //===----------------------------------------------------------------------===// |
| |
| static constexpr unsigned kRealPosInComplexNumberStruct = 0; |
| static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; |
| |
| ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder, |
| Location loc, Type type) { |
| Value val = builder.create<LLVM::PoisonOp>(loc, type); |
| return ComplexStructBuilder(val); |
| } |
| |
| void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, |
| Value real) { |
| setPtr(builder, loc, kRealPosInComplexNumberStruct, real); |
| } |
| |
| Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { |
| return extractPtr(builder, loc, kRealPosInComplexNumberStruct); |
| } |
| |
| void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, |
| Value imaginary) { |
| setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); |
| } |
| |
| Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { |
| return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion patterns. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { |
| using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| |
| ComplexStructBuilder complexStruct(adaptor.getComplex()); |
| Value real = complexStruct.real(rewriter, op.getLoc()); |
| Value imag = complexStruct.imaginary(rewriter, op.getLoc()); |
| |
| arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
| LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
| op.getContext(), |
| convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
| Value sqNorm = rewriter.create<LLVM::FAddOp>( |
| loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), |
| rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); |
| |
| rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); |
| return success(); |
| } |
| }; |
| |
| struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { |
| using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| return LLVM::detail::oneToOneRewrite( |
| op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), |
| op->getAttrs(), *getTypeConverter(), rewriter); |
| } |
| }; |
| |
| struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> { |
| using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Pack real and imaginary part in a complex number struct. |
| auto loc = complexOp.getLoc(); |
| auto structType = typeConverter->convertType(complexOp.getType()); |
| auto complexStruct = |
| ComplexStructBuilder::poison(rewriter, loc, structType); |
| complexStruct.setReal(rewriter, loc, adaptor.getReal()); |
| complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary()); |
| |
| rewriter.replaceOp(complexOp, {complexStruct}); |
| return success(); |
| } |
| }; |
| |
| struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> { |
| using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Extract real part from the complex number struct. |
| ComplexStructBuilder complexStruct(adaptor.getComplex()); |
| Value real = complexStruct.real(rewriter, op.getLoc()); |
| rewriter.replaceOp(op, real); |
| |
| return success(); |
| } |
| }; |
| |
| struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> { |
| using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Extract imaginary part from the complex number struct. |
| ComplexStructBuilder complexStruct(adaptor.getComplex()); |
| Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); |
| rewriter.replaceOp(op, imaginary); |
| |
| return success(); |
| } |
| }; |
| |
| struct BinaryComplexOperands { |
| std::complex<Value> lhs; |
| std::complex<Value> rhs; |
| }; |
| |
| template <typename OpTy> |
| BinaryComplexOperands |
| unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) { |
| auto loc = op.getLoc(); |
| |
| // Extract real and imaginary values from operands. |
| BinaryComplexOperands unpacked; |
| ComplexStructBuilder lhs(adaptor.getLhs()); |
| unpacked.lhs.real(lhs.real(rewriter, loc)); |
| unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); |
| ComplexStructBuilder rhs(adaptor.getRhs()); |
| unpacked.rhs.real(rhs.real(rewriter, loc)); |
| unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); |
| |
| return unpacked; |
| } |
| |
| struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { |
| using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| BinaryComplexOperands arg = |
| unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter); |
| |
| // Initialize complex number struct for result. |
| auto structType = typeConverter->convertType(op.getType()); |
| auto result = ComplexStructBuilder::poison(rewriter, loc, structType); |
| |
| // Emit IR to add complex numbers. |
| arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
| LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
| op.getContext(), |
| convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
| Value real = |
| rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); |
| Value imag = |
| rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); |
| result.setReal(rewriter, loc, real); |
| result.setImaginary(rewriter, loc, imag); |
| |
| rewriter.replaceOp(op, {result}); |
| return success(); |
| } |
| }; |
| |
| struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> { |
| DivOpConversion(const LLVMTypeConverter &converter, |
| complex::ComplexRangeFlags target) |
| : ConvertOpToLLVMPattern<complex::DivOp>(converter), |
| complexRange(target) {} |
| |
| using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| BinaryComplexOperands arg = |
| unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter); |
| |
| // Initialize complex number struct for result. |
| auto structType = typeConverter->convertType(op.getType()); |
| auto result = ComplexStructBuilder::poison(rewriter, loc, structType); |
| |
| // Emit IR to add complex numbers. |
| arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
| LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
| op.getContext(), |
| convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
| Value rhsRe = arg.rhs.real(); |
| Value rhsIm = arg.rhs.imag(); |
| Value lhsRe = arg.lhs.real(); |
| Value lhsIm = arg.lhs.imag(); |
| |
| Value resultRe, resultIm; |
| |
| if (complexRange == complex::ComplexRangeFlags::basic || |
| complexRange == complex::ComplexRangeFlags::none) { |
| mlir::complex::convertDivToLLVMUsingAlgebraic( |
| rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm); |
| } else if (complexRange == complex::ComplexRangeFlags::improved) { |
| mlir::complex::convertDivToLLVMUsingRangeReduction( |
| rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm); |
| } |
| |
| result.setReal(rewriter, loc, resultRe); |
| result.setImaginary(rewriter, loc, resultIm); |
| |
| rewriter.replaceOp(op, {result}); |
| return success(); |
| } |
| |
| private: |
| complex::ComplexRangeFlags complexRange; |
| }; |
| |
| struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { |
| using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| BinaryComplexOperands arg = |
| unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter); |
| |
| // Initialize complex number struct for result. |
| auto structType = typeConverter->convertType(op.getType()); |
| auto result = ComplexStructBuilder::poison(rewriter, loc, structType); |
| |
| // Emit IR to add complex numbers. |
| arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
| LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
| op.getContext(), |
| convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
| Value rhsRe = arg.rhs.real(); |
| Value rhsIm = arg.rhs.imag(); |
| Value lhsRe = arg.lhs.real(); |
| Value lhsIm = arg.lhs.imag(); |
| |
| Value real = rewriter.create<LLVM::FSubOp>( |
| loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), |
| rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); |
| |
| Value imag = rewriter.create<LLVM::FAddOp>( |
| loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), |
| rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); |
| |
| result.setReal(rewriter, loc, real); |
| result.setImaginary(rewriter, loc, imag); |
| |
| rewriter.replaceOp(op, {result}); |
| return success(); |
| } |
| }; |
| |
| struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { |
| using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| BinaryComplexOperands arg = |
| unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter); |
| |
| // Initialize complex number struct for result. |
| auto structType = typeConverter->convertType(op.getType()); |
| auto result = ComplexStructBuilder::poison(rewriter, loc, structType); |
| |
| // Emit IR to substract complex numbers. |
| arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); |
| LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( |
| op.getContext(), |
| convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); |
| Value real = |
| rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); |
| Value imag = |
| rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); |
| result.setReal(rewriter, loc, real); |
| result.setImaginary(rewriter, loc, imag); |
| |
| rewriter.replaceOp(op, {result}); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void mlir::populateComplexToLLVMConversionPatterns( |
| const LLVMTypeConverter &converter, RewritePatternSet &patterns, |
| complex::ComplexRangeFlags complexRange) { |
| // clang-format off |
| patterns.add< |
| AbsOpConversion, |
| AddOpConversion, |
| ConstantOpLowering, |
| CreateOpConversion, |
| ImOpConversion, |
| MulOpConversion, |
| ReOpConversion, |
| SubOpConversion |
| >(converter); |
| |
| patterns.add<DivOpConversion>(converter, complexRange); |
| // clang-format on |
| } |
| |
| namespace { |
| struct ConvertComplexToLLVMPass |
| : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> { |
| using Base::Base; |
| |
| void runOnOperation() override; |
| }; |
| } // namespace |
| |
| void ConvertComplexToLLVMPass::runOnOperation() { |
| // Convert to the LLVM IR dialect using the converter defined above. |
| RewritePatternSet patterns(&getContext()); |
| LLVMTypeConverter converter(&getContext()); |
| populateComplexToLLVMConversionPatterns(converter, patterns, complexRange); |
| |
| LLVMConversionTarget target(getContext()); |
| target.addIllegalDialect<complex::ComplexDialect>(); |
| if (failed( |
| applyPartialConversion(getOperation(), target, std::move(patterns)))) |
| signalPassFailure(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConvertToLLVMPatternInterface implementation |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// Implement the interface to convert MemRef to LLVM. |
| struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
| using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
| void loadDependentDialects(MLIRContext *context) const final { |
| context->loadDialect<LLVM::LLVMDialect>(); |
| } |
| |
| /// Hook for derived dialect interface to provide conversion patterns |
| /// and mark dialect legal for the conversion target. |
| void populateConvertToLLVMConversionPatterns( |
| ConversionTarget &target, LLVMTypeConverter &typeConverter, |
| RewritePatternSet &patterns) const final { |
| populateComplexToLLVMConversionPatterns(typeConverter, patterns); |
| } |
| }; |
| } // namespace |
| |
| void mlir::registerConvertComplexToLLVMInterface(DialectRegistry ®istry) { |
| registry.addExtension( |
| +[](MLIRContext *ctx, complex::ComplexDialect *dialect) { |
| dialect->addInterfaces<ComplexToLLVMDialectInterface>(); |
| }); |
| } |