Skip to content

Commit fb1c36d

Browse files
authored
impr: intermediate representation of array expression (#2021)
1 parent 90aa897 commit fb1c36d

6 files changed

Lines changed: 318 additions & 37 deletions

File tree

packages/typegpu/src/tgsl/generationHelpers.ts

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
11
import { UnknownData } from '../data/dataTypes.ts';
22
import { abstractFloat, abstractInt, bool, f32, i32 } from '../data/numeric.ts';
33
import { isRef } from '../data/ref.ts';
4-
import { isSnippet, snip, type Snippet } from '../data/snippet.ts';
4+
import {
5+
isEphemeralSnippet,
6+
isSnippet,
7+
type ResolvedSnippet,
8+
snip,
9+
type Snippet,
10+
} from '../data/snippet.ts';
511
import {
612
type AnyWgslData,
713
type BaseData,
814
type F32,
915
type I32,
1016
isMatInstance,
17+
isNaturallyEphemeral,
1118
isVecInstance,
19+
type WgslArray,
1220
WORKAROUND_getSchema,
1321
} from '../data/wgslTypes.ts';
1422
import {
1523
type FunctionScopeLayer,
1624
getOwnSnippet,
1725
type ResolutionCtx,
26+
type SelfResolvable,
1827
} from '../types.ts';
1928
import type { ShelllessRepository } from './shellless.ts';
29+
import { stitch } from '../../src/core/resolve/stitch.ts';
30+
import { WgslTypeError } from '../../src/errors.ts';
31+
import { $internal, $resolve } from '../../src/shared/symbols.ts';
2032

2133
export function numericLiteralToSnippet(value: number): Snippet {
2234
if (value >= 2 ** 63 || value < -(2 ** 63)) {
@@ -130,3 +142,50 @@ export function coerceToSnippet(value: unknown): Snippet {
130142

131143
return snip(value, UnknownData, /* origin */ 'constant');
132144
}
145+
146+
/**
147+
* Intermediate representation for WGSL array expressions.
148+
* Defers resolution. Stores array elements as snippets so the
149+
* generator can access them when needed.
150+
*/
151+
export class ArrayExpression implements SelfResolvable {
152+
readonly [$internal] = true;
153+
154+
constructor(
155+
public readonly type: WgslArray<AnyWgslData>,
156+
public readonly elements: Snippet[],
157+
) {
158+
}
159+
160+
toString(): string {
161+
return 'ArrayExpression';
162+
}
163+
164+
[$resolve](ctx: ResolutionCtx): ResolvedSnippet {
165+
for (const elem of this.elements) {
166+
// We check if there are no references among the elements
167+
if (
168+
(elem.origin === 'argument' &&
169+
!isNaturallyEphemeral(elem.dataType)) ||
170+
!isEphemeralSnippet(elem)
171+
) {
172+
const snippetStr = ctx.resolve(elem.value, elem.dataType).value;
173+
const snippetType =
174+
ctx.resolve(concretize(elem.dataType as BaseData)).value;
175+
throw new WgslTypeError(
176+
`'${snippetStr}' reference cannot be used in an array constructor.\n-----\nTry '${snippetType}(${snippetStr})' or 'arrayOf(${snippetType}, count)([...])' to copy the value instead.\n-----`,
177+
);
178+
}
179+
}
180+
181+
const arrayType = `array<${
182+
ctx.resolve(this.type.elementType).value
183+
}, ${this.elements.length}>`;
184+
185+
return snip(
186+
stitch`${arrayType}(${this.elements})`,
187+
this.type,
188+
/* origin */ 'runtime',
189+
);
190+
}
191+
}

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import {
3030
tryConvertSnippet,
3131
} from './conversion.ts';
3232
import {
33+
ArrayExpression,
3334
concretize,
3435
type GenerationCtx,
3536
numericLiteralToSnippet,
@@ -503,21 +504,21 @@ ${this.ctx.pre}}`;
503504
const [_, calleeNode, argNodes] = expression;
504505
const callee = this.expression(calleeNode);
505506

506-
if (wgsl.isWgslStruct(callee.value) || wgsl.isWgslArray(callee.value)) {
507-
// Struct/array schema call.
507+
if (wgsl.isWgslStruct(callee.value)) {
508+
// Struct schema call.
508509
if (argNodes.length > 1) {
509510
throw new WgslTypeError(
510-
'Array and struct schemas should always be called with at most 1 argument',
511+
'Struct schemas should always be called with at most 1 argument',
511512
);
512513
}
513514

514515
// No arguments `Struct()`, resolve struct name and return.
515516
if (!argNodes[0]) {
516-
// the schema becomes the data type
517+
// The schema becomes the data type.
517518
return snip(
518519
`${this.ctx.resolve(callee.value).value}()`,
519520
callee.value,
520-
// A new struct, so not a reference
521+
// A new struct, so not a reference.
521522
/* origin */ 'runtime',
522523
);
523524
}
@@ -532,7 +533,54 @@ ${this.ctx.pre}}`;
532533
return snip(
533534
this.ctx.resolve(arg.value, callee.value).value,
534535
callee.value,
535-
// A new struct, so not a reference
536+
// A new struct, so not a reference.
537+
/* origin */ 'runtime',
538+
);
539+
}
540+
541+
if (wgsl.isWgslArray(callee.value)) {
542+
// Array schema call.
543+
if (argNodes.length > 1) {
544+
throw new WgslTypeError(
545+
'Array schemas should always be called with at most 1 argument',
546+
);
547+
}
548+
549+
// No arguments `array<...>()`, resolve array type and return.
550+
if (!argNodes[0]) {
551+
// The schema becomes the data type.
552+
return snip(
553+
`${this.ctx.resolve(callee.value).value}()`,
554+
callee.value,
555+
// A new array, so not a reference.
556+
/* origin */ 'runtime',
557+
);
558+
}
559+
560+
const arg = this.typedExpression(
561+
argNodes[0],
562+
callee.value,
563+
);
564+
565+
// `d.arrayOf(...)([...])`.
566+
// We don't resolve the ArrayExpression object itself to
567+
// avoid reference checks (we're copying so it's fine)
568+
if (arg.value instanceof ArrayExpression) {
569+
return snip(
570+
stitch`${
571+
this.ctx.resolve(callee.value).value
572+
}(${arg.value.elements})`,
573+
arg.dataType,
574+
/* origin */ 'runtime',
575+
);
576+
}
577+
578+
// `d.arrayOf(...)(otherArr)`.
579+
// We just let the argument resolve everything.
580+
return snip(
581+
this.ctx.resolve(arg.value, callee.value).value,
582+
callee.value,
583+
// A new array, so not a reference.
536584
/* origin */ 'runtime',
537585
);
538586
}
@@ -729,25 +777,9 @@ ${this.ctx.pre}}`;
729777
}
730778
} else {
731779
// The array is not typed, so we try to guess the types.
732-
const valuesSnippets = valueNodes.map((value) => {
733-
const snippet = this.expression(value as tinyest.Expression);
734-
// We check if there are no references among the elements
735-
if (
736-
(snippet.origin === 'argument' &&
737-
!wgsl.isNaturallyEphemeral(snippet.dataType)) ||
738-
!isEphemeralSnippet(snippet)
739-
) {
740-
const snippetStr =
741-
this.ctx.resolve(snippet.value, snippet.dataType).value;
742-
const snippetType =
743-
this.ctx.resolve(concretize(snippet.dataType as wgsl.BaseData))
744-
.value;
745-
throw new WgslTypeError(
746-
`'${snippetStr}' reference cannot be used in an array constructor.\n-----\nTry '${snippetType}(${snippetStr})' or 'arrayOf(${snippetType}, count)([...])' to copy the value instead.\n-----`,
747-
);
748-
}
749-
return snippet;
750-
});
780+
const valuesSnippets = valueNodes.map((value) =>
781+
this.expression(value as tinyest.Expression)
782+
);
751783

752784
if (valuesSnippets.length === 0) {
753785
throw new WgslTypeError(
@@ -766,13 +798,17 @@ ${this.ctx.pre}}`;
766798
elemType = concretize(values[0]?.dataType as wgsl.AnyWgslData);
767799
}
768800

769-
const arrayType = `array<${
770-
this.ctx.resolve(elemType).value
771-
}, ${values.length}>`;
801+
const arrayType = arrayOf(
802+
elemType as wgsl.AnyWgslData,
803+
values.length,
804+
);
772805

773806
return snip(
774-
stitch`${arrayType}(${values})`,
775-
arrayOf(elemType as wgsl.AnyWgslData, values.length),
807+
new ArrayExpression(
808+
arrayType,
809+
values,
810+
),
811+
arrayType,
776812
/* origin */ 'runtime',
777813
);
778814
}

packages/typegpu/tests/array.test.ts

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,24 @@ describe('array', () => {
139139
);
140140
});
141141

142+
it('throws when invalid number of arguments during code generation', () => {
143+
const ArraySchema = d.arrayOf(d.u32, 2);
144+
145+
const f = () => {
146+
'use gpu';
147+
// @ts-expect-error
148+
const arr = ArraySchema([1, 1], [6, 7]);
149+
return;
150+
};
151+
152+
expect(() => tgpu.resolve([f])).toThrowErrorMatchingInlineSnapshot(`
153+
[Error: Resolution of the following tree failed:
154+
- <root>
155+
- fn*:f
156+
- fn*:f(): Array schemas should always be called with at most 1 argument]
157+
`);
158+
});
159+
142160
it('can be called to create a default value', () => {
143161
const ArraySchema = d.arrayOf(d.vec3f, 2);
144162

@@ -187,16 +205,28 @@ describe('array', () => {
187205
it('generates correct code when array clone is used', () => {
188206
const ArraySchema = d.arrayOf(d.u32, 1);
189207

190-
const testFn = tgpu.fn([])(() => {
208+
const f = (arr: d.Infer<typeof ArraySchema>) => {
209+
'use gpu';
210+
const clone = ArraySchema(arr);
211+
};
212+
213+
const testFn = () => {
214+
'use gpu';
191215
const myArray = ArraySchema([d.u32(10)]);
192216
const myClone = ArraySchema(myArray);
217+
f(myArray);
193218
return;
194-
});
219+
};
195220

196221
expect(tgpu.resolve([testFn])).toMatchInlineSnapshot(`
197-
"fn testFn() {
222+
"fn f(arr: array<u32, 1>) {
223+
var clone = arr;
224+
}
225+
226+
fn testFn() {
198227
var myArray = array<u32, 1>(10u);
199228
var myClone = myArray;
229+
f(myArray);
200230
return;
201231
}"
202232
`);
@@ -220,6 +250,65 @@ describe('array', () => {
220250
`);
221251
});
222252

253+
it('generates correct code when array expression with ephemeral element type clone is used', () => {
254+
const f = () => {
255+
'use gpu';
256+
const arr = d.arrayOf(d.f32, 2)([6, 7]);
257+
return;
258+
};
259+
260+
expect(tgpu.resolve([f])).toMatchInlineSnapshot(`
261+
"fn f() {
262+
var arr = array<f32, 2>(6f, 7f);
263+
return;
264+
}"
265+
`);
266+
});
267+
268+
it('generates correct code when array expression with reference element type clone is used', () => {
269+
const f = (v: d.v4f) => {
270+
'use gpu';
271+
const v2 = d.vec4f(3);
272+
const v3 = v2;
273+
const arr = d.arrayOf(d.vec4f, 3)([v, v2, v3]);
274+
};
275+
276+
const main = tgpu.fn([])(() => {
277+
const v1 = d.vec4f(7);
278+
f(v1);
279+
return;
280+
});
281+
282+
expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
283+
"fn f(v: vec4f) {
284+
var v2 = vec4f(3);
285+
let v3 = (&v2);
286+
var arr = array<vec4f, 3>(v, v2, (*v3));
287+
}
288+
289+
fn main() {
290+
var v1 = vec4f(7);
291+
f(v1);
292+
return;
293+
}"
294+
`);
295+
});
296+
297+
it('generates correct code when array expression with mixed element types clone is used', () => {
298+
const f = () => {
299+
'use gpu';
300+
const arr = d.arrayOf(d.f32, 3)([5, 6.7, 8.0]);
301+
return;
302+
};
303+
304+
expect(tgpu.resolve([f])).toMatchInlineSnapshot(`
305+
"fn f() {
306+
var arr = array<f32, 3>(5f, 6.7f, 8f);
307+
return;
308+
}"
309+
`);
310+
});
311+
223312
it('can be immediately-invoked in TGSL', () => {
224313
const foo = tgpu.fn([])(() => {
225314
const result = d.arrayOf(d.f32, 4)();
@@ -334,7 +423,8 @@ describe('array', () => {
334423
expect(() => tgpu.resolve([foo])).toThrowErrorMatchingInlineSnapshot(`
335424
[Error: Resolution of the following tree failed:
336425
- <root>
337-
- fn:foo: 'myVec' reference cannot be used in an array constructor.
426+
- fn:foo
427+
- ArrayExpression: 'myVec' reference cannot be used in an array constructor.
338428
-----
339429
Try 'vec2f(myVec)' or 'arrayOf(vec2f, count)([...])' to copy the value instead.
340430
-----]
@@ -349,7 +439,8 @@ describe('array', () => {
349439
expect(() => tgpu.resolve([foo])).toThrowErrorMatchingInlineSnapshot(`
350440
[Error: Resolution of the following tree failed:
351441
- <root>
352-
- fn:foo: 'myVec' reference cannot be used in an array constructor.
442+
- fn:foo
443+
- ArrayExpression: 'myVec' reference cannot be used in an array constructor.
353444
-----
354445
Try 'vec2f(myVec)' or 'arrayOf(vec2f, count)([...])' to copy the value instead.
355446
-----]

0 commit comments

Comments
 (0)