blob: 48f7a731eba2a4f37cfa618c3b3f42752b9d0f7f [file] [log] [blame]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::{borrow::Cow, convert::TryFrom, mem};
use bytemuck::{Pod, Zeroable};
use ramhorns::{Content, Template};
use wgpu::util::DeviceExt;
const MAX_WORKGROUPS: usize = 65_536;
pub const BLOCK_SIZE: BlockSize = BlockSize::new(64, 9);
fn div_round_up(n: usize, d: usize) -> usize {
(n + d - 1) / d
}
fn log2_round_up(n: usize) -> u32 {
if n.count_ones() == 1 {
n.trailing_zeros()
} else {
mem::size_of::<usize>() as u32 * 8 - n.leading_zeros()
}
}
#[derive(Content, Debug)]
pub struct BlockSize {
block_width: u32,
block_height: u32,
pub block_len: u32,
}
impl BlockSize {
pub const fn new(block_width: u32, block_height: u32) -> Self {
Self { block_width, block_height, block_len: block_width * block_height }
}
}
#[derive(Debug)]
pub struct SortContext {
block_sort_pipeline: wgpu::ComputePipeline,
block_sort_bind_group_layout: wgpu::BindGroupLayout,
find_merge_offsets_pipeline: wgpu::ComputePipeline,
find_merge_offsets_bind_group_layout: wgpu::BindGroupLayout,
merge_blocks_pipeline: wgpu::ComputePipeline,
merge_blocks_bind_group_layout: wgpu::BindGroupLayout,
}
pub fn init(device: &wgpu::Device) -> SortContext {
let template = Template::new(include_str!("sort.wgsl")).unwrap();
let module = device.create_shader_module(&wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Owned(template.render(&BLOCK_SIZE))),
});
let block_sort_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "blockSort",
});
let block_sort_bind_group_layout = block_sort_pipeline.get_bind_group_layout(0);
let find_merge_offsets_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "findMergeOffsets",
});
let find_merge_offsets_bind_group_layout = find_merge_offsets_pipeline.get_bind_group_layout(0);
let merge_blocks_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "mergeBlocks",
});
let merge_blocks_bind_group_layout = merge_blocks_pipeline.get_bind_group_layout(0);
SortContext {
block_sort_pipeline,
block_sort_bind_group_layout,
find_merge_offsets_pipeline,
find_merge_offsets_bind_group_layout,
merge_blocks_pipeline,
merge_blocks_bind_group_layout,
}
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
struct Config {
len: u32,
len_in_blocks: u32,
n_way: u32,
}
impl Config {
pub fn new(len: usize) -> Option<Self> {
let len_in_blocks = u32::try_from(div_round_up(len, BLOCK_SIZE.block_len as usize)).ok()?;
Some(Self { len: len as _, len_in_blocks, n_way: 0 })
}
pub fn workgroup_size(&self) -> u32 {
self.len_in_blocks.min(MAX_WORKGROUPS as u32)
}
}
#[allow(clippy::too_many_arguments)]
pub fn encode<'b>(
device: &wgpu::Device,
encoder: &mut wgpu::CommandEncoder,
context: &SortContext,
len: usize,
storage_buffer0: &'b wgpu::Buffer,
storage_buffer1: &'b wgpu::Buffer,
offsets_buffer: &wgpu::Buffer,
timestamp: Option<(&wgpu::QuerySet, u32, u32)>,
) -> &'b wgpu::Buffer {
let mut config = Config::new(len).expect("numbers length too high");
let config_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::bytes_of(&config),
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::UNIFORM,
});
let block_sort_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &context.block_sort_bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: storage_buffer0.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: config_buffer.as_entire_binding() },
],
});
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
cpass.set_pipeline(&context.block_sort_pipeline);
cpass.set_bind_group(0, &block_sort_bind_group, &[]);
if let Some((timestamp, start_index, _)) = timestamp {
cpass.write_timestamp(timestamp, start_index);
}
cpass.dispatch_workgroups(config.workgroup_size(), 1, 1);
}
let rounds = log2_round_up(config.len_in_blocks as usize);
let max_rounds = log2_round_up(div_round_up(
device.limits().max_storage_buffer_binding_size as usize,
BLOCK_SIZE.block_len as usize,
));
for round in 0..max_rounds {
config.n_way = 1 << (round + 1);
let config_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::bytes_of(&config),
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::UNIFORM,
});
let find_merge_offsets_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &context.find_merge_offsets_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: if round % 2 == 0 {
storage_buffer0.as_entire_binding()
} else {
storage_buffer1.as_entire_binding()
},
},
wgpu::BindGroupEntry { binding: 2, resource: config_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: offsets_buffer.as_entire_binding() },
],
});
{
let mut cpass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
cpass.set_pipeline(&context.find_merge_offsets_pipeline);
cpass.set_bind_group(0, &find_merge_offsets_bind_group, &[]);
cpass.dispatch_workgroups(
div_round_up(config.workgroup_size() as usize, BLOCK_SIZE.block_width as usize)
as u32,
1,
1,
);
}
let merge_blocks_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &context.merge_blocks_bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: if round % 2 == 0 {
storage_buffer0.as_entire_binding()
} else {
storage_buffer1.as_entire_binding()
},
},
wgpu::BindGroupEntry {
binding: 1,
resource: if round % 2 == 0 {
storage_buffer1.as_entire_binding()
} else {
storage_buffer0.as_entire_binding()
},
},
wgpu::BindGroupEntry { binding: 2, resource: config_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: offsets_buffer.as_entire_binding() },
],
});
{
let mut cpass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
cpass.set_pipeline(&context.merge_blocks_pipeline);
cpass.set_bind_group(0, &merge_blocks_bind_group, &[]);
cpass.dispatch_workgroups(config.workgroup_size(), 1, 1);
if round == max_rounds - 1 {
if let Some((timestamp, _, end_index)) = timestamp {
cpass.write_timestamp(timestamp, end_index);
}
}
}
}
if rounds % 2 == 0 {
storage_buffer0
} else {
storage_buffer1
}
}