@@ -228,37 +228,47 @@ static bool fillInputBuffer(LPCSTR Name, std::vector<BYTE> &Data,
228228 return false ;
229229}
230230
231- static VariantCompType makeExpected (ComponentType CompType, size_t NumElements,
232- float StartingVal, bool Increment) {
233- switch (CompType) {
234- case ComponentType::F32: {
235- std::vector<float > Floats (NumElements);
236- for (size_t I = 0 ; I < NumElements; I++)
237- Floats[I] = StartingVal + static_cast <float >(Increment ? I : 0 );
238- return Floats;
239- }
240- case ComponentType::I32: {
241- DXASSERT (StartingVal < static_cast <float >(INT_MAX),
242- " Value too large to cast to int32_t" );
243- std::vector<int32_t > Ints (NumElements);
244- for (size_t I = 0 ; I < NumElements; I++)
245- Ints[I] = static_cast <int32_t >(StartingVal) +
246- static_cast <int32_t >(Increment ? I : 0 );
247- return Ints;
248- }
249- case ComponentType::F16: {
250- std::vector<HLSLHalf_t> Halfs (NumElements);
251- for (size_t I = 0 ; I < NumElements; I++) {
252- // Downcasting is safe here since HLSLHalf_t will clamp if F is too large.
253- float F = StartingVal + static_cast <float >(Increment ? I : 0 );
254- Halfs[I] = HLSLHalf_t (F);
231+ static VariantCompType makeExpected (ComponentType CompType, int32_t M, int32_t N,
232+ float StartingVal, bool Increment = true , bool Transpose = false ) {
233+ int32_t NumElements = M * N;
234+ std::vector<float > Floats (NumElements);
235+ std::vector<int32_t > Ints (NumElements);
236+ std::vector<HLSLHalf_t> Halfs (NumElements);
237+
238+ for (int32_t I = 0 ; I < M; ++I) {
239+ for (int32_t J = 0 ; J < M; ++J) {
240+ int32_t Value = I * M + J;
241+ int32_t Idx = Transpose ? J * N + I : Value;
242+ switch (CompType) {
243+ case ComponentType::F32:
244+ Floats[Idx] = StartingVal + static_cast <float >(Increment ? Value : 0 );
245+ break ;
246+ case ComponentType::I32:
247+ DXASSERT (StartingVal < static_cast <float >(INT_MAX),
248+ " Value too large to cast to int32_t" );
249+ Ints[Idx] = static_cast <int32_t >(StartingVal) + (Increment ? Value : 0 );
250+ break ;
251+ case ComponentType::F16: {
252+ // Downcasting is safe here since HLSLHalf_t will clamp if F is too large.
253+ float F = StartingVal + static_cast <float >(Increment ? Value : 0 );
254+ Halfs[Idx] = HLSLHalf_t (F);
255+ break ;
256+ }
257+ }
255258 }
256- return Halfs;
257- }
258259 }
259260
260- DXASSERT (false , " Unable to fill unexpected ComponentType" );
261- return std::vector<float >();
261+ switch (CompType) {
262+ case ComponentType::F32:
263+ return Floats;
264+ case ComponentType::I32:
265+ return Ints;
266+ case ComponentType::F16:
267+ return Halfs;
268+ default :
269+ DXASSERT (false , " Unable to fill unexpected ComponentType" );
270+ return Floats;
271+ }
262272}
263273
264274static void logCompiledButSkipping () {
@@ -429,7 +439,7 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
429439 return ;
430440 }
431441
432- auto Expected = makeExpected (Params.CompType , NumElements, 1 , true );
442+ auto Expected = makeExpected (Params.CompType , Params. M , Params. N , 1 );
433443
434444 // Construct the ShaderOp: two UAV buffers, load from one, store to other.
435445 auto Op = createComputeOp (LoadStoreShader, Target.c_str (), " UAV(u0), UAV(u1)" ,
@@ -517,7 +527,7 @@ static void runSplatStore(ID3D12Device *Device,
517527 return ;
518528 }
519529
520- auto Expected = makeExpected (Params.CompType , NumElements , FillValue, false );
530+ auto Expected = makeExpected (Params.CompType , Params. M , Params. N , FillValue, false );
521531
522532 auto Op = createComputeOp (SplatStoreShader, Target.c_str (), " UAV(u0)" ,
523533 Args.c_str ());
@@ -553,11 +563,13 @@ static const char ElementAccessShader[] = R"(
553563 RWByteAddressBuffer Output : register(u1);
554564
555565 // flatten the 2D index into a 1D index then scale by element size
566+ // Always store row-major and work it out in the test runner
556567 uint coordToByteOffset(uint2 coord) {
557- return (coord.x * MAJOR_DIM + coord.y ) * ELEM_SIZE;
568+ return (coord.y * N_DIM + coord.x ) * ELEM_SIZE;
558569 }
559570
560571#ifndef EMULATE_TEST
572+ [WaveSize(4, 64)]
561573 [numthreads(NUMTHREADS, 1, 1)]
562574 void main(uint threadIndex : SV_GroupIndex) {
563575 __builtin_LinAlgMatrix
@@ -605,8 +617,7 @@ static void runElementAccess(ID3D12Device *Device,
605617 const size_t NumThreads = Params.NumThreads ;
606618 const size_t InputBufSize = Params.totalBytes ();
607619 const size_t ElementSize = elementSize (Params.CompType );
608- const size_t MajorDim =
609- Params.Layout == LinalgMatrixLayout::RowMajor ? Params.M : Params.N ;
620+
610621 // Output: ElementSize bytes per element
611622 // 1 element for each mat idx
612623 // 1 uint for each thread's length
@@ -618,7 +629,6 @@ static void runElementAccess(ID3D12Device *Device,
618629 Target = " cs_6_8" ;
619630
620631 std::stringstream ExtraDefs;
621- ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
622632 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
623633
624634 compileShader (DxcSupport, ElementAccessShader, Target.c_str (), Args, Verbose);
@@ -628,7 +638,7 @@ static void runElementAccess(ID3D12Device *Device,
628638 return ;
629639 }
630640
631- auto Expected = makeExpected (Params.CompType , NumElements, 1 , true );
641+ auto Expected = makeExpected (Params.CompType , Params. M , Params. N , 1 );
632642
633643 auto Op = createComputeOp (ElementAccessShader, Target.c_str (),
634644 " UAV(u0), UAV(u1)" , Args.c_str ());
@@ -674,7 +684,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
674684 Params.Use = MatrixUse::Accumulator;
675685 Params.Scope = MatrixScope::Wave;
676686 Params.Layout = LinalgMatrixLayout::RowMajor;
677- Params.NumThreads = 4 ;
687+ Params.NumThreads = 64 ;
678688 Params.Enable16Bit = true ;
679689 Params.EmulateTest = EmulateTest;
680690 runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
0 commit comments