import { ScanBlockSize, prefix_sum_comp_spv, block_prefix_sum_comp_spv, add_block_sums_comp_spv } from "./embedded_shaders"; export var alignTo = function(val, align) { return Math.floor((val + align - 1) / align) * align; }; // Serial scan for validation var serialExclusiveScan = function(array, output) { output[0] = 0; for (var i = 1; i < array.length; ++i) { output[i] = array[i - 1] + output[i - 1]; } return output[array.length - 1] + array[array.length - 1]; }; export var ExclusiveScanPipeline = function(device) { this.device = device; // Each thread in a work group is responsible for 2 elements this.workGroupSize = ScanBlockSize / 2; // The max size which can be scanned by a single batch without carry in/out this.maxScanSize = ScanBlockSize * ScanBlockSize; console.log(`Block size: ${ScanBlockSize}, max scan size: ${this.maxScanSize}`); // Buffer to clear the block sums for each new scan var clearBlocks = device.createBuffer({ size: ScanBlockSize * 4, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, mappedAtCreation: true, }); new Uint32Array(clearBlocks.getMappedRange()).fill(0); clearBlocks.unmap(); this.clearBuf = clearBlocks; this.scanBlocksLayout = device.createBindGroupLayout({ entries: [ {binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {type: "storage"}}, { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage", } }, ], }); this.scanBlockResultsLayout = device.createBindGroupLayout({ entries: [ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage", } }, { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage", } }, ], }); this.scanBlocksPipeline = device.createComputePipeline({ layout: device.createPipelineLayout({ bindGroupLayouts: [this.scanBlocksLayout], }), compute: { module: device.createShaderModule({code: prefix_sum_comp_spv}), entryPoint: "main", }, }); this.scanBlockResultsPipeline = device.createComputePipeline({ layout: device.createPipelineLayout({ bindGroupLayouts: [this.scanBlockResultsLayout], }), compute: { module: device.createShaderModule({code: block_prefix_sum_comp_spv}), entryPoint: "main", }, }); this.addBlockSumsPipeline = device.createComputePipeline({ layout: device.createPipelineLayout({ bindGroupLayouts: [this.scanBlocksLayout], }), compute: { module: device.createShaderModule({code: add_block_sums_comp_spv}), entryPoint: "main", }, }); }; ExclusiveScanPipeline.prototype.getAlignedSize = function(size) { return alignTo(size, ScanBlockSize); }; // TODO: refactor to have this return a prepared scanner object? // Then the pipelines and bind group layouts can be re-used and shared between the scanners ExclusiveScanPipeline.prototype.prepareInput = function(cpuArray) { var alignedSize = alignTo(cpuArray.length, ScanBlockSize); // Upload input and pad to block size elements var inputBuf = this.device.createBuffer({ size: alignedSize * 4, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, mappedAtCreation: true, }); new Uint32Array(inputBuf.getMappedRange()).set(cpuArray); inputBuf.unmap(); return new ExclusiveScanner(this, inputBuf, alignedSize, cpuArray.length); }; ExclusiveScanPipeline.prototype.prepareGPUInput = function(gpuBuffer, alignedSize) { if (this.getAlignedSize(alignedSize) != alignedSize) { alert("Error: GPU input must be aligned to getAlignedSize"); } return new ExclusiveScanner(this, gpuBuffer, alignedSize); }; var ExclusiveScanner = function(scanPipeline, gpuBuffer, alignedSize) { this.scanPipeline = scanPipeline; this.inputSize = alignedSize; this.inputBuf = gpuBuffer; this.readbackBuf = scanPipeline.device.createBuffer({ size: 4, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, }); // Block sum buffer var blockSumBuf = scanPipeline.device.createBuffer({ size: ScanBlockSize * 4, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, mappedAtCreation: true, }); new Uint32Array(blockSumBuf.getMappedRange()).fill(0); blockSumBuf.unmap(); this.blockSumBuf = blockSumBuf; var carryBuf = scanPipeline.device.createBuffer({ size: 8, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, mappedAtCreation: true, }); new Uint32Array(carryBuf.getMappedRange()).fill(0); carryBuf.unmap(); this.carryBuf = carryBuf; // Can't copy from a buffer to itself so we need an intermediate to move the carry this.carryIntermediateBuf = scanPipeline.device.createBuffer({ size: 4, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, }); this.scanBlockResultsBindGroup = scanPipeline.device.createBindGroup({ layout: this.scanPipeline.scanBlockResultsLayout, entries: [ { binding: 0, resource: { buffer: blockSumBuf, }, }, { binding: 1, resource: { buffer: carryBuf, }, }, ], }); }; ExclusiveScanner.prototype.scan = async function(dataSize) { // If the data size we're scanning within the larger input array has changed, // we just need to re-record the scan commands var numChunks = Math.ceil(dataSize / this.scanPipeline.maxScanSize); this.offsets = new Uint32Array(numChunks); for (var i = 0; i < numChunks; ++i) { this.offsets.set([i * this.scanPipeline.maxScanSize * 4], i); } // Scan through the data in chunks, updating carry in/out at the end to carry // over the results of the previous chunks var commandEncoder = this.scanPipeline.device.createCommandEncoder(); // Clear the carry buffer and the readback sum entry if it's not scan size aligned commandEncoder.copyBufferToBuffer(this.scanPipeline.clearBuf, 0, this.carryBuf, 0, 8); for (var i = 0; i < numChunks; ++i) { var nWorkGroups = Math.min((this.inputSize - i * this.scanPipeline.maxScanSize) / ScanBlockSize, ScanBlockSize); var scanBlockBG = null; if (nWorkGroups === ScanBlockSize) { scanBlockBG = this.scanPipeline.device.createBindGroup({ layout: this.scanPipeline.scanBlocksLayout, entries: [ { binding: 0, resource: { buffer: this.inputBuf, size: Math.min(this.scanPipeline.maxScanSize, this.inputSize) * 4, offset: this.offsets[i], }, }, { binding: 1, resource: { buffer: this.blockSumBuf, }, }, ], }); } else { // Bind groups for processing the remainder if the aligned size isn't // an even multiple of the max scan size scanBlockBG = this.scanPipeline.device.createBindGroup({ layout: this.scanPipeline.scanBlocksLayout, entries: [ { binding: 0, resource: { buffer: this.inputBuf, size: (this.inputSize % this.scanPipeline.maxScanSize) * 4, offset: this.offsets[i], }, }, { binding: 1, resource: { buffer: this.blockSumBuf, }, }, ], }); } // Clear the previous block sums commandEncoder.copyBufferToBuffer( this.scanPipeline.clearBuf, 0, this.blockSumBuf, 0, ScanBlockSize * 4); var computePass = commandEncoder.beginComputePass(); computePass.setPipeline(this.scanPipeline.scanBlocksPipeline); computePass.setBindGroup(0, scanBlockBG); computePass.dispatchWorkgroups(nWorkGroups, 1, 1); computePass.setPipeline(this.scanPipeline.scanBlockResultsPipeline); computePass.setBindGroup(0, this.scanBlockResultsBindGroup); computePass.dispatchWorkgroups(1, 1, 1); computePass.setPipeline(this.scanPipeline.addBlockSumsPipeline); computePass.setBindGroup(0, scanBlockBG); computePass.dispatchWorkgroups(nWorkGroups, 1, 1); computePass.end(); // Update the carry in value for the next chunk, copy carry out to carry in commandEncoder.copyBufferToBuffer(this.carryBuf, 4, this.carryIntermediateBuf, 0, 4); commandEncoder.copyBufferToBuffer(this.carryIntermediateBuf, 0, this.carryBuf, 0, 4); } var commandBuffer = commandEncoder.finish(); // We need to clear a different element in the input buf for the last item if the data size // shrinks if (dataSize < this.inputSize) { var commandEncoder = this.scanPipeline.device.createCommandEncoder(); commandEncoder.copyBufferToBuffer( this.scanPipeline.clearBuf, 0, this.inputBuf, dataSize * 4, 4); this.scanPipeline.device.queue.submit([commandEncoder.finish()]); } this.scanPipeline.device.queue.submit([commandBuffer]); // Readback the the last element to return the total sum as well var commandEncoder = this.scanPipeline.device.createCommandEncoder(); if (dataSize < this.inputSize) { commandEncoder.copyBufferToBuffer(this.inputBuf, dataSize * 4, this.readbackBuf, 0, 4); } else { commandEncoder.copyBufferToBuffer(this.carryBuf, 4, this.readbackBuf, 0, 4); } this.scanPipeline.device.queue.submit([commandEncoder.finish()]); await this.readbackBuf.mapAsync(GPUMapMode.READ); var mapping = new Uint32Array(this.readbackBuf.getMappedRange()); var sum = mapping[0]; this.readbackBuf.unmap(); return sum; };