Skip to content

Commit 706187e

Browse files
authored
impr: Refine std function and conversion handling (#2126)
1 parent 3ad2398 commit 706187e

28 files changed

Lines changed: 692 additions & 232 deletions
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import { type MapValueToSnippet, snip } from '../../data/snippet.ts';
2+
import { type BaseData, isPtr } from '../../data/wgslTypes.ts';
3+
import { setName } from '../../shared/meta.ts';
4+
import { $gpuCallable } from '../../shared/symbols.ts';
5+
import { tryConvertSnippet } from '../../tgsl/conversion.ts';
6+
import {
7+
type DualFn,
8+
isKnownAtComptime,
9+
NormalState,
10+
type ResolutionCtx,
11+
} from '../../types.ts';
12+
import type { AnyFn } from './fnTypes.ts';
13+
14+
type MapValueToDataType<T> = { [K in keyof T]: BaseData };
15+
16+
interface CallableSchemaOptions<T extends AnyFn> {
17+
readonly name: string;
18+
readonly normalImpl: T;
19+
readonly codegenImpl: (
20+
ctx: ResolutionCtx,
21+
args: MapValueToSnippet<Parameters<T>>,
22+
) => string;
23+
readonly signature: (
24+
...inArgTypes: MapValueToDataType<Parameters<T>>
25+
) => { argTypes: (BaseData | BaseData[])[]; returnType: BaseData };
26+
}
27+
28+
export function callableSchema<T extends AnyFn>(
29+
options: CallableSchemaOptions<T>,
30+
): DualFn<T> {
31+
const impl = ((...args: Parameters<T>) => {
32+
return options.normalImpl(...args);
33+
}) as DualFn<T>;
34+
35+
setName(impl, options.name);
36+
impl.toString = () => options.name;
37+
impl[$gpuCallable] = {
38+
get strictSignature() {
39+
return undefined;
40+
},
41+
call(ctx, args) {
42+
const { argTypes, returnType } = options.signature(
43+
...args.map((s) => {
44+
// Dereference implicit pointers
45+
if (isPtr(s.dataType) && s.dataType.implicit) {
46+
return s.dataType.inner;
47+
}
48+
return s.dataType;
49+
}) as MapValueToDataType<Parameters<T>>,
50+
);
51+
52+
const converted = args.map((s, idx) => {
53+
const argType = argTypes[idx];
54+
if (!argType) {
55+
throw new Error('Function called with invalid arguments');
56+
}
57+
return tryConvertSnippet(
58+
ctx,
59+
s,
60+
argType,
61+
false,
62+
);
63+
}) as MapValueToSnippet<Parameters<T>>;
64+
65+
if (converted.every((s) => isKnownAtComptime(s))) {
66+
ctx.pushMode(new NormalState());
67+
try {
68+
return snip(
69+
options.normalImpl(...converted.map((s) => s.value) as never[]),
70+
returnType,
71+
// Functions give up ownership of their return value
72+
/* origin */ 'constant',
73+
);
74+
} finally {
75+
ctx.popMode('normal');
76+
}
77+
}
78+
79+
return snip(
80+
options.codegenImpl(ctx, converted),
81+
returnType,
82+
// Functions give up ownership of their return value
83+
/* origin */ 'runtime',
84+
);
85+
},
86+
};
87+
88+
return impl;
89+
}

packages/typegpu/src/core/function/dualImpl.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { type MapValueToSnippet, snip } from '../../data/snippet.ts';
22
import { setName } from '../../shared/meta.ts';
33
import { $gpuCallable } from '../../shared/symbols.ts';
44
import { tryConvertSnippet } from '../../tgsl/conversion.ts';
5+
import { concretize } from '../../tgsl/generationHelpers.ts';
56
import {
67
type DualFn,
78
isKnownAtComptime,
@@ -21,10 +22,13 @@ interface DualImplOptions<T extends AnyFn> {
2122
args: MapValueToSnippet<Parameters<T>>,
2223
) => string;
2324
readonly signature:
24-
| { argTypes: BaseData[]; returnType: BaseData }
25+
| {
26+
argTypes: (BaseData | BaseData[])[];
27+
returnType: BaseData;
28+
}
2529
| ((
2630
...inArgTypes: MapValueToDataType<Parameters<T>>
27-
) => { argTypes: BaseData[]; returnType: BaseData });
31+
) => { argTypes: (BaseData | BaseData[])[]; returnType: BaseData });
2832
/**
2933
* Whether the function should skip trying to execute the "normal" implementation if
3034
* all arguments are known at compile time.
@@ -112,7 +116,7 @@ export function dualImpl<T extends AnyFn>(
112116

113117
return snip(
114118
options.codegenImpl(ctx, converted),
115-
returnType,
119+
concretize(returnType),
116120
// Functions give up ownership of their return value
117121
/* origin */ 'runtime',
118122
);

packages/typegpu/src/data/matrix.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { comptime } from '../core/function/comptime.ts';
2+
import { callableSchema } from '../core/function/createCallableSchema.ts';
23
import { dualImpl } from '../core/function/dualImpl.ts';
34
import { stitch } from '../core/resolve/stitch.ts';
45
import { $repr } from '../shared/symbols.ts';
@@ -64,7 +65,7 @@ function createMatSchema<
6465
>(
6566
options: MatSchemaOptions<TType, ColumnType>,
6667
): { type: TType; [$repr]: ValueType } & MatConstructor<ValueType, ColumnType> {
67-
const construct = dualImpl({
68+
const construct = callableSchema({
6869
name: options.type,
6970
normalImpl: (...args: (number | ColumnType)[]): ValueType => {
7071
const elements: number[] = [];
@@ -94,7 +95,6 @@ function createMatSchema<
9495

9596
return new options.MatImpl(...elements) as ValueType;
9697
},
97-
ignoreImplicitCastWarning: true,
9898
signature: (...args) => ({
9999
argTypes: args.map((arg) => (isVec(arg) ? arg : f32)),
100100
returnType: schema as unknown as BaseData,

packages/typegpu/src/data/numeric.ts

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { stitch } from '../core/resolve/stitch.ts';
2-
import { dualImpl } from '../core/function/dualImpl.ts';
32
import { $internal } from '../shared/symbols.ts';
43
import type {
54
AbstractFloat,
@@ -11,6 +10,7 @@ import type {
1110
U16,
1211
U32,
1312
} from './wgslTypes.ts';
13+
import { callableSchema } from '../core/function/createCallableSchema.ts';
1414

1515
export const abstractInt = {
1616
[$internal]: {},
@@ -28,7 +28,7 @@ export const abstractFloat = {
2828
},
2929
} as AbstractFloat;
3030

31-
const boolCast = dualImpl({
31+
const boolCast = callableSchema({
3232
name: 'bool',
3333
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: bool }),
3434
normalImpl(v?: number | boolean) {
@@ -66,7 +66,7 @@ export const bool: Bool = Object.assign(boolCast, {
6666
type: 'bool',
6767
}) as unknown as Bool;
6868

69-
const u32Cast = dualImpl({
69+
const u32Cast = callableSchema({
7070
name: 'u32',
7171
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: u32 }),
7272
normalImpl(v?: number | boolean) {
@@ -76,6 +76,17 @@ const u32Cast = dualImpl({
7676
if (typeof v === 'boolean') {
7777
return v ? 1 : 0;
7878
}
79+
if (!Number.isInteger(v)) {
80+
const truncated = Math.trunc(v);
81+
if (truncated < 0) {
82+
return 0;
83+
}
84+
if (truncated > 0xffffffff) {
85+
return 0xffffffff;
86+
}
87+
return truncated;
88+
}
89+
// Integer input: treat as bit reinterpretation (i32 -> u32)
7990
return (v & 0xffffffff) >>> 0;
8091
},
8192
codegenImpl: (_ctx, [arg]) =>
@@ -106,7 +117,7 @@ export const u32: U32 = Object.assign(u32Cast, {
106117
type: 'u32',
107118
}) as unknown as U32;
108119

109-
const i32Cast = dualImpl({
120+
const i32Cast = callableSchema({
110121
name: 'i32',
111122
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: i32 }),
112123
normalImpl(v?: number | boolean) {
@@ -149,7 +160,7 @@ export const i32: I32 = Object.assign(i32Cast, {
149160
type: 'i32',
150161
}) as unknown as I32;
151162

152-
const f32Cast = dualImpl({
163+
const f32Cast = callableSchema({
153164
name: 'f32',
154165
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: f32 }),
155166
normalImpl(v?: number | boolean) {
@@ -275,7 +286,7 @@ function roundToF16(x: number): number {
275286
return fromHalfBits(toHalfBits(x));
276287
}
277288

278-
const f16Cast = dualImpl({
289+
const f16Cast = callableSchema({
279290
name: 'f16',
280291
signature: (arg) => ({ argTypes: arg ? [arg] : [], returnType: f16 }),
281292
normalImpl(v?: number | boolean) {

packages/typegpu/src/data/vector.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { dualImpl } from '../core/function/dualImpl.ts';
1+
import { callableSchema } from '../core/function/createCallableSchema.ts';
22
import { stitch } from '../core/resolve/stitch.ts';
33
import { $internal, $repr } from '../shared/symbols.ts';
44
import { bool, f16, f32, i32, u32 } from './numeric.ts';
@@ -307,14 +307,13 @@ function makeVecSchema<TValue, S extends number | boolean>(
307307
);
308308
};
309309

310-
const construct = dualImpl({
310+
const construct = callableSchema({
311311
name: type,
312312
signature: (...args) => ({
313313
argTypes: args.map((arg) => isVec(arg) ? arg : primitive),
314314
returnType: schema as unknown as BaseData,
315315
}),
316316
normalImpl: cpuConstruct,
317-
ignoreImplicitCastWarning: true,
318317
codegenImpl: (_ctx, args) => {
319318
if (
320319
args.length === 1 && args[0]?.dataType === schema as unknown as BaseData

packages/typegpu/src/errors.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,20 @@ export class WgslTypeError extends Error {
210210
Object.setPrototypeOf(this, WgslTypeError.prototype);
211211
}
212212
}
213+
214+
export class SignatureNotSupportedError extends Error {
215+
constructor(actual: BaseData[], candidates: BaseData[]) {
216+
super(
217+
`Unsupported data types: ${
218+
actual.map((a) => a.type).join(', ')
219+
}. Supported types are: ${
220+
candidates
221+
.map((r) => r.type)
222+
.join(', ')
223+
}.`,
224+
);
225+
226+
// Set the prototype explicitly.
227+
Object.setPrototypeOf(this, SignatureNotSupportedError.prototype);
228+
}
229+
}

packages/typegpu/src/std/array.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ export const arrayLength = dualImpl({
2323
isRef(a) ? a.$.length : a.length,
2424
codegenImpl(_ctx, [a]) {
2525
const length = sizeOfPointedToArray(a.dataType);
26-
return length > 0 ? String(length) : stitch`arrayLength(${a})`;
26+
return length > 0 ? `${length}u` : stitch`arrayLength(${a})`;
2727
},
2828
});

packages/typegpu/src/std/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ export {
150150

151151
export {
152152
textureDimensions,
153+
textureGather,
153154
textureLoad,
154155
textureSample,
155156
textureSampleBaseClampToEdge,

0 commit comments

Comments
 (0)