blob: 0004c2e3403e53a78731fa5c5766a6f8c11b2991 [file] [log] [blame]
//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements helper functions to call common simple C functions in
// LLVMIR (e.g. amon others to support printing and debugging).
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
using namespace mlir::LLVM;
/// Helper functions to lookup or create the declaration for commonly used
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as
/// part of the libc).
static constexpr llvm::StringRef kPrintI64 = "printI64";
static constexpr llvm::StringRef kPrintU64 = "printU64";
static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
static constexpr llvm::StringRef kPrintComma = "printComma";
static constexpr llvm::StringRef kPrintNewline = "printNewline";
static constexpr llvm::StringRef kMalloc = "malloc";
static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
static constexpr llvm::StringRef kFree = "free";
static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
static constexpr llvm::StringRef kGenericAlignedAlloc =
"_mlir_memref_to_llvm_aligned_alloc";
static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
/// Generic print function lookupOrCreate helper.
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
ArrayRef<Type> paramTypes,
Type resultType, bool isVarArg) {
auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
if (func)
return func;
OpBuilder b(moduleOp.getBodyRegion());
return b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintI64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintU64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF64,
Float64Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
// A char pointer and void ptr are the same in LLVM IR.
return getCharPtr(context);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
getVoidPtr(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
getVoidPtr(moduleOp->getContext()));
}
LLVM::LLVMFuncOp
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
{indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp
mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
Type unrankedDescriptorType) {
return LLVM::lookupOrCreateFn(
moduleOp, kMemRefCopy,
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}