-
-
Notifications
You must be signed in to change notification settings - Fork 63
impr: Use f16 and better subgroup shader in MNIST example
#2412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,40 +1,20 @@ | ||
| import tgpu, { d, type StorageFlag, type TgpuBuffer } from 'typegpu'; | ||
|
|
||
| export const ReadonlyFloats = { | ||
| storage: d.arrayOf(d.f32), | ||
| access: 'readonly', | ||
| } as const; | ||
|
|
||
| export const MutableFloats = { | ||
| storage: d.arrayOf(d.f32), | ||
| access: 'mutable', | ||
| } as const; | ||
|
|
||
| export const ioLayout = tgpu.bindGroupLayout({ | ||
| input: ReadonlyFloats, | ||
| output: MutableFloats, | ||
| }); | ||
|
|
||
| export const weightsBiasesLayout = tgpu.bindGroupLayout({ | ||
| weights: ReadonlyFloats, | ||
| biases: ReadonlyFloats, | ||
| }); | ||
| import { d, type StorageFlag, type TgpuBuffer } from 'typegpu'; | ||
|
|
||
| export interface LayerData { | ||
| shape: readonly [number] | readonly [number, number]; | ||
| buffer: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag; | ||
| buffer: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag; | ||
| } | ||
|
|
||
| export interface Layer { | ||
| weights: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag; | ||
| biases: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag; | ||
| state: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag; | ||
| weights: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag; | ||
| biases: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag; | ||
| state: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag; | ||
| } | ||
|
|
||
| export interface Network { | ||
| layers: Layer[]; | ||
| input: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag; | ||
| output: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag; | ||
| input: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag; | ||
| output: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag; | ||
|
|
||
| inference(data: number[]): Promise<number[]>; | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,10 @@ function getLayerData(layer: ArrayBuffer): { | |
| }; | ||
| } | ||
|
|
||
| export function downloadLayers(root: TgpuRoot): Promise<[LayerData, LayerData][]> { | ||
| export function downloadLayers( | ||
| root: TgpuRoot, | ||
| floatShcema: d.F32 | d.F16, | ||
| ): Promise<[LayerData, LayerData][]> { | ||
|
Comment on lines
+40
to
+43
|
||
| const downloadLayer = async (fileName: string): Promise<LayerData> => { | ||
| const buffer = await fetch(`/TypeGPU/assets/mnist-weights/${fileName}`).then((res) => | ||
| res.arrayBuffer(), | ||
|
|
@@ -46,7 +49,7 @@ export function downloadLayers(root: TgpuRoot): Promise<[LayerData, LayerData][] | |
| const { shape, data } = getLayerData(buffer); | ||
|
|
||
| const layerBuffer = root | ||
| .createBuffer(d.arrayOf(d.f32, data.length), [...data]) | ||
| .createBuffer(d.arrayOf(floatShcema, data.length), [...data]) | ||
| .$usage('storage'); | ||
|
|
||
| return { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,69 +1,89 @@ | ||
| import tgpu, { d, std } from 'typegpu'; | ||
| import { ioLayout, type LayerData, type Network, weightsBiasesLayout } from './data.ts'; | ||
| import type { LayerData, Network } from './data.ts'; | ||
| import { downloadLayers } from './helpers.ts'; | ||
| import { defineControls } from '../../common/defineControls.ts'; | ||
|
|
||
| const SIZE = 28; | ||
| const WORKGROUP_SIZE = 64; | ||
|
|
||
| const root = await tgpu.init({ | ||
| device: { | ||
| optionalFeatures: ['timestamp-query', 'subgroups'], | ||
| }, | ||
| device: { optionalFeatures: ['timestamp-query', 'subgroups', 'shader-f16'] }, | ||
| }); | ||
| const hasTimestampQuery = root.enabledFeatures.has('timestamp-query'); | ||
| const hasSubgroups = root.enabledFeatures.has('subgroups'); | ||
| const hasShaderF16 = root.enabledFeatures.has('shader-f16'); | ||
| let useSubgroups = hasSubgroups; | ||
|
|
||
| const float = hasShaderF16 ? d.f16 : d.f32; | ||
|
|
||
| const ioLayout = tgpu.bindGroupLayout({ | ||
| input: { storage: d.arrayOf(float) }, | ||
| output: { | ||
| storage: d.arrayOf(float), | ||
| access: 'mutable', | ||
| }, | ||
| }); | ||
|
|
||
| const weightsBiasesLayout = tgpu.bindGroupLayout({ | ||
| weights: { storage: d.arrayOf(float) }, | ||
| biases: { storage: d.arrayOf(float) }, | ||
| }); | ||
|
|
||
| const canvasData = Array.from({ length: SIZE ** 2 }, () => 0); | ||
|
|
||
| // Shaders | ||
|
|
||
| const relu = tgpu.fn([d.f32], d.f32)((x) => std.max(0, x)); | ||
| function relu(x: number): number { | ||
| 'use gpu'; | ||
| return std.max(0, x); | ||
| } | ||
|
|
||
| const defaultCompute = tgpu.computeFn({ | ||
| in: { | ||
| gid: d.builtin.globalInvocationId, | ||
| }, | ||
| workgroupSize: [1], | ||
| in: { gid: d.builtin.globalInvocationId }, | ||
| workgroupSize: [WORKGROUP_SIZE], | ||
| })(({ gid }) => { | ||
| const inputSize = ioLayout.$.input.length; | ||
|
|
||
| const i = gid.x; | ||
| const outLen = ioLayout.$.output.length; | ||
| if (i >= outLen) { | ||
| return; | ||
| } | ||
|
|
||
| const inputSize = ioLayout.$.input.length; | ||
| const weightsOffset = i * inputSize; | ||
| let sum = d.f32(); | ||
| let sum = float(); | ||
|
|
||
| for (let j = d.u32(); j < inputSize; j++) { | ||
| for (let j = d.u32(0); j < inputSize; j++) { | ||
| sum = std.fma(ioLayout.$.input[j], weightsBiasesLayout.$.weights[weightsOffset + j], sum); | ||
| } | ||
|
|
||
| const total = sum + weightsBiasesLayout.$.biases[i]; | ||
| ioLayout.$.output[i] = relu(total); | ||
| }); | ||
|
|
||
| const workgroupSize = tgpu.const(d.u32, 128); | ||
| const subgroupCompute = tgpu.computeFn({ | ||
| in: { | ||
| lid: d.builtin.localInvocationId, | ||
| wid: d.builtin.workgroupId, | ||
| sid: d.builtin.subgroupInvocationId, | ||
| ssize: d.builtin.subgroupSize, | ||
| sgid: d.builtin.subgroupId, | ||
| nsg: d.builtin.numSubgroups, | ||
| }, | ||
| workgroupSize: [128], | ||
| })(({ lid, wid, sid, ssize }) => { | ||
| const subgroupId = d.u32(lid.x / ssize); | ||
| const outputsPerWG = d.u32(workgroupSize.$ / ssize); | ||
| const neuronIndex = wid.x * outputsPerWG + subgroupId; | ||
|
|
||
| workgroupSize: [WORKGROUP_SIZE], | ||
| })(({ wid, sid, sgid, nsg }) => { | ||
| const outLen = ioLayout.$.output.length; | ||
| const inputSize = ioLayout.$.input.length; | ||
|
|
||
| const neuronIndex = wid.x * nsg + sgid; | ||
| const valid = neuronIndex < outLen; | ||
|
iwoplaza marked this conversation as resolved.
|
||
|
|
||
| const inputSize = ioLayout.$.input.length; | ||
| // Actual number of active lanes in this subgroup. | ||
| const laneCount = std.subgroupAdd(1); | ||
|
|
||
| let partial = d.f32(); | ||
| let partial = float(0); | ||
|
|
||
| if (valid) { | ||
| const weightsOffset = neuronIndex * inputSize; | ||
| for (let j = sid; j < inputSize; j += ssize) { | ||
|
|
||
| for (let j = sid; j < inputSize; j += laneCount) { | ||
| partial = std.fma( | ||
|
reczkok marked this conversation as resolved.
|
||
| ioLayout.$.input[j], | ||
| weightsBiasesLayout.$.weights[weightsOffset + j], | ||
|
|
@@ -74,7 +94,7 @@ const subgroupCompute = tgpu.computeFn({ | |
|
|
||
| const sum = std.subgroupAdd(partial); | ||
|
|
||
| if (valid && sid === 0) { | ||
| if (valid && std.subgroupElect()) { | ||
| ioLayout.$.output[neuronIndex] = relu(sum + weightsBiasesLayout.$.biases[neuronIndex]); | ||
| } | ||
| }); | ||
|
|
@@ -107,11 +127,11 @@ function createNetwork(layers: [LayerData, LayerData][]): Network { | |
| return { | ||
| weights: weights.buffer, | ||
| biases: biases.buffer, | ||
| state: root.createBuffer(d.arrayOf(d.f32, biases.shape[0])).$usage('storage'), | ||
| state: root.createBuffer(d.arrayOf(float, biases.shape[0])).$usage('storage'), | ||
| }; | ||
| }); | ||
|
|
||
| const input = root.createBuffer(d.arrayOf(d.f32, layers[0][0].shape[0])).$usage('storage'); | ||
| const input = root.createBuffer(d.arrayOf(float, layers[0][0].shape[0])).$usage('storage'); | ||
| const output = buffers[buffers.length - 1].state; | ||
|
|
||
| const ioBindGroups = buffers.map((_, i) => | ||
|
|
@@ -137,7 +157,8 @@ function createNetwork(layers: [LayerData, LayerData][]): Network { | |
| } | ||
| input.write(data); | ||
|
|
||
| const pipeline = useSubgroups && pipelines.subgroup ? pipelines.subgroup : pipelines.default; | ||
| const subgroupPipeline = useSubgroups ? pipelines.subgroup : null; | ||
| const pipeline = subgroupPipeline ?? pipelines.default; | ||
|
|
||
| // Run the network | ||
| for (let i = 0; i < buffers.length; i++) { | ||
|
|
@@ -155,7 +176,10 @@ function createNetwork(layers: [LayerData, LayerData][]): Network { | |
| boundPipeline = boundPipeline.withTimestampWrites(descriptor); | ||
| } | ||
|
|
||
| boundPipeline.dispatchWorkgroups(buffers[i].biases.dataType.elementCount); | ||
| const outputCount = buffers[i].biases.dataType.elementCount; | ||
| boundPipeline.dispatchWorkgroups( | ||
| subgroupPipeline ? outputCount : Math.ceil(outputCount / WORKGROUP_SIZE), | ||
| ); | ||
|
Comment on lines
+179
to
+182
|
||
| } | ||
|
|
||
| if (querySet?.available) { | ||
|
|
@@ -180,7 +204,7 @@ function createNetwork(layers: [LayerData, LayerData][]): Network { | |
| }; | ||
| } | ||
|
|
||
| const network = createNetwork(await downloadLayers(root)); | ||
| const network = createNetwork(await downloadLayers(root, float)); | ||
|
|
||
| // #region Example controls and cleanup | ||
|
|
||
|
|
@@ -386,7 +410,7 @@ export const controls = defineControls({ | |
| 'Test Resolution': import.meta.env.DEV && { | ||
| onButtonClick: () => | ||
| [defaultCompute, subgroupCompute] | ||
| .map((fn) => tgpu.resolve([fn], { enableExtensions: ['subgroups'] })) | ||
| .map((fn) => tgpu.resolve([fn], { enableExtensions: ['subgroups', 'f16'] })) | ||
|
reczkok marked this conversation as resolved.
|
||
| .map((r) => root.device.createShaderModule({ code: r })), | ||
|
reczkok marked this conversation as resolved.
|
||
| }, | ||
| }); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameter name
floatShcemaappears to be a typo; rename tofloatSchemato avoid propagating a misspelled identifier through the API.