| //===- LegalizeData.cpp - -------------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/OpenACC/Transforms/Passes.h" |
| |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/OpenACC/OpenACC.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/RegionUtils.h" |
| |
| namespace mlir { |
| namespace acc { |
| #define GEN_PASS_DEF_LEGALIZEDATAINREGION |
| #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" |
| } // namespace acc |
| } // namespace mlir |
| |
| using namespace mlir; |
| |
| namespace { |
| |
| static void collectPtrs(mlir::ValueRange operands, |
| llvm::SmallVector<std::pair<Value, Value>> &values, |
| bool hostToDevice) { |
| for (auto operand : operands) { |
| Value varPtr = acc::getVarPtr(operand.getDefiningOp()); |
| Value accPtr = acc::getAccPtr(operand.getDefiningOp()); |
| if (varPtr && accPtr) { |
| if (hostToDevice) |
| values.push_back({varPtr, accPtr}); |
| else |
| values.push_back({accPtr, varPtr}); |
| } |
| } |
| } |
| |
| template <typename Op> |
| static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { |
| llvm::SmallVector<std::pair<Value, Value>> values; |
| |
| if constexpr (std::is_same_v<Op, acc::LoopOp>) { |
| collectPtrs(op.getReductionOperands(), values, hostToDevice); |
| collectPtrs(op.getPrivateOperands(), values, hostToDevice); |
| } else { |
| collectPtrs(op.getDataClauseOperands(), values, hostToDevice); |
| if constexpr (!std::is_same_v<Op, acc::KernelsOp>) { |
| collectPtrs(op.getReductionOperands(), values, hostToDevice); |
| collectPtrs(op.getGangPrivateOperands(), values, hostToDevice); |
| collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice); |
| } |
| } |
| |
| for (auto p : values) |
| replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion()); |
| } |
| |
| struct LegalizeDataInRegion |
| : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> { |
| |
| void runOnOperation() override { |
| func::FuncOp funcOp = getOperation(); |
| bool replaceHostVsDevice = this->hostToDevice.getValue(); |
| |
| funcOp.walk([&](Operation *op) { |
| if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op)) |
| return; |
| |
| if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) { |
| collectAndReplaceInRegion(parallelOp, replaceHostVsDevice); |
| } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) { |
| collectAndReplaceInRegion(serialOp, replaceHostVsDevice); |
| } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) { |
| collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice); |
| } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) { |
| collectAndReplaceInRegion(loopOp, replaceHostVsDevice); |
| } |
| }); |
| } |
| }; |
| |
| } // end anonymous namespace |
| |
| std::unique_ptr<OperationPass<func::FuncOp>> |
| mlir::acc::createLegalizeDataInRegion() { |
| return std::make_unique<LegalizeDataInRegion>(); |
| } |