blob: 6e95e4dad2a3362d250d6297a225d50b111dbf0a [file] [log] [blame]
// Copyright 2021 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// #pragma use_vulkan_memory_model // results in spirv-remap validation error
//
// Each "pass" scatters the keyvals to their new destinations.
//
// clang-format off
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_KHR_shader_subgroup_vote : require
#extension GL_KHR_memory_scope_semantics : require
#extension GL_KHR_shader_subgroup_ballot : require
// clang-format on
//
// Load arch/keyval configuration
//
#include "config.h"
//
// Optional switches:
//
// #define RS_SCATTER_DISABLE_REORDER
// #define RS_SCATTER_ENABLE_BITFIELD_EXTRACT
// #define RS_SCATTER_ENABLE_NV_MATCH
// #define RS_SCATTER_ENABLE_BROADCAST_MATCH
// #define RS_SCATTER_DISABLE_COMPONENTS_IN_REGISTERS
// #define RS_SCATTER_NONSEQUENTIAL_DISPATCH
//
//
// Use NVIDIA Turing/Volta+ partitioning operator (`match_any()`)?
//
#ifdef RS_SCATTER_ENABLE_NV_MATCH
#extension GL_NV_shader_subgroup_partitioned : require
#endif
//
// Store prefix intermediates in registers?
//
#ifdef RS_SCATTER_DISABLE_COMPONENTS_IN_REGISTERS
#define RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS
#endif
//
// Buffer reference macros and push constants
//
#include "bufref.h"
#include "push.h"
//
// Push constants for scatter shader
//
RS_STRUCT_PUSH_SCATTER();
layout(push_constant) uniform block_push
{
rs_push_scatter push;
};
//
// Subgroup uniform support
//
#if defined(RS_SCATTER_SUBGROUP_UNIFORM_DISABLE) && defined(GL_EXT_subgroupuniform_qualifier)
#extension GL_EXT_subgroupuniform_qualifier : required
#define RS_SUBGROUP_UNIFORM subgroupuniformEXT
#else
#define RS_SUBGROUP_UNIFORM
#endif
//
// Check all mandatory switches are defined
//
// What's the size of the keyval?
#ifndef RS_KEYVAL_DWORDS
#error "Undefined: RS_KEYVAL_DWORDS"
#endif
// Which keyval dword does this shader bitfieldExtract() bits?
#ifndef RS_SCATTER_KEYVAL_DWORD_BASE
#error "Undefined: RS_SCATTER_KEYVAL_DWORD_BASE"
#endif
//
#ifndef RS_SCATTER_BLOCK_ROWS
#error "Undefined: RS_SCATTER_BLOCK_ROWS"
#endif
//
#ifndef RS_SCATTER_SUBGROUP_SIZE_LOG2
#error "Undefined: RS_SCATTER_SUBGROUP_SIZE_LOG2"
#endif
//
#ifndef RS_SCATTER_WORKGROUP_SIZE_LOG2
#error "Undefined: RS_SCATTER_WORKGROUP_SIZE_LOG2"
#endif
//
// Status masks are defined differently for the scatter_even and
// scatter_odd shaders.
//
#ifndef RS_PARTITION_STATUS_INVALID
#error "Undefined: RS_PARTITION_STATUS_INVALID"
#endif
#ifndef RS_PARTITION_STATUS_REDUCTION
#error "Undefined: RS_PARTITION_STATUS_REDUCTION"
#endif
#ifndef RS_PARTITION_STATUS_PREFIX
#error "Undefined: RS_PARTITION_STATUS_PREFIX"
#endif
//
// Assumes (RS_RADIX_LOG2 == 8)
//
// Error if this ever changes!
//
#if (RS_RADIX_LOG2 != 8)
#error "Error: (RS_RADIX_LOG2 != 8)"
#endif
//
// Masks are different for scatter_even/odd.
//
// clang-format off
#define RS_PARTITION_MASK_INVALID (RS_PARTITION_STATUS_INVALID << 30)
#define RS_PARTITION_MASK_REDUCTION (RS_PARTITION_STATUS_REDUCTION << 30)
#define RS_PARTITION_MASK_PREFIX (RS_PARTITION_STATUS_PREFIX << 30)
#define RS_PARTITION_MASK_STATUS 0xC0000000
#define RS_PARTITION_MASK_COUNT 0x3FFFFFFF
// clang-format on
//
// Local macros
//
// clang-format off
#define RS_KEYVAL_SIZE (RS_KEYVAL_DWORDS * 4)
#define RS_WORKGROUP_SIZE (1 << RS_SCATTER_WORKGROUP_SIZE_LOG2)
#define RS_SUBGROUP_SIZE (1 << RS_SCATTER_SUBGROUP_SIZE_LOG2)
#define RS_WORKGROUP_SUBGROUPS (RS_WORKGROUP_SIZE / RS_SUBGROUP_SIZE)
#define RS_SUBGROUP_KEYVALS (RS_SCATTER_BLOCK_ROWS * RS_SUBGROUP_SIZE)
#define RS_BLOCK_KEYVALS (RS_SCATTER_BLOCK_ROWS * RS_WORKGROUP_SIZE)
#define RS_RADIX_MASK ((1 << RS_RADIX_LOG2) - 1)
// clang-format on
//
// Validate number of keyvals fit in a uint16_t.
//
#if (RS_BLOCK_KEYVALS >= 65536)
#error "Error: (RS_BLOCK_KEYVALS >= 65536)"
#endif
//
// Keyval type
//
#if (RS_KEYVAL_DWORDS == 1)
#define RS_KEYVAL_TYPE uint32_t
#elif (RS_KEYVAL_DWORDS == 2)
#define RS_KEYVAL_TYPE u32vec2
#else
#error "Error: Unsupported RS_KEYVAL_DWORDS"
#endif
//
// Set up match mask
//
#if (RS_SUBGROUP_SIZE <= 32)
#if (RS_SUBGROUP_SIZE == 32)
#define RS_SUBGROUP_MASK 0xFFFFFFFF
#else
#define RS_SUBGROUP_MASK ((1 << RS_SUBGROUP_SIZE) - 1)
#endif
#endif
//
// Determine at compile time the base of the final iteration for
// workgroups smaller than RS_RADIX_SIZE.
//
#if (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
#define RS_WORKGROUP_BASE_FINAL ((RS_RADIX_SIZE / RS_WORKGROUP_SIZE) * RS_WORKGROUP_SIZE)
#endif
//
// Max macro
//
#define RS_MAX_2(a_, b_) (((a_) >= (b_)) ? (a_) : (b_))
//
// Select a keyval dword
//
#if (RS_KEYVAL_DWORDS == 1)
#define RS_KV_DWORD(kv_, dword_) (kv_)
#else
#define RS_KV_DWORD(kv_, dword_) (kv_)[dword_]
#endif
//
// Is bitfield extract faster?
//
#ifdef RS_SCATTER_ENABLE_BITFIELD_EXTRACT
//----------------------------------------------------------------------
//
// Test a bit in a radix digit
//
#define RS_BIT_IS_ONE(val_, bit_) (bitfieldExtract(val_, bit_, 1) != 0)
//
// Extract a keyval digit
//
#if (RS_KEYVAL_DWORDS == 1)
#define RS_KV_EXTRACT_DIGIT(kv_) bitfieldExtract(kv_, int32_t(push.pass_offset), RS_RADIX_LOG2)
#else
#define RS_KV_EXTRACT_DIGIT(kv_) \
bitfieldExtract(kv_[RS_SCATTER_KEYVAL_DWORD_BASE], int32_t(push.pass_offset), RS_RADIX_LOG2)
#endif
//----------------------------------------------------------------------
#else
//----------------------------------------------------------------------
//
// Test a bit in a radix digit
//
#define RS_BIT_IS_ONE(val_, bit_) (((val_) & (1 << (bit_))) != 0)
//
// Extract a keyval digit
//
#if (RS_KEYVAL_DWORDS == 1)
#define RS_KV_EXTRACT_DIGIT(kv_) ((kv_ >> push.pass_offset) & RS_RADIX_MASK)
#else
#define RS_KV_EXTRACT_DIGIT(kv_) \
((kv_[RS_SCATTER_KEYVAL_DWORD_BASE] >> push.pass_offset) & RS_RADIX_MASK)
#endif
//----------------------------------------------------------------------
#endif
//
// Load prefix limits before loading prefix function and before
// calculating SMEM limits.
//
#include "prefix_limits.h"
//
// - The lookback span is RS_RADIX_SIZE dwords and overwrites the
// ballots span.
//
// - The histogram span is RS_RADIX_SIZE dwords
//
// - The keyvals span is at least one dword per keyval in the
// workgroup. This span overwrites anything past the lookback
// radix span.
//
// Shared memory map phase 1:
//
// < LOOKBACK > < HISTOGRAM > < PREFIX > ...
//
// Shared memory map phase 3:
//
// < LOOKBACK > < REORDER > ...
//
// FIXME(allanmac): Create a spreadsheet showing the exact shared
// memory footprint (RS_SMEM_DWORDS) for a configuration.
//
// | Dwords | Bytes
// ----------+-------------------------------------------+--------
// Lookback | 256 | 1 KB
// Histogram | 256 | 1 KB
// Prefix | 4-84 | 16-336
// Reorder | RS_WORKGROUP_SIZE * RS_SCATTER_BLOCK_ROWS | 2-8 KB
//
// clang-format off
#define RS_SMEM_LOOKBACK_SIZE RS_RADIX_SIZE
#define RS_SMEM_HISTOGRAM_SIZE RS_RADIX_SIZE
#define RS_SMEM_REORDER_SIZE (RS_SCATTER_BLOCK_ROWS * RS_WORKGROUP_SIZE)
#define RS_SMEM_DWORDS_PHASE_1 (RS_SMEM_LOOKBACK_SIZE + RS_SMEM_HISTOGRAM_SIZE + RS_SWEEP_SIZE)
#define RS_SMEM_DWORDS_PHASE_2 (RS_SMEM_LOOKBACK_SIZE + RS_SMEM_REORDER_SIZE)
#define RS_SMEM_DWORDS RS_MAX_2(RS_SMEM_DWORDS_PHASE_1, RS_SMEM_DWORDS_PHASE_2)
#define RS_SMEM_LOOKBACK_OFFSET 0
#define RS_SMEM_HISTOGRAM_OFFSET (RS_SMEM_LOOKBACK_OFFSET + RS_SMEM_LOOKBACK_SIZE)
#define RS_SMEM_PREFIX_OFFSET (RS_SMEM_HISTOGRAM_OFFSET + RS_SMEM_HISTOGRAM_SIZE)
#define RS_SMEM_REORDER_OFFSET (RS_SMEM_LOOKBACK_OFFSET + RS_SMEM_LOOKBACK_SIZE)
// clang-format on
//
//
//
layout(local_size_x = RS_WORKGROUP_SIZE) in;
//
//
//
layout(buffer_reference, std430) buffer buffer_rs_kv
{
RS_KEYVAL_TYPE extent[];
};
layout(buffer_reference, std430) buffer buffer_rs_histogram // single histogram
{
uint32_t extent[];
};
layout(buffer_reference, std430) buffer buffer_rs_partitions
{
uint32_t extent[];
};
//
// Declare shared memory
//
struct rs_scatter_smem
{
uint32_t extent[RS_SMEM_DWORDS];
};
shared rs_scatter_smem smem;
//
// The shared memory barrier is either subgroup-wide or
// workgroup-wide.
//
#if (RS_WORKGROUP_SUBGROUPS == 1)
#define RS_BARRIER() subgroupBarrier()
#else
#define RS_BARRIER() barrier()
#endif
//
// If multi-subgroup then define shared memory
//
#if (RS_WORKGROUP_SUBGROUPS > 1)
//----------------------------------------
#define RS_PREFIX_SWEEP0(idx_) smem.extent[RS_SMEM_PREFIX_OFFSET + RS_SWEEP_0_OFFSET + (idx_)]
//----------------------------------------
#if (RS_SWEEP_1_SIZE > 0)
//----------------------------------------
#define RS_PREFIX_SWEEP1(idx_) smem.extent[RS_SMEM_PREFIX_OFFSET + RS_SWEEP_1_OFFSET + (idx_)]
//----------------------------------------
#endif
#if (RS_SWEEP_2_SIZE > 0)
//----------------------------------------
#define RS_PREFIX_SWEEP2(idx_) smem.extent[RS_SMEM_PREFIX_OFFSET + RS_SWEEP_2_OFFSET + (idx_)]
//----------------------------------------
#endif
#endif
//
// Define prefix load/store functions
//
// clang-format off
#if (RS_WORKGROUP_SUBGROUPS == 1)
#define RS_PREFIX_LOAD(idx_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID + (idx_)]
#define RS_PREFIX_STORE(idx_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID + (idx_)]
#else
#define RS_PREFIX_LOAD(idx_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x + (idx_)]
#define RS_PREFIX_STORE(idx_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x + (idx_)]
#endif
// clang-format on
//
// If this is a nonsequential dispatch device then atomically acquire
// a scatter block index instead of using `gl_WorkGroupID.x`
//
#ifdef RS_SCATTER_NONSEQUENTIAL_DISPATCH
layout(buffer_reference, std430) buffer buffer_rs_workgroup_id
{
uint32_t x[RS_KEYVAL_DWORDS * 4];
};
#if (RS_WORKGROUP_SUBGROUPS == 1)
#define RS_IS_FIRST_LOCAL_INVOCATION() (gl_SubgroupInvocationID == 0)
#else
#define RS_IS_FIRST_LOCAL_INVOCATION() (gl_LocalInvocationID.x == 0)
#endif
RS_SUBGROUP_UNIFORM uint32_t rs_gl_workgroup_id_x;
#define RS_GL_WORKGROUP_ID_X rs_gl_workgroup_id_x
//
// Default is a device that sequentially dispatches workgroups.
//
#else
#define RS_GL_WORKGROUP_ID_X gl_WorkGroupID.x
#endif
//
// Load the prefix function
//
// The prefix function operates on shared memory so there are no
// arguments.
//
#define RS_PREFIX_ARGS // EMPTY
#include "prefix.h"
//
// Zero the SMEM histogram
//
void
rs_histogram_zero()
{
#if (RS_WORKGROUP_SUBGROUPS == 1)
const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
{
smem.extent[smem_offset + ii] = 0;
}
#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
{
smem.extent[smem_offset + ii] = 0;
}
#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE)
const uint32_t smem_offset_final = smem_offset + RS_WORKGROUP_BASE_FINAL;
if (smem_offset_final < RS_RADIX_SIZE)
{
smem.histogram[smem_offset_final] = 0;
}
#endif
#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE)
#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE)
if (gl_LocalInvocationID.x < RS_RADIX_SIZE)
#endif
{
smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x] = 0;
}
#endif
RS_BARRIER();
}
//
// Perform a workgroup-wide match operation that computes both a
// workgroup-wide index for each keyval and a workgroup-wide
// histogram.
//
// FIXME(allanmac): Special case (RS_WORKGROUP_SUBGROUPS==1)
//
void
rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
out uint32_t kr[RS_SCATTER_BLOCK_ROWS])
{
// clang-format off
#define RS_HISTOGRAM_LOAD(digit_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + (digit_)]
#define RS_HISTOGRAM_STORE(digit_, count_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + (digit_)] = (count_)
// clang-format on
#ifdef RS_SCATTER_ENABLE_NV_MATCH
//----------------------------------------------------------------------
//
// Use the Volta/Turing `match.sync` instruction.
//
// Note that performance is quite poor and the break-even for
// `match.sync` requires more bits.
//
//----------------------------------------------------------------------
#if (RS_SUBGROUP_SIZE == 32)
//
// 32
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
//
// NOTE(allanmac): Unfortunately there is no `match.any.sync.b8`
//
// TODO(allanmac): Consider using the `atomicOr()` match approach
// described by Adinets since Volta/Turing have extremely fast
// atomic smem operations.
//
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
const uint32_t match = subgroupPartitionNV(digit).x;
kr[ii] = (bitCount(match) << 16) | bitCount(match & gl_SubgroupLeMask.x);
}
#else
//
// Undefined!
//
#error "Error: rs_histogram_rank() undefined for subgroup size"
#endif
#elif !defined(RS_SCATTER_ENABLE_BROADCAST_MATCH)
//----------------------------------------------------------------------
//
// Default is to emulate a `match` operation with ballots.
//
//----------------------------------------------------------------------
#if (RS_SUBGROUP_SIZE == 64)
//
// 64
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
u32vec2 match;
{
const bool is_one = RS_BIT_IS_ONE(digit, RS_RADIX_LOG2 - 1);
const u32vec2 ballot = subgroupBallot(is_one).xy;
const uint32_t mask = is_one ? 0 : 0xFFFFFFFF;
match.x = (ballot.x ^ mask);
match.y = (ballot.y ^ mask);
}
[[unroll]] for (int32_t bit = RS_RADIX_LOG2 - 2; bit >= 0; bit--)
{
const bool is_one = RS_BIT_IS_ONE(digit, bit);
const u32vec2 ballot = subgroupBallot(is_one).xy;
const uint32_t mask = is_one ? 0 : 0xFFFFFFFF;
match.x &= (ballot.x ^ mask);
match.y &= (ballot.y ^ mask);
//
// Use the pigeonhole principle to exit early.
//
// If every key partition only contains itself then we're done.
//
#ifdef RS_SCATTER_RANK_EARLY_EXIT_LOG2
if ((bit > 0) && (bit <= RS_RADIX_LOG2 - RS_SCATTER_SUBGROUP_SIZE_LOG2) &&
(bit == RS_SCATTER_RANK_EARLY_EXIT_LOG2))
{
if (subgroupAll(bitCount(match.x) + bitCount(match.y) == 1))
{
break;
}
}
#endif
}
kr[ii] = ((bitCount(match.x) + bitCount(match.y)) << 16) |
(bitCount(match.x & gl_SubgroupLeMask.x) + //
bitCount(match.y & gl_SubgroupLeMask.y));
}
#elif ((RS_SUBGROUP_SIZE <= 32) && !defined(RS_SCATTER_ENABLE_NV_MATCH))
//
// <= 32
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
uint32_t match;
{
const bool is_one = RS_BIT_IS_ONE(digit, RS_RADIX_LOG2 - 1);
const uint32_t ballot = subgroupBallot(is_one).x;
const uint32_t mask = is_one ? 0 : RS_SUBGROUP_MASK;
match = (ballot ^ mask);
}
[[unroll]] for (int32_t bit = RS_RADIX_LOG2 - 2; bit >= 0; bit--)
{
const bool is_one = RS_BIT_IS_ONE(digit, bit);
const uint32_t ballot = subgroupBallot(is_one).x;
const uint32_t mask = is_one ? 0 : RS_SUBGROUP_MASK;
match &= (ballot ^ mask);
//
// Use the pigeonhole principle to exit early.
//
// If every key partition only contains itself then we're done.
//
#ifdef RS_SCATTER_RANK_EARLY_EXIT_LOG2
if ((bit > 0) && (bit <= RS_RADIX_LOG2 - RS_SCATTER_SUBGROUP_SIZE_LOG2) &&
(bit == RS_SCATTER_RANK_EARLY_EXIT_LOG2))
{
if (subgroupAll(bitCount(match) == 1))
{
break;
}
}
#endif
}
kr[ii] = (bitCount(match) << 16) | bitCount(match & gl_SubgroupLeMask.x);
}
#else
//
// Undefined!
//
#error "Error: rs_histogram_rank() undefined for subgroup size"
#endif
#else
//----------------------------------------------------------------------
//
// Emulate a `match` operation with broadcasts.
//
// In general, using broadcasts is a win for narrow subgroups.
//
//----------------------------------------------------------------------
#if (RS_SUBGROUP_SIZE == 64)
//
// 64
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
u32vec2 match;
// subgroup invocation 0
{
match[0] = (subgroupBroadcast(digit, 0) == digit) ? (1u << 0) : 0;
}
// subgroup invocations 1-31
[[unroll]] for (int32_t jj = 1; jj < 32; jj++)
{
match[0] |= (subgroupBroadcast(digit, jj) == digit) ? (1u << jj) : 0;
}
// subgroup invocation 32
{
match[1] = (subgroupBroadcast(digit, 32) == digit) ? (1u << 0) : 0;
}
// subgroup invocations 33-63
[[unroll]] for (int32_t jj = 1; jj < 32; jj++)
{
match[1] |= (subgroupBroadcast(digit, jj) == digit) ? (1u << jj) : 0;
}
kr[ii] = ((bitCount(match.x) + bitCount(match.y)) << 16) |
(bitCount(match.x & gl_SubgroupLeMask.x) + //
bitCount(match.y & gl_SubgroupLeMask.y));
}
#elif ((RS_SUBGROUP_SIZE <= 32) && !defined(RS_SCATTER_ENABLE_NV_MATCH))
//
// <= 32
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
// subgroup invocation 0
uint32_t match = (subgroupBroadcast(digit, 0) == digit) ? (1u << 0) : 0;
// subgroup invocations 1-(RS_SUBGROUP_SIZE-1)
[[unroll]] for (int32_t jj = 1; jj < RS_SUBGROUP_SIZE; jj++)
{
match |= (subgroupBroadcast(digit, jj) == digit) ? (1u << jj) : 0;
}
kr[ii] = (bitCount(match) << 16) | bitCount(match & gl_SubgroupLeMask.x);
}
#else
//
// Undefined!
//
#error "Error: rs_histogram_rank() undefined for subgroup size"
#endif
#endif
//
// This is a little unconventional but cycling through a subgroup at
// a time is a performance win on the tested architectures.
//
for (uint32_t ii = 0; ii < RS_WORKGROUP_SUBGROUPS; ii++)
{
if (gl_SubgroupID == ii)
{
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
{
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[jj]);
const uint32_t prev = RS_HISTOGRAM_LOAD(digit);
const uint32_t rank = kr[jj] & 0xFFFF;
const uint32_t count = kr[jj] >> 16;
kr[jj] = prev + rank;
if (rank == count)
{
RS_HISTOGRAM_STORE(digit, (prev + count));
}
subgroupMemoryBarrierShared();
}
}
RS_BARRIER();
}
}
//
// Other partitions may lookback on this partition.
//
// Load the global exclusive prefix and for each subgroup
// store the exclusive prefix to shared memory and store the
// final inclusive prefix to global memory.
//
void
rs_first_prefix_store(restrict buffer_rs_partitions rs_partitions)
{
//
// Define the histogram reference
//
#if (RS_WORKGROUP_SUBGROUPS == 1)
const uint32_t hist_offset_bytes = gl_SubgroupInvocationID * 4;
#else
const uint32_t hist_offset_bytes = gl_LocalInvocationID.x * 4;
#endif
readonly RS_BUFREF_DEFINE_AT_OFFSET_UINT32(buffer_rs_histogram,
rs_histogram,
push.devaddr_histograms,
hist_offset_bytes);
#if (RS_WORKGROUP_SUBGROUPS == 1)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SUBGROUPS == 1)
//
const uint32_t smem_offset_h = RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID;
const uint32_t smem_offset_l = RS_SMEM_LOOKBACK_OFFSET + gl_SubgroupInvocationID;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
{
const uint32_t exc = rs_histogram.extent[ii];
const uint32_t red = smem.extent[smem_offset_h + ii];
smem.extent[smem_offset_l + ii] = exc;
const uint32_t inc = exc + red;
atomicStore(rs_partitions.extent[ii],
inc | RS_PARTITION_MASK_PREFIX,
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsRelease);
}
#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
//
const uint32_t smem_offset_h = RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x;
const uint32_t smem_offset_l = RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
{
const uint32_t exc = rs_histogram.extent[ii];
const uint32_t red = smem.extent[smem_offset_h + ii];
smem.extent[smem_offset_l + ii] = exc;
const uint32_t inc = exc + red;
atomicStore(rs_partitions.extent[ii],
inc | RS_PARTITION_MASK_PREFIX,
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsRelease);
}
#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE)
const uint32_t smem_offset_final_h = smem_offset_h + RS_WORKGROUP_BASE_FINAL;
const uint32_t smem_offset_final_l = smem_offset_l + RS_WORKGROUP_BASE_FINAL;
if (smem_offset_final < RS_RADIX_SIZE)
{
const uint32_t exc = rs_histogram.extent[RS_WORKGROUP_BASE_FINAL];
const uint32_t red = smem.extent[smem_offset_final_h];
smem.extent[smem_offset_final_l] = exc;
const uint32_t inc = exc + red;
atomicStore(rs_partitions.extent[RS_WORKGROUP_BASE_FINAL],
inc | RS_PARTITION_MASK_PREFIX,
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsRelease);
}
#endif
#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE)
//
#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE)
if (gl_LocalInvocationID.x < RS_RADIX_SIZE)
#endif
{
const uint32_t exc = rs_histogram.extent[0];
const uint32_t red = smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x];
smem.extent[RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x] = exc;
const uint32_t inc = exc + red;
atomicStore(rs_partitions.extent[0],
inc | RS_PARTITION_MASK_PREFIX,
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsRelease);
}
#endif
}
//
// Atomically store the reduction to the global partition.
//
void
rs_reduction_store(restrict buffer_rs_partitions rs_partitions,
RS_SUBGROUP_UNIFORM const uint32_t partition_base)
{
#if (RS_WORKGROUP_SUBGROUPS == 1)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SUBGROUPS == 1)
//
const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
{
const uint32_t red = smem.extent[smem_offset + ii];
atomicStore(rs_partitions.extent[partition_base + ii],
red | RS_PARTITION_MASK_REDUCTION,
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsRelease);
}
#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
//
const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
{
const uint32_t red = smem.extent[smem_offset + ii];
atomicStore(rs_partitions.extent[partition_base + ii],
red | RS_PARTITION_MASK_REDUCTION,
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsRelease);
}
#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE)
const uint32_t smem_offset_final = smem_offset + RS_WORKGROUP_BASE_FINAL;
if (smem_offset_final < RS_RADIX_SIZE)
{
const uint32_t red = smem.extent[smem_offset_final];
atomicStore(rs_partitions.extent[partition_base + RS_WORKGROUP_BASE_FINAL],
red | RS_PARTITION_MASK_REDUCTION,
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsRelease);
}
#endif
#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE)
//
#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE)
if (gl_LocalInvocationID.x < RS_RADIX_SIZE)
#endif
{
const uint32_t red = smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x];
atomicStore(rs_partitions.extent[partition_base],
red | RS_PARTITION_MASK_REDUCTION,
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsRelease);
}
#endif
}
//
// Lookback and accumulate reductions until a PREFIX partition is
// reached and then update this workgroup's partition and local
// histogram prefix.
//
// TODO(allanmac): Consider reenabling the cyclic/ring buffer of
// partitions in order to save memory. It actually adds complexity
// but reduces the amount of pre-scatter buffer zeroing.
//
void
rs_lookback_store(restrict buffer_rs_partitions rs_partitions,
RS_SUBGROUP_UNIFORM const uint32_t partition_base)
{
#if (RS_WORKGROUP_SUBGROUPS == 1)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SUBGROUPS == 1)
//
const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_SubgroupInvocationID;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
{
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
uint32_t exc = 0;
//
// NOTE: Each workgroup invocation can proceed independently.
// Subgroups and workgroups do NOT have to coordinate.
//
while (true)
{
const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii],
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquire);
// spin until valid
if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID)
{
continue;
}
exc += (prev & RS_PARTITION_MASK_COUNT);
if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX)
{
// continue accumulating reductions
partition_base_prev -= RS_RADIX_SIZE;
continue;
}
//
// Otherwise, save the exclusive scan and atomically transform
// the reduction into an inclusive prefix status math:
//
// reduction + 1 = prefix
//
smem.extent[smem_offset + ii] = exc;
atomicAdd(rs_partitions.extent[partition_base + ii],
exc | (1 << 30),
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquireRelease);
break;
}
}
#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
//
const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
{
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
uint32_t exc = 0;
//
// NOTE: Each workgroup invocation can proceed independently.
// Subgroups and workgroups do NOT have to coordinate.
//
while (true)
{
const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii],
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquire);
// spin until valid
if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID)
{
continue;
}
exc += (prev & RS_PARTITION_MASK_COUNT);
if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX)
{
// continue accumulating reductions
partition_base_prev -= RS_RADIX_SIZE;
continue;
}
//
// Otherwise, save the exclusive scan and atomically transform
// the reduction into an inclusive prefix status math:
//
// reduction + 1 = prefix
//
smem.extent[smem_offset + ii] = exc;
atomicAdd(rs_partitions.extent[partition_base + ii],
exc | (1 << 30),
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquireRelease);
break;
}
}
#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE)
const uint32_t smem_offset_final = smem_offset + RS_WORKGROUP_BASE_FINAL;
if (smem_offset_final < RS_SMEM_LOOKBACK_OFFSET + RS_RADIX_SIZE)
{
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
uint32_t exc = 0;
//
// NOTE: Each workgroup invocation can proceed independently.
// Subgroups and workgroups do NOT have to coordinate.
//
while (true)
{
const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii],
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquire);
// spin until valid
if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID)
{
continue;
}
exc += (prev & RS_PARTITION_MASK_COUNT);
if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX)
{
// continue accumulating reductions
partition_base_prev -= RS_RADIX_SIZE;
continue;
}
//
// Otherwise, save the exclusive scan and atomically transform
// the reduction into an inclusive prefix status math:
//
// reduction + 1 = prefix
//
smem.extent[smem_offset + ii] = exc;
atomicAdd(rs_partitions.extent[partition_base + RS_WORKGROUP_BASE_FINAL],
exc | (1 << 30),
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquireRelease);
break;
}
}
#endif
#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE)
//
#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE)
if (gl_LocalInvocationID.x < RS_RADIX_SIZE)
#endif
{
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
uint32_t exc = 0;
//
// NOTE: Each workgroup invocation can proceed independently.
// Subgroups and workgroups do NOT have to coordinate.
//
while (true)
{
const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev],
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquire);
// spin until valid
if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID)
{
continue;
}
exc += (prev & RS_PARTITION_MASK_COUNT);
if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX)
{
// continue accumulating reductions
partition_base_prev -= RS_RADIX_SIZE;
continue;
}
//
// Otherwise, save the exclusive scan and atomically transform
// the reduction into an inclusive prefix status math:
//
// reduction + 1 = prefix
//
smem.extent[RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x] = exc;
atomicAdd(rs_partitions.extent[partition_base],
exc | (1 << 30),
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquireRelease);
break;
}
}
#endif
}
//
// Lookback and accumulate reductions until a PREFIX partition is
// reached and then update this workgroup's local histogram prefix.
//
// Skip updating this workgroup's partition because it's last.
//
void
rs_lookback_skip_store(restrict buffer_rs_partitions rs_partitions,
RS_SUBGROUP_UNIFORM const uint32_t partition_base)
{
#if (RS_WORKGROUP_SUBGROUPS == 1)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SUBGROUPS == 1)
//
const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_SubgroupInvocationID;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE)
{
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
uint32_t exc = 0;
//
// NOTE: Each workgroup invocation can proceed independently.
// Subgroups and workgroups do NOT have to coordinate.
//
while (true)
{
const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii],
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquire);
// spin until valid
if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID)
{
continue;
}
exc += (prev & RS_PARTITION_MASK_COUNT);
if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX)
{
// continue accumulating reductions
partition_base_prev -= RS_RADIX_SIZE;
continue;
}
// Otherwise, save the exclusive scan.
smem.extent[smem_offset + ii] = exc;
break;
}
}
#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SIZE < RS_RADIX_SIZE)
//
const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x;
[[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE)
{
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
uint32_t exc = 0;
//
// NOTE: Each workgroup invocation can proceed independently.
// Subgroups and workgroups do NOT have to coordinate.
//
while (true)
{
const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii],
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquire);
// spin until valid
if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID)
{
continue;
}
exc += (prev & RS_PARTITION_MASK_COUNT);
if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX)
{
// continue accumulating reductions
partition_base_prev -= RS_RADIX_SIZE;
continue;
}
// Otherwise, save the exclusive scan.
smem.extent[smem_offset + ii] = exc;
break;
}
}
#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE)
const uint32_t smem_offset_final = smem_offset + RS_WORKGROUP_BASE_FINAL;
if (smem_offset_final < RS_RADIX_SIZE)
{
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
uint32_t exc = 0;
//
// NOTE: Each workgroup invocation can proceed independently.
// Subgroups and workgroups do NOT have to coordinate.
//
while (true)
{
const uint32_t prev =
atomicLoad(rs_partitions.extent[partition_base_prev + RS_WORKGROUP_BASE_FINAL],
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquire);
// spin until valid
if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID)
{
continue;
}
exc += (prev & RS_PARTITION_MASK_COUNT);
if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX)
{
// continue accumulating reductions
partition_base_prev -= RS_RADIX_SIZE;
continue;
}
// Otherwise, save the exclusive scan.
smem.extent[smem_offset_final] = exc;
break;
}
}
#endif
#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE)
////////////////////////////////////////////////////////////////////////////
//
// (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE)
//
#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE)
if (gl_LocalInvocationID.x < RS_RADIX_SIZE)
#endif
{
uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE;
uint32_t exc = 0;
//
// NOTE: Each workgroup invocation can proceed independently.
// Subgroups and workgroups do NOT have to coordinate.
//
while (true)
{
const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev],
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquire);
// spin until valid
if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID)
{
continue;
}
exc += (prev & RS_PARTITION_MASK_COUNT);
if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX)
{
// continue accumulating reductions
partition_base_prev -= RS_RADIX_SIZE;
continue;
}
// Otherwise, save the exclusive scan.
smem.extent[RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x] = exc;
break;
}
}
#endif
}
//
// Compute a 1-based local index for each keyval by adding the 1-based
// rank to the local histogram prefix.
//
void
rs_rank_to_local(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
inout uint32_t kr[RS_SCATTER_BLOCK_ROWS])
{
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
const uint32_t exc = smem.extent[RS_SMEM_HISTOGRAM_OFFSET + digit];
const uint32_t idx = exc + kr[ii];
kr[ii] |= (idx << 16);
}
//
// Reordering phase will overwrite histogram span.
//
RS_BARRIER();
}
//
// Compute a 1-based local index for each keyval by adding the 1-based
// rank to the global histogram prefix.
//
void
rs_rank_to_global(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
inout uint32_t kr[RS_SCATTER_BLOCK_ROWS])
{
//
// Define the histogram reference
//
readonly RS_BUFREF_DEFINE(buffer_rs_histogram, rs_histogram, push.devaddr_histograms);
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
const uint32_t exc = rs_histogram.extent[digit];
kr[ii] += (exc - 1);
}
}
//
// Using the local indices, rearrange the keyvals into sorted order.
//
void
rs_reorder(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], inout uint32_t kr[RS_SCATTER_BLOCK_ROWS])
{
// clang-format off
#if (RS_WORKGROUP_SUBGROUPS == 1)
const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + gl_SubgroupInvocationID;
#else
const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + gl_LocalInvocationID.x;
#endif
// clang-format on
[[unroll]] for (uint32_t ii = 0; ii < RS_KEYVAL_DWORDS; ii++)
{
//
// Store keyval dword to sorted location
//
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
{
const uint32_t smem_idx = (RS_SMEM_REORDER_OFFSET - 1) + (kr[jj] >> 16);
smem.extent[smem_idx] = RS_KV_DWORD(kv[jj], ii);
}
RS_BARRIER();
//
// Load keyval dword from sorted location
//
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
{
RS_KV_DWORD(kv[jj], ii) = smem.extent[smem_base + jj * RS_WORKGROUP_SIZE];
}
RS_BARRIER();
}
//
// Store the digit-index to sorted location
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
const uint32_t smem_idx = (RS_SMEM_REORDER_OFFSET - 1) + (kr[ii] >> 16);
smem.extent[smem_idx] = uint32_t(kr[ii]);
}
RS_BARRIER();
//
// Load kr[] from sorted location -- we only need the rank.
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
kr[ii] = smem.extent[smem_base + ii * RS_WORKGROUP_SIZE] & 0xFFFF;
}
}
//
// Using the global/local indices obtained by a single workgroup,
// rearrange the keyvals into sorted order.
//
void
rs_reorder_1(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
inout uint32_t kr[RS_SCATTER_BLOCK_ROWS])
{
// clang-format off
#if (RS_WORKGROUP_SUBGROUPS == 1)
const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + gl_SubgroupInvocationID;
#else
const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + gl_LocalInvocationID.x;
#endif
// clang-format on
[[unroll]] for (uint32_t ii = 0; ii < RS_KEYVAL_DWORDS; ii++)
{
//
// Store keyval dword to sorted location
//
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
{
const uint32_t smem_idx = RS_SMEM_REORDER_OFFSET + kr[jj];
smem.extent[smem_idx] = RS_KV_DWORD(kv[jj], ii);
}
RS_BARRIER();
//
// Load keyval dword from sorted location
//
[[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++)
{
RS_KV_DWORD(kv[jj], ii) = smem.extent[smem_base + jj * RS_WORKGROUP_SIZE];
}
RS_BARRIER();
}
//
// Store the digit-index to sorted location
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
const uint32_t smem_idx = RS_SMEM_REORDER_OFFSET + kr[ii];
smem.extent[smem_idx] = uint32_t(kr[ii]);
}
RS_BARRIER();
//
// Load kr[] from sorted location -- we only need the rank.
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
kr[ii] = smem.extent[smem_base + ii * RS_WORKGROUP_SIZE];
}
}
//
// Each subgroup loads RS_SCATTER_BLOCK_ROWS rows of keyvals into
// registers.
//
void
rs_load(out RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS])
{
//
// Set up buffer reference
//
const uint32_t kv_in_offset_keyvals = RS_GL_WORKGROUP_ID_X * RS_BLOCK_KEYVALS +
gl_SubgroupID * RS_SUBGROUP_KEYVALS +
gl_SubgroupInvocationID;
u32vec2 kv_in_offset;
umulExtended(kv_in_offset_keyvals,
RS_KEYVAL_SIZE,
kv_in_offset.y, // msb
kv_in_offset.x); // lsb
readonly RS_BUFREF_DEFINE_AT_OFFSET_U32VEC2(buffer_rs_kv,
rs_kv_in,
RS_DEVADDR_KEYVALS_IN(push),
kv_in_offset);
//
// Load keyvals
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
kv[ii] = rs_kv_in.extent[ii * RS_SUBGROUP_SIZE];
}
}
//
// Convert local index to global
//
void
rs_local_to_global(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS],
inout uint32_t kr[RS_SCATTER_BLOCK_ROWS])
{
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]);
const uint32_t exc = smem.extent[RS_SMEM_LOOKBACK_OFFSET + digit];
kr[ii] += (exc - 1);
}
}
//
// Store a single workgroup
//
void
rs_store(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], const uint32_t kr[RS_SCATTER_BLOCK_ROWS])
{
//
// Define kv_out bufref
//
writeonly RS_BUFREF_DEFINE(buffer_rs_kv, rs_kv_out, RS_DEVADDR_KEYVALS_OUT(push));
//
// Store keyval:
//
// "out[ keyval.rank ] = keyval"
//
// FIXME(allanmac): Consider implementing an aligned writeout
// strategy to avoid excess global memory transactions.
//
[[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++)
{
rs_kv_out.extent[kr[ii]] = kv[ii];
}
}
//
//
//
void
main()
{
//
// If this is a nonsequential dispatch device then acquire a virtual
// workgroup id.
//
// This is only run once and is a special compile-time-enabled case
// so we leverage the existing `push.devaddr_partitions` address
// instead of altering the push constant structure definition.
//
#ifdef RS_SCATTER_NONSEQUENTIAL_DISPATCH
if (RS_IS_FIRST_LOCAL_INVOCATION())
{
// The "internal" memory map looks like this:
//
// +---------------------------------+ <-- 0
// | histograms[keyval_size] |
// +---------------------------------+ <-- keyval_size * histo_size
// | partitions[scatter_blocks_ru-1] |
// +---------------------------------+ <-- (keyval_size + scatter_blocks_ru - 1) * histo_size
// | workgroup_ids[keyval_size] |
// +---------------------------------+ <-- (keyval_size + scatter_blocks_ru - 1) * histo_size + workgroup_ids_size
//
// Extended multiply to avoid 4GB overflow
//
u32vec2 workgroup_id_offset;
umulExtended((gl_NumWorkGroups.x - 1), // virtual workgroup ids follow partitions[]
4 * RS_RADIX_SIZE, // sizeof(uint32_t) * 256
workgroup_id_offset.y, // msb
workgroup_id_offset.x); // lsb
RS_BUFREF_DEFINE_AT_OFFSET_U32VEC2(buffer_rs_workgroup_id,
rs_workgroup_id,
push.devaddr_partitions,
workgroup_id_offset);
const uint32_t x_idx = RS_SCATTER_KEYVAL_DWORD_BASE * 4 + (push.pass_offset / RS_RADIX_LOG2);
smem.extent[0] = atomicAdd(rs_workgroup_id.x[x_idx],
1,
gl_ScopeQueueFamily,
gl_StorageSemanticsBuffer,
gl_SemanticsAcquireRelease);
}
RS_BARRIER();
rs_gl_workgroup_id_x = smem.extent[0];
RS_BARRIER();
#endif
//
// Load keyvals
//
RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS];
rs_load(kv);
//
// Zero shared histogram
//
// Ends with barrier.
//
rs_histogram_zero();
//
// Compute histogram and bin-relative keyval indices
//
// This histogram can immediately be used to update the partition
// with either a PREFIX or REDUCTION flag.
//
// Ends with a barrier.
//
uint32_t kr[RS_SCATTER_BLOCK_ROWS];
rs_histogram_rank(kv, kr);
//
// When there is a single workgroup then the local and global
// exclusive scanned histograms are the same.
//
if (gl_NumWorkGroups.x == 1)
{
rs_rank_to_global(kv, kr);
#ifndef RS_SCATTER_DISABLE_REORDER
rs_reorder_1(kv, kr);
#endif
rs_store(kv, kr);
}
else
{
//
// Define partitions bufref
//
#if (RS_WORKGROUP_SUBGROUPS == 1)
const uint32_t partition_offset_bytes = gl_SubgroupInvocationID * 4;
#else
const uint32_t partition_offset_bytes = gl_LocalInvocationID.x * 4;
#endif
RS_BUFREF_DEFINE_AT_OFFSET_UINT32(buffer_rs_partitions,
rs_partitions,
push.devaddr_partitions,
partition_offset_bytes);
//
// The first partition is a special case.
//
if (RS_GL_WORKGROUP_ID_X == 0)
{
//
// Other workgroups may lookback on this partition.
//
// Load the global histogram and local histogram and store
// the exclusive prefix.
//
rs_first_prefix_store(rs_partitions);
}
else
{
//
// Otherwise, this is not the first workgroup.
//
RS_SUBGROUP_UNIFORM const uint32_t partition_base = RS_GL_WORKGROUP_ID_X * RS_RADIX_SIZE;
//
// The last partition is a special case.
//
if (RS_GL_WORKGROUP_ID_X + 1 < gl_NumWorkGroups.x)
{
//
// Atomically store the reduction to the global partition.
//
rs_reduction_store(rs_partitions, partition_base);
//
// Lookback and accumulate reductions until a PREFIX
// partition is reached and then update this workgroup's
// partition and local histogram prefix.
//
rs_lookback_store(rs_partitions, partition_base);
}
else
{
//
// Lookback and accumulate reductions until a PREFIX
// partition is reached and then update this workgroup's
// local histogram prefix.
//
// Skip updating this workgroup's partition because it's
// last.
//
rs_lookback_skip_store(rs_partitions, partition_base);
}
}
#ifndef RS_SCATTER_DISABLE_REORDER
//
// Compute exclusive prefix scan of histogram.
//
// No barrier.
//
rs_prefix();
//
// Barrier before reading prefix scanned histogram.
//
RS_BARRIER();
//
// Convert keyval's rank to a local index
//
// Ends with a barrier.
//
rs_rank_to_local(kv, kr);
//
// Reorder kv[] and kr[]
//
// Ends with a barrier.
//
rs_reorder(kv, kr);
#else
//
// Wait for lookback to complete.
//
RS_BARRIER();
#endif
//
// Convert local index to a global index.
//
rs_local_to_global(kv, kr);
//
// Store keyvals to their new locations
//
rs_store(kv, kr);
}
}
//
//
//