blob: a6f9af83b97738b018e8258f6e143f35f4d4a376 [file] [log] [blame]
//===- PresburgerSpace.cpp - MLIR PresburgerSpace Class -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/PresburgerSpace.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
using namespace mlir;
using namespace presburger;
bool Identifier::isEqual(const Identifier &other) const {
if (value == nullptr || other.value == nullptr)
return false;
assert(value != other.value ||
(value == other.value && idType == other.idType &&
"Values of Identifiers are equal but their types do not match."));
return value == other.value;
}
void Identifier::print(llvm::raw_ostream &os) const {
os << "Id<" << value << ">";
}
void Identifier::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
PresburgerSpace PresburgerSpace::getDomainSpace() const {
PresburgerSpace newSpace = *this;
newSpace.removeVarRange(VarKind::Range, 0, getNumRangeVars());
newSpace.convertVarKind(VarKind::Domain, 0, getNumDomainVars(),
VarKind::SetDim, 0);
return newSpace;
}
PresburgerSpace PresburgerSpace::getRangeSpace() const {
PresburgerSpace newSpace = *this;
newSpace.removeVarRange(VarKind::Domain, 0, getNumDomainVars());
return newSpace;
}
PresburgerSpace PresburgerSpace::getSpaceWithoutLocals() const {
PresburgerSpace space = *this;
space.removeVarRange(VarKind::Local, 0, getNumLocalVars());
return space;
}
unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
if (kind == VarKind::Domain)
return getNumDomainVars();
if (kind == VarKind::Range)
return getNumRangeVars();
if (kind == VarKind::Symbol)
return getNumSymbolVars();
if (kind == VarKind::Local)
return getNumLocalVars();
llvm_unreachable("VarKind does not exist!");
}
unsigned PresburgerSpace::getVarKindOffset(VarKind kind) const {
if (kind == VarKind::Domain)
return 0;
if (kind == VarKind::Range)
return getNumDomainVars();
if (kind == VarKind::Symbol)
return getNumDimVars();
if (kind == VarKind::Local)
return getNumDimAndSymbolVars();
llvm_unreachable("VarKind does not exist!");
}
unsigned PresburgerSpace::getVarKindEnd(VarKind kind) const {
return getVarKindOffset(kind) + getNumVarKind(kind);
}
unsigned PresburgerSpace::getVarKindOverlap(VarKind kind, unsigned varStart,
unsigned varLimit) const {
unsigned varRangeStart = getVarKindOffset(kind);
unsigned varRangeEnd = getVarKindEnd(kind);
// Compute number of elements in intersection of the ranges [varStart,
// varLimit) and [varRangeStart, varRangeEnd).
unsigned overlapStart = std::max(varStart, varRangeStart);
unsigned overlapEnd = std::min(varLimit, varRangeEnd);
if (overlapStart > overlapEnd)
return 0;
return overlapEnd - overlapStart;
}
VarKind PresburgerSpace::getVarKindAt(unsigned pos) const {
assert(pos < getNumVars() && "`pos` should represent a valid var position");
if (pos < getVarKindEnd(VarKind::Domain))
return VarKind::Domain;
if (pos < getVarKindEnd(VarKind::Range))
return VarKind::Range;
if (pos < getVarKindEnd(VarKind::Symbol))
return VarKind::Symbol;
if (pos < getVarKindEnd(VarKind::Local))
return VarKind::Local;
llvm_unreachable("`pos` should represent a valid var position");
}
unsigned PresburgerSpace::insertVar(VarKind kind, unsigned pos, unsigned num) {
assert(pos <= getNumVarKind(kind));
unsigned absolutePos = getVarKindOffset(kind) + pos;
if (kind == VarKind::Domain)
numDomain += num;
else if (kind == VarKind::Range)
numRange += num;
else if (kind == VarKind::Symbol)
numSymbols += num;
else
numLocals += num;
// Insert NULL identifiers if `usingIds` and variables inserted are
// not locals.
if (usingIds && kind != VarKind::Local)
identifiers.insert(identifiers.begin() + absolutePos, num, Identifier());
return absolutePos;
}
void PresburgerSpace::removeVarRange(VarKind kind, unsigned varStart,
unsigned varLimit) {
assert(varLimit <= getNumVarKind(kind) && "invalid var limit");
if (varStart >= varLimit)
return;
unsigned numVarsEliminated = varLimit - varStart;
if (kind == VarKind::Domain)
numDomain -= numVarsEliminated;
else if (kind == VarKind::Range)
numRange -= numVarsEliminated;
else if (kind == VarKind::Symbol)
numSymbols -= numVarsEliminated;
else
numLocals -= numVarsEliminated;
// Remove identifiers if `usingIds` and variables removed are not
// locals.
if (usingIds && kind != VarKind::Local)
identifiers.erase(identifiers.begin() + getVarKindOffset(kind) + varStart,
identifiers.begin() + getVarKindOffset(kind) + varLimit);
}
void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos,
unsigned num, VarKind dstKind,
unsigned dstPos) {
assert(srcKind != dstKind && "cannot convert variables to the same kind");
assert(srcPos + num <= getNumVarKind(srcKind) &&
"invalid range for source variables");
assert(dstPos <= getNumVarKind(dstKind) &&
"invalid position for destination variables");
// Move identifiers if `usingIds` and variables moved are not locals.
unsigned srcOffset = getVarKindOffset(srcKind) + srcPos;
unsigned dstOffset = getVarKindOffset(dstKind) + dstPos;
if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) {
identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
// Update srcOffset if insertion of new elements invalidates it.
if (dstOffset < srcOffset)
srcOffset += num;
std::move(identifiers.begin() + srcOffset,
identifiers.begin() + srcOffset + num,
identifiers.begin() + dstOffset);
identifiers.erase(identifiers.begin() + srcOffset,
identifiers.begin() + srcOffset + num);
} else if (isUsingIds() && srcKind != VarKind::Local) {
identifiers.erase(identifiers.begin() + srcOffset,
identifiers.begin() + srcOffset + num);
} else if (isUsingIds() && dstKind != VarKind::Local) {
identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
}
auto addVars = [&](VarKind kind, int num) {
switch (kind) {
case VarKind::Domain:
numDomain += num;
break;
case VarKind::Range:
numRange += num;
break;
case VarKind::Symbol:
numSymbols += num;
break;
case VarKind::Local:
numLocals += num;
break;
}
};
addVars(srcKind, -(signed)num);
addVars(dstKind, num);
}
void PresburgerSpace::swapVar(VarKind kindA, VarKind kindB, unsigned posA,
unsigned posB) {
if (!isUsingIds())
return;
if (kindA == VarKind::Local && kindB == VarKind::Local)
return;
if (kindA == VarKind::Local) {
getId(kindB, posB) = Identifier();
return;
}
if (kindB == VarKind::Local) {
getId(kindA, posA) = Identifier();
return;
}
std::swap(getId(kindA, posA), getId(kindB, posB));
}
bool PresburgerSpace::isCompatible(const PresburgerSpace &other) const {
return getNumDomainVars() == other.getNumDomainVars() &&
getNumRangeVars() == other.getNumRangeVars() &&
getNumSymbolVars() == other.getNumSymbolVars();
}
bool PresburgerSpace::isEqual(const PresburgerSpace &other) const {
return isCompatible(other) && getNumLocalVars() == other.getNumLocalVars();
}
/// Checks if the number of ids of the given kind in the two spaces are
/// equal and if the ids are equal. Assumes that both spaces are using
/// ids.
static bool areIdsEqual(const PresburgerSpace &spaceA,
const PresburgerSpace &spaceB, VarKind kind) {
assert(spaceA.isUsingIds() && spaceB.isUsingIds() &&
"Both spaces should be using ids");
if (spaceA.getNumVarKind(kind) != spaceB.getNumVarKind(kind))
return false;
if (kind == VarKind::Local)
return true; // No ids.
return spaceA.getIds(kind) == spaceB.getIds(kind);
}
bool PresburgerSpace::isAligned(const PresburgerSpace &other) const {
// If only one of the spaces is using identifiers, then they are
// not aligned.
if (isUsingIds() != other.isUsingIds())
return false;
// If both spaces are using identifiers, then they are aligned if
// their identifiers are equal. Identifiers being equal implies
// that the number of variables of each kind is same, which implies
// compatiblity, so we do not check for that.
if (isUsingIds())
return areIdsEqual(*this, other, VarKind::Domain) &&
areIdsEqual(*this, other, VarKind::Range) &&
areIdsEqual(*this, other, VarKind::Symbol);
// If neither space is using identifiers, then they are aligned if
// they are compatible.
return isCompatible(other);
}
bool PresburgerSpace::isAligned(const PresburgerSpace &other,
VarKind kind) const {
// If only one of the spaces is using identifiers, then they are
// not aligned.
if (isUsingIds() != other.isUsingIds())
return false;
// If both spaces are using identifiers, then they are aligned if
// their identifiers are equal. Identifiers being equal implies
// that the number of variables of each kind is same, which implies
// compatiblity, so we do not check for that
if (isUsingIds())
return areIdsEqual(*this, other, kind);
// If neither space is using identifiers, then they are aligned if
// the number of variable kind is equal.
return getNumVarKind(kind) == other.getNumVarKind(kind);
}
void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
assert(newSymbolCount <= getNumDimAndSymbolVars() &&
"invalid separation position");
numRange = numRange + numSymbols - newSymbolCount;
numSymbols = newSymbolCount;
// We do not need to change `identifiers` since the ordering of
// `identifiers` remains same.
}
void PresburgerSpace::mergeAndAlignSymbols(PresburgerSpace &other) {
assert(usingIds && other.usingIds &&
"Both spaces need to have identifers to merge & align");
// First merge & align identifiers into `other` from `this`.
unsigned kindBeginOffset = other.getVarKindOffset(VarKind::Symbol);
unsigned i = 0;
for (const Identifier *identifier =
identifiers.begin() + getVarKindOffset(VarKind::Symbol);
identifier != identifiers.begin() + getVarKindEnd(VarKind::Symbol);
identifier++) {
// If the identifier exists in `other`, then align it; otherwise insert it
// assuming it is a new identifier. Search in `other` starting at position
// `i` since the left of `i` is aligned.
auto *findEnd =
other.identifiers.begin() + other.getVarKindEnd(VarKind::Symbol);
auto *itr = std::find(other.identifiers.begin() + kindBeginOffset + i,
findEnd, *identifier);
if (itr != findEnd) {
std::iter_swap(other.identifiers.begin() + kindBeginOffset + i, itr);
} else {
other.insertVar(VarKind::Symbol, i);
other.getId(VarKind::Symbol, i) = *identifier;
}
i++;
}
// Finally add identifiers that are in `other`, but not in `this` to `this`.
for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; i++) {
insertVar(VarKind::Symbol, i);
getId(VarKind::Symbol, i) = other.getId(VarKind::Symbol, i);
}
}
void PresburgerSpace::print(llvm::raw_ostream &os) const {
os << "Domain: " << getNumDomainVars() << ", "
<< "Range: " << getNumRangeVars() << ", "
<< "Symbols: " << getNumSymbolVars() << ", "
<< "Locals: " << getNumLocalVars() << "\n";
if (isUsingIds()) {
auto printIds = [&](VarKind kind) {
os << " ";
for (Identifier id : getIds(kind)) {
if (id.hasValue())
id.print(os);
else
os << "None";
os << " ";
}
};
os << "(";
printIds(VarKind::Domain);
os << ") -> (";
printIds(VarKind::Range);
os << ") : [";
printIds(VarKind::Symbol);
os << "]";
}
}
void PresburgerSpace::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}