blob: 536dd7ea07e4a4ae8fc54a506ad5683d5aa05613 [file] [log] [blame]
//===- Rules.h - Helpers for declaring facts and rules ----------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines helper classes and functions for managing state (facts),
// merging and tracking modification for various data types important for
// quantization.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_QUANTIZER_SUPPORT_RULES_H
#define MLIR_QUANTIZER_SUPPORT_RULES_H
#include "llvm/ADT/Optional.h"
#include <algorithm>
#include <limits>
#include <utility>
namespace mlir {
namespace quantizer {
/// Typed indicator of whether a mutator produces a modification.
struct ModificationResult {
enum ModificationEnum { Retained, Modified } value;
ModificationResult(ModificationEnum v) : value(v) {}
ModificationResult operator|(ModificationResult other) {
if (value == Modified || other.value == Modified) {
return ModificationResult(Modified);
} else {
return ModificationResult(Retained);
}
}
ModificationResult operator|=(ModificationResult other) {
value =
(value == Modified || other.value == Modified) ? Modified : Retained;
return *this;
}
};
inline ModificationResult modify(bool isModified = true) {
return ModificationResult{isModified ? ModificationResult::Modified
: ModificationResult::Retained};
}
inline bool modified(ModificationResult m) {
return m.value == ModificationResult::Modified;
}
/// A fact that can converge through forward propagation alone without the
/// need to track ownership or individual assertions. In practice, this works
/// for static assertions that are either minimized or maximized and do not
/// vary dynamically.
///
/// It is expected that ValueTy is appropriate to pass by value and has an
/// operator==. The BinaryReducer type should have two static methods:
/// using ValueTy : Type of the value.
/// ValueTy initialValue() : Returns the initial value of the fact.
/// ValueTy reduce(ValueTy lhs, ValueTy rhs) : Reduces two values.
template <typename BinaryReducer>
class BasePropagatedFact {
public:
using ValueTy = typename BinaryReducer::ValueTy;
using ThisTy = BasePropagatedFact<BinaryReducer>;
BasePropagatedFact()
: value(BinaryReducer::initialValue()),
salience(std::numeric_limits<int>::min()) {}
int getSalience() const { return salience; }
bool hasValue() const { return salience != std::numeric_limits<int>::min(); }
ValueTy getValue() const { return value; }
ModificationResult assertValue(int assertSalience, ValueTy assertValue) {
if (assertSalience > salience) {
// New salience band.
value = assertValue;
salience = assertSalience;
return modify(true);
} else if (assertSalience < salience) {
// Lower salience - ignore.
return modify(false);
}
// Merge within same salience band.
ValueTy updatedValue = BinaryReducer::reduce(value, assertValue);
auto mod = modify(value != updatedValue);
value = updatedValue;
return mod;
}
ModificationResult mergeFrom(const ThisTy &other) {
if (other.hasValue()) {
return assertValue(other.getSalience(), other.getValue());
}
return modify(false);
}
private:
ValueTy value;
int salience;
};
/// A binary reducer that expands a min/max range represented by a pair
/// of doubles such that it represents the largest of all inputs.
/// The initial value is (Inf, -Inf).
struct ExpandingMinMaxReducer {
using ValueTy = std::pair<double, double>;
static ValueTy initialValue() {
return std::make_pair(std::numeric_limits<double>::infinity(),
-std::numeric_limits<double>::infinity());
}
static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
return std::make_pair(std::min(lhs.first, rhs.first),
std::max(lhs.second, rhs.second));
}
};
using ExpandingMinMaxFact = BasePropagatedFact<ExpandingMinMaxReducer>;
/// A binary reducer that minimizing a numeric type.
template <typename T>
struct MinimizingNumericReducer {
using ValueTy = T;
static ValueTy initialValue() {
if (std::numeric_limits<T>::has_infinity()) {
return std::numeric_limits<T>::infinity();
} else {
return std::numeric_limits<T>::max();
}
}
static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::min(lhs, rhs); }
};
using MinimizingDoubleFact =
BasePropagatedFact<MinimizingNumericReducer<double>>;
using MinimizingIntFact = BasePropagatedFact<MinimizingNumericReducer<int>>;
/// A binary reducer that maximizes a numeric type.
template <typename T>
struct MaximizingNumericReducer {
using ValueTy = T;
static ValueTy initialValue() {
if (std::numeric_limits<T>::has_infinity()) {
return -std::numeric_limits<T>::infinity();
} else {
return std::numeric_limits<T>::min();
}
}
static ValueTy reduce(ValueTy lhs, ValueTy rhs) { return std::max(lhs, rhs); }
};
using MaximizingDoubleFact =
BasePropagatedFact<MaximizingNumericReducer<double>>;
using MaximizingIntFact = BasePropagatedFact<MaximizingNumericReducer<int>>;
/// A fact and reducer for tracking agreement of discrete values. The value
/// type consists of a |T| value and a flag indicating whether there is a
/// conflict (in which case, the preserved value is arbitrary).
template <typename T>
struct DiscreteReducer {
struct ValueTy {
ValueTy() : conflict(false) {}
ValueTy(T value) : value(value), conflict(false) {}
ValueTy(T value, bool conflict) : value(value), conflict(conflict) {}
llvm::Optional<T> value;
bool conflict;
bool operator==(const ValueTy &other) const {
if (conflict != other.conflict)
return false;
if (value && other.value) {
return *value == *other.value;
} else {
return !value && !other.value;
}
}
bool operator!=(const ValueTy &other) const { return !(*this == other); }
};
static ValueTy initialValue() { return ValueTy(); }
static ValueTy reduce(ValueTy lhs, ValueTy rhs) {
if (!lhs.value && !rhs.value)
return lhs;
else if (!lhs.value)
return rhs;
else if (!rhs.value)
return lhs;
else
return ValueTy(*lhs.value, *lhs.value != *rhs.value);
}
};
template <typename T>
using DiscreteFact = BasePropagatedFact<DiscreteReducer<T>>;
/// Discrete scale/zeroPoint fact.
using DiscreteScaleZeroPointFact = DiscreteFact<std::pair<double, int64_t>>;
} // end namespace quantizer
} // end namespace mlir
#endif // MLIR_QUANTIZER_SUPPORT_RULES_H