Skip to content

Commit ec1a771

Browse files
authored
fix: Unify operand types in arithmetic operators (#1603)
* Unify operand types in arithmetic operators * Remove rogue console logs * Tweak concretizeTypes logic in convertToCommonType helper * Make convertToCommonType generic over input tuple type * chefs kiss
1 parent ef1f059 commit ec1a771

6 files changed

Lines changed: 68 additions & 56 deletions

File tree

packages/typegpu/src/std/operators.ts

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import { stitch } from '../core/resolve/stitch.ts';
22
import { createDualImpl } from '../core/function/dualImpl.ts';
3-
import { f16, f32 } from '../data/numeric.ts';
43
import { isSnippetNumeric, snip, type Snippet } from '../data/snippet.ts';
54
import { vecTypeToConstructor } from '../data/vector.ts';
65
import { VectorOps } from '../data/vectorOps.ts';
76
import {
87
type AnyMatInstance,
98
type AnyNumericVecInstance,
9+
type AnyWgslData,
1010
isFloat32VecInstance,
1111
isMatInstance,
1212
isVecInstance,
@@ -17,6 +17,24 @@ import { convertToCommonType } from '../tgsl/generationHelpers.ts';
1717
import { getResolutionCtx } from '../execMode.ts';
1818
import type { ResolutionCtx } from '../types.ts';
1919
import { $internal } from '../shared/symbols.ts';
20+
import { f16, f32 } from '../data/numeric.ts';
21+
22+
function tryUnify<T extends Snippet[]>(
23+
values: T,
24+
restrictTo?: AnyWgslData[],
25+
concretizeTypes = false,
26+
verbose = true,
27+
): T {
28+
const ctx = getResolutionCtx() as ResolutionCtx;
29+
const converted = convertToCommonType({
30+
ctx,
31+
values,
32+
restrictTo,
33+
concretizeTypes,
34+
verbose,
35+
});
36+
return converted ?? values;
37+
}
2038

2139
type NumVec = AnyNumericVecInstance;
2240
type Mat = AnyMatInstance;
@@ -58,7 +76,10 @@ export const add = createDualImpl(
5876
cpuAdd,
5977
// CODEGEN implementation
6078
(lhs, rhs) => {
61-
const resultType = isSnippetNumeric(lhs) ? rhs.dataType : lhs.dataType;
79+
const [convLhs, convRhs] = tryUnify([lhs, rhs]);
80+
const resultType = isSnippetNumeric(convLhs)
81+
? convRhs.dataType
82+
: convLhs.dataType;
6283

6384
if (
6485
(typeof lhs.value === 'number' ||
@@ -72,7 +93,7 @@ export const add = createDualImpl(
7293
return snip(cpuAdd(lhs.value as never, rhs.value as never), resultType);
7394
}
7495

75-
return snip(stitch`(${lhs} + ${rhs})`, resultType);
96+
return snip(stitch`(${convLhs} + ${convRhs})`, resultType);
7697
},
7798
'add',
7899
);
@@ -99,7 +120,10 @@ export const sub = createDualImpl(
99120
cpuSub,
100121
// CODEGEN implementation
101122
(lhs, rhs) => {
102-
const resultType = isSnippetNumeric(lhs) ? rhs.dataType : lhs.dataType;
123+
const [convLhs, convRhs] = tryUnify([lhs, rhs]);
124+
const resultType = isSnippetNumeric(convLhs)
125+
? convRhs.dataType
126+
: convLhs.dataType;
103127

104128
if (
105129
(typeof lhs.value === 'number' ||
@@ -113,7 +137,7 @@ export const sub = createDualImpl(
113137
return snip(cpuSub(lhs.value as never, rhs.value as never), resultType);
114138
}
115139

116-
return snip(stitch`(${lhs} - ${rhs})`, resultType);
140+
return snip(stitch`(${convLhs} - ${convRhs})`, resultType);
117141
},
118142
'sub',
119143
);
@@ -164,20 +188,21 @@ export const mul = createDualImpl(
164188
cpuMul,
165189
// GPU implementation
166190
(lhs, rhs) => {
167-
const returnType = isSnippetNumeric(lhs)
191+
const [convLhs, convRhs] = tryUnify([lhs, rhs]);
192+
const returnType = isSnippetNumeric(convLhs)
168193
// Scalar * Scalar/Vector/Matrix
169-
? rhs.dataType
170-
: isSnippetNumeric(rhs)
194+
? convRhs.dataType
195+
: isSnippetNumeric(convRhs)
171196
// Vector/Matrix * Scalar
172-
? lhs.dataType
173-
: lhs.dataType.type.startsWith('vec')
197+
? convLhs.dataType
198+
: convLhs.dataType.type.startsWith('vec')
174199
// Vector * Vector/Matrix
175-
? lhs.dataType
176-
: rhs.dataType.type.startsWith('vec')
200+
? convLhs.dataType
201+
: convRhs.dataType.type.startsWith('vec')
177202
// Matrix * Vector
178-
? rhs.dataType
203+
? convRhs.dataType
179204
// Matrix * Matrix
180-
: lhs.dataType;
205+
: convLhs.dataType;
181206

182207
if (
183208
(typeof lhs.value === 'number' ||
@@ -191,7 +216,7 @@ export const mul = createDualImpl(
191216
return snip(cpuMul(lhs.value as never, rhs.value as never), returnType);
192217
}
193218

194-
return snip(stitch`(${lhs} * ${rhs})`, returnType);
219+
return snip(stitch`(${convLhs} * ${convRhs})`, returnType);
195220
},
196221
'mul',
197222
);
@@ -223,35 +248,20 @@ export const div = createDualImpl(
223248
cpuDiv,
224249
// CODEGEN implementation
225250
(lhs, rhs) => {
226-
let conv: [Snippet, Snippet] = [lhs, rhs];
227-
228-
if (isSnippetNumeric(lhs) && isSnippetNumeric(rhs)) {
229-
const ctx = getResolutionCtx() as ResolutionCtx;
230-
const converted = convertToCommonType({
231-
ctx,
232-
values: [lhs, rhs],
233-
restrictTo: [f32, f16],
234-
concretizeTypes: true,
235-
}) as
236-
| [Snippet, Snippet]
237-
| undefined;
238-
if (converted) {
239-
conv = converted;
240-
}
241-
}
242-
243-
const lhsVal = conv[0].value;
244-
const rhsVal = conv[1].value;
251+
const [convLhs, convRhs] = tryUnify([lhs, rhs], [f32, f16], true, false);
245252

246253
if (
247-
(typeof lhsVal === 'number' || isVecInstance(lhsVal)) &&
248-
(typeof rhsVal === 'number' || isVecInstance(rhsVal))
254+
(typeof lhs.value === 'number' || isVecInstance(lhs.value)) &&
255+
(typeof rhs.value === 'number' || isVecInstance(rhs.value))
249256
) {
250257
// Precomputing
251-
return snip(cpuDiv(lhsVal as never, rhsVal as never), conv[0].dataType);
258+
return snip(
259+
cpuDiv(lhs.value as never, rhs.value as never),
260+
convLhs.dataType,
261+
);
252262
}
253263

254-
return snip(stitch`(${conv[0]} / ${conv[1]})`, conv[0].dataType);
264+
return snip(stitch`(${convLhs} / ${convRhs})`, convLhs.dataType);
255265
},
256266
'div',
257267
);
@@ -294,8 +304,9 @@ export const mod: ModOverload = createDualImpl(
294304
},
295305
// GPU implementation
296306
(a, b) => {
297-
const type = isSnippetNumeric(a) ? b.dataType : a.dataType;
298-
return snip(stitch`(${a} % ${b})`, type);
307+
const [convA, convB] = tryUnify([a, b]);
308+
const type = isSnippetNumeric(convA) ? convB.dataType : convA.dataType;
309+
return snip(stitch`(${convA} % ${convB})`, type);
299310
},
300311
'mod',
301312
);

packages/typegpu/src/tgsl/generationHelpers.ts

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -506,23 +506,31 @@ function applyActionToSnippet(
506506
}
507507
}
508508

509-
export type ConvertToCommonTypeOptions = {
509+
export type ConvertToCommonTypeOptions<T extends Snippet[]> = {
510510
ctx: ResolutionCtx;
511-
values: Snippet[];
511+
values: T;
512512
restrictTo?: AnyData[] | undefined;
513513
concretizeTypes?: boolean | undefined;
514514
verbose?: boolean | undefined;
515515
};
516516

517-
export function convertToCommonType({
517+
export function convertToCommonType<T extends Snippet[]>({
518518
ctx,
519519
values,
520520
restrictTo,
521521
concretizeTypes = false,
522522
verbose = true,
523-
}: ConvertToCommonTypeOptions): Snippet[] | undefined {
523+
}: ConvertToCommonTypeOptions<T>): T | undefined {
524+
const needsConcretization = concretizeTypes &&
525+
// If we have any concrete type among the values, we don't need to concretize
526+
!values.some((value) =>
527+
concretize(value.dataType as AnyWgslData) === value.dataType
528+
);
529+
524530
const types = values.map((value) =>
525-
concretizeTypes ? concretize(value.dataType as AnyWgslData) : value.dataType
531+
needsConcretization
532+
? concretize(value.dataType as AnyWgslData)
533+
: value.dataType
526534
);
527535

528536
if (types.some((type) => type === UnknownData)) {
@@ -557,7 +565,7 @@ Consider using explicit conversions instead.`,
557565
const action = conversion.actions[index];
558566
invariant(action, 'Action should not be undefined');
559567
return applyActionToSnippet(ctx, value, action, conversion.targetType);
560-
});
568+
}) as T;
561569
}
562570

563571
export function tryConvertSnippet(

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,9 @@ export function generateExpression(
215215

216216
const converted = convertToCommonType({
217217
ctx,
218-
values: [lhsExpr, rhsExpr],
218+
values: [lhsExpr, rhsExpr] as const,
219219
restrictTo: forcedType,
220-
}) as
221-
| [Snippet, Snippet]
222-
| undefined;
220+
});
223221
const [convLhs, convRhs] = converted || [lhsExpr, rhsExpr];
224222

225223
const lhsStr = ctx.resolve(convLhs.value);

packages/typegpu/tests/compiledIO.test.ts

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import {
66
import * as d from '../src/data/index.ts';
77
import { sizeOf } from '../src/data/sizeOf.ts';
88
import { it } from './utils/extendedIt.ts';
9-
import tgpu from '../src/index.ts';
109

1110
describe('buildWriter', () => {
1211
it('should compile a writer for a struct', () => {
@@ -293,8 +292,6 @@ describe('createCompileInstructions', () => {
293292
b: d.align(64, d.arrayOf(d.f32, 2)),
294293
});
295294

296-
console.log(tgpu.resolve({ externals: { schema } }));
297-
298295
const builtWriter = buildWriter(schema, 'offset', 'value');
299296
expect(builtWriter).toMatchInlineSnapshot(`
300297
"output.setFloat32((offset + 0), value.a, littleEndian);

packages/typegpu/tests/struct.test.ts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,6 @@ describe('abstruct', () => {
388388
return result.exp;
389389
});
390390

391-
console.log(tgpu.resolve({ externals: { testFn } }));
392-
393391
expect(parseResolved({ testFn })).toBe(
394392
parse(`
395393
fn testFn(x: f32) -> f32 {

packages/typegpu/tests/tgsl/wgslGenerator.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,7 @@ describe('wgslGenerator division operator', () => {
997997
});
998998
expect(div()).toBe(0.1);
999999
expect(parseResolved({ divide1: div })).toMatchInlineSnapshot(
1000-
`"fn div ( ) -> f32 { return ( f32 ( f16 ( ( f32 ( 1 ) / f32 ( 2 ) ) ) ) / f32 ( 5 ) ) ; }"`,
1000+
`"fn div ( ) -> f32 { return ( f32 ( f16 ( 0.5 ) ) / f32 ( 5 ) ) ; }"`,
10011001
);
10021002
});
10031003

0 commit comments

Comments
 (0)