blob: 44b0f67d1e3078319affb2e9c0e504f9ff00f581 [file] [log] [blame]
//===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines several support classes for defining interfaces.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_INTERFACESUPPORT_H
#define MLIR_SUPPORT_INTERFACESUPPORT_H
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/TypeName.h"
namespace mlir {
namespace detail {
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
/// This class represents an abstract interface. An interface is a simplified
/// mechanism for attaching concept based polymorphism to a class hierarchy. An
/// interface is comprised of two components:
/// * The derived interface class: This is what users interact with, and invoke
/// methods on.
/// * An interface `Trait` class: This is the class that is attached to the
/// object implementing the interface. It is the mechanism with which models
/// are specialized.
///
/// Derived interfaces types must provide the following template types:
/// * ConcreteType: The CRTP derived type.
/// * ValueT: The opaque type the derived interface operates on. For example
/// `Operation*` for operation interfaces, or `Attribute` for
/// attribute interfaces.
/// * Traits: A class that contains definitions for a 'Concept' and a 'Model'
/// class. The 'Concept' class defines an abstract virtual interface,
/// where as the 'Model' class implements this interface for a
/// specific derived T type. Both of these classes *must* not contain
/// non-static data. A simple example is shown below:
///
/// ```c++
/// struct ExampleInterfaceTraits {
/// struct Concept {
/// virtual unsigned getNumInputs(T t) const = 0;
/// };
/// template <typename DerivedT> class Model {
/// unsigned getNumInputs(T t) const final {
/// return cast<DerivedT>(t).getNumInputs();
/// }
/// };
/// };
/// ```
///
/// * BaseType: A desired base type for the interface. This is a class that
/// provides that provides specific functionality for the `ValueT`
/// value. For instance the specific `Op` that will wrap the
/// `Operation*` for an `OpInterface`.
/// * BaseTrait: The base type for the interface trait. This is the base class
/// to use for the interface trait that will be attached to each
/// instance of `ValueT` that implements this interface.
///
template <typename ConcreteType, typename ValueT, typename Traits,
typename BaseType,
template <typename, template <typename> class> class BaseTrait>
class Interface : public BaseType {
public:
using Concept = typename Traits::Concept;
template <typename T> using Model = typename Traits::template Model<T>;
using InterfaceBase =
Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
/// This is a special trait that registers a given interface with an object.
template <typename ConcreteT>
struct Trait : public BaseTrait<ConcreteT, Trait> {
using ModelT = Model<ConcreteT>;
/// Define an accessor for the ID of this interface.
static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
};
/// Construct an interface from an instance of the value type.
Interface(ValueT t = ValueT())
: BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
assert((!t || impl) && "expected value to provide interface instance");
}
/// Construct an interface instance from a type that implements this
/// interface's trait.
template <typename T, typename std::enable_if_t<
std::is_base_of<Trait<T>, T>::value> * = nullptr>
Interface(T t)
: BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
assert((!t || impl) && "expected value to provide interface instance");
}
/// Support 'classof' by checking if the given object defines the concrete
/// interface.
static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
/// Define an accessor for the ID of this interface.
static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
protected:
/// Get the raw concept in the correct derived concept type.
const Concept *getImpl() const { return impl; }
Concept *getImpl() { return impl; }
private:
/// A pointer to the impl concept object.
Concept *impl;
};
//===----------------------------------------------------------------------===//
// InterfaceMap
//===----------------------------------------------------------------------===//
/// Utility to filter a given sequence of types base upon a predicate.
template <bool>
struct FilterTypeT {
template <class E>
using type = std::tuple<E>;
};
template <>
struct FilterTypeT<false> {
template <class E>
using type = std::tuple<>;
};
template <template <class> class Pred, class... Es>
struct FilterTypes {
using type = decltype(std::tuple_cat(
std::declval<
typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
};
/// This class provides an efficient mapping between a given `Interface` type,
/// and a particular implementation of its concept.
class InterfaceMap {
/// Trait to check if T provides a static 'getInterfaceID' method.
template <typename T, typename... Args>
using has_get_interface_id = decltype(T::getInterfaceID());
template <typename T>
using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
template <typename... Types>
using num_interface_types = typename std::tuple_size<
typename FilterTypes<detect_get_interface_id, Types...>::type>;
public:
InterfaceMap(InterfaceMap &&) = default;
~InterfaceMap() {
if (interfaces) {
for (auto &it : *interfaces)
free(it.second);
}
}
/// Construct an InterfaceMap with the given set of template types. For
/// convenience given that object trait lists may contain other non-interface
/// types, not all of the types need to be interfaces. The provided types that
/// do not represent interfaces are not added to the interface map.
template <typename... Types>
static std::enable_if_t<num_interface_types<Types...>::value != 0,
InterfaceMap>
get() {
// Filter the provided types for those that are interfaces.
using FilteredTupleType =
typename FilterTypes<detect_get_interface_id, Types...>::type;
return getImpl((FilteredTupleType *)nullptr);
}
template <typename... Types>
static std::enable_if_t<num_interface_types<Types...>::value == 0,
InterfaceMap>
get() {
return InterfaceMap();
}
/// Returns an instance of the concept object for the given interface if it
/// was registered to this map, null otherwise.
template <typename T> typename T::Concept *lookup() const {
void *inst = interfaces ? interfaces->lookup(T::getInterfaceID()) : nullptr;
return reinterpret_cast<typename T::Concept *>(inst);
}
private:
InterfaceMap() = default;
InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements)
: interfaces(std::make_unique<llvm::SmallDenseMap<TypeID, void *>>(
elements.begin(), elements.end())) {}
template <typename... Ts>
static InterfaceMap getImpl(std::tuple<Ts...> *) {
std::pair<TypeID, void *> elements[] = {std::make_pair(
Ts::getInterfaceID(),
new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...};
return InterfaceMap(elements);
}
/// The internal map of interfaces. This is constructed statically for each
/// set of interfaces.
std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces;
};
} // end namespace detail
} // end namespace mlir
#endif