@@ -238,7 +238,7 @@ static VariantCompType makeExpected(ComponentType CompType, int32_t M,
238238 std::vector<HLSLHalf_t> Halfs (NumElements);
239239
240240 for (int32_t I = 0 ; I < M; ++I) {
241- for (int32_t J = 0 ; J < M ; ++J) {
241+ for (int32_t J = 0 ; J < N ; ++J) {
242242 int32_t Value = I * M + J;
243243 int32_t Idx = Transpose ? J * N + I : Value;
244244 switch (CompType) {
@@ -397,6 +397,7 @@ static const char LoadStoreShader[] = R"(
397397 RWByteAddressBuffer Output : register(u1);
398398
399399#ifndef EMULATE_TEST
400+ [WaveSize(4, 64)]
400401 [numthreads(NUMTHREADS, 1, 1)]
401402 void main() {
402403 __builtin_LinAlgMatrix
@@ -476,7 +477,7 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
476477 Params.Use = MatrixUse::A;
477478 Params.Scope = MatrixScope::Wave;
478479 Params.Layout = LinalgMatrixLayout::RowMajor;
479- Params.NumThreads = 4 ;
480+ Params.NumThreads = 64 ;
480481 Params.Enable16Bit = true ;
481482 Params.EmulateTest = EmulateTest;
482483 runLoadStoreRoundtrip (D3DDevice, DxcSupport, Params, VerboseLogging,
@@ -487,6 +488,7 @@ static const char SplatStoreShader[] = R"(
487488 RWByteAddressBuffer Output : register(u0);
488489
489490#ifndef EMULATE_TEST
491+ [WaveSize(4, 64)]
490492 [numthreads(NUMTHREADS, 1, 1)]
491493 void main() {
492494 __builtin_LinAlgMatrix
@@ -555,7 +557,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
555557 Params.Use = MatrixUse::Accumulator;
556558 Params.Scope = MatrixScope::Wave;
557559 Params.Layout = LinalgMatrixLayout::RowMajor;
558- Params.NumThreads = 4 ;
560+ Params.NumThreads = 64 ;
559561 Params.Enable16Bit = true ;
560562 Params.EmulateTest = EmulateTest;
561563 runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging,
0 commit comments