| // Copyright 2022 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| use std::{borrow::Cow, convert::TryFrom, mem}; |
| |
| use bytemuck::{Pod, Zeroable}; |
| use ramhorns::{Content, Template}; |
| use wgpu::util::DeviceExt; |
| |
| const MAX_WORKGROUPS: usize = 65_536; |
| pub const BLOCK_SIZE: BlockSize = BlockSize::new(64, 9); |
| |
| fn div_round_up(n: usize, d: usize) -> usize { |
| (n + d - 1) / d |
| } |
| |
| fn log2_round_up(n: usize) -> u32 { |
| if n.count_ones() == 1 { |
| n.trailing_zeros() |
| } else { |
| mem::size_of::<usize>() as u32 * 8 - n.leading_zeros() |
| } |
| } |
| |
| #[derive(Content, Debug)] |
| pub struct BlockSize { |
| block_width: u32, |
| block_height: u32, |
| pub block_len: u32, |
| } |
| |
| impl BlockSize { |
| pub const fn new(block_width: u32, block_height: u32) -> Self { |
| Self { block_width, block_height, block_len: block_width * block_height } |
| } |
| } |
| |
| #[derive(Debug)] |
| pub struct SortContext { |
| block_sort_pipeline: wgpu::ComputePipeline, |
| block_sort_bind_group_layout: wgpu::BindGroupLayout, |
| find_merge_offsets_pipeline: wgpu::ComputePipeline, |
| find_merge_offsets_bind_group_layout: wgpu::BindGroupLayout, |
| merge_blocks_pipeline: wgpu::ComputePipeline, |
| merge_blocks_bind_group_layout: wgpu::BindGroupLayout, |
| } |
| |
| pub fn init(device: &wgpu::Device) -> SortContext { |
| let template = Template::new(include_str!("sort.wgsl")).unwrap(); |
| |
| let module = device.create_shader_module(&wgpu::ShaderModuleDescriptor { |
| label: None, |
| source: wgpu::ShaderSource::Wgsl(Cow::Owned(template.render(&BLOCK_SIZE))), |
| }); |
| |
| let block_sort_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { |
| label: None, |
| layout: None, |
| module: &module, |
| entry_point: "blockSort", |
| }); |
| |
| let block_sort_bind_group_layout = block_sort_pipeline.get_bind_group_layout(0); |
| |
| let find_merge_offsets_pipeline = |
| device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { |
| label: None, |
| layout: None, |
| module: &module, |
| entry_point: "findMergeOffsets", |
| }); |
| |
| let find_merge_offsets_bind_group_layout = find_merge_offsets_pipeline.get_bind_group_layout(0); |
| |
| let merge_blocks_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { |
| label: None, |
| layout: None, |
| module: &module, |
| entry_point: "mergeBlocks", |
| }); |
| |
| let merge_blocks_bind_group_layout = merge_blocks_pipeline.get_bind_group_layout(0); |
| |
| SortContext { |
| block_sort_pipeline, |
| block_sort_bind_group_layout, |
| find_merge_offsets_pipeline, |
| find_merge_offsets_bind_group_layout, |
| merge_blocks_pipeline, |
| merge_blocks_bind_group_layout, |
| } |
| } |
| |
| #[repr(C)] |
| #[derive(Clone, Copy, Debug, Pod, Zeroable)] |
| struct Config { |
| len: u32, |
| len_in_blocks: u32, |
| n_way: u32, |
| } |
| |
| impl Config { |
| pub fn new(len: usize) -> Option<Self> { |
| let len_in_blocks = u32::try_from(div_round_up(len, BLOCK_SIZE.block_len as usize)).ok()?; |
| |
| Some(Self { len: len as _, len_in_blocks, n_way: 0 }) |
| } |
| |
| pub fn workgroup_size(&self) -> u32 { |
| self.len_in_blocks.min(MAX_WORKGROUPS as u32) |
| } |
| } |
| |
| #[allow(clippy::too_many_arguments)] |
| pub fn encode<'b>( |
| device: &wgpu::Device, |
| encoder: &mut wgpu::CommandEncoder, |
| context: &SortContext, |
| len: usize, |
| storage_buffer0: &'b wgpu::Buffer, |
| storage_buffer1: &'b wgpu::Buffer, |
| offsets_buffer: &wgpu::Buffer, |
| timestamp: Option<(&wgpu::QuerySet, u32, u32)>, |
| ) -> &'b wgpu::Buffer { |
| let mut config = Config::new(len).expect("numbers length too high"); |
| |
| let config_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { |
| label: None, |
| contents: bytemuck::bytes_of(&config), |
| usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::UNIFORM, |
| }); |
| |
| let block_sort_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { |
| label: None, |
| layout: &context.block_sort_bind_group_layout, |
| entries: &[ |
| wgpu::BindGroupEntry { binding: 0, resource: storage_buffer0.as_entire_binding() }, |
| wgpu::BindGroupEntry { binding: 2, resource: config_buffer.as_entire_binding() }, |
| ], |
| }); |
| |
| { |
| let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); |
| cpass.set_pipeline(&context.block_sort_pipeline); |
| cpass.set_bind_group(0, &block_sort_bind_group, &[]); |
| |
| if let Some((timestamp, start_index, _)) = timestamp { |
| cpass.write_timestamp(timestamp, start_index); |
| } |
| |
| cpass.dispatch_workgroups(config.workgroup_size(), 1, 1); |
| } |
| |
| let rounds = log2_round_up(config.len_in_blocks as usize); |
| let max_rounds = log2_round_up(div_round_up( |
| device.limits().max_storage_buffer_binding_size as usize, |
| BLOCK_SIZE.block_len as usize, |
| )); |
| |
| for round in 0..max_rounds { |
| config.n_way = 1 << (round + 1); |
| |
| let config_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { |
| label: None, |
| contents: bytemuck::bytes_of(&config), |
| usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::UNIFORM, |
| }); |
| |
| let find_merge_offsets_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { |
| label: None, |
| layout: &context.find_merge_offsets_bind_group_layout, |
| entries: &[ |
| wgpu::BindGroupEntry { |
| binding: 0, |
| resource: if round % 2 == 0 { |
| storage_buffer0.as_entire_binding() |
| } else { |
| storage_buffer1.as_entire_binding() |
| }, |
| }, |
| wgpu::BindGroupEntry { binding: 2, resource: config_buffer.as_entire_binding() }, |
| wgpu::BindGroupEntry { binding: 3, resource: offsets_buffer.as_entire_binding() }, |
| ], |
| }); |
| |
| { |
| let mut cpass = |
| encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); |
| cpass.set_pipeline(&context.find_merge_offsets_pipeline); |
| cpass.set_bind_group(0, &find_merge_offsets_bind_group, &[]); |
| |
| cpass.dispatch_workgroups( |
| div_round_up(config.workgroup_size() as usize, BLOCK_SIZE.block_width as usize) |
| as u32, |
| 1, |
| 1, |
| ); |
| } |
| |
| let merge_blocks_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { |
| label: None, |
| layout: &context.merge_blocks_bind_group_layout, |
| entries: &[ |
| wgpu::BindGroupEntry { |
| binding: 0, |
| resource: if round % 2 == 0 { |
| storage_buffer0.as_entire_binding() |
| } else { |
| storage_buffer1.as_entire_binding() |
| }, |
| }, |
| wgpu::BindGroupEntry { |
| binding: 1, |
| resource: if round % 2 == 0 { |
| storage_buffer1.as_entire_binding() |
| } else { |
| storage_buffer0.as_entire_binding() |
| }, |
| }, |
| wgpu::BindGroupEntry { binding: 2, resource: config_buffer.as_entire_binding() }, |
| wgpu::BindGroupEntry { binding: 3, resource: offsets_buffer.as_entire_binding() }, |
| ], |
| }); |
| |
| { |
| let mut cpass = |
| encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); |
| cpass.set_pipeline(&context.merge_blocks_pipeline); |
| cpass.set_bind_group(0, &merge_blocks_bind_group, &[]); |
| |
| cpass.dispatch_workgroups(config.workgroup_size(), 1, 1); |
| |
| if round == max_rounds - 1 { |
| if let Some((timestamp, _, end_index)) = timestamp { |
| cpass.write_timestamp(timestamp, end_index); |
| } |
| } |
| } |
| } |
| |
| if rounds % 2 == 0 { |
| storage_buffer0 |
| } else { |
| storage_buffer1 |
| } |
| } |