blob: 926d580ac78527ed8fcdd2faaafb598d0925ef0d [file] [log] [blame]
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::ml_program;
namespace mlir {
namespace ml_program {
namespace {
template <typename Interface, typename Op>
struct ExternalModelBase
: public BufferizableOpInterface::ExternalModel<Interface, Op> {
AliasingValueList getAliasingValues(Operation *, OpOperand &,
const AnalysisState &) const {
return {};
}
BufferRelation bufferRelation(Operation *, OpResult,
const AnalysisState &) const {
return BufferRelation::Unknown;
}
};
/// Bufferization of ml_program.global into a memref.global
struct GlobalOpInterface
: public ExternalModelBase<GlobalOpInterface, GlobalOp> {
bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}
bool hasTensorSemantics(Operation *) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &) const {
auto globalOp = cast<GlobalOp>(op);
if (!globalOp.getValue().has_value())
return globalOp.emitError("global op must have a value");
auto tensorType = cast<TensorType>(globalOp.getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
replaceOpWithNewBufferizedOp<memref::GlobalOp>(
rewriter, globalOp, globalOp.getSymName(),
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
/*type=*/cast<MemRefType>(memrefType),
/*initial_value=*/globalOp.getValue().value(),
/*constant=*/!globalOp.getIsMutable(),
/*alignment=*/nullptr);
return success();
}
};
/// Bufferization of ml_program.global_load into a memref.get_global
struct GlobalLoadOpInterface
: public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}
bool isWritable(Operation *, Value, const AnalysisState &) const {
return false;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &) const {
auto globalLoadOp = cast<GlobalLoadOp>(op);
auto tensorType = cast<TensorType>(globalLoadOp.getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
rewriter, globalLoadOp, memrefType,
globalLoadOp.getGlobalAttr().getLeafReference());
return success();
}
};
/// Bufferization of ml_program.global_store into a memref.get_global and
/// memcpy
struct GlobalStoreOpInterface
: public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto globalStoreOp = cast<GlobalStoreOp>(op);
auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
auto loc = globalStoreOp.getLoc();
auto targetMemref = rewriter.create<memref::GetGlobalOp>(
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
if (failed(sourceMemref)) {
return failure();
}
auto memcpy =
options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
if (failed(memcpy)) {
return failure();
}
rewriter.eraseOp(globalStoreOp);
return success();
}
};
} // namespace
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
});
}
} // namespace ml_program
} // namespace mlir