blob: efcf18a11deb9d3802f937b3daca3bb277d09ba8 [file] [log] [blame] [edit]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef LIB_UTF_UTILS_INTERNAL_GENERIC_SIMD_H_
#define LIB_UTF_UTILS_INTERNAL_GENERIC_SIMD_H_
#include <lib/utf-utils/internal/scalar.h>
#include <algorithm>
#include <array>
#include <cassert>
#include <cinttypes>
#include <cstddef>
#include <cstring>
// This is an implementation of a SIMD-based validation method based on the lookup method described
// in the paper DOI:10.1002/spe.2920.
//
// To add an implementation for a new set of SIMD instructions, simply implement the following
// skeleton class:
//
// ```cpp
// class ArchImpl {
// public:
// class Vector {
// public:
// // Alias for the underlying vector type.
// using Underlying = int;
//
// // Creates a vector from data loaded from a pointer to an array.
// //
// // The array must be at least of size `sizeof(Vector)`.
// static Vector LoadFromArray(const void* ptr);
//
// // Creates a vector filled with a value.
// static Vector Fill(uint8_t val);
//
// // Creates and fills a vector with up to 32 elements.
// //
// // If the vector holds less than 32 elements, the vector ignores the elements in the front of
// // the array.
// static Vector Set32(const std::array<uint8_t, 32>& vals);
//
// // Creates and fills a vector with 16 elements.
// //
// // If the vector holds more than 16 elements, the elements are repeated in chunks of 16.
// static Vector SetRepeat16(const std::array<uint8_t, 16>& vals);
//
// // Returns whether the vector is all zero.
// bool IsAllZero() const;
//
// // Returns whether all bytes in the vector are ASCII.
// bool IsAscii() const;
//
// // Returns the underlying vector.
// Underlying& value();
//
// // Returns the underlying vector.
// const Underlying& value() const;
//
// // Performs a logical OR on every byte between two vectors.
// friend Vector operator|(const Vector& a, const Vector& b);
// Vector& operator|=(const Vector& other);
//
// // Performs a logical AND on every byte between two vectors.
// friend Vector operator&(const Vector& a, const Vector& b);
//
// // Performs a logical XOR on every byte between two vectors.
// friend Vector operator^(const Vector& a, const Vector& b);
//
// // Performs a saturating subtraction on every byte in the vector, given a subtrahend.
// Vector SaturatingSub(const Vector& subtrahend) const;
//
// // Shifts the vector contents right by 4.
// Vector Shr4() const;
//
// // Grabs `N` bytes from the lower part of the previous vector, `prev`, and in a new vector,
// // replaces the upper parts of the current vector with the `N` bytes.
// template <size_t N>
// Vector Prev(const Vector& prev) const;
//
// // Performs a vector table lookup.
// Vector Lookup16(const std::array<uint8_t, 16>& table) const;
//
// // Stores the vector to an array.
// //
// // The array must be at least of size `sizeof(Vector)`.
// void StoreToArray(void *ptr);
// };
//
// // Prefetch data at a given address.
// static void Prefetch(const void* ptr);
//
// // Returns the size of the vector.
// static constexpr size_t VectorSize();
//
// // Checks for continuation bytes in previous vectors.
// static Vector Check2Or3Continuation(const Vector& prev2, const Vector& prev3);
// };
// ```
//
// The resulting class can then be passed as a template parameter to `IsValidUtf8Simd()`, which runs
// the generic algorithm.
namespace utfutils {
namespace internal {
// Helper class to help group vectors into chunks. This helps unroll loops into (slightly) larger
// strings of vectorized instructions to reduce the frequency (and thus, cost) of branching checks.
template <typename ArchImpl>
class VectorChunk {
private:
using Vector = typename ArchImpl::Vector;
// Target a common cache line size as the size of chunks.
static constexpr size_t kTargetChunkSize = 64;
static constexpr size_t GetVectorCount() {
return std::max(kTargetChunkSize / ArchImpl::VectorSize(), size_t{1});
}
public:
static constexpr size_t GetSize() { return ArchImpl::VectorSize() * GetVectorCount(); }
static VectorChunk LoadFromArray(const void* ptr) {
VectorChunk chunk;
for (size_t i = 0; i < chunk.GetVectorCount(); ++i) {
chunk.vectors()[i] =
Vector::LoadFromArray(static_cast<const uint8_t*>(ptr) + (i * ArchImpl::VectorSize()));
}
return chunk;
}
const std::array<Vector, GetVectorCount()>& vectors() const { return chunk_; }
std::array<Vector, GetVectorCount()>& vectors() { return chunk_; }
bool IsAscii() const {
Vector result;
// Specialize cases to reduce data dependencies
if (GetVectorCount() == 4) {
Vector v1 = chunk_[0] | chunk_[1];
Vector v2 = chunk_[2] | chunk_[3];
result = v1 | v2;
} else if (GetVectorCount() == 2) {
result = chunk_[0] | chunk_[1];
} else {
result = Vector::Fill(0);
for (const auto& vec : chunk_) {
result |= vec;
}
}
return result.IsAscii();
}
void StoreToArray(void* ptr) {
for (size_t i = 0; i < GetVectorCount(); ++i) {
vectors()[i].StoreToArray(static_cast<uint8_t*>(ptr) + (i * ArchImpl::VectorSize()));
}
}
private:
std::array<Vector, GetVectorCount()> chunk_;
};
// Classify each byte according to the rules it must adhere to.
template <typename Vector>
Vector ClassifyRules(const Vector& cur, const Vector& prev) {
// The rules are as follows:
// * Too Short: the leading byte must be followed by N-1 continuation bytes, where N is the UTF-8
// codepoint length.
// * Too Long: the leading byte must not be a continuation byte.
// * Overlong: the character must be above U+7F for 2-byte codepoints, U+7FF for 3-byte
// codepoints, and U+FFFF for four-byte characters.
// * Too Large: the character must be <= U+10FFFF.
// * Surrogate: the character must not be in the range [U+D800, U+DFFF].
constexpr uint8_t kTooShort = 1 << 0; // Bad: lead byte -> lead byte, lead byte -> ASCII
constexpr uint8_t kTooLong = 1 << 1; // Bad: ASCII -> continuation byte
constexpr uint8_t kOverlong2 = 1 << 2; // Bad: 11100000 100*****
constexpr uint8_t kTooLarge1001 = 1 << 3; // Bad: too large with second byte 101**** or 1001****
constexpr uint8_t kSurrogate = 1 << 4; // Bad: surrogate code point
constexpr uint8_t kOverlong3 = 1 << 5; // Bad: 1100000* 10******
constexpr uint8_t kTooLarge1000 = 1 << 6; // Bad: too large with second byte 1000****
constexpr uint8_t kOverlong4 = 1 << 6; // Bad: 11110000 1000****
constexpr uint8_t kTwoContinuations = 1 << 7; // This, by itself, is not invalid.
// For codepoints that allow anything in the first 4 bits of the first byte.
constexpr uint8_t kCarry = kTooShort | kTooLong | kTwoContinuations;
Vector prev1 = cur.template Prev<1>(prev);
Vector byte1_lo =
(prev1 & Vector::Fill(0xF))
.Lookup16({kCarry | kOverlong2 | kOverlong3 | kOverlong4, kCarry | kOverlong3, kCarry,
kCarry, kCarry | kTooLarge1001, kCarry | kTooLarge1001 | kTooLarge1000,
kCarry | kTooLarge1001 | kTooLarge1000, kCarry | kTooLarge1001 | kTooLarge1000,
kCarry | kTooLarge1001 | kTooLarge1000, kCarry | kTooLarge1001 | kTooLarge1000,
kCarry | kTooLarge1001 | kTooLarge1000, kCarry | kTooLarge1001 | kTooLarge1000,
kCarry | kTooLarge1001 | kTooLarge1000,
kCarry | kTooLarge1001 | kTooLarge1000 | kSurrogate,
kCarry | kTooLarge1001 | kTooLarge1000,
kCarry | kTooLarge1001 | kTooLarge1000});
Vector byte1_hi = prev1.Shr4().Lookup16(
{kTooLong, kTooLong, kTooLong, kTooLong, kTooLong, kTooLong, kTooLong, kTooLong,
kTwoContinuations, kTwoContinuations, kTwoContinuations, kTwoContinuations,
kTooShort | kOverlong3, kTooShort, kTooShort | kOverlong2 | kSurrogate,
kTooShort | kTooLarge1001 | kTooLarge1000 | kOverlong4});
Vector byte2_hi = cur.Shr4().Lookup16(
{kTooShort, kTooShort, kTooShort, kTooShort, kTooShort, kTooShort, kTooShort, kTooShort,
kTooLong | kOverlong3 | kTwoContinuations | kOverlong2 | kTooLarge1000 | kOverlong4,
kTooLong | kOverlong3 | kTwoContinuations | kOverlong2 | kTooLarge1001,
kTooLong | kOverlong3 | kTwoContinuations | kSurrogate | kTooLarge1001,
kTooLong | kOverlong3 | kTwoContinuations | kSurrogate | kTooLarge1001, kTooShort, kTooShort,
kTooShort, kTooShort});
return byte1_lo & byte1_hi & byte2_hi;
}
// Checks that the multi-byte codepoints are of appropriate length.
template <typename ArchImpl>
typename ArchImpl::Vector CheckMultiByte(const typename ArchImpl::Vector& cur,
const typename ArchImpl::Vector& prev) {
auto prev2 = cur.template Prev<2>(prev);
auto prev3 = cur.template Prev<3>(prev);
return ArchImpl::Check2Or3Continuation(prev2, prev3) & ArchImpl::Vector::Fill(0x80);
}
// Checks whether a vector, if terminal, would be incomplete UTF-8.
template <typename Vector>
Vector CheckIncomplete(const Vector& vec) {
return vec.SaturatingSub(Vector::Set32(
{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0b11101111, 0b11011111, 0b10111111}));
}
template <typename ArchImpl>
void ProcessChunk(const VectorChunk<ArchImpl>& cur, typename ArchImpl::Vector& prev,
typename ArchImpl::Vector& error) {
using Vector = typename ArchImpl::Vector;
if (cur.IsAscii()) {
// If this chunk is all ASCII, then the last chunk must not have continuation bytes that
// overflow the previous chunk and into this chunk.
error |= CheckIncomplete(prev);
prev = cur.vectors().back();
} else {
// Chunk is not all ASCII so check rules and multibyte.
for (auto& vec : cur.vectors()) {
Vector rule_check = ClassifyRules(vec, prev);
Vector multibyte = CheckMultiByte<ArchImpl>(vec, prev);
error |= (rule_check ^ multibyte);
prev = vec;
}
}
}
// Runs the generic SIMD validate and store algorithm on a particular architecture implementation.
//
// Note the `dst` parameter must only be valid if template parameter `do_copy` is set.
template <typename ArchImpl, bool do_copy>
bool RunValidateAndCopyUtf8Simd(const char* src, __attribute__((unused)) char* dst,
const size_t size) {
using Vector = typename ArchImpl::Vector;
// Use the scalar path for small strings.
constexpr size_t kSmallStringSize = VectorChunk<ArchImpl>::GetSize() * 3 / 2;
if (size < kSmallStringSize) {
if /* constexpr */ (do_copy) {
return ValidateAndCopyUtf8Scalar(src, dst, size);
} else {
return IsValidUtf8Scalar(src, size);
}
}
if (src == nullptr) {
return false;
}
auto prev = Vector::Fill(0);
auto error = Vector::Fill(0);
const size_t aligned_size = size - (size % VectorChunk<ArchImpl>::GetSize());
size_t offset = 0;
bool encountered_non_ascii = false;
// Fast path for ASCII
for (; offset < aligned_size; offset += VectorChunk<ArchImpl>::GetSize()) {
auto cur = VectorChunk<ArchImpl>::LoadFromArray(src + offset);
if (__builtin_expect(!cur.IsAscii(), 0)) {
encountered_non_ascii = true;
break;
}
if /* constexpr */ (do_copy) {
cur.StoreToArray(dst + offset);
}
// Note that it's not necessary to assign `prev = cur` here since `prev` is initialized to all
// zeros, which is valid ASCII. If a non-ASCII sequence were to be detected in this loop and
// break early, the previous chunk (before the one that triggered the break) was some valid
// sequence of ASCII characters. The exact contents of that sequence does not matter.
}
// Process chunks of vectors.
for (; offset < aligned_size; offset += VectorChunk<ArchImpl>::GetSize()) {
// Prefetch the next chunk of data.
ArchImpl::Prefetch(src + offset + VectorChunk<ArchImpl>::GetSize());
auto cur = VectorChunk<ArchImpl>::LoadFromArray(src + offset);
if /* constexpr */ (do_copy) {
cur.StoreToArray(dst + offset);
}
ProcessChunk<ArchImpl>(cur, prev, error);
}
if (offset < size) {
if (!encountered_non_ascii) {
// If a non-ASCII character has yet to be encountered, use the scalar implementation, which is
// quicker for very small strings.
if /* constexpr */ (do_copy) {
return ValidateAndCopyUtf8Scalar(src + offset, dst + offset, size - offset);
} else {
return IsValidUtf8Scalar(src + offset, size - offset);
}
}
// Copy remaining data into a zero-initialized chunk.
VectorChunk<ArchImpl> cur = {};
memcpy(&cur, src + offset, size - offset);
if /* constexpr */ (do_copy) {
memcpy(dst + offset, src + offset, size - offset);
}
ProcessChunk<ArchImpl>(cur, prev, error);
}
if (__builtin_expect(!prev.IsAscii(), 0)) {
// The previous chunk was not all ASCII, so check that there wasn't an incomplete multibyte
// codepoint hanging at the end.
error |= CheckIncomplete(prev);
}
return error.IsAllZero();
}
template <typename ArchImpl>
bool IsValidUtf8Simd(const char* data, size_t size) {
return RunValidateAndCopyUtf8Simd<ArchImpl, false>(data, nullptr, size);
}
template <typename ArchImpl>
bool ValidateAndCopyUtf8Simd(const char* src, char* dst, size_t size) {
return RunValidateAndCopyUtf8Simd<ArchImpl, true>(src, dst, size);
}
} // namespace internal
} // namespace utfutils
#endif // LIB_UTF_UTILS_INTERNAL_GENERIC_SIMD_H_