@@ -108,8 +108,6 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
108108 SS << " -DELEM_TYPE=uint" ;
109109 break ;
110110 }
111- if (Params.EmulateTest )
112- SS << " -DEMULATE_TEST" ;
113111 if (Params.Enable16Bit )
114112 SS << " -enable-16bit-types" ;
115113 if (ExtraDefines)
@@ -282,12 +280,6 @@ static VariantCompType makeExpected(ComponentType CompType, int32_t M,
282280 }
283281}
284282
285- static void logCompiledButSkipping () {
286- hlsl_test::LogCommentFmt (
287- L" Shader compiled OK; skipping execution (no SM 6.10 device)" );
288- WEX::Logging::Log::Result (WEX::Logging::TestResults::Skipped);
289- }
290-
291283class DxilConf_SM610_LinAlg {
292284public:
293285 BEGIN_TEST_CLASS (DxilConf_SM610_LinAlg)
@@ -316,36 +308,16 @@ class DxilConf_SM610_LinAlg {
316308 TEST_METHOD (ElementAccess_Wave_16x16_F16);
317309
318310private:
319- D3D_SHADER_MODEL createDevice ();
320-
321311 CComPtr<ID3D12Device> D3DDevice;
322312 dxc::SpecificDllLoader DxcSupport;
323313 bool VerboseLogging = false ;
324- bool EmulateTest = false ;
325314 bool Initialized = false ;
326- bool CompileOnly = false ;
327315 std::optional<D3D12SDKSelector> D3D12SDK;
328316
329317 WEX::TestExecution::SetVerifyOutput VerifyOutput{
330318 WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures};
331319};
332320
333- // / Attempts to create a device. If shaders are being emulated then a SM6.8
334- // / device is attempted. Otherwise a SM6.10 device is attempted
335- D3D_SHADER_MODEL DxilConf_SM610_LinAlg::createDevice () {
336- if (EmulateTest) {
337- if (D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_8, false ))
338- return D3D_SHADER_MODEL_6_8;
339-
340- return D3D_SHADER_MODEL_NONE;
341- }
342-
343- if (D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10, false ))
344- return D3D_SHADER_MODEL_6_10;
345-
346- return D3D_SHADER_MODEL_NONE;
347- }
348-
349321bool DxilConf_SM610_LinAlg::setupClass () {
350322 if (!Initialized) {
351323 Initialized = true ;
@@ -354,28 +326,12 @@ bool DxilConf_SM610_LinAlg::setupClass() {
354326 D3D12SDK = D3D12SDKSelector ();
355327 WEX::TestExecution::RuntimeParameters::TryGetValue (L" VerboseLogging" ,
356328 VerboseLogging);
357- WEX::TestExecution::RuntimeParameters::TryGetValue (L" EmulateTest" ,
358- EmulateTest);
359- D3D_SHADER_MODEL SupportedSM = createDevice ();
360-
361- if (EmulateTest) {
362- hlsl_test::LogWarningFmt (L" EmulateTest flag set. Tests are NOT REAL" );
363- if (SupportedSM != D3D_SHADER_MODEL_6_8) {
364- hlsl_test::LogErrorFmt (
365- L" Device creation failed. Expected a driver supporting SM6.8" );
366- return false ;
367- }
368- }
369329
370- #ifdef _HLK_CONF
371- if (SupportedSM != D3D_SHADER_MODEL_6_10) {
330+ if (!D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10, false )) {
372331 hlsl_test::LogErrorFmt (
373332 L" Device creation failed. Expected a driver supporting SM6.10" );
374333 return false ;
375334 }
376- #endif
377-
378- CompileOnly = SupportedSM == D3D_SHADER_MODEL_NONE;
379335 }
380336
381337 return true ;
@@ -387,27 +343,24 @@ bool DxilConf_SM610_LinAlg::setupMethod() {
387343 if (D3DDevice && D3DDevice->GetDeviceRemovedReason () == S_OK)
388344 return true ;
389345
390- // Device is expected to be null. No point in recreating it
391- if (CompileOnly)
392- return true ;
393-
394346 hlsl_test::LogCommentFmt (L" Device was lost!" );
395347 D3DDevice.Release ();
396348
397349 hlsl_test::LogCommentFmt (L" Recreating device" );
398350
399- // !CompileOnly implies we expect it to succeeded
400- return createDevice () != D3D_SHADER_MODEL_NONE;
351+ return D3D12SDK->createDevice (&D3DDevice, D3D_SHADER_MODEL_6_10, false );
401352}
402353
403354static const char LoadStoreShader[] = R"(
404355 RWByteAddressBuffer Input : register(u0);
405356 RWByteAddressBuffer Output : register(u1);
406357
407- #ifndef EMULATE_TEST
408358 [WaveSize(4, 64)]
409359 [numthreads(NUMTHREADS, 1, 1)]
410- void main() {
360+ void main(uint threadID : SV_GroupIndex) {
361+ if (WaveReadLaneFirst(threadID) != 0)
362+ return;
363+
411364 __builtin_LinAlgMatrix
412365 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
413366 Mat;
@@ -416,45 +369,26 @@ static const char LoadStoreShader[] = R"(
416369 __builtin_LinAlg_MatrixStoreToDescriptor(
417370 Mat, Output, OFFSET, STRIDE, LAYOUT, 128);
418371 }
419- #else
420- [numthreads(NUMTHREADS, 1, 1)]
421- void main() {
422- for (uint I = 0; I < M_DIM*N_DIM; ++I) {
423- Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
424- }
425- }
426- #endif
427372)" ;
428373
429374static void runLoadStoreRoundtrip (ID3D12Device *Device,
430375 dxc::SpecificDllLoader &DxcSupport,
431- const MatrixParams &Params, bool Verbose,
432- bool CompileOnly) {
376+ const MatrixParams &Params, bool Verbose) {
433377 const size_t NumElements = Params.totalElements ();
434378 const size_t BufferSize = Params.totalBytes ();
435379
436- std::string Target = " cs_6_10" ;
437- if (Params.EmulateTest )
438- Target = " cs_6_8" ;
439-
440380 // TODO: these should be varied by test to ensure full coverage
441381 std::stringstream ExtraDefs;
442382 ExtraDefs << " -DOFFSET=" << 0 ;
443383
444384 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
445385
446- // Always verify the shader compiles.
447- compileShader (DxcSupport, LoadStoreShader, Target.c_str (), Args, Verbose);
448-
449- if (CompileOnly) {
450- logCompiledButSkipping ();
451- return ;
452- }
386+ compileShader (DxcSupport, LoadStoreShader, " cs_6_10" , Args, Verbose);
453387
454388 auto Expected = makeExpected (Params.CompType , Params.M , Params.N , 1 );
455389
456390 // Construct the ShaderOp: two UAV buffers, load from one, store to other.
457- auto Op = createComputeOp (LoadStoreShader, Target. c_str () , " UAV(u0), UAV(u1)" ,
391+ auto Op = createComputeOp (LoadStoreShader, " cs_6_10 " , " UAV(u0), UAV(u1)" ,
458392 Args.c_str ());
459393 addUAVBuffer (Op.get (), " Input" , BufferSize, false , " byname" );
460394 addUAVBuffer (Op.get (), " Output" , BufferSize, true );
@@ -487,63 +421,45 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
487421 Params.Layout = LinalgMatrixLayout::RowMajor;
488422 Params.NumThreads = 64 ;
489423 Params.Enable16Bit = true ;
490- Params.EmulateTest = EmulateTest;
491- runLoadStoreRoundtrip (D3DDevice, DxcSupport, Params, VerboseLogging,
492- CompileOnly);
424+ runLoadStoreRoundtrip (D3DDevice, DxcSupport, Params, VerboseLogging);
493425}
494426
495427static const char SplatStoreShader[] = R"(
496428 RWByteAddressBuffer Output : register(u0);
497429
498- #ifndef EMULATE_TEST
499430 [WaveSize(4, 64)]
500431 [numthreads(NUMTHREADS, 1, 1)]
501- void main() {
432+ void main(uint threadID : SV_GroupIndex) {
433+ if (WaveReadLaneFirst(threadID) != 0)
434+ return;
435+
502436 __builtin_LinAlgMatrix
503437 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
504438 Mat;
505439 __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
506440 __builtin_LinAlg_MatrixStoreToDescriptor(
507441 Mat, Output, 0, STRIDE, LAYOUT, 128);
508442 }
509- #else
510- [numthreads(NUMTHREADS, 1, 1)]
511- void main() {
512- ELEM_TYPE fill = FILL_VALUE;
513- for (uint I = 0; I < M_DIM*N_DIM; ++I) {
514- Output.Store<ELEM_TYPE>(I*ELEM_SIZE, fill);
515- }
516- }
517- #endif
518443)" ;
519444
520445static void runSplatStore (ID3D12Device *Device,
521446 dxc::SpecificDllLoader &DxcSupport,
522447 const MatrixParams &Params, float FillValue,
523- bool Verbose, bool CompileOnly ) {
448+ bool Verbose) {
524449 const size_t NumElements = Params.totalElements ();
525450 const size_t BufferSize = Params.totalBytes ();
526- std::string Target = " cs_6_10" ;
527- if (Params.EmulateTest )
528- Target = " cs_6_8" ;
529451
530452 std::stringstream ExtraDefs;
531453 ExtraDefs << " -DFILL_VALUE=" << FillValue;
532454
533455 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
534456
535- // Always verify the shader compiles.
536- compileShader (DxcSupport, SplatStoreShader, Target.c_str (), Args, Verbose);
537-
538- if (CompileOnly) {
539- logCompiledButSkipping ();
540- return ;
541- }
457+ compileShader (DxcSupport, SplatStoreShader, " cs_6_10" , Args, Verbose);
542458
543459 auto Expected =
544460 makeExpected (Params.CompType , Params.M , Params.N , FillValue, false );
545461
546- auto Op = createComputeOp (SplatStoreShader, Target. c_str () , " UAV(u0)" ,
462+ auto Op = createComputeOp (SplatStoreShader, " cs_6_10 " , " UAV(u0)" ,
547463 Args.c_str ());
548464 addUAVBuffer (Op.get (), " Output" , BufferSize, true );
549465 addRootUAV (Op.get (), 0 , " Output" );
@@ -567,9 +483,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
567483 Params.Layout = LinalgMatrixLayout::RowMajor;
568484 Params.NumThreads = 64 ;
569485 Params.Enable16Bit = true ;
570- Params.EmulateTest = EmulateTest;
571- runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging,
572- CompileOnly);
486+ runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging);
573487}
574488
575489static const char ElementAccessShader[] = R"(
@@ -582,10 +496,12 @@ static const char ElementAccessShader[] = R"(
582496 return (coord.y * N_DIM + coord.x) * ELEM_SIZE;
583497 }
584498
585- #ifndef EMULATE_TEST
586499 [WaveSize(4, 64)]
587500 [numthreads(NUMTHREADS, 1, 1)]
588- void main(uint threadIndex : SV_GroupIndex) {
501+ void main(uint threadID : SV_GroupIndex) {
502+ if (WaveReadLaneFirst(threadID) != 0)
503+ return;
504+
589505 __builtin_LinAlgMatrix
590506 [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
591507 Mat;
@@ -603,30 +519,15 @@ static const char ElementAccessShader[] = R"(
603519
604520 // Save the matrix length that this thread saw. The length is written
605521 // to the output right after the matrix, offset by the thread index
606- uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
522+ uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadID * sizeof(uint));
607523 uint Len = __builtin_LinAlg_MatrixLength(Mat);
608524 Output.Store<uint>(LenIdx, Len);
609525 }
610- #else
611- [numthreads(NUMTHREADS, 1, 1)]
612- void main(uint threadIndex : SV_GroupIndex) {
613- uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
614- Output.Store<uint>(LenIdx, M_DIM * N_DIM / NUMTHREADS);
615-
616- if (threadIndex != 0)
617- return;
618-
619- for (uint I = 0; I < M_DIM*N_DIM; ++I) {
620- Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
621- }
622- }
623- #endif
624526)" ;
625527
626528static void runElementAccess (ID3D12Device *Device,
627529 dxc::SpecificDllLoader &DxcSupport,
628- const MatrixParams &Params, bool Verbose,
629- bool CompileOnly) {
530+ const MatrixParams &Params, bool Verbose) {
630531 const size_t NumElements = Params.totalElements ();
631532 const size_t NumThreads = Params.NumThreads ;
632533 const size_t InputBufSize = Params.totalBytes ();
@@ -638,23 +539,14 @@ static void runElementAccess(ID3D12Device *Device,
638539 const size_t OutputBufSize =
639540 NumElements * ElementSize + NumThreads * sizeof (uint32_t );
640541
641- std::string Target = " cs_6_10" ;
642- if (Params.EmulateTest )
643- Target = " cs_6_8" ;
644-
645542 std::stringstream ExtraDefs;
646543 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
647544
648- compileShader (DxcSupport, ElementAccessShader, Target.c_str (), Args, Verbose);
649-
650- if (CompileOnly) {
651- logCompiledButSkipping ();
652- return ;
653- }
545+ compileShader (DxcSupport, ElementAccessShader, " cs_6_10" , Args, Verbose);
654546
655547 auto Expected = makeExpected (Params.CompType , Params.M , Params.N , 1 );
656548
657- auto Op = createComputeOp (ElementAccessShader, Target. c_str () ,
549+ auto Op = createComputeOp (ElementAccessShader, " cs_6_10 " ,
658550 " UAV(u0), UAV(u1)" , Args.c_str ());
659551 addUAVBuffer (Op.get (), " Input" , InputBufSize, false , " byname" );
660552 addUAVBuffer (Op.get (), " Output" , OutputBufSize, true );
@@ -700,8 +592,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
700592 Params.Layout = LinalgMatrixLayout::RowMajor;
701593 Params.NumThreads = 64 ;
702594 Params.Enable16Bit = true ;
703- Params.EmulateTest = EmulateTest;
704- runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
595+ runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging);
705596}
706597
707598} // namespace LinAlg
0 commit comments