11import { stitch } from '../core/resolve/stitch.ts' ;
22import { createDualImpl } from '../core/function/dualImpl.ts' ;
3- import { f16 , f32 } from '../data/numeric.ts' ;
43import { isSnippetNumeric , snip , type Snippet } from '../data/snippet.ts' ;
54import { vecTypeToConstructor } from '../data/vector.ts' ;
65import { VectorOps } from '../data/vectorOps.ts' ;
76import {
87 type AnyMatInstance ,
98 type AnyNumericVecInstance ,
9+ type AnyWgslData ,
1010 isFloat32VecInstance ,
1111 isMatInstance ,
1212 isVecInstance ,
@@ -17,6 +17,24 @@ import { convertToCommonType } from '../tgsl/generationHelpers.ts';
1717import { getResolutionCtx } from '../execMode.ts' ;
1818import type { ResolutionCtx } from '../types.ts' ;
1919import { $internal } from '../shared/symbols.ts' ;
20+ import { f16 , f32 } from '../data/numeric.ts' ;
21+
22+ function tryUnify < T extends Snippet [ ] > (
23+ values : T ,
24+ restrictTo ?: AnyWgslData [ ] ,
25+ concretizeTypes = false ,
26+ verbose = true ,
27+ ) : T {
28+ const ctx = getResolutionCtx ( ) as ResolutionCtx ;
29+ const converted = convertToCommonType ( {
30+ ctx,
31+ values,
32+ restrictTo,
33+ concretizeTypes,
34+ verbose,
35+ } ) ;
36+ return converted ?? values ;
37+ }
2038
2139type NumVec = AnyNumericVecInstance ;
2240type Mat = AnyMatInstance ;
@@ -58,7 +76,10 @@ export const add = createDualImpl(
5876 cpuAdd ,
5977 // CODEGEN implementation
6078 ( lhs , rhs ) => {
61- const resultType = isSnippetNumeric ( lhs ) ? rhs . dataType : lhs . dataType ;
79+ const [ convLhs , convRhs ] = tryUnify ( [ lhs , rhs ] ) ;
80+ const resultType = isSnippetNumeric ( convLhs )
81+ ? convRhs . dataType
82+ : convLhs . dataType ;
6283
6384 if (
6485 ( typeof lhs . value === 'number' ||
@@ -72,7 +93,7 @@ export const add = createDualImpl(
7293 return snip ( cpuAdd ( lhs . value as never , rhs . value as never ) , resultType ) ;
7394 }
7495
75- return snip ( stitch `(${ lhs } + ${ rhs } )` , resultType ) ;
96+ return snip ( stitch `(${ convLhs } + ${ convRhs } )` , resultType ) ;
7697 } ,
7798 'add' ,
7899) ;
@@ -99,7 +120,10 @@ export const sub = createDualImpl(
99120 cpuSub ,
100121 // CODEGEN implementation
101122 ( lhs , rhs ) => {
102- const resultType = isSnippetNumeric ( lhs ) ? rhs . dataType : lhs . dataType ;
123+ const [ convLhs , convRhs ] = tryUnify ( [ lhs , rhs ] ) ;
124+ const resultType = isSnippetNumeric ( convLhs )
125+ ? convRhs . dataType
126+ : convLhs . dataType ;
103127
104128 if (
105129 ( typeof lhs . value === 'number' ||
@@ -113,7 +137,7 @@ export const sub = createDualImpl(
113137 return snip ( cpuSub ( lhs . value as never , rhs . value as never ) , resultType ) ;
114138 }
115139
116- return snip ( stitch `(${ lhs } - ${ rhs } )` , resultType ) ;
140+ return snip ( stitch `(${ convLhs } - ${ convRhs } )` , resultType ) ;
117141 } ,
118142 'sub' ,
119143) ;
@@ -164,20 +188,21 @@ export const mul = createDualImpl(
164188 cpuMul ,
165189 // GPU implementation
166190 ( lhs , rhs ) => {
167- const returnType = isSnippetNumeric ( lhs )
191+ const [ convLhs , convRhs ] = tryUnify ( [ lhs , rhs ] ) ;
192+ const returnType = isSnippetNumeric ( convLhs )
168193 // Scalar * Scalar/Vector/Matrix
169- ? rhs . dataType
170- : isSnippetNumeric ( rhs )
194+ ? convRhs . dataType
195+ : isSnippetNumeric ( convRhs )
171196 // Vector/Matrix * Scalar
172- ? lhs . dataType
173- : lhs . dataType . type . startsWith ( 'vec' )
197+ ? convLhs . dataType
198+ : convLhs . dataType . type . startsWith ( 'vec' )
174199 // Vector * Vector/Matrix
175- ? lhs . dataType
176- : rhs . dataType . type . startsWith ( 'vec' )
200+ ? convLhs . dataType
201+ : convRhs . dataType . type . startsWith ( 'vec' )
177202 // Matrix * Vector
178- ? rhs . dataType
203+ ? convRhs . dataType
179204 // Matrix * Matrix
180- : lhs . dataType ;
205+ : convLhs . dataType ;
181206
182207 if (
183208 ( typeof lhs . value === 'number' ||
@@ -191,7 +216,7 @@ export const mul = createDualImpl(
191216 return snip ( cpuMul ( lhs . value as never , rhs . value as never ) , returnType ) ;
192217 }
193218
194- return snip ( stitch `(${ lhs } * ${ rhs } )` , returnType ) ;
219+ return snip ( stitch `(${ convLhs } * ${ convRhs } )` , returnType ) ;
195220 } ,
196221 'mul' ,
197222) ;
@@ -223,35 +248,20 @@ export const div = createDualImpl(
223248 cpuDiv ,
224249 // CODEGEN implementation
225250 ( lhs , rhs ) => {
226- let conv : [ Snippet , Snippet ] = [ lhs , rhs ] ;
227-
228- if ( isSnippetNumeric ( lhs ) && isSnippetNumeric ( rhs ) ) {
229- const ctx = getResolutionCtx ( ) as ResolutionCtx ;
230- const converted = convertToCommonType ( {
231- ctx,
232- values : [ lhs , rhs ] ,
233- restrictTo : [ f32 , f16 ] ,
234- concretizeTypes : true ,
235- } ) as
236- | [ Snippet , Snippet ]
237- | undefined ;
238- if ( converted ) {
239- conv = converted ;
240- }
241- }
242-
243- const lhsVal = conv [ 0 ] . value ;
244- const rhsVal = conv [ 1 ] . value ;
251+ const [ convLhs , convRhs ] = tryUnify ( [ lhs , rhs ] , [ f32 , f16 ] , true , false ) ;
245252
246253 if (
247- ( typeof lhsVal === 'number' || isVecInstance ( lhsVal ) ) &&
248- ( typeof rhsVal === 'number' || isVecInstance ( rhsVal ) )
254+ ( typeof lhs . value === 'number' || isVecInstance ( lhs . value ) ) &&
255+ ( typeof rhs . value === 'number' || isVecInstance ( rhs . value ) )
249256 ) {
250257 // Precomputing
251- return snip ( cpuDiv ( lhsVal as never , rhsVal as never ) , conv [ 0 ] . dataType ) ;
258+ return snip (
259+ cpuDiv ( lhs . value as never , rhs . value as never ) ,
260+ convLhs . dataType ,
261+ ) ;
252262 }
253263
254- return snip ( stitch `(${ conv [ 0 ] } / ${ conv [ 1 ] } )` , conv [ 0 ] . dataType ) ;
264+ return snip ( stitch `(${ convLhs } / ${ convRhs } )` , convLhs . dataType ) ;
255265 } ,
256266 'div' ,
257267) ;
@@ -294,8 +304,9 @@ export const mod: ModOverload = createDualImpl(
294304 } ,
295305 // GPU implementation
296306 ( a , b ) => {
297- const type = isSnippetNumeric ( a ) ? b . dataType : a . dataType ;
298- return snip ( stitch `(${ a } % ${ b } )` , type ) ;
307+ const [ convA , convB ] = tryUnify ( [ a , b ] ) ;
308+ const type = isSnippetNumeric ( convA ) ? convB . dataType : convA . dataType ;
309+ return snip ( stitch `(${ convA } % ${ convB } )` , type ) ;
299310 } ,
300311 'mod' ,
301312) ;
0 commit comments