blob: 187b3b71e3458baf44bd158158d59ddab1df14d4 [file] [log] [blame]
//===- Var.cpp ------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Var.h"
#include "DimLvlMap.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
using namespace mlir::sparse_tensor::ir_detail;
//===----------------------------------------------------------------------===//
// `VarKind` helpers.
//===----------------------------------------------------------------------===//
/// For use in foreach loops.
static constexpr const VarKind everyVarKind[] = {
VarKind::Dimension, VarKind::Symbol, VarKind::Level};
//===----------------------------------------------------------------------===//
// `Var` implementation.
//===----------------------------------------------------------------------===//
std::string Var::str() const {
std::string str;
llvm::raw_string_ostream os(str);
print(os);
return os.str();
}
void Var::print(AsmPrinter &printer) const { print(printer.getStream()); }
void Var::print(llvm::raw_ostream &os) const {
os << toChar(getKind()) << getNum();
}
void Var::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
//===----------------------------------------------------------------------===//
// `Ranks` implementation.
//===----------------------------------------------------------------------===//
bool Ranks::operator==(Ranks const &other) const {
for (const auto vk : everyVarKind)
if (getRank(vk) != other.getRank(vk))
return false;
return true;
}
bool Ranks::isValid(DimLvlExpr expr) const {
assert(expr);
// Compute the maximum identifiers for symbol-vars and dim/lvl-vars
// (each `DimLvlExpr` only allows one kind of non-symbol variable).
int64_t maxSym = -1, maxVar = -1;
mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>({{expr.getAffineExpr()}},
maxVar, maxSym);
return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind());
}
//===----------------------------------------------------------------------===//
// `VarSet` implementation.
//===----------------------------------------------------------------------===//
VarSet::VarSet(Ranks const &ranks) {
for (const auto vk : everyVarKind)
impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
assert(getRanks() == ranks);
}
bool VarSet::contains(Var var) const {
// NOTE: We make sure to return false on OOB, for consistency with
// the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
// However beware that, as always with silencing OOB, this can hide
// bugs in client code.
const llvm::SmallBitVector &bits = impl[var.getKind()];
const auto num = var.getNum();
return num < bits.size() && bits[num];
}
void VarSet::add(Var var) {
// NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
impl[var.getKind()][var.getNum()] = true;
}
void VarSet::add(VarSet const &other) {
// NOTE: `SmallBitVector::operator&=` will implicitly resize
// the bitvector (unlike `BitVector::operator&=`), so we add an
// assertion against OOB for consistency with the implementation
// of `VarSet::add(Var)`.
for (const auto vk : everyVarKind) {
assert(impl[vk].size() >= other.impl[vk].size());
impl[vk] &= other.impl[vk];
}
}
void VarSet::add(DimLvlExpr expr) {
if (!expr)
return;
switch (expr.getAffineKind()) {
case AffineExprKind::Constant:
return;
case AffineExprKind::SymbolId:
add(expr.castSymVar());
return;
case AffineExprKind::DimId:
add(expr.castDimLvlVar());
return;
case AffineExprKind::Add:
case AffineExprKind::Mul:
case AffineExprKind::Mod:
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
const auto [lhs, op, rhs] = expr.unpackBinop();
(void)op;
add(lhs);
add(rhs);
return;
}
}
llvm_unreachable("unknown AffineExprKind");
}
//===----------------------------------------------------------------------===//
// `VarInfo` implementation.
//===----------------------------------------------------------------------===//
void VarInfo::setNum(Var::Num n) {
assert(!hasNum() && "Var::Num is already set");
assert(Var::isWF_Num(n) && "Var::Num is too large");
num = n;
}
//===----------------------------------------------------------------------===//
// `VarEnv` implementation.
//===----------------------------------------------------------------------===//
/// Helper function for `assertUsageConsistency` to better handle SMLoc
/// mismatches.
LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
if (loc1.getFilename() != loc2.getFilename())
return SMLoc();
const auto pair1 = std::make_pair(loc1.getLine(), loc1.getColumn());
const auto pair2 = std::make_pair(loc2.getLine(), loc2.getColumn());
return pair1 <= pair2 ? sm1 : sm2;
}
bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
const auto &var = env.access(id);
return (var.getName() == name && var.getID() == id);
}
bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
VarKind vk) {
const auto &var = env.access(id);
return var.getKind() == vk;
}
std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
const auto iter = ids.find(name);
if (iter == ids.end())
return std::nullopt;
const auto id = iter->second;
if (!isInternalConsistent(*this, id, name))
return std::nullopt;
return id;
}
std::optional<std::pair<VarInfo::ID, bool>>
VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
const auto id = iter->second;
if (didInsert) {
vars.emplace_back(id, name, loc, vk);
} else {
if (!isInternalConsistent(*this, id, name))
return std::nullopt;
if (verifyUsage)
if (!isUsageConsistent(*this, id, loc, vk))
return std::nullopt;
}
return std::make_pair(id, didInsert);
}
std::optional<std::pair<VarInfo::ID, bool>>
VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
VarKind vk) {
switch (creationPolicy) {
case Policy::MustNot: {
const auto oid = lookup(name);
if (!oid)
return std::nullopt; // Doesn't exist, but must not create.
if (!isUsageConsistent(*this, *oid, loc, vk))
return std::nullopt;
return std::make_pair(*oid, false);
}
case Policy::May:
return create(name, loc, vk, /*verifyUsage=*/true);
case Policy::Must: {
const auto res = create(name, loc, vk, /*verifyUsage=*/false);
const auto didCreate = res->second;
if (!didCreate)
return std::nullopt; // Already exists, but must create.
return res;
}
}
llvm_unreachable("unknown Policy");
}
Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); }
Var VarEnv::bindVar(VarInfo::ID id) {
auto &info = access(id);
const auto var = bindUnusedVar(info.getKind());
info.setNum(var.getNum());
return var;
}
InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
for (const auto &var : vars)
if (!var.hasNum())
return parser.emitError(var.getLoc(),
"Unbound variable: " + var.getName());
return {};
}
//===----------------------------------------------------------------------===//