Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions packages/typegpu/src/core/root/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,30 @@ export class TgpuGuardedComputePipelineImpl<
);
}

withPerformanceCallback(
callback: (start: bigint, end: bigint) => void | Promise<void>,
): TgpuGuardedComputePipeline<TArgs> {
return new TgpuGuardedComputePipelineImpl(
this.#root,
this.#pipeline.withPerformanceCallback(callback),
this.#sizeUniform,
this.#workgroupSize,
);
}

withTimestampWrites(options: {
querySet: TgpuQuerySet<'timestamp'> | GPUQuerySet;
beginningOfPassWriteIndex?: number;
endOfPassWriteIndex?: number;
}): TgpuGuardedComputePipeline<TArgs> {
return new TgpuGuardedComputePipelineImpl(
this.#root,
this.#pipeline.withTimestampWrites(options),
this.#sizeUniform,
this.#workgroupSize,
);
}

dispatchThreads(...threads: TArgs): void {
const sanitizedSize = toVec3(threads);
const workgroupCount = ceil(vec3f(sanitizedSize).div(vec3f(this.#workgroupSize)));
Expand Down
40 changes: 29 additions & 11 deletions packages/typegpu/src/core/root/rootTypes.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { AnyComputeBuiltin, AnyFragmentInputBuiltin, OmitBuiltins } from '../../builtin.ts';
import type { TgpuQuerySet } from '../../core/querySet/querySet.ts';
import type { AnyData, Disarray, UndecorateRecord } from '../../data/dataTypes.ts';
import type { InstanceToSchema } from '../../data/instanceToSchema.ts';
import type { WgslComparisonSamplerProps, WgslSamplerProps } from '../../data/sampler.ts';
import type {
AnyWgslData,
Expand All @@ -12,6 +13,7 @@ import type {
Void,
WgslArray,
} from '../../data/wgslTypes.ts';
import type { TgpuNamable } from '../../shared/meta.ts';
import type {
ExtractInvalidSchemaError,
InferGPURecord,
Expand All @@ -33,7 +35,13 @@ import type { ShaderGenerator } from '../../tgsl/shaderGenerator.ts';
import type { Unwrapper } from '../../unwrapper.ts';
import type { TgpuBuffer, VertexFlag } from '../buffer/buffer.ts';
import type { TgpuMutable, TgpuReadonly, TgpuUniform } from '../buffer/bufferShorthand.ts';
import type { TgpuFixedComparisonSampler, TgpuFixedSampler } from '../sampler/sampler.ts';
import type {
AnyAutoCustoms,
AutoFragmentIn,
AutoFragmentOut,
AutoVertexIn,
AutoVertexOut,
} from '../function/autoIO.ts';
import type { IORecord } from '../function/fnTypes.ts';
import type {
FragmentInConstrained,
Expand All @@ -44,6 +52,7 @@ import type {
import type { TgpuVertexFn } from '../function/tgpuVertexFn.ts';
import type { TgpuComputePipeline } from '../pipeline/computePipeline.ts';
import type { FragmentOutToTargets, TgpuRenderPipeline } from '../pipeline/renderPipeline.ts';
import type { TgpuFixedComparisonSampler, TgpuFixedSampler } from '../sampler/sampler.ts';
import type { Eventual, TgpuAccessor, TgpuMutableAccessor, TgpuSlot } from '../slot/slotTypes.ts';
import type { TgpuTexture } from '../texture/texture.ts';
import type {
Expand All @@ -52,15 +61,6 @@ import type {
} from '../vertexLayout/vertexAttribute.ts';
import type { TgpuVertexLayout } from '../vertexLayout/vertexLayout.ts';
import type { TgpuComputeFn } from './../function/tgpuComputeFn.ts';
import type { TgpuNamable } from '../../shared/meta.ts';
import type {
AnyAutoCustoms,
AutoFragmentIn,
AutoFragmentOut,
AutoVertexIn,
AutoVertexOut,
} from '../function/autoIO.ts';
import type { InstanceToSchema } from '../../data/instanceToSchema.ts';

// ----------
// Public API
Expand All @@ -80,6 +80,24 @@ export interface TgpuGuardedComputePipeline<TArgs extends number[] = number[]> e
*/
with(encoder: GPUCommandEncoder): TgpuGuardedComputePipeline<TArgs>;

/**
* Returns a pipeline wrapper with the given performance callback attached.
* Analogous to `TgpuComputePipeline.withPerformanceCallback(callback)`.
*/
withPerformanceCallback(
callback: (start: bigint, end: bigint) => void | Promise<void>,
): TgpuGuardedComputePipeline<TArgs>;

/**
* Returns a pipeline wrapper with the given timestamp writes configuration.
* Analogous to `TgpuComputePipeline.withTimestampWrites(options)`.
*/
withTimestampWrites(options: {
querySet: TgpuQuerySet<'timestamp'> | GPUQuerySet;
beginningOfPassWriteIndex?: number;
endOfPassWriteIndex?: number;
}): TgpuGuardedComputePipeline<TArgs>;

/**
* Dispatches the pipeline.
* Unlike `TgpuComputePipeline.dispatchWorkgroups()`, this method takes in the
Expand Down Expand Up @@ -378,7 +396,7 @@ export interface WithBinding extends Withable<WithBinding> {

/**
* Creates a compute pipeline that executes the given callback in an exact number of threads.
* This is different from `withCompute(...).createPipeline()` in that it does a bounds check on the
* This is different from `createComputePipeline()` in that it does a bounds check on the
* thread id, where as regular pipelines do not and work in units of workgroups.
Comment thread
cieplypolar marked this conversation as resolved.
*
* @param callback A function converted to WGSL and executed on the GPU.
Expand Down
32 changes: 31 additions & 1 deletion packages/typegpu/tests/guardedComputePipeline.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, expect } from 'vitest';
import { describe, expect, vi } from 'vitest';
import { it } from 'typegpu-testing-utility';
import { getName } from '../src/shared/meta.ts';
import { bindGroupLayout } from '../src/tgpuBindGroupLayout.ts';
Expand Down Expand Up @@ -31,4 +31,34 @@ describe('TgpuGuardedComputePipeline', () => {
expect(getName(pipeline)).toBe('myPipeline');
expect(getName(pipeline.pipeline)).toBe('myPipeline');
});

it('delegates `withPerformanceCallback` to the underlying pipeline', ({ root }) => {
const callback = vi.fn();
const guarded = root.createGuardedComputePipeline(() => {
'use gpu';
});

const spy = vi.spyOn(guarded.pipeline, 'withPerformanceCallback');
guarded.withPerformanceCallback(callback);

expect(spy).toHaveBeenCalledWith(callback);
});

it('delegates `withTimestampWrites` to the underlying pipeline', ({ root }) => {
const querySet = root.createQuerySet('timestamp', 2);
const guarded = root.createGuardedComputePipeline(() => {
'use gpu';
});

const options = {
querySet,
beginningOfPassWriteIndex: 0,
endOfPassWriteIndex: 1,
};

const spy = vi.spyOn(guarded.pipeline, 'withTimestampWrites');
guarded.withTimestampWrites(options);

expect(spy).toHaveBeenCalledWith(options);
});
});
Loading