| // Copyright 2021 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <assert.h> |
| #include <inttypes.h> |
| #include <stdio.h> |
| #include <stdlib.h> |
| #include <string.h> |
| |
| #include "common/macros.h" |
| #include "common/util.h" |
| #include "common/vk/assert.h" |
| #include "common/vk/barrier.h" |
| #include "radix_sort/platforms/vk/radix_sort_vk_devaddr.h" |
| #include "shaders/push.h" |
| #include "target.h" |
| #include "target_archive/target_archive.h" |
| |
| // |
| // |
| // |
| |
| #ifdef RADIX_SORT_VK_ENABLE_DEBUG_UTILS |
| #include "common/vk/debug_utils.h" |
| #endif |
| |
| // |
| // |
| // |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| #include "radix_sort_vk_ext.h" |
| #endif |
| |
| // |
| // NOTE: The library currently supports uint32_t and uint64_t keyvals. |
| // |
| |
| #define RS_KV_DWORDS_MAX 2 |
| |
| // |
| // |
| // |
| |
| struct rs_pipeline_layout_scatter |
| { |
| VkPipelineLayout even; |
| VkPipelineLayout odd; |
| }; |
| |
| struct rs_pipeline_scatter |
| { |
| VkPipeline even; |
| VkPipeline odd; |
| }; |
| |
| // |
| // |
| // |
| |
| struct rs_pipeline_layouts_named |
| { |
| VkPipelineLayout init; |
| VkPipelineLayout fill; |
| VkPipelineLayout histogram; |
| VkPipelineLayout prefix; |
| struct rs_pipeline_layout_scatter scatter[RS_KV_DWORDS_MAX]; |
| }; |
| |
| struct rs_pipelines_named |
| { |
| VkPipeline init; |
| VkPipeline fill; |
| VkPipeline histogram; |
| VkPipeline prefix; |
| struct rs_pipeline_scatter scatter[RS_KV_DWORDS_MAX]; |
| }; |
| |
| // clang-format off |
| #define RS_PIPELINE_LAYOUTS_HANDLES (sizeof(struct rs_pipeline_layouts_named) / sizeof(VkPipelineLayout)) |
| #define RS_PIPELINES_HANDLES (sizeof(struct rs_pipelines_named) / sizeof(VkPipeline)) |
| // clang-format on |
| |
| // |
| // |
| // |
| |
| struct radix_sort_vk |
| { |
| struct radix_sort_vk_target_config config; |
| |
| union |
| { |
| struct rs_pipeline_layouts_named named; |
| VkPipelineLayout handles[RS_PIPELINE_LAYOUTS_HANDLES]; |
| } pipeline_layouts; |
| |
| union |
| { |
| struct rs_pipelines_named named; |
| VkPipeline handles[RS_PIPELINES_HANDLES]; |
| } pipelines; |
| |
| struct |
| { |
| struct |
| { |
| VkDeviceSize histograms; |
| VkDeviceSize partitions; |
| } offset; |
| } internal; |
| }; |
| |
| // |
| // FIXME(allanmac): Memoize some of these calculations. |
| // |
| void |
| radix_sort_vk_get_memory_requirements(radix_sort_vk_t const * rs, |
| uint32_t count, |
| radix_sort_vk_memory_requirements_t * mr) |
| { |
| // |
| // Keyval size |
| // |
| mr->keyval_size = rs->config.keyval_dwords * sizeof(uint32_t); |
| |
| // |
| // Subgroup and workgroup sizes |
| // |
| uint32_t const histo_sg_size = 1 << rs->config.histogram.subgroup_size_log2; |
| uint32_t const histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2; |
| uint32_t const prefix_sg_size = 1 << rs->config.prefix.subgroup_size_log2; |
| uint32_t const scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2; |
| uint32_t const internal_sg_size = MAX_MACRO(uint32_t, histo_sg_size, prefix_sg_size); |
| |
| // |
| // If for some reason count is zero then initialize appropriately. |
| // |
| if (count == 0) |
| { |
| mr->keyvals_size = 0; |
| mr->keyvals_alignment = mr->keyval_size * histo_sg_size; |
| mr->internal_size = 0; |
| mr->internal_alignment = internal_sg_size * sizeof(uint32_t); |
| mr->indirect_size = 0; |
| mr->indirect_alignment = internal_sg_size * sizeof(uint32_t); |
| } |
| else |
| { |
| // |
| // Keyvals |
| // |
| |
| // |
| // Round up to the scatter block size. |
| // |
| // Then round up to the histogram block size. |
| // |
| // Fill the difference between this new count and the original keyval |
| // count. |
| // |
| // How many scatter blocks? |
| // |
| uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows; |
| uint32_t const scatter_blocks_ru = (count + scatter_block_kvs - 1) / scatter_block_kvs; |
| uint32_t const count_ru_scatter = scatter_blocks_ru * scatter_block_kvs; |
| |
| // |
| // How many histogram blocks? |
| // |
| // Note that it's OK to have more max-valued digits counted by the histogram |
| // than sorted by the scatters because the sort is stable. |
| // |
| uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows; |
| uint32_t const histo_blocks_ru = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs; |
| uint32_t const count_ru_histo = histo_blocks_ru * histo_block_kvs; |
| |
| mr->keyvals_size = mr->keyval_size * count_ru_histo; |
| mr->keyvals_alignment = mr->keyval_size * histo_sg_size; |
| |
| // |
| // Internal |
| // |
| // NOTE: Assumes .histograms are before .partitions. |
| // |
| // Each RS_RADIX_LOG2 (8) bit pass has a zero-initialized histogram. This |
| // is one RS_RADIX_SIZE histogram per keyval byte. |
| // |
| // The last scatter workgroup skips writing to a partition so it doesn't |
| // need to be allocated. |
| // |
| // If the device doesn't support "sequential dispatch" of workgroups, then |
| // we need a zero-initialized dword counter per radix pass in the keyval |
| // to atomically acquire a virtual workgroup id. On sequentially |
| // dispatched devices, this is simply `gl_WorkGroupID.x`. |
| // |
| // The "internal" memory map looks like this: |
| // |
| // +---------------------------------+ <-- 0 |
| // | histograms[keyval_size] | |
| // +---------------------------------+ <-- keyval_size * histo_size |
| // | partitions[scatter_blocks_ru-1] | |
| // +---------------------------------+ <-- (keyval_size + scatter_blocks_ru - 1) * histo_size |
| // | workgroup_ids[keyval_size] | |
| // +---------------------------------+ <-- (keyval_size + scatter_blocks_ru - 1) * histo_size + workgroup_ids_size |
| // |
| // The `.workgroup_ids[]` are located after the last partition. |
| // |
| VkDeviceSize const histo_size = RS_RADIX_SIZE * sizeof(uint32_t); |
| |
| mr->internal_size = (mr->keyval_size + scatter_blocks_ru - 1) * histo_size; |
| mr->internal_alignment = internal_sg_size * sizeof(uint32_t); |
| |
| // |
| // Support for nonsequential dispatch can be disabled. |
| // |
| #ifndef RADIX_SORT_VK_DISABLE_NONSEQUENTIAL_DISPATCH |
| VkDeviceSize const workgroup_ids_size = mr->keyval_size * sizeof(uint32_t); |
| |
| mr->internal_size += workgroup_ids_size; |
| #endif |
| |
| // |
| // Indirect |
| // |
| mr->indirect_size = sizeof(struct rs_indirect_info); |
| mr->indirect_alignment = sizeof(struct u32vec4); |
| } |
| } |
| |
| // |
| // |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_DEBUG_UTILS |
| |
| static void |
| rs_debug_utils_set(VkDevice device, radix_sort_vk_t * rs) |
| { |
| if (pfn_vkSetDebugUtilsObjectNameEXT != NULL) |
| { |
| VkDebugUtilsObjectNameInfoEXT duoni = { |
| .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_OBJECT_NAME_INFO_EXT, |
| .pNext = NULL, |
| .objectType = VK_OBJECT_TYPE_PIPELINE, |
| }; |
| |
| duoni.objectHandle = (uint64_t)rs->pipelines.named.init; |
| duoni.pObjectName = "radix_sort_init"; |
| vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); |
| |
| duoni.objectHandle = (uint64_t)rs->pipelines.named.fill; |
| duoni.pObjectName = "radix_sort_fill"; |
| vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); |
| |
| duoni.objectHandle = (uint64_t)rs->pipelines.named.histogram; |
| duoni.pObjectName = "radix_sort_histogram"; |
| vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); |
| |
| duoni.objectHandle = (uint64_t)rs->pipelines.named.prefix; |
| duoni.pObjectName = "radix_sort_prefix"; |
| vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); |
| |
| duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].even; |
| duoni.pObjectName = "radix_sort_scatter_0_even"; |
| vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); |
| |
| duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].odd; |
| duoni.pObjectName = "radix_sort_scatter_0_odd"; |
| vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); |
| |
| if (rs->config.keyval_dwords >= 2) |
| { |
| duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].even; |
| duoni.pObjectName = "radix_sort_scatter_1_even"; |
| vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); |
| |
| duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].odd; |
| duoni.pObjectName = "radix_sort_scatter_1_odd"; |
| vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); |
| } |
| } |
| } |
| |
| #endif |
| |
| // |
| // |
| // |
| struct radix_sort_vk_target |
| { |
| struct target_archive_header ar_header; |
| }; |
| |
| // |
| // How many pipelines are there? |
| // |
| static uint32_t |
| rs_pipeline_count(radix_sort_vk_t const * rs) |
| { |
| return 1 + // init |
| 1 + // fill |
| 1 + // histogram |
| 1 + // prefix |
| 2 * rs->config.keyval_dwords; // scatters.even/odd[keyval_dwords] |
| } |
| |
| // |
| // |
| // |
| radix_sort_vk_t * |
| radix_sort_vk_create(VkDevice device, |
| VkAllocationCallbacks const * ac, |
| VkPipelineCache pc, |
| radix_sort_vk_target_t const * target) |
| { |
| // |
| // Unmarshalling assumes dword alignment. |
| // |
| assert(alignof(struct radix_sort_vk_target_header) == 4); |
| |
| // |
| // Must not be NULL. |
| // |
| if (target == NULL) |
| { |
| return NULL; |
| } |
| |
| #ifndef RADIX_SORT_VK_DISABLE_VERIFY |
| // |
| // Verify target archive is valid archive. |
| // |
| if (target->ar_header.magic != TARGET_ARCHIVE_MAGIC) |
| { |
| #ifndef NDEBUG |
| fprintf(stderr, "Error: Invalid target -- missing magic."); |
| #endif |
| return NULL; |
| } |
| #endif |
| |
| // |
| // Get the target archive header. |
| // |
| struct target_archive_header const * const ar_header = &target->ar_header; |
| struct target_archive_entry const * const ar_entries = ar_header->entries; |
| uint32_t const * const ar_data = ar_entries[ar_header->count - 1].data; |
| |
| // |
| // Get the radix sort target header. |
| // |
| struct radix_sort_vk_target_header const * rs_target_header; |
| |
| // We assert `alignof(radix_sort_vk_target_header) == 4` (see above) so we can |
| // memcpy() pointers. |
| memcpy(&rs_target_header, &ar_data, sizeof(ar_data)); |
| |
| // |
| // Verify target is compatible with the library. |
| // |
| // TODO(allanmac): Verify `ar_header->count` but note that not all target |
| // archives will have a static count. |
| // |
| #ifndef RADIX_SORT_VK_DISABLE_VERIFY |
| if (rs_target_header->magic != RS_HEADER_MAGIC) |
| { |
| #ifndef NDEBUG |
| fprintf(stderr, "Error: Target is not compatible with library."); |
| #endif |
| return NULL; |
| } |
| #endif |
| |
| // |
| // Allocate radix_sort_vk |
| // |
| radix_sort_vk_t * const rs = MALLOC_MACRO(sizeof(*rs)); |
| |
| // |
| // Save the config for layer |
| // |
| rs->config = rs_target_header->config; |
| |
| // |
| // How many pipelines? |
| // |
| uint32_t const pipeline_count = rs_pipeline_count(rs); |
| |
| // |
| // Prepare to create pipelines |
| // |
| VkPushConstantRange const pcr[] = { |
| { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // |
| .offset = 0, |
| .size = sizeof(struct rs_push_init) }, |
| |
| { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // |
| .offset = 0, |
| .size = sizeof(struct rs_push_fill) }, |
| |
| { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // |
| .offset = 0, |
| .size = sizeof(struct rs_push_histogram) }, |
| |
| { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // |
| .offset = 0, |
| .size = sizeof(struct rs_push_prefix) }, |
| |
| { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // |
| .offset = 0, |
| .size = sizeof(struct rs_push_scatter) }, // scatter_0_even |
| |
| { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // |
| .offset = 0, |
| .size = sizeof(struct rs_push_scatter) }, // scatter_0_odd |
| |
| { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // |
| .offset = 0, |
| .size = sizeof(struct rs_push_scatter) }, // scatter_1_even |
| |
| { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // |
| .offset = 0, |
| .size = sizeof(struct rs_push_scatter) }, // scatter_1_odd |
| }; |
| |
| VkPipelineLayoutCreateInfo plci = { |
| |
| .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, |
| .pNext = NULL, |
| .flags = 0, |
| .setLayoutCount = 0, |
| .pSetLayouts = NULL, |
| .pushConstantRangeCount = 1, |
| // .pPushConstantRanges = pcr + ii; |
| }; |
| |
| for (uint32_t ii = 0; ii < pipeline_count; ii++) |
| { |
| plci.pPushConstantRanges = pcr + ii; |
| |
| vk(CreatePipelineLayout(device, &plci, NULL, rs->pipeline_layouts.handles + ii)); |
| } |
| |
| // |
| // Create compute pipelines |
| // |
| VkShaderModuleCreateInfo smci = { |
| |
| .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, |
| .pNext = NULL, |
| .flags = 0, |
| // .codeSize = ar_entries[...].size; |
| // .pCode = ar_data + ...; |
| }; |
| |
| VkShaderModule sms[ARRAY_LENGTH_MACRO(rs->pipelines.handles)]; |
| |
| for (uint32_t ii = 0; ii < pipeline_count; ii++) |
| { |
| smci.codeSize = ar_entries[ii + 1].size; |
| smci.pCode = ar_data + (ar_entries[ii + 1].offset >> 2); |
| |
| vk(CreateShaderModule(device, &smci, ac, sms + ii)); |
| } |
| |
| // |
| // If necessary, set the expected subgroup size. |
| // |
| #define RS_SUBGROUP_SIZE_CREATE_INFO_SET(size_) \ |
| { \ |
| .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT, \ |
| .pNext = NULL, \ |
| .requiredSubgroupSize = size_, \ |
| } |
| |
| #define RS_SUBGROUP_SIZE_CREATE_INFO_NAME(name_) \ |
| RS_SUBGROUP_SIZE_CREATE_INFO_SET(1 << rs_target_header->config.name_.subgroup_size_log2) |
| |
| #define RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(name_) RS_SUBGROUP_SIZE_CREATE_INFO_SET(0) |
| |
| VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT const rsscis[] = { |
| RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(init), // init |
| RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(fill), // fill |
| RS_SUBGROUP_SIZE_CREATE_INFO_NAME(histogram), // histogram |
| RS_SUBGROUP_SIZE_CREATE_INFO_NAME(prefix), // prefix |
| RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[0].even |
| RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[0].odd |
| RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[1].even |
| RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[1].odd |
| }; |
| |
| // |
| // Define compute pipeline create infos. |
| // |
| #define RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(idx_) \ |
| { .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, \ |
| .pNext = NULL, \ |
| .flags = 0, \ |
| .stage = { .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, \ |
| .pNext = NULL, \ |
| .flags = 0, \ |
| .stage = VK_SHADER_STAGE_COMPUTE_BIT, \ |
| .module = sms[idx_], \ |
| .pName = "main", \ |
| .pSpecializationInfo = NULL }, \ |
| \ |
| .layout = rs->pipeline_layouts.handles[idx_], \ |
| .basePipelineHandle = VK_NULL_HANDLE, \ |
| .basePipelineIndex = 0 } |
| |
| VkComputePipelineCreateInfo cpcis[] = { |
| RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(0), // init |
| RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(1), // fill |
| RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(2), // histogram |
| RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(3), // prefix |
| RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(4), // scatter[0].even |
| RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(5), // scatter[0].odd |
| RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(6), // scatter[1].even |
| RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(7), // scatter[1].odd |
| }; |
| |
| // |
| // Which of these compute pipelines require subgroup size control? |
| // |
| if (rs_target_header->extensions.named.EXT_subgroup_size_control) |
| { |
| for (uint32_t ii = 0; ii < pipeline_count; ii++) |
| { |
| if (rsscis[ii].requiredSubgroupSize > 1) |
| { |
| cpcis[ii].stage.pNext = rsscis + ii; |
| } |
| } |
| } |
| |
| // |
| // Create the compute pipelines. |
| // |
| vk(CreateComputePipelines(device, pc, pipeline_count, cpcis, ac, rs->pipelines.handles)); |
| |
| // |
| // Shader modules can be destroyed now. |
| // |
| for (uint32_t ii = 0; ii < pipeline_count; ii++) |
| { |
| vkDestroyShaderModule(device, sms[ii], ac); |
| } |
| |
| #ifdef RADIX_SORT_VK_ENABLE_DEBUG_UTILS |
| // |
| // Tag pipelines with names. |
| // |
| rs_debug_utils_set(device, rs); |
| #endif |
| |
| // |
| // Initialize `.internal` buffer offsets. |
| // |
| // See the internal memory map diagram. |
| // |
| // NOTE(allanmac): The partitions.offset must be aligned differently if |
| // RS_RADIX_LOG2 is less than the target's subgroup size log2. At this time, |
| // no GPU that meets this criteria. |
| // |
| // The `internal.workgroup_ids[keyval_size]` offset is implicitly located |
| // after the last partition and determined at runtime. |
| // |
| VkDeviceSize const keyval_size = rs->config.keyval_dwords * sizeof(uint32_t); |
| VkDeviceSize const histo_size = keyval_size * RS_RADIX_SIZE * sizeof(uint32_t); |
| |
| rs->internal.offset.histograms = 0; |
| rs->internal.offset.partitions = histo_size; |
| |
| return rs; |
| } |
| |
| // |
| // |
| // |
| void |
| radix_sort_vk_destroy(radix_sort_vk_t * rs, VkDevice d, VkAllocationCallbacks const * const ac) |
| { |
| uint32_t const pipeline_count = rs_pipeline_count(rs); |
| |
| // Destroy pipelines |
| for (uint32_t ii = 0; ii < pipeline_count; ii++) |
| { |
| vkDestroyPipeline(d, rs->pipelines.handles[ii], ac); |
| } |
| |
| // Destroy pipeline layouts |
| for (uint32_t ii = 0; ii < pipeline_count; ii++) |
| { |
| vkDestroyPipelineLayout(d, rs->pipeline_layouts.handles[ii], ac); |
| } |
| |
| free(rs); |
| } |
| |
| // |
| // |
| // |
| static VkDeviceAddress |
| rs_get_devaddr(VkDevice device, VkDescriptorBufferInfo const * dbi) |
| { |
| VkBufferDeviceAddressInfo const bdai = { |
| |
| .sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, |
| .pNext = NULL, |
| .buffer = dbi->buffer |
| }; |
| |
| VkDeviceAddress const devaddr = vkGetBufferDeviceAddress(device, &bdai) + dbi->offset; |
| |
| return devaddr; |
| } |
| |
| // |
| // |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| |
| void |
| rs_ext_cmd_write_timestamp(struct radix_sort_vk_ext_timestamps * ext_timestamps, |
| VkCommandBuffer cb, |
| VkPipelineStageFlagBits pipeline_stage) |
| { |
| if ((ext_timestamps != NULL) && |
| (ext_timestamps->timestamps_set < ext_timestamps->timestamp_count)) |
| { |
| vkCmdWriteTimestamp(cb, |
| pipeline_stage, |
| ext_timestamps->timestamps, |
| ext_timestamps->timestamps_set++); |
| } |
| } |
| |
| #endif |
| |
| // |
| // Validate alignment of buffer device addresses. |
| // |
| // Note that the size of the extents can also be validated when using the |
| // VkDescriptorBufferInfo sorter. |
| // |
| #ifndef NDEBUG |
| |
| static void |
| radix_sort_vk_sort_validate_info(radix_sort_vk_t const * rs, |
| radix_sort_vk_sort_devaddr_info_t const * info) |
| { |
| assert(info->count > 0); |
| |
| radix_sort_vk_memory_requirements_t mr; |
| |
| radix_sort_vk_get_memory_requirements(rs, info->count, &mr); |
| |
| // clang-format off |
| assert((info->keyvals_even.devaddr & (mr.keyvals_alignment - 1)) == 0); |
| assert((info->keyvals_odd & (mr.keyvals_alignment - 1)) == 0); |
| assert((info->internal.devaddr & (mr.internal_alignment - 1)) == 0); |
| // clang-format on |
| } |
| |
| static void |
| radix_sort_vk_sort_indirect_validate_info(radix_sort_vk_t const * rs, |
| radix_sort_vk_sort_indirect_devaddr_info_t const * info) |
| { |
| radix_sort_vk_memory_requirements_t mr; |
| |
| radix_sort_vk_get_memory_requirements(rs, 1 << 20, &mr); |
| |
| // clang-format off |
| assert((info->count & (sizeof(uint32_t) - 1)) == 0); |
| assert((info->keyvals_even & (mr.keyvals_alignment - 1)) == 0); |
| assert((info->keyvals_odd & (mr.keyvals_alignment - 1)) == 0); |
| assert((info->internal & (mr.internal_alignment - 1)) == 0); |
| // clang-format on |
| } |
| |
| #endif |
| |
| // |
| // |
| // |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| |
| struct radix_sort_vk_ext_base |
| { |
| void * ext; |
| enum radix_sort_vk_ext_type type; |
| }; |
| |
| #endif |
| |
| // |
| // |
| // |
| void |
| radix_sort_vk_sort_devaddr(radix_sort_vk_t const * rs, |
| radix_sort_vk_sort_devaddr_info_t const * info, |
| VkDevice device, |
| VkCommandBuffer cb, |
| VkDeviceAddress * keyvals_sorted) |
| { |
| // |
| // Anything to do? |
| // |
| if ((info->count <= 1) || (info->key_bits == 0)) |
| { |
| *keyvals_sorted = info->keyvals_even.devaddr; |
| |
| return; |
| } |
| |
| #ifndef NDEBUG |
| radix_sort_vk_sort_validate_info(rs, info); |
| #endif |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| // |
| // Any extensions? |
| // |
| struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL; |
| |
| void * ext_next = info->ext; |
| |
| while (ext_next != NULL) |
| { |
| struct radix_sort_vk_ext_base * const base = ext_next; |
| |
| switch (base->type) |
| { |
| case RADIX_SORT_VK_EXT_TIMESTAMPS: |
| ext_timestamps = ext_next; |
| ext_timestamps->timestamps_set = 0; |
| break; |
| } |
| |
| ext_next = base->ext; |
| } |
| #endif |
| |
| //////////////////////////////////////////////////////////////////////// |
| // |
| // OVERVIEW |
| // |
| // 1. Pad the keyvals in `scatter_even`. |
| // 2. Zero the `histograms` and `partitions`. |
| // --- BARRIER --- |
| // 3. HISTOGRAM is dispatched before PREFIX. |
| // --- BARRIER --- |
| // 4. PREFIX is dispatched before the first SCATTER. |
| // --- BARRIER --- |
| // 5. One or more SCATTER dispatches. |
| // |
| // Note that the `partitions` buffer can be zeroed anytime before the first |
| // scatter. |
| // |
| //////////////////////////////////////////////////////////////////////// |
| |
| // |
| // Label the command buffer |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_DEBUG_UTILS |
| if (pfn_vkCmdBeginDebugUtilsLabelEXT != NULL) |
| { |
| VkDebugUtilsLabelEXT const label = { |
| .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, |
| .pNext = NULL, |
| .pLabelName = "radix_sort_vk_sort", |
| }; |
| |
| pfn_vkCmdBeginDebugUtilsLabelEXT(cb, &label); |
| } |
| #endif |
| |
| // |
| // How many passes? |
| // |
| uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t); |
| uint32_t const keyval_bits = keyval_bytes * 8; |
| uint32_t const key_bits = MIN_MACRO(uint32_t, info->key_bits, keyval_bits); |
| uint32_t const passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2; |
| |
| *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even.devaddr; |
| |
| //////////////////////////////////////////////////////////////////////// |
| // |
| // PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS |
| // |
| // Pad fractional blocks with max-valued keyvals. |
| // |
| // Zero the histograms and partitions buffer. |
| // |
| // This assumes the partitions follow the histograms. |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT); |
| #endif |
| |
| // |
| // FIXME(allanmac): Consider precomputing some of these values and hang them |
| // off `rs`. |
| // |
| |
| // |
| // How many scatter blocks? |
| // |
| uint32_t const scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2; |
| uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows; |
| uint32_t const scatter_blocks_ru = (info->count + scatter_block_kvs - 1) / scatter_block_kvs; |
| uint32_t const count_ru_scatter = scatter_blocks_ru * scatter_block_kvs; |
| |
| // |
| // How many histogram blocks? |
| // |
| // Note that it's OK to have more max-valued digits counted by the histogram |
| // than sorted by the scatters because the sort is stable. |
| // |
| uint32_t const histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2; |
| uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows; |
| uint32_t const histo_blocks_ru = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs; |
| uint32_t const count_ru_histo = histo_blocks_ru * histo_block_kvs; |
| |
| // |
| // Fill with max values |
| // |
| if (count_ru_histo > info->count) |
| { |
| info->fill_buffer_pfn(cb, |
| &info->keyvals_even, |
| info->count * keyval_bytes, |
| (count_ru_histo - info->count) * keyval_bytes, |
| 0xFFFFFFFF); |
| } |
| |
| // |
| // Zero histograms, partitions and (optionally) workgroup_ids. |
| // |
| // Note that the actively used histograms are "right justified" to the end of |
| // the `.histograms[keyval_size]` region. That is, if the sort uses only 8 |
| // bits of key in the keyval then the starting histogram index points to the |
| // last histogram. This creates a contiguous region of histograms and |
| // partitions that can be zeroed with one dispatch. |
| // |
| // Note that the partition invalidation only needs to be performed once |
| // because the even/odd scatter dispatches rely on the the previous pass to |
| // leave the partitions in an invalid state. |
| // |
| // Note that the last workgroup doesn't read/write a partition so it doesn't |
| // need to be initialized. |
| // |
| // The additional `workgroup_ids_size` bytes is to support devices that do not |
| // support nonsequential dispatch. |
| // |
| uint32_t const histo_partition_count = passes + scatter_blocks_ru - 1; |
| uint32_t pass_idx = (keyval_bytes - passes); |
| |
| VkDeviceSize const histo_size = RS_RADIX_SIZE * sizeof(uint32_t); |
| VkDeviceSize const zero_offset = pass_idx * histo_size; |
| VkDeviceSize zero_size = histo_partition_count * histo_size; |
| |
| #ifndef RADIX_SORT_VK_DISABLE_NONSEQUENTIAL_DISPATCH |
| VkDeviceSize const workgroup_ids_size = keyval_bytes * sizeof(uint32_t); |
| |
| zero_size += workgroup_ids_size; |
| #endif |
| |
| info->fill_buffer_pfn(cb, |
| &info->internal, |
| rs->internal.offset.histograms + zero_offset, |
| zero_size, |
| 0); |
| |
| //////////////////////////////////////////////////////////////////////// |
| // |
| // Pipeline: HISTOGRAM |
| // |
| // TODO(allanmac): All subgroups should try to process approximately the same |
| // number of blocks in order to minimize tail effects. This was implemented |
| // and reverted but should be reimplemented and benchmarked later. |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TRANSFER_BIT); |
| #endif |
| |
| vk_barrier_transfer_w_to_compute_r(cb); |
| |
| // clang-format off |
| VkDeviceAddress const devaddr_histograms = info->internal.devaddr + rs->internal.offset.histograms; |
| VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even.devaddr; |
| // clang-format on |
| |
| // |
| // Dispatch histogram |
| // |
| struct rs_push_histogram const push_histogram = { |
| |
| .devaddr_histograms = devaddr_histograms, |
| .devaddr_keyvals = devaddr_keyvals_even, |
| .passes = passes |
| }; |
| |
| vkCmdPushConstants(cb, |
| rs->pipeline_layouts.named.histogram, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| 0, |
| sizeof(push_histogram), |
| &push_histogram); |
| |
| vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram); |
| |
| vkCmdDispatch(cb, histo_blocks_ru, 1, 1); |
| |
| //////////////////////////////////////////////////////////////////////// |
| // |
| // Pipeline: PREFIX |
| // |
| // Launch one workgroup per pass. |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| vk_barrier_compute_w_to_compute_r(cb); |
| |
| struct rs_push_prefix const push_prefix = { |
| |
| .devaddr_histograms = devaddr_histograms, |
| }; |
| |
| vkCmdPushConstants(cb, |
| rs->pipeline_layouts.named.prefix, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| 0, |
| sizeof(push_prefix), |
| &push_prefix); |
| |
| vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix); |
| |
| vkCmdDispatch(cb, passes, 1, 1); |
| |
| //////////////////////////////////////////////////////////////////////// |
| // |
| // Pipeline: SCATTER |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| vk_barrier_compute_w_to_compute_r(cb); |
| |
| // clang-format off |
| uint32_t const histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); |
| VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd; |
| VkDeviceAddress const devaddr_partitions = info->internal.devaddr + rs->internal.offset.partitions; |
| // clang-format on |
| |
| struct rs_push_scatter push_scatter = { |
| |
| .devaddr_keyvals_even = devaddr_keyvals_even, |
| .devaddr_keyvals_odd = devaddr_keyvals_odd, |
| .devaddr_partitions = devaddr_partitions, |
| .devaddr_histograms = devaddr_histograms + histogram_offset, |
| .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2, |
| }; |
| |
| { |
| uint32_t const pass_dword = pass_idx / 4; |
| |
| vkCmdPushConstants(cb, |
| rs->pipeline_layouts.named.scatter[pass_dword].even, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| 0, |
| sizeof(push_scatter), |
| &push_scatter); |
| |
| vkCmdBindPipeline(cb, |
| VK_PIPELINE_BIND_POINT_COMPUTE, |
| rs->pipelines.named.scatter[pass_dword].even); |
| } |
| |
| bool is_even = true; |
| |
| while (true) |
| { |
| vkCmdDispatch(cb, scatter_blocks_ru, 1, 1); |
| |
| // |
| // Continue? |
| // |
| if (++pass_idx >= keyval_bytes) |
| { |
| break; |
| } |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| vk_barrier_compute_w_to_compute_r(cb); |
| |
| // clang-format off |
| is_even ^= true; |
| push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t)); |
| push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; |
| // clang-format on |
| |
| uint32_t const pass_dword = pass_idx / 4; |
| |
| // |
| // Update push constants that changed |
| // |
| VkPipelineLayout const pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even // |
| : rs->pipeline_layouts.named.scatter[pass_dword].odd; |
| vkCmdPushConstants(cb, |
| pl, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms), |
| sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset), |
| &push_scatter.devaddr_histograms); |
| |
| // |
| // Bind new pipeline |
| // |
| VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even // |
| : rs->pipelines.named.scatter[pass_dword].odd; |
| |
| vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p); |
| } |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| // |
| // End the label |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_DEBUG_UTILS |
| if (pfn_vkCmdEndDebugUtilsLabelEXT != NULL) |
| { |
| pfn_vkCmdEndDebugUtilsLabelEXT(cb); |
| } |
| #endif |
| } |
| |
| // |
| // |
| // |
| void |
| radix_sort_vk_sort_indirect_devaddr(radix_sort_vk_t const * rs, |
| radix_sort_vk_sort_indirect_devaddr_info_t const * info, |
| VkDevice device, |
| VkCommandBuffer cb, |
| VkDeviceAddress * keyvals_sorted) |
| { |
| // |
| // Anything to do? |
| // |
| if (info->key_bits == 0) |
| { |
| *keyvals_sorted = info->keyvals_even; |
| return; |
| } |
| |
| #ifndef NDEBUG |
| radix_sort_vk_sort_indirect_validate_info(rs, info); |
| #endif |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| // |
| // Any extensions? |
| // |
| struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL; |
| |
| void * ext_next = info->ext; |
| |
| while (ext_next != NULL) |
| { |
| struct radix_sort_vk_ext_base * const base = ext_next; |
| |
| switch (base->type) |
| { |
| case RADIX_SORT_VK_EXT_TIMESTAMPS: |
| ext_timestamps = ext_next; |
| ext_timestamps->timestamps_set = 0; |
| break; |
| } |
| |
| ext_next = base->ext; |
| } |
| #endif |
| |
| //////////////////////////////////////////////////////////////////////// |
| // |
| // OVERVIEW |
| // |
| // 1. Init |
| // --- BARRIER --- |
| // 2. Pad the keyvals in `scatter_even`. |
| // 3. Zero the `histograms` and `partitions`. |
| // --- BARRIER --- |
| // 4. HISTOGRAM is dispatched before PREFIX. |
| // --- BARRIER --- |
| // 5. PREFIX is dispatched before the first SCATTER. |
| // --- BARRIER --- |
| // 6. One or more SCATTER dispatches. |
| // |
| // Note that the `partitions` buffer can be zeroed anytime before the first |
| // scatter. |
| // |
| //////////////////////////////////////////////////////////////////////// |
| |
| // |
| // Label the command buffer |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_DEBUG_UTILS |
| if (pfn_vkCmdBeginDebugUtilsLabelEXT != NULL) |
| { |
| VkDebugUtilsLabelEXT const label = { |
| .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, |
| .pNext = NULL, |
| .pLabelName = "radix_sort_vk_sort_indirect", |
| }; |
| |
| pfn_vkCmdBeginDebugUtilsLabelEXT(cb, &label); |
| } |
| #endif |
| |
| // |
| // How many passes? |
| // |
| uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t); |
| uint32_t const keyval_bits = keyval_bytes * 8; |
| uint32_t const key_bits = MIN_MACRO(uint32_t, info->key_bits, keyval_bits); |
| uint32_t const passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2; |
| uint32_t pass_idx = (keyval_bytes - passes); |
| |
| *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even; |
| |
| // |
| // NOTE(allanmac): Some of these initializations appear redundant but for now |
| // we're going to assume the compiler will elide them. |
| // |
| // clang-format off |
| VkDeviceAddress const devaddr_info = info->indirect.devaddr; |
| VkDeviceAddress const devaddr_count = info->count; |
| VkDeviceAddress const devaddr_histograms = info->internal + rs->internal.offset.histograms; |
| VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even; |
| // clang-format on |
| |
| // |
| // START |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT); |
| #endif |
| |
| // |
| // INIT |
| // |
| { |
| struct rs_push_init const push_init = { |
| |
| .devaddr_info = devaddr_info, |
| .devaddr_count = devaddr_count, |
| .passes = passes |
| }; |
| |
| vkCmdPushConstants(cb, |
| rs->pipeline_layouts.named.init, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| 0, |
| sizeof(push_init), |
| &push_init); |
| |
| vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.init); |
| |
| vkCmdDispatch(cb, 1, 1, 1); |
| } |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| vk_barrier_compute_w_to_indirect_compute_r(cb); |
| |
| { |
| // |
| // PAD |
| // |
| struct rs_push_fill const push_pad = { |
| |
| .devaddr_info = devaddr_info + offsetof(struct rs_indirect_info, pad), |
| .devaddr_dwords = devaddr_keyvals_even, |
| .dword = 0xFFFFFFFF |
| }; |
| |
| vkCmdPushConstants(cb, |
| rs->pipeline_layouts.named.fill, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| 0, |
| sizeof(push_pad), |
| &push_pad); |
| |
| vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill); |
| |
| info->dispatch_indirect_pfn(cb, |
| &info->indirect, |
| offsetof(struct rs_indirect_info, dispatch.pad)); |
| } |
| |
| // |
| // ZERO |
| // |
| { |
| VkDeviceSize const histo_offset = pass_idx * (sizeof(uint32_t) * RS_RADIX_SIZE); |
| |
| struct rs_push_fill const push_zero = { |
| |
| .devaddr_info = devaddr_info + offsetof(struct rs_indirect_info, zero), |
| .devaddr_dwords = devaddr_histograms + histo_offset, |
| .dword = 0 |
| }; |
| |
| vkCmdPushConstants(cb, |
| rs->pipeline_layouts.named.fill, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| 0, |
| sizeof(push_zero), |
| &push_zero); |
| |
| vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill); |
| |
| info->dispatch_indirect_pfn(cb, |
| &info->indirect, |
| offsetof(struct rs_indirect_info, dispatch.zero)); |
| } |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| vk_barrier_compute_w_to_compute_r(cb); |
| |
| // |
| // HISTOGRAM |
| // |
| { |
| struct rs_push_histogram const push_histogram = { |
| |
| .devaddr_histograms = devaddr_histograms, |
| .devaddr_keyvals = devaddr_keyvals_even, |
| .passes = passes |
| }; |
| |
| vkCmdPushConstants(cb, |
| rs->pipeline_layouts.named.histogram, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| 0, |
| sizeof(push_histogram), |
| &push_histogram); |
| |
| vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram); |
| |
| info->dispatch_indirect_pfn(cb, |
| &info->indirect, |
| offsetof(struct rs_indirect_info, dispatch.histogram)); |
| } |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| vk_barrier_compute_w_to_compute_r(cb); |
| |
| // |
| // PREFIX |
| // |
| { |
| struct rs_push_prefix const push_prefix = { |
| .devaddr_histograms = devaddr_histograms, |
| }; |
| |
| vkCmdPushConstants(cb, |
| rs->pipeline_layouts.named.prefix, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| 0, |
| sizeof(push_prefix), |
| &push_prefix); |
| |
| vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix); |
| |
| vkCmdDispatch(cb, passes, 1, 1); |
| } |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| vk_barrier_compute_w_to_compute_r(cb); |
| |
| // |
| // SCATTER |
| // |
| { |
| // clang-format off |
| uint32_t const histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); |
| VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd; |
| VkDeviceAddress const devaddr_partitions = info->internal + rs->internal.offset.partitions; |
| // clang-format on |
| |
| struct rs_push_scatter push_scatter = { |
| .devaddr_keyvals_even = devaddr_keyvals_even, |
| .devaddr_keyvals_odd = devaddr_keyvals_odd, |
| .devaddr_partitions = devaddr_partitions, |
| .devaddr_histograms = devaddr_histograms + histogram_offset, |
| .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2, |
| }; |
| |
| { |
| uint32_t const pass_dword = pass_idx / 4; |
| |
| vkCmdPushConstants(cb, |
| rs->pipeline_layouts.named.scatter[pass_dword].even, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| 0, |
| sizeof(push_scatter), |
| &push_scatter); |
| |
| vkCmdBindPipeline(cb, |
| VK_PIPELINE_BIND_POINT_COMPUTE, |
| rs->pipelines.named.scatter[pass_dword].even); |
| } |
| |
| bool is_even = true; |
| |
| while (true) |
| { |
| info->dispatch_indirect_pfn(cb, |
| &info->indirect, |
| offsetof(struct rs_indirect_info, dispatch.scatter)); |
| |
| // |
| // Continue? |
| // |
| if (++pass_idx >= keyval_bytes) |
| { |
| break; |
| } |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| vk_barrier_compute_w_to_compute_r(cb); |
| |
| // clang-format off |
| is_even ^= true; |
| push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t)); |
| push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; |
| // clang-format on |
| |
| uint32_t const pass_dword = pass_idx / 4; |
| |
| // |
| // Update push constants that changed |
| // |
| VkPipelineLayout const pl = is_even |
| ? rs->pipeline_layouts.named.scatter[pass_dword].even // |
| : rs->pipeline_layouts.named.scatter[pass_dword].odd; |
| // clang-format off |
| vkCmdPushConstants(cb, |
| pl, |
| VK_SHADER_STAGE_COMPUTE_BIT, |
| OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms), |
| sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset), |
| &push_scatter.devaddr_histograms); |
| // clang-format on |
| |
| // |
| // Bind new pipeline |
| // |
| VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even // |
| : rs->pipelines.named.scatter[pass_dword].odd; |
| |
| vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p); |
| } |
| } |
| |
| #ifdef RADIX_SORT_VK_ENABLE_EXTENSIONS |
| rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); |
| #endif |
| |
| // |
| // End the label |
| // |
| #ifdef RADIX_SORT_VK_ENABLE_DEBUG_UTILS |
| if (pfn_vkCmdEndDebugUtilsLabelEXT != NULL) |
| { |
| pfn_vkCmdEndDebugUtilsLabelEXT(cb); |
| } |
| #endif |
| } |
| |
| // |
| // Implementation of radix_sort_vk_fill_buffer_pfn. |
| // |
| static void |
| radix_sort_vk_fill_buffer(VkCommandBuffer cb, |
| radix_sort_vk_buffer_info_t const * buffer_info, |
| VkDeviceSize offset, |
| VkDeviceSize size, |
| uint32_t data) |
| { |
| vkCmdFillBuffer(cb, buffer_info->buffer, buffer_info->offset + offset, size, data); |
| } |
| |
| // |
| // |
| // |
| void |
| radix_sort_vk_sort(radix_sort_vk_t const * rs, |
| radix_sort_vk_sort_info_t const * info, |
| VkDevice device, |
| VkCommandBuffer cb, |
| VkDescriptorBufferInfo * keyvals_sorted) |
| { |
| struct radix_sort_vk_sort_devaddr_info const di = { |
| .ext = info->ext, |
| .key_bits = info->key_bits, |
| .count = info->count, |
| .keyvals_even = { .buffer = info->keyvals_even.buffer, |
| .offset = info->keyvals_even.offset, |
| .devaddr = rs_get_devaddr(device, &info->keyvals_even) }, |
| .keyvals_odd = rs_get_devaddr(device, &info->keyvals_odd), |
| .internal = { .buffer = info->internal.buffer, |
| .offset = info->internal.offset, |
| .devaddr = rs_get_devaddr(device, &info->internal), }, |
| |
| .fill_buffer_pfn = radix_sort_vk_fill_buffer, |
| }; |
| |
| VkDeviceAddress di_keyvals_sorted; |
| |
| radix_sort_vk_sort_devaddr(rs, &di, device, cb, &di_keyvals_sorted); |
| |
| *keyvals_sorted = (di_keyvals_sorted == di.keyvals_even.devaddr) // |
| ? info->keyvals_even |
| : info->keyvals_odd; |
| } |
| |
| // |
| // Implementation of radix_sort_vk_dispatch_indirect_pfn. |
| // |
| static void |
| radix_sort_vk_dispatch_indirect(VkCommandBuffer cb, |
| radix_sort_vk_buffer_info_t const * buffer_info, |
| VkDeviceSize offset) |
| { |
| vkCmdDispatchIndirect(cb, buffer_info->buffer, buffer_info->offset + offset); |
| } |
| |
| // |
| // |
| // |
| void |
| radix_sort_vk_sort_indirect(radix_sort_vk_t const * rs, |
| radix_sort_vk_sort_indirect_info_t const * info, |
| VkDevice device, |
| VkCommandBuffer cb, |
| VkDescriptorBufferInfo * keyvals_sorted) |
| { |
| struct radix_sort_vk_sort_indirect_devaddr_info const idi = { |
| .ext = info->ext, |
| .key_bits = info->key_bits, |
| .count = rs_get_devaddr(device, &info->count), |
| .keyvals_even = rs_get_devaddr(device, &info->keyvals_even), |
| .keyvals_odd = rs_get_devaddr(device, &info->keyvals_odd), |
| .internal = rs_get_devaddr(device, &info->internal), |
| .indirect = { .buffer = info->indirect.buffer, |
| .offset = info->indirect.offset, |
| .devaddr = rs_get_devaddr(device, &info->indirect) }, |
| |
| .dispatch_indirect_pfn = radix_sort_vk_dispatch_indirect, |
| }; |
| |
| VkDeviceAddress idi_keyvals_sorted; |
| |
| radix_sort_vk_sort_indirect_devaddr(rs, &idi, device, cb, &idi_keyvals_sorted); |
| |
| *keyvals_sorted = (idi_keyvals_sorted == idi.keyvals_even) // |
| ? info->keyvals_even |
| : info->keyvals_odd; |
| } |
| |
| // |
| // |
| // |