blob: 24719fe748fe4fab2cdb6f776c2ed122481c18cd [file] [log] [blame]
//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
namespace xegpu {
void XeGPUDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
>();
addOperations<
#define GET_OP_LIST
#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
>();
}
//===----------------------------------------------------------------------===//
// XeGPU_TensorDescAttr
//===----------------------------------------------------------------------===//
TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context,
xegpu::MemoryScope memory_scope,
int array_length, bool boundary_check,
bool scattered) {
auto scopeAttr = MemoryScopeAttr::get(context, memory_scope);
auto lengthAttr =
IntegerAttr::get(IntegerType::get(context, 64), array_length);
auto boundaryAttr = BoolAttr::get(context, boundary_check);
auto scatteredAttr = BoolAttr::get(context, scattered);
return Base::get(context, scopeAttr, lengthAttr, boundaryAttr, scatteredAttr);
}
//===----------------------------------------------------------------------===//
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> encoding;
// Parse literal '<'
if (parser.parseLess())
return {};
auto shapeLoc = parser.getCurrentLocation();
if (mlir::failed(parser.parseDimensionList(shape))) {
parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
return {};
}
auto elemTypeLoc = parser.getCurrentLocation();
if (mlir::failed(parser.parseType(elementType))) {
parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
return {};
}
// parse optional attributes
if (mlir::succeeded(parser.parseOptionalComma())) {
encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
if (mlir::failed(encoding)) {
parser.emitError(
parser.getCurrentLocation(),
"Failed to parse the attribute field for TensorDescType.\n");
return {};
}
}
// Parse literal '>'
if (parser.parseGreater())
return {};
return TensorDescType::get(parser.getContext(), shape, elementType,
encoding.value_or(mlir::Attribute()));
}
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
printer << "<";
auto shape = getShape();
for (int64_t dim : shape) {
if (mlir::ShapedType::isDynamic(dim))
printer << '?';
else
printer << dim;
printer << 'x';
}
printer << getElementType();
if (auto encoding = getEncoding())
printer << ", " << encoding;
printer << ">";
}
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
mlir::Type elementType, bool scattered,
int array_length, MemoryScope memory_scope,
bool boundary_check) {
auto context = elementType.getContext();
auto attr = TensorDescAttr::get(context, memory_scope, array_length,
boundary_check, scattered);
return Base::get(context, shape, elementType, attr);
}
} // namespace xegpu
} // namespace mlir
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
#define GET_TYPEDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>