blob: 0ca8809bca8a41ed907ba4933cdb12b3754689bf [file] [log] [blame]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "emitc/emitc_std.h"
#include "emitc/emitc_types.h"
namespace {
using namespace emitc;
using ::testing::Eq;
using ::testing::FloatEq;
using ::testing::Pointwise;
TEST(std, splat) {
EXPECT_THAT(standard::splat<Tensor0D<uint32_t>>(1), Pointwise(Eq(), {1}));
auto lambda_1d = []() -> Tensor1D<int32_t, 3> {
return standard::splat<Tensor1D<int32_t, 3>>(-1);
};
EXPECT_THAT(lambda_1d(), Pointwise(Eq(), {-1, -1, -1}));
auto lambda_2d = []() -> Tensor2D<float, 2, 2> {
return standard::splat<Tensor2D<float, 2, 2>>(1.5f);
};
EXPECT_THAT(lambda_2d(), Pointwise(FloatEq(), {1.5f, 1.5f, 1.5f, 1.5f}));
auto lambda_3d = []() -> Tensor3D<float, 2, 3, 1> {
return standard::splat<Tensor3D<float, 2, 3, 1>>(1.2f);
};
EXPECT_THAT(lambda_3d(),
Pointwise(FloatEq(), {1.2f, 1.2f, 1.2f, 1.2f, 1.2f, 1.2f}));
auto lambda_4d = []() -> Tensor4D<float, 2, 3, 1, 2> {
return standard::splat<Tensor4D<float, 2, 3, 1, 2>>(1.1f);
};
EXPECT_THAT(lambda_4d(),
Pointwise(FloatEq(), {1.1f, 1.1f, 1.1f, 1.1f, 1.1f, 1.1f, 1.1f,
1.1f, 1.1f, 1.1f, 1.1f, 1.1f}));
}
} // namespace