Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 7 additions & 27 deletions apps/typegpu-docs/src/examples/algorithms/mnist-inference/data.ts
Original file line number Diff line number Diff line change
@@ -1,40 +1,20 @@
import tgpu, { d, type StorageFlag, type TgpuBuffer } from 'typegpu';

export const ReadonlyFloats = {
storage: d.arrayOf(d.f32),
access: 'readonly',
} as const;

export const MutableFloats = {
storage: d.arrayOf(d.f32),
access: 'mutable',
} as const;

export const ioLayout = tgpu.bindGroupLayout({
input: ReadonlyFloats,
output: MutableFloats,
});

export const weightsBiasesLayout = tgpu.bindGroupLayout({
weights: ReadonlyFloats,
biases: ReadonlyFloats,
});
import { d, type StorageFlag, type TgpuBuffer } from 'typegpu';

export interface LayerData {
shape: readonly [number] | readonly [number, number];
buffer: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
buffer: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
}

export interface Layer {
weights: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
biases: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
state: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
weights: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
biases: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
state: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
}

export interface Network {
layers: Layer[];
input: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
output: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
input: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
output: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;

inference(data: number[]): Promise<number[]>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ function getLayerData(layer: ArrayBuffer): {
};
}

export function downloadLayers(root: TgpuRoot): Promise<[LayerData, LayerData][]> {
export function downloadLayers(
root: TgpuRoot,
floatShcema: d.F32 | d.F16,
): Promise<[LayerData, LayerData][]> {
Comment on lines +40 to +43
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter name floatShcema appears to be a typo; rename to floatSchema to avoid propagating a misspelled identifier through the API.

Copilot uses AI. Check for mistakes.
Comment on lines +40 to +43
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in parameter name floatShcema (should be floatSchema). Keeping the misspelling makes the API harder to read/search and increases the chance of propagating the typo to call sites.

Copilot uses AI. Check for mistakes.
const downloadLayer = async (fileName: string): Promise<LayerData> => {
const buffer = await fetch(`/TypeGPU/assets/mnist-weights/${fileName}`).then((res) =>
res.arrayBuffer(),
Expand All @@ -46,7 +49,7 @@ export function downloadLayers(root: TgpuRoot): Promise<[LayerData, LayerData][]
const { shape, data } = getLayerData(buffer);

const layerBuffer = root
.createBuffer(d.arrayOf(d.f32, data.length), [...data])
.createBuffer(d.arrayOf(floatShcema, data.length), [...data])
.$usage('storage');

return {
Expand Down
88 changes: 56 additions & 32 deletions apps/typegpu-docs/src/examples/algorithms/mnist-inference/index.ts
Original file line number Diff line number Diff line change
@@ -1,69 +1,89 @@
import tgpu, { d, std } from 'typegpu';
import { ioLayout, type LayerData, type Network, weightsBiasesLayout } from './data.ts';
import type { LayerData, Network } from './data.ts';
import { downloadLayers } from './helpers.ts';
import { defineControls } from '../../common/defineControls.ts';

const SIZE = 28;
const WORKGROUP_SIZE = 64;

const root = await tgpu.init({
device: {
optionalFeatures: ['timestamp-query', 'subgroups'],
},
device: { optionalFeatures: ['timestamp-query', 'subgroups', 'shader-f16'] },
});
const hasTimestampQuery = root.enabledFeatures.has('timestamp-query');
const hasSubgroups = root.enabledFeatures.has('subgroups');
const hasShaderF16 = root.enabledFeatures.has('shader-f16');
let useSubgroups = hasSubgroups;

const float = hasShaderF16 ? d.f16 : d.f32;

const ioLayout = tgpu.bindGroupLayout({
input: { storage: d.arrayOf(float) },
output: {
storage: d.arrayOf(float),
access: 'mutable',
},
});

const weightsBiasesLayout = tgpu.bindGroupLayout({
weights: { storage: d.arrayOf(float) },
biases: { storage: d.arrayOf(float) },
});

const canvasData = Array.from({ length: SIZE ** 2 }, () => 0);

// Shaders

const relu = tgpu.fn([d.f32], d.f32)((x) => std.max(0, x));
function relu(x: number): number {
'use gpu';
return std.max(0, x);
}

const defaultCompute = tgpu.computeFn({
in: {
gid: d.builtin.globalInvocationId,
},
workgroupSize: [1],
in: { gid: d.builtin.globalInvocationId },
workgroupSize: [WORKGROUP_SIZE],
})(({ gid }) => {
const inputSize = ioLayout.$.input.length;

const i = gid.x;
const outLen = ioLayout.$.output.length;
if (i >= outLen) {
return;
}

const inputSize = ioLayout.$.input.length;
const weightsOffset = i * inputSize;
let sum = d.f32();
let sum = float();

for (let j = d.u32(); j < inputSize; j++) {
for (let j = d.u32(0); j < inputSize; j++) {
sum = std.fma(ioLayout.$.input[j], weightsBiasesLayout.$.weights[weightsOffset + j], sum);
}

const total = sum + weightsBiasesLayout.$.biases[i];
ioLayout.$.output[i] = relu(total);
});

const workgroupSize = tgpu.const(d.u32, 128);
const subgroupCompute = tgpu.computeFn({
in: {
lid: d.builtin.localInvocationId,
wid: d.builtin.workgroupId,
sid: d.builtin.subgroupInvocationId,
ssize: d.builtin.subgroupSize,
sgid: d.builtin.subgroupId,
nsg: d.builtin.numSubgroups,
},
workgroupSize: [128],
})(({ lid, wid, sid, ssize }) => {
const subgroupId = d.u32(lid.x / ssize);
const outputsPerWG = d.u32(workgroupSize.$ / ssize);
const neuronIndex = wid.x * outputsPerWG + subgroupId;

workgroupSize: [WORKGROUP_SIZE],
})(({ wid, sid, sgid, nsg }) => {
const outLen = ioLayout.$.output.length;
const inputSize = ioLayout.$.input.length;

const neuronIndex = wid.x * nsg + sgid;
const valid = neuronIndex < outLen;
Comment thread
iwoplaza marked this conversation as resolved.

const inputSize = ioLayout.$.input.length;
// Actual number of active lanes in this subgroup.
const laneCount = std.subgroupAdd(1);

let partial = d.f32();
let partial = float(0);

if (valid) {
const weightsOffset = neuronIndex * inputSize;
for (let j = sid; j < inputSize; j += ssize) {

for (let j = sid; j < inputSize; j += laneCount) {
partial = std.fma(
Comment thread
reczkok marked this conversation as resolved.
ioLayout.$.input[j],
weightsBiasesLayout.$.weights[weightsOffset + j],
Expand All @@ -74,7 +94,7 @@ const subgroupCompute = tgpu.computeFn({

const sum = std.subgroupAdd(partial);

if (valid && sid === 0) {
if (valid && std.subgroupElect()) {
ioLayout.$.output[neuronIndex] = relu(sum + weightsBiasesLayout.$.biases[neuronIndex]);
}
});
Expand Down Expand Up @@ -107,11 +127,11 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
return {
weights: weights.buffer,
biases: biases.buffer,
state: root.createBuffer(d.arrayOf(d.f32, biases.shape[0])).$usage('storage'),
state: root.createBuffer(d.arrayOf(float, biases.shape[0])).$usage('storage'),
};
});

const input = root.createBuffer(d.arrayOf(d.f32, layers[0][0].shape[0])).$usage('storage');
const input = root.createBuffer(d.arrayOf(float, layers[0][0].shape[0])).$usage('storage');
const output = buffers[buffers.length - 1].state;

const ioBindGroups = buffers.map((_, i) =>
Expand All @@ -137,7 +157,8 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
}
input.write(data);

const pipeline = useSubgroups && pipelines.subgroup ? pipelines.subgroup : pipelines.default;
const subgroupPipeline = useSubgroups ? pipelines.subgroup : null;
const pipeline = subgroupPipeline ?? pipelines.default;

// Run the network
for (let i = 0; i < buffers.length; i++) {
Expand All @@ -155,7 +176,10 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
boundPipeline = boundPipeline.withTimestampWrites(descriptor);
}

boundPipeline.dispatchWorkgroups(buffers[i].biases.dataType.elementCount);
const outputCount = buffers[i].biases.dataType.elementCount;
boundPipeline.dispatchWorkgroups(
subgroupPipeline ? outputCount : Math.ceil(outputCount / WORKGROUP_SIZE),
);
Comment on lines +179 to +182
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispatchWorkgroups uses outputCount when the subgroup pipeline is selected, but subgroupCompute computes num_subgroups outputs per workgroup (neuronIndex = wid.x * nsg + sgid). This over-dispatches workgroups by a factor of nsg (e.g., 2x for 64 threads with 32-wide subgroups), doing unnecessary work for larger layers. Consider either dispatching ceil(outputCount / outputsPerWorkgroup) (if you can determine outputsPerWorkgroup) or adjusting the shader/work mapping so each workgroup corresponds to exactly one output when dispatch count must be outputCount.

Copilot uses AI. Check for mistakes.
}

if (querySet?.available) {
Expand All @@ -180,7 +204,7 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
};
}

const network = createNetwork(await downloadLayers(root));
const network = createNetwork(await downloadLayers(root, float));

// #region Example controls and cleanup

Expand Down Expand Up @@ -386,7 +410,7 @@ export const controls = defineControls({
'Test Resolution': import.meta.env.DEV && {
onButtonClick: () =>
[defaultCompute, subgroupCompute]
.map((fn) => tgpu.resolve([fn], { enableExtensions: ['subgroups'] }))
.map((fn) => tgpu.resolve([fn], { enableExtensions: ['subgroups', 'f16'] }))
Comment thread
reczkok marked this conversation as resolved.
.map((r) => root.device.createShaderModule({ code: r })),
Comment thread
reczkok marked this conversation as resolved.
},
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,27 @@ describe('mnist inference example', () => {

expect(shaderCodes).toMatchInlineSnapshot(`
"enable subgroups;
enable f16;

@group(0) @binding(1) var<storage, read_write> output: array<f32>;

@group(0) @binding(0) var<storage, read> input: array<f32>;

@group(1) @binding(0) var<storage, read> weights: array<f32>;

@group(1) @binding(1) var<storage, read> biases: array<f32>;

@group(0) @binding(1) var<storage, read_write> output: array<f32>;

fn relu(x: f32) -> f32 {
return max(0f, x);
}

@compute @workgroup_size(1) fn defaultCompute(@builtin(global_invocation_id) gid: vec3u) {
let inputSize = arrayLength(&input);
@compute @workgroup_size(64) fn defaultCompute(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
let outLen = arrayLength(&output);
if ((i >= outLen)) {
return;
}
let inputSize = arrayLength(&input);
let weightsOffset = (i * inputSize);
var sum = 0f;
for (var j = 0u; (j < inputSize); j++) {
Expand All @@ -50,8 +55,7 @@ describe('mnist inference example', () => {
}

enable subgroups;

const workgroupSize: u32 = 128u;
enable f16;

@group(0) @binding(1) var<storage, read_write> output: array<f32>;

Expand All @@ -65,22 +69,21 @@ describe('mnist inference example', () => {
return max(0f, x);
}

@compute @workgroup_size(128) fn subgroupCompute(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u, @builtin(subgroup_invocation_id) sid: u32, @builtin(subgroup_size) ssize: u32) {
let subgroupId = u32((f32(lid.x) / f32(ssize)));
let outputsPerWG = u32((f32(workgroupSize) / f32(ssize)));
let neuronIndex = ((wid.x * outputsPerWG) + subgroupId);
@compute @workgroup_size(64) fn subgroupCompute(@builtin(workgroup_id) wid: vec3u, @builtin(subgroup_invocation_id) sid: u32, @builtin(subgroup_id) sgid: u32, @builtin(num_subgroups) nsg: u32) {
let outLen = arrayLength(&output);
let valid = (neuronIndex < outLen);
let inputSize = arrayLength(&input);
let neuronIndex = ((wid.x * nsg) + sgid);
let valid = (neuronIndex < outLen);
let laneCount = subgroupAdd(1);
var partial = 0f;
if (valid) {
let weightsOffset = (neuronIndex * inputSize);
for (var j = sid; (j < inputSize); j += ssize) {
for (var j = sid; (j < inputSize); j += u32(laneCount)) {
partial = fma(input[j], weights[(weightsOffset + j)], partial);
}
}
let sum = subgroupAdd(partial);
if ((valid && (sid == 0u))) {
if ((valid && subgroupElect())) {
output[neuronIndex] = relu((sum + biases[neuronIndex]));
}
}"
Expand Down
Loading