blob: 9ff2b1a9f877748a25161dbe257738d1c43e64eb [file] [log] [blame]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file defines the EmitC core ops.
#ifndef EMITC_EMITC_CORE_OPS_H
#define EMITC_EMITC_CORE_OPS_H
#include <algorithm>
#include <cmath>
#include <complex>
#include <cstdint>
#include <cstring>
#include <functional>
#include <type_traits>
#include <vector>
#include "emitc_types.h"
namespace emitc {
/// Functions for unary elementwise ops.
// AbsOp
// TODO: Add support for complex numbers.
template <typename Src>
inline Src abs(Src x) {
using ET_Src = typename get_element_type<Src>::type;
auto f = static_cast<ET_Src (*)(ET_Src)>(std::abs);
return unary<Src>(x, f);
}
// CeilOp
template <typename Src>
inline Src ceil(Src x) {
using ET_Src = typename get_element_type<Src>::type;
auto f = static_cast<ET_Src (*)(ET_Src)>(std::ceil);
return unary<Src>(x, f);
}
// ConvertOp
template <typename Dest, typename Src>
inline Dest convert(Src x) {
using ET_Dest = typename get_element_type<Dest>::type;
using ET_Src = typename get_element_type<Src>::type;
auto cast = [](ET_Src value) { return static_cast<ET_Dest>(value); };
return unary<Dest, Src, UnaryFuncType<ET_Dest, ET_Src>>(x, cast);
}
// ExpOp
template <typename Src>
inline Src exp(Src x) {
using ET_Src = typename get_element_type<Src>::type;
auto f = static_cast<ET_Src (*)(ET_Src)>(std::exp);
return unary<Src>(x, f);
}
// FloorOp
template <typename Src>
inline Src floor(Src x) {
using ET_Src = typename get_element_type<Src>::type;
auto f = static_cast<ET_Src (*)(ET_Src)>(std::floor);
return unary<Src>(x, f);
}
// LogOp
template <typename Src>
inline Src log(Src x) {
using ET_Src = typename get_element_type<Src>::type;
auto f = static_cast<ET_Src (*)(ET_Src)>(std::log);
return unary<Src>(x, f);
}
// NegateOp
template <typename Src>
inline Src negate(Src x) {
using ET_Src = typename get_element_type<Src>::type;
auto f = std::negate<ET_Src>{};
return unary<Src>(x, f);
}
// ReluNOp
template <typename Min, typename Src, typename Max>
inline Src clamp(Min min, Src operand, Max max) {
static_assert(
std::is_same<Min, Src>::value ||
(is_tensor_of_dim<0, Min>::value &&
std::is_same<typename get_element_type<Src>::type,
typename get_element_type<Min>::type>::value),
"Expected the same type for min and operand or a 0-dim tensor of the "
"same element type for min");
static_assert(
std::is_same<Max, Src>::value ||
(is_tensor_of_dim<0, Max>::value &&
std::is_same<typename get_element_type<Src>::type,
typename get_element_type<Max>::type>::value),
"Expected the same type for min and operand or a 0-dim tensor of the "
"same element type for max");
const bool broadcast_min = !std::is_same<Min, Src>::value;
const bool broadcast_max = !std::is_same<Max, Src>::value;
Src result;
for (size_t index = 0; index < Src::size(); index++) {
const auto value_min = broadcast_min ? min[0] : min[index];
const auto value_max = broadcast_max ? max[0] : max[index];
auto value = operand[index];
value = value < value_min ? value_min : value;
value = value > value_max ? value_max : value;
result[index] = value;
}
return result;
}
// SqrtOp
template <typename Src>
inline Src sqrt(Src x) {
using ET_Src = typename get_element_type<Src>::type;
auto f = static_cast<ET_Src (*)(ET_Src)>(std::sqrt);
return unary<Src>(x, f);
}
// TanhOp
template <typename Src>
inline Src tanh(Src x) {
using ET_Src = typename get_element_type<Src>::type;
auto f = static_cast<ET_Src (*)(ET_Src)>(std::tanh);
return unary<Src>(x, f);
}
/// Functions for binary elementwise ops.
// AddOp
template <typename Src>
inline Src add(Src x, Src y) {
using ET_Src = typename get_element_type<Src>::type;
auto f = std::plus<ET_Src>{};
return binary<Src>(x, y, f);
}
// MaxOp
template <typename Src>
inline Src max(Src x, Src y) {
using ET_Src = typename get_element_type<Src>::type;
auto f =
static_cast<const ET_Src &(*)(const ET_Src &, const ET_Src &)>(std::max);
return binary<Src>(x, y, f);
}
// MinOp
template <typename Src>
inline Src min(Src x, Src y) {
using ET_Src = typename get_element_type<Src>::type;
auto f =
static_cast<const ET_Src &(*)(const ET_Src &, const ET_Src &)>(std::min);
return binary<Src>(x, y, f);
}
// MulOp
template <typename Src>
inline Src mul(Src x, Src y) {
using ET_Src = typename get_element_type<Src>::type;
auto f = std::multiplies<ET_Src>{};
return binary<Src>(x, y, f);
}
// PowOp
template <typename Src>
inline Src pow(Src x, Src y) {
using ET_Src = typename get_element_type<Src>::type;
auto f = [](ET_Src a, ET_Src b) -> ET_Src {
if (std::is_integral<ET_Src>::value) {
const bool negative = b < 0;
if (b < 0) {
b = -b;
}
ET_Src result = 1;
for (ET_Src i = 0; i < b; i++) {
result *= a;
}
if (negative) {
result = 1 / result;
}
return result;
} else {
return std::pow(a, b);
}
};
return binary<Src>(x, y, f);
}
// SubOp
template <typename Src>
inline Src sub(Src x, Src y) {
using ET_Src = typename get_element_type<Src>::type;
auto f = std::minus<ET_Src>{};
return binary<Src>(x, y, f);
}
/// Other ops.
// BroadcastInDimOp
// The broadcast_dimensions argument maps from Src to Dest dimensions
template <typename Dest, typename Src>
inline Dest
broadcast_in_dim(Src operand,
Tensor<int64_t, Src::rank()> broadcast_dimensions) {
static_assert(is_tensor<Src>::value, "Expected tensor argument");
static_assert(is_tensor<Dest>::value, "Expected tensor result");
std::vector<size_t> retainedDimensions(Dest::rank());
std::iota(retainedDimensions.begin(), retainedDimensions.end(), 0);
// Checks if broadcast_dimensions is a subset of 0 .. Dest::rank().
retainedDimensions.erase(
std::remove_if(retainedDimensions.begin(), retainedDimensions.end(),
[&broadcast_dimensions](size_t i) {
return std::find(broadcast_dimensions.begin(),
broadcast_dimensions.end(),
i) == broadcast_dimensions.end();
}),
retainedDimensions.end());
assert(retainedDimensions.size() == Src::rank());
Dest result;
for (size_t i = 0; i < result.size(); i++) {
auto dest_index = result.unravel_index(i);
// Reverse mapping with broadcast_dimensions
std::array<size_t, Src::rank()> src_index;
for (size_t j = 0; j < src_index.size(); j++) {
src_index[j] = dest_index[broadcast_dimensions(j)];
}
// Handle case of broadcasting dimensions of size 1
for (size_t i = 0; i < src_index.size(); ++i) {
if (Src::shape()[i] == 1) {
src_index[i] = 0;
}
}
result[i] = operand[operand.ravel_index(src_index)];
}
return result;
}
// DotOp
template <typename Dest, typename Lhs, typename Rhs>
Dest dot(Lhs lhs, Rhs rhs) {
static_assert(is_tensor_of_dim<2, Lhs>::value, "Expected 2 dimensional lhs");
static_assert(is_tensor_of_dim<2, Rhs>::value, "Expected 2 dimensional rhs");
static_assert(Lhs::dim(1) == Rhs::dim(0),
"Expected contracting dimension to match");
Dest output;
for (size_t m = 0; m < lhs.dim(0); m++) {
for (size_t n = 0; n < lhs.dim(1); n++) {
for (size_t k = 0; k < rhs.dim(1); k++) {
output(m, k) += lhs(m, n) * rhs(n, k);
}
}
}
return output;
}
// BatchMatmulOp
template <typename Dest, typename Lhs, typename Rhs>
Dest batch_matmul(Lhs lhs, Rhs rhs) {
static_assert(is_tensor_of_dim<3, Lhs>::value, "Expected 3 dimensional lhs");
static_assert(is_tensor_of_dim<3, Rhs>::value, "Expected 3 dimensional rhs");
static_assert(Lhs::dim(0) == Rhs::dim(0) && Lhs::dim(0) == Dest::dim(0),
"Expected batch dimension to match");
static_assert(Lhs::dim(2) == Rhs::dim(1),
"Expected contracting dimension to match");
static_assert(Dest::dim(1) == Lhs::dim(1), "Expected row dimension to match");
static_assert(Dest::dim(2) == Rhs::dim(2),
"Expected column dimension to match");
Dest output;
for (size_t b = 0; b < lhs.dim(0); b++) {
for (size_t m = 0; m < lhs.dim(1); m++) {
for (size_t n = 0; n < lhs.dim(2); n++) {
for (size_t k = 0; k < rhs.dim(2); k++) {
output(b, m, k) += lhs(b, m, n) * rhs(b, n, k);
}
}
}
}
return output;
}
// ReshapeOp
template <typename Dest, typename Src>
inline Dest reshape(Src x) {
static_assert(is_tensor<Src>::value, "Expected tensor argument");
static_assert(is_tensor<Dest>::value, "Expected tensor result");
using ET_Src = typename get_element_type<Src>::type;
using ET_Dest = typename get_element_type<Dest>::type;
static_assert(std::is_same<ET_Src, ET_Dest>::value, "Element type mismatch");
static_assert(Src::size() == Dest::size(), "Tensor size mismatch");
Dest z;
std::copy(x.begin(), x.end(), z.begin());
return z;
}
// SliceOp
// Overload for 1d case.
template <typename Dest, typename Src, IsTensorOfDim<1, Src> = true>
Dest slice(Src x, Tensor<int64_t, 1> start_indices,
Tensor<int64_t, 1> limit_indices, Tensor<int64_t, 1> strides) {
Dest z;
size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
z[index++] = x(i);
}
return z;
}
// Overload for 2d case.
template <typename Dest, typename Src, IsTensorOfDim<2, Src> = true>
Dest slice(Src x, Tensor<int64_t, 2> start_indices,
Tensor<int64_t, 2> limit_indices, Tensor<int64_t, 2> strides) {
Dest z;
size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
for (int64_t j = start_indices[1]; j < limit_indices[1]; j += strides[1]) {
z[index++] = x(i, j);
}
}
return z;
}
// Overload for 3d case.
template <typename Dest, typename Src, IsTensorOfDim<3, Src> = true>
Dest slice(Src x, Tensor<int64_t, 3> start_indices,
Tensor<int64_t, 3> limit_indices, Tensor<int64_t, 3> strides) {
Dest z;
size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
for (int64_t j = start_indices[1]; j < limit_indices[1]; j += strides[1]) {
for (int64_t k = start_indices[2]; k < limit_indices[2];
k += strides[2]) {
z[index++] = x(i, j, k);
}
}
}
return z;
}
// Overload for 4d case.
template <typename Dest, typename Src, IsTensorOfDim<4, Src> = true>
Dest slice(Src x, Tensor<int64_t, 4> start_indices,
Tensor<int64_t, 4> limit_indices, Tensor<int64_t, 4> strides) {
Dest z;
size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
for (int64_t j = start_indices[1]; j < limit_indices[1]; j += strides[1]) {
for (int64_t k = start_indices[2]; k < limit_indices[2];
k += strides[2]) {
for (int64_t c = start_indices[3]; c < limit_indices[3];
c += strides[3]) {
z[index++] = x(i, j, k, c);
}
}
}
}
return z;
}
// PadOp
// TODO: Add support for negative edge padding
template <typename Dest, typename Src>
inline Dest pad(Src operand,
Tensor<typename get_element_type<Src>::type> padding_value,
Tensor<int64_t, Src::rank()> edge_padding_low,
Tensor<int64_t, Src::rank()> edge_padding_high,
Tensor<int64_t, Src::rank()> interior_padding) {
assert(std::all_of(interior_padding.begin(), interior_padding.end(),
[](int64_t i) { return i >= 0; }));
assert(std::all_of(edge_padding_low.begin(), edge_padding_low.end(),
[](int64_t i) { return i >= 0; }));
assert(std::all_of(edge_padding_high.begin(), edge_padding_high.end(),
[](int64_t i) { return i >= 0; }));
Dest result;
auto interior = [&interior_padding](std::array<size_t, Src::rank()> index) {
for (size_t i = 0; i < index.size(); i++) {
if (index[i] % (interior_padding[i] + 1) != 0) {
return true;
}
}
return false;
};
auto out_of_bounds = [](std::array<size_t, Src::rank()> index) {
for (size_t i = 0; i < index.size(); i++) {
if (index[i] < 0 || index[i] >= Src::dim(i)) {
return true;
}
}
return false;
};
for (size_t i = 0; i < result.size(); i++) {
auto index = result.unravel_index(i);
// Shift by low padding
for (size_t j = 0; j < index.size(); j++) {
index[j] -= edge_padding_low[j];
}
if (interior(index)) {
result[i] = padding_value();
} else {
// Squeeze by interrior padding
for (size_t j = 0; j < index.size(); j++) {
size_t pad = interior_padding[j];
assert(index[j] % (pad + 1) == 0);
index[j] /= (pad + 1);
}
if (out_of_bounds(index)) {
result[i] = padding_value();
} else {
result[i] = operand[operand.ravel_index(index)];
}
}
}
return result;
}
} // namespace emitc
#endif // EMITC_EMITC_CORE_OPS_H