blob: 5b74236376ac14edc054c8265b72e8cb7a23b103 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "ruy/prepacked_cache.h"
#include <thread> // NOLINT(build/c++11)
#include "ruy/context.h"
#include "ruy/context_get_ctx.h"
#include "ruy/gtest_wrapper.h"
#include "ruy/ruy.h"
#include "ruy/time.h"
namespace ruy {
namespace {
TEST(PrepackedCacheTest, TestCacheEjection) {
// Create the cache.
PrepackedCache prepacked_cache(32);
// Allocate the prepacked matrix.
PrepackedMatrix mat1;
mat1.data_size = 16;
mat1.sums_size = 8;
prepacked_cache.AllocatePrepackedMatrix(&mat1);
auto cache_key1 = std::make_pair(nullptr, mat1.data);
prepacked_cache.Insert(cache_key1, mat1);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
// Get a time point after the insertion into the cache.
TimePoint current = CoarseNow();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
PrepackedCache::CacheIterator itr = prepacked_cache.FindAndUpdate(cache_key1);
EXPECT_NE(itr, prepacked_cache.cend());
// By finding mat1, we updated its timestamp. Verify that `current` is older
// than the time stamp now associated with mat1.
EXPECT_LT(current, itr->second.second);
PrepackedMatrix mat2;
mat2.data_size = 8;
mat2.sums_size = 4;
prepacked_cache.AllocatePrepackedMatrix(&mat2);
auto cache_key2 = std::make_pair(nullptr, mat2.data);
prepacked_cache.Insert(cache_key2, mat2);
// The cache size was exceeded by inserting mat2. Ensure that mat1 was
// ejected.
EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
}
TEST(PrepackedCacheTest, TestCacheBasic) {
// Create the cache.
PrepackedCache prepacked_cache(48);
// Allocate the prepacked matrix.
PrepackedMatrix mat1;
mat1.data_size = 16;
mat1.sums_size = 8;
prepacked_cache.AllocatePrepackedMatrix(&mat1);
auto cache_key1 = std::make_pair(nullptr, mat1.data);
prepacked_cache.Insert(cache_key1, mat1);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
PrepackedMatrix mat2;
mat2.data_size = 8;
mat2.sums_size = 4;
prepacked_cache.AllocatePrepackedMatrix(&mat2);
auto cache_key2 = std::make_pair(nullptr, mat2.data);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
prepacked_cache.Insert(cache_key2, mat2);
// The cache size was not exceeded by inserting mat2. Ensure that mat1 was not
// ejected.
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
}
TEST(PrepackedCacheTest, TestCacheEjection2) {
// Create the cache.
PrepackedCache prepacked_cache(73);
// Allocate the prepacked matrix 1.
PrepackedMatrix mat1;
mat1.data_size = 16;
mat1.sums_size = 8;
prepacked_cache.AllocatePrepackedMatrix(&mat1);
auto cache_key1 = std::make_pair(nullptr, mat1.data);
prepacked_cache.Insert(cache_key1, mat1);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
// Allocate the prepacked matrix 2.
PrepackedMatrix mat2;
mat2.data_size = 16;
mat2.sums_size = 8;
prepacked_cache.AllocatePrepackedMatrix(&mat2);
auto cache_key2 = std::make_pair(nullptr, mat2.data);
prepacked_cache.Insert(cache_key2, mat2);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
// Allocate the prepacked matrix 3.
PrepackedMatrix mat31;
mat31.data_size = 16;
mat31.sums_size = 8;
prepacked_cache.AllocatePrepackedMatrix(&mat31);
auto cache_key3 = std::make_pair(nullptr, mat31.data);
prepacked_cache.Insert(cache_key3, mat31);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
// The next insertion will cause the cache size to go over the ejection
// threshold. Touch matrix 1 and matrix 3 to make matrix 2 the oldest
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend());
std::this_thread::sleep_for(std::chrono::milliseconds(10));
// Allocate the prepacked matrix 4.
PrepackedMatrix mat4;
mat4.data_size = 16;
mat4.sums_size = 8;
prepacked_cache.AllocatePrepackedMatrix(&mat4);
auto cache_key4 = std::make_pair(nullptr, mat4.data);
prepacked_cache.Insert(cache_key4, mat4);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
// Ensure that mat2 was ejected, but mat1, mat3, and mat4 were not.
EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key2), prepacked_cache.cend());
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend());
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend());
}
void TestCacheOnCacheable(CachePolicy cache_policy, bool expected_cached) {
ruy::Context context;
ruy::Ctx* ctx = get_ctx(&context);
PrepackedCache* cache = ctx->GetPrepackedCache();
EXPECT_EQ(cache->TotalSize(), 0);
const float lhs_data[] = {1, 2, 3, 4};
const float rhs_data[] = {1, 2};
float dst_data[4];
ruy::Matrix<float> lhs;
ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
lhs.set_data(lhs_data);
ruy::Matrix<float> rhs;
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, rhs.mutable_layout());
rhs.set_data(rhs_data);
ruy::Matrix<float> dst;
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, dst.mutable_layout());
dst.set_data(dst_data);
ruy::MulParams<float, float> mul_params;
// Perform the multiplication and confirm no caching occurred.
ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst);
EXPECT_EQ(cache->TotalSize(), 0);
// Set cache policy for the LHS, repeat the multiplication, and see
// that caching did occur.
lhs.set_cache_policy(cache_policy);
ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst);
const bool actual_cached = cache->TotalSize() > 0;
EXPECT_EQ(actual_cached, expected_cached);
}
TEST(PrepackedCacheTest, TestCacheOnCacheable) {
for (CachePolicy cache_policy :
{CachePolicy::kNeverCache, CachePolicy::kCacheIfLargeSpeedup,
CachePolicy::kCacheIfSignificantSpeedup, CachePolicy::kAlwaysCache,
CachePolicy::kCacheLikeTheOldCode}) {
TestCacheOnCacheable(cache_policy,
cache_policy != CachePolicy::kNeverCache);
}
}
TEST(PrepackedCacheTest, TestClearCache) {
ruy::Context context;
PrepackedCache* cache = get_ctx(&context)->GetPrepackedCache();
EXPECT_EQ(cache->TotalSize(), 0);
const float lhs_data[] = {1, 2, 3, 4};
const float rhs_data[] = {1, 2};
float dst_data[4];
ruy::Matrix<float> lhs;
ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
lhs.set_data(lhs_data);
ruy::Matrix<float> rhs;
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, rhs.mutable_layout());
rhs.set_data(rhs_data);
ruy::Matrix<float> dst;
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, dst.mutable_layout());
dst.set_data(dst_data);
ruy::MulParams<float, float> mul_params;
// Set cacheable for the LHS and see that caching occurs.
lhs.set_cache_policy(CachePolicy::kAlwaysCache);
ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst);
EXPECT_NE(cache->TotalSize(), 0);
// Clear the cache via the Context.
context.ClearPrepackedCache();
// Verify that the cache is now empty.
cache = get_ctx(&context)->GetPrepackedCache();
EXPECT_EQ(cache->TotalSize(), 0);
}
} // namespace
} // namespace ruy
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}