Skip to content

Commit 226e90a

Browse files
authored
feat: for ... of ... loop support (#1976)
1 parent 34b6d60 commit 226e90a

10 files changed

Lines changed: 670 additions & 114 deletions

File tree

apps/typegpu-docs/src/examples/simulation/boids/index.ts

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,7 @@ const simulate = (index: number) => {
177177
let alignmentCount = 0;
178178
let cohesionCount = 0;
179179

180-
for (let i = d.u32(0); i < layout.$.currentTrianglePos.length; i++) {
181-
if (i === index) {
182-
continue;
183-
}
184-
const other = layout.$.currentTrianglePos[i];
180+
for (const other of layout.$.currentTrianglePos) {
185181
const dist = std.distance(instanceInfo.position, other.position);
186182
if (dist < params.$.separationDistance) {
187183
separation = std.add(

apps/typegpu-docs/src/examples/simulation/fluid-double-buffering/index.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ const time = root.createUniform(d.f32);
107107

108108
const isInsideObstacle = (x: number, y: number): boolean => {
109109
'use gpu';
110-
for (let obsIdx = 0; obsIdx < MAX_OBSTACLES; obsIdx++) {
111-
const obs = obstacles.$[obsIdx];
112-
110+
for (const obs of obstacles.$) {
113111
if (obs.enabled === 0) {
114112
continue;
115113
}
@@ -162,8 +160,7 @@ const computeVelocity = (x: number, y: number): d.v2f => {
162160
];
163161
let dirChoiceCount = 1;
164162

165-
for (let i = 0; i < 4; i++) {
166-
const offset = neighborOffsets[i];
163+
for (const offset of neighborOffsets) {
167164
const neighborDensity = getCell(x + offset.x, y + offset.y);
168165
const cost = neighborDensity.z + d.f32(offset.y) * gravityCost;
169166

packages/tinyest-for-wgsl/src/parsers.ts

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,7 @@ const Transpilers: Partial<
244244
}
245245

246246
if (node.kind === 'const') {
247-
if (init === undefined) {
248-
throw new Error(
249-
'Did not provide initial value in `const` declaration.',
250-
);
251-
}
252-
return [NODE.const, id, init];
247+
return init !== undefined ? [NODE.const, id, init] : [NODE.const, id];
253248
}
254249

255250
return init !== undefined ? [NODE.let, id, init] : [NODE.let, id];
@@ -322,6 +317,13 @@ const Transpilers: Partial<
322317
return [NODE.while, condition, body];
323318
},
324319

320+
ForOfStatement(ctx, node) {
321+
const loopVar = transpile(ctx, node.left) as tinyest.Const | tinyest.Let;
322+
const iterable = transpile(ctx, node.right) as tinyest.Expression;
323+
const body = transpile(ctx, node.body) as tinyest.Statement;
324+
return [NODE.forOf, loopVar, iterable, body];
325+
},
326+
325327
ContinueStatement() {
326328
return [NODE.continue];
327329
},

packages/tinyest/src/nodes.ts

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export const NodeTypeCatalog = {
2323
while: 15,
2424
continue: 16,
2525
break: 17,
26+
forOf: 18,
2627

2728
// rare
2829
arrayExpr: 100,
@@ -73,11 +74,13 @@ export type Let =
7374
/**
7475
* Represents a const statement
7576
*/
76-
export type Const = readonly [
77-
type: NodeTypeCatalog['const'],
78-
identifier: string,
79-
value: Expression,
80-
];
77+
export type Const =
78+
| readonly [type: NodeTypeCatalog['const'], identifier: string]
79+
| readonly [
80+
type: NodeTypeCatalog['const'],
81+
identifier: string,
82+
value: Expression,
83+
];
8184

8285
export type For = readonly [
8386
type: NodeTypeCatalog['for'],
@@ -97,6 +100,13 @@ export type Continue = readonly [type: NodeTypeCatalog['continue']];
97100

98101
export type Break = readonly [type: NodeTypeCatalog['break']];
99102

103+
export type ForOf = readonly [
104+
type: NodeTypeCatalog['forOf'],
105+
left: Const | Let,
106+
right: Expression,
107+
body: Statement,
108+
];
109+
100110
/**
101111
* A union type of all statements
102112
*/
@@ -110,7 +120,8 @@ export type Statement =
110120
| For
111121
| While
112122
| Continue
113-
| Break;
123+
| Break
124+
| ForOf;
114125

115126
//
116127
// Expression

packages/typegpu/src/data/compiledIO.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ export function buildWriter(
196196
}
197197

198198
if (wgsl.isVec(node)) {
199+
if (wgsl.isVecBool(node)) {
200+
throw new Error('Compiled writers do not support boolean vectors');
201+
}
202+
199203
const primitive = typeToPrimitive[node.type];
200204
let code = '';
201205
const writeFunc = primitiveToWriteFunction[primitive];

packages/typegpu/src/data/wgslTypes.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1717,17 +1717,29 @@ export function isVec(
17171717
| Vec2h
17181718
| Vec2i
17191719
| Vec2u
1720+
| Vec2b
17201721
| Vec3f
17211722
| Vec3h
17221723
| Vec3i
17231724
| Vec3u
1725+
| Vec3b
17241726
| Vec4f
17251727
| Vec4h
17261728
| Vec4i
1727-
| Vec4u {
1729+
| Vec4u
1730+
| Vec4b {
17281731
return isVec2(value) || isVec3(value) || isVec4(value);
17291732
}
17301733

1734+
export function isVecBool(
1735+
value: unknown,
1736+
): value is
1737+
| Vec2b
1738+
| Vec3b
1739+
| Vec4b {
1740+
return isVec(value) && value.type.includes('b');
1741+
}
1742+
17311743
export function isMatInstance(value: unknown): value is AnyMatInstance {
17321744
const v = value as AnyMatInstance | undefined;
17331745
return isMarkedInternal(v) &&

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import type { ShaderGenerator } from './shaderGenerator.ts';
4141
import { createPtrFromOrigin, implicitFrom, ptrFn } from '../data/ptr.ts';
4242
import { RefOperator } from '../data/ref.ts';
4343
import { constant } from '../core/constant/tgpuConstant.ts';
44+
import { arrayLength } from '../std/array.ts';
4445
import { AutoStruct } from '../data/autoStruct.ts';
4546
import { mathToStd } from './math.ts';
4647

@@ -1137,6 +1138,131 @@ ${this.ctx.pre}else ${alternate}`;
11371138
return `${this.ctx.pre}while (${conditionStr}) ${bodyStr}`;
11381139
}
11391140

1141+
if (statement[0] === NODE.forOf) {
1142+
const [_, loopVar, iterable, body] = statement;
1143+
const iterableSnippet = this.expression(iterable);
1144+
1145+
if (isEphemeralSnippet(iterableSnippet)) {
1146+
throw new Error(
1147+
'`for ... of ...` loops only support iterables stored in variables',
1148+
);
1149+
}
1150+
1151+
// Our index name will be some element from infinite sequence (i, ii, iii, ...).
1152+
// If user defines `i` and `ii` before `for ... of ...` loop, then our index name will be `iii`.
1153+
// If user defines `i` inside `for ... of ...` then it will be scoped to a new block,
1154+
// so we can safely use `i`.
1155+
let index = 'i'; // it will be valid name, no need to call this.ctx.makeNameValid
1156+
while (this.ctx.getById(index) !== null) {
1157+
index += 'i';
1158+
}
1159+
1160+
const elementSnippet = accessIndex(
1161+
iterableSnippet,
1162+
snip(index, u32, 'runtime'),
1163+
);
1164+
if (!elementSnippet) {
1165+
throw new WgslTypeError(
1166+
'`for ... of ...` loops only support array or vector iterables',
1167+
);
1168+
}
1169+
1170+
const iterableDataType = iterableSnippet.dataType;
1171+
let elementCountSnippet: Snippet;
1172+
let elementType = elementSnippet.dataType;
1173+
1174+
if (elementType === UnknownData) {
1175+
throw new WgslTypeError(
1176+
stitch`The elements in iterable ${iterableSnippet} are of unknown type`,
1177+
);
1178+
}
1179+
1180+
if (wgsl.isWgslArray(iterableDataType)) {
1181+
elementCountSnippet = iterableDataType.elementCount > 0
1182+
? snip(
1183+
`${iterableDataType.elementCount}`,
1184+
u32,
1185+
'constant',
1186+
)
1187+
: arrayLength[$gpuCallable].call(this.ctx, [iterableSnippet]);
1188+
} else if (wgsl.isVec(iterableDataType)) {
1189+
elementCountSnippet = snip(
1190+
`${Number(iterableDataType.type.match(/\d/))}`,
1191+
u32,
1192+
'constant',
1193+
);
1194+
} else {
1195+
throw new WgslTypeError(
1196+
'`for ... of ...` loops only support array or vector iterables',
1197+
);
1198+
}
1199+
1200+
if (loopVar[0] !== NODE.const) {
1201+
throw new WgslTypeError(
1202+
'Only `for (const ... of ... )` loops are supported',
1203+
);
1204+
}
1205+
1206+
// If it's ephemeral, it's a value that cannot change. If it's a reference, we take
1207+
// an implicit pointer to it
1208+
let loopVarKind = 'let';
1209+
const loopVarName = this.ctx.makeNameValid(loopVar[1]);
1210+
1211+
if (!isEphemeralSnippet(elementSnippet)) {
1212+
if (elementSnippet.origin === 'constant-tgpu-const-ref') {
1213+
loopVarKind = 'const';
1214+
} else if (elementSnippet.origin === 'runtime-tgpu-const-ref') {
1215+
loopVarKind = 'let';
1216+
} else {
1217+
loopVarKind = 'let';
1218+
if (!wgsl.isPtr(elementType)) {
1219+
const ptrType = createPtrFromOrigin(
1220+
elementSnippet.origin,
1221+
concretize(elementType as wgsl.AnyWgslData) as wgsl.StorableData,
1222+
);
1223+
invariant(
1224+
ptrType !== undefined,
1225+
`Creating pointer type from origin ${elementSnippet.origin}`,
1226+
);
1227+
elementType = ptrType;
1228+
}
1229+
1230+
elementType = implicitFrom(elementType as wgsl.Ptr);
1231+
}
1232+
}
1233+
1234+
const loopVarSnippet = snip(
1235+
loopVarName,
1236+
elementType,
1237+
elementSnippet.origin,
1238+
);
1239+
this.ctx.defineVariable(loopVarName, loopVarSnippet);
1240+
1241+
const forStr = stitch`${this.ctx.pre}for (var ${index} = 0u; ${index} < ${
1242+
tryConvertSnippet(this.ctx, elementCountSnippet, u32, false)
1243+
}; ${index}++) {`;
1244+
1245+
this.ctx.indent();
1246+
1247+
const loopVarDeclStr =
1248+
stitch`${this.ctx.pre}${loopVarKind} ${loopVarName} = ${
1249+
tryConvertSnippet(
1250+
this.ctx,
1251+
elementSnippet,
1252+
elementType,
1253+
false,
1254+
)
1255+
};`;
1256+
1257+
const bodyStr = `${this.ctx.pre}${
1258+
this.block(blockifySingleStatement(body))
1259+
}`;
1260+
1261+
this.ctx.dedent();
1262+
1263+
return stitch`${forStr}\n${loopVarDeclStr}\n${bodyStr}\n${this.ctx.pre}}`;
1264+
}
1265+
11401266
if (statement[0] === NODE.continue) {
11411267
return `${this.ctx.pre}continue;`;
11421268
}

packages/typegpu/tests/examples/individual/boids.test.ts

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,21 @@ describe('boids example', () => {
4646
var cohesion = vec2f();
4747
var alignmentCount = 0;
4848
var cohesionCount = 0;
49-
for (var i = 0u; (i < arrayLength(&currentTrianglePos)); i++) {
50-
if ((i == index)) {
51-
continue;
52-
}
49+
for (var i = 0u; i < arrayLength((&currentTrianglePos)); i++) {
5350
let other = (&currentTrianglePos[i]);
54-
let dist = distance(instanceInfo.position, (*other).position);
55-
if ((dist < paramsBuffer.separationDistance)) {
56-
separation = (separation + (instanceInfo.position - (*other).position));
57-
}
58-
if ((dist < paramsBuffer.alignmentDistance)) {
59-
alignment = (alignment + (*other).velocity);
60-
alignmentCount++;
61-
}
62-
if ((dist < paramsBuffer.cohesionDistance)) {
63-
cohesion = (cohesion + (*other).position);
64-
cohesionCount++;
51+
{
52+
let dist = distance(instanceInfo.position, (*other).position);
53+
if ((dist < paramsBuffer.separationDistance)) {
54+
separation = (separation + (instanceInfo.position - (*other).position));
55+
}
56+
if ((dist < paramsBuffer.alignmentDistance)) {
57+
alignment = (alignment + (*other).velocity);
58+
alignmentCount++;
59+
}
60+
if ((dist < paramsBuffer.cohesionDistance)) {
61+
cohesion = (cohesion + (*other).position);
62+
cohesionCount++;
63+
}
6564
}
6665
}
6766
if ((alignmentCount > 0i)) {

0 commit comments

Comments
 (0)