1414
1515using namespace gpu ;
1616
17- const char * versionToStr (int version);
18-
1917static const char *kShaderMatmul1 = R"(
2018@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
2119@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -468,123 +466,6 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
468466 }
469467}
470468
471- /* 2D block-tiling with transpose
472- *
473- */
474- static const char *kShaderMatmulWithTranspose = R"(
475- @group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
476- @group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
477- @group(0) @binding(2) var<storage, read_write> c: array<vec4<{{precision}}>>;
478- var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
479- var<workgroup> tileB: array<{{precision}}, {{BK}} * {{BN}}>;
480-
481- @compute @workgroup_size({{workgroupSize}})
482- fn main(
483- @builtin(global_invocation_id) globalID : vec3<u32>,
484- @builtin(local_invocation_id) localID : vec3<u32>,
485- @builtin(workgroup_id) groupid : vec3<u32>) {
486-
487- var threadResults: array<vec4<{{precision}}>, {{TM}} * {{TN4}}>;
488- var localM: array<{{precision}}, {{TM}}>;
489- var localN: array<vec4<{{precision}}>, {{TN4}}>;
490-
491- let cRow: u32 = groupid.x;
492- let cCol: u32 = groupid.y;
493- let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
494-
495- // position of the first c element computed by the thread
496- let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
497- let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
498-
499- // aPtr and bPtr are the starting positions of the tiles in a and b,
500- // incremented in the bkidx loop.
501- // cPtr is the starting position of the tile in c which is fixed.
502-
503- var aPtr: u32 = cRow * {{BM}} * {{K}};
504- var bPtr: u32 = cCol * {{BN}};
505- let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
506-
507- for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
508-
509- // Load tile
510- // Load BM x BK by numThread(BM * BN / (TM * TN))
511- // The number of iteration == BM * BK / (BM * BN / (TM * TN))
512- for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
513- tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
514- }
515- // Load BK x BN by numThread(BM * BN / (TM * TN))
516- // The number of iteration == BK * BN / (BM * BN / (TM * TN))
517- for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
518- tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
519- }
520-
521- aPtr += {{BK}};
522- bPtr += {{BK}} * {{N}};
523-
524- workgroupBarrier();
525- // Compute tile
526- for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
527- for (var idx: u32 = 0; idx < {{TM}}; idx++) {
528- localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
529- }
530- for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
531- localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) + dotIdx * {{BN}}],
532- tileB[(threadCol + idx*4 + 1) + dotIdx * {{BN}}],
533- tileB[(threadCol + idx*4 + 2) + dotIdx * {{BN}}],
534- tileB[(threadCol + idx*4 + 3) + dotIdx * {{BN}}]);
535- }
536- for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
537- for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
538- threadResults[resIdxM * {{TN4}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
539- }
540- }
541- }
542- workgroupBarrier();
543- }
544-
545- for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
546- for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
547- c[cPtr + (threadRow + resIdxM) * {{N4}} + (threadCol/4) + resIdxN] = threadResults[resIdxM * {{TN4}} + resIdxN];
548- }
549- }
550- }
551- )" ;
552-
553- inline KernelCode createMatmulWithTranspose (const char *shaderTemplate, const size_t M,
554- const size_t K, const size_t N, const size_t BM,
555- const size_t BK, const size_t BN,
556- const size_t TM, const size_t TN,
557- const Shape &workgroupSize = {256 , 1 , 1 },
558- NumType precision = kf32) {
559- assert (BM % TM == 0 );
560- assert (BN % TN == 0 );
561- assert (K % BK == 0 );
562- assert (M % BM == 0 );
563- assert (N % BN == 0 );
564- // # threads = tile A size == tile B size == # threads for computing C
565- int num_threads = BM * BN / (TM * TN);
566- std::string codeString (shaderTemplate);
567- replaceAll (codeString, {{" {{workgroupSize}}" , toString (workgroupSize)},
568- {" {{precision}}" , toString (precision)},
569- {" {{M}}" , toString (M)},
570- {" {{K}}" , toString (K)},
571- {" {{N}}" , toString (N)},
572- {" {{BM}}" , toString (BM)},
573- {" {{BK}}" , toString (BK)},
574- {" {{BN}}" , toString (BN)},
575- {" {{TM}}" , toString (TM)},
576- {" {{TN}}" , toString (TN)},
577- {" {{NUM_TILEA}}" , toString (BM * BK / num_threads)},
578- {" {{NUM_TILEB}}" , toString (BN * BK / num_threads)},
579- {" {{TN4}}" , toString (TN / 4 )},
580- {" {{N4}}" , toString (N / 4 )},
581- {" {{BN4}}" , toString (BN / 4 )},
582- });
583- std::string unrolledCode = loopUnrolling (codeString);
584- // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
585- return {unrolledCode, workgroupSize};
586- }
587-
588469/* *
589470 * @brief No-Op shader with matmul bindings for performance testing
590471 */
@@ -638,26 +519,20 @@ Kernel selectMatmul(Context &ctx, int version,
638519 size_t M, size_t K, size_t N) {
639520 Kernel kernel;
640521 if (version == 1 ) {
641- Shape wgSize = {256 , 1 , 1 };
642- Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
643- KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
644- kernel = createKernel (ctx, matmul, bindings,
645- /* nWorkgroups*/ nWorkgroups);
646- } else if (version == 2 ) {
647522 Shape wgSize = {16 , 16 , 1 };
648523 LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
649524 KernelCode matmul =
650525 createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize);
651526 kernel = createKernel (ctx, matmul, bindings,
652527 /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
653- } else if (version == 3 ) {
528+ } else if (version == 2 ) {
654529 static constexpr size_t tileSize = 16 ;
655530 KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
656531 /* wgSize*/ {tileSize * tileSize, 1 , 1 });
657532 kernel =
658533 createKernel (ctx, matmul, bindings,
659534 /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
660- } else if (version == 4 || version == 6 ) {
535+ } else if (version == 3 || version == 5 ) {
661536 static constexpr size_t BM = 64 ;
662537 static constexpr size_t BK = 4 ;
663538 static constexpr size_t BN = BM;
@@ -673,10 +548,10 @@ Kernel selectMatmul(Context &ctx, int version,
673548 KernelCode matmul = createMatmul3 (kShaderMatmul3 , M, K, N, BM, BK, BN, TM,
674549 /* wgSize*/ wgSize,
675550 kf32,
676- /* Loop unrolling*/ version == 6 ? true : false );
551+ /* Loop unrolling*/ version == 5 ? true : false );
677552 kernel = createKernel (ctx, matmul, bindings,
678553 /* nWorkgroups*/ nWorkgroups);
679- } else if (version == 5 || version == 7 ) {
554+ } else if (version == 4 || version == 6 ) {
680555 static constexpr size_t BM = 64 ;
681556 static constexpr size_t BK = 8 ;
682557 static constexpr size_t BN = 64 ;
@@ -691,10 +566,10 @@ Kernel selectMatmul(Context &ctx, int version,
691566 KernelCode matmul = createMatmul4 (kShaderMatmul4 , M, K, N, BM, BK, BN, TM, TN,
692567 /* wgSize*/ wgSize,
693568 kf32,
694- /* Loop unrolling*/ version == 7 ? true : false );
569+ /* Loop unrolling*/ version == 6 ? true : false );
695570 kernel = createKernel (ctx, matmul, bindings,
696571 /* nWorkgroups*/ nWorkgroups);
697- } else if (version == 8 ) {
572+ } else if (version == 7 ) {
698573 static constexpr size_t BM = 64 ;
699574 static constexpr size_t BK = 8 ;
700575 static constexpr size_t BN = 64 ;
@@ -712,21 +587,10 @@ Kernel selectMatmul(Context &ctx, int version,
712587 /* Loop unrolling*/ true );
713588 kernel = createKernel (ctx, matmul, bindings,
714589 /* nWorkgroups*/ nWorkgroups);
715- } else if (version == 9 ) {
716- static constexpr size_t BM = 64 ;
717- static constexpr size_t BK = 8 ;
718- static constexpr size_t BN = 64 ;
719- static constexpr size_t TM = BM / BK;
720- static constexpr size_t TN = BN / BK;
721- Shape wgSize = {(BM / TM) * (BN / TN), 1 , 1 }; // This is the same as BK * BK.
722- Shape nWorkgroups = {cdiv (M, BM), cdiv (N, BN), 1 };
723- LOG (kDefLog , kInfo , " M: %d, K: %d, N: %d" , M, K, N);
724- LOG (kDefLog , kInfo , " BM: %d, BK: %d, BN: %d, TM: %d, TN: %d" , BM, BK, BN, TM, TN);
725- LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
726- LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
727- KernelCode matmul = createMatmulWithTranspose (kShaderMatmulWithTranspose , M, K, N, BM, BK, BN, TM, TN,
728- /* wgSize*/ wgSize,
729- kf32);
590+ } else if (version == 8 ) {
591+ Shape wgSize = {256 , 1 , 1 };
592+ Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
593+ KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
730594 kernel = createKernel (ctx, matmul, bindings,
731595 /* nWorkgroups*/ nWorkgroups);
732596 }
@@ -762,8 +626,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
762626
763627 printf (" [ Press enter to start tests ... ]\n " );
764628 getchar ();
765- LOG (kDefLog , kInfo , " Dispatching Kernel version %d: %s , %d iterations ..." ,
766- version, versionToStr (version), nIter);
629+ LOG (kDefLog , kInfo , " Dispatching Kernel version %d, %d iterations ..." ,
630+ version, nIter);
767631
768632 // Dispatch kernel nIter times
769633 auto start = std::chrono::high_resolution_clock::now ();
@@ -798,43 +662,26 @@ void runTest(int version, size_t M, size_t K, size_t N,
798662 M, K, N, nIter, duration.count () / static_cast <double >(nIter) / 1000.0 /* us -> ms */ , gflops);
799663}
800664
801- const char * versionToStr (int version){
802- switch (version) {
803- case 1 : return " No-Op" ;
804- case 2 : return " naive matmul" ;
805- case 3 : return " tiling" ;
806- case 4 : return " 1D blocktiling" ;
807- case 5 : return " 2D blocktiling" ;
808- case 6 : return " 1D blocktiling with loop unrolling" ;
809- case 7 : return " 2D blocktiling with loop unrolling" ;
810- case 8 : return " 2D blocktiling with loop unrolling and vectorization" ;
811- case 9 : return " 2D blocktiling with loop unrolling, vectorization and transpose" ;
812- default : return " Not specified" ;
813- }
814- }
815-
816665int main () {
817666 char * version_str = getenv (" MATMUL_VERSION" );
818- char * kTestSize_str = getenv (" MATMUL_SIZE" );
819- int version = version_str == NULL ? 9 : atoi (version_str);
820- // 1 == No-Op
821- // 2 == naive matmul
822- // 3 == tiling
823- // 4 == 1D blocktiling
824- // 5 == 2D blocktiling
825- // 6 == 1D blocktiling with loop unrolling
826- // 7 == 2D blocktiling with loop unrolling
827- // 8 == 2D blocktiling with loop unrolling and vectorization
828- // 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
667+ int version = version_str == NULL ? 7 : atoi (version_str);
668+ // 1 == naive matmul
669+ // 2 == tiling
670+ // 3 == 1D blocktiling
671+ // 4 == 2D blocktiling
672+ // 5 == 1D blocktiling with loop unrolling
673+ // 6 == 2D blocktiling with loop unrolling
674+ // 7 == 2D blocktiling with loop unrolling and vectorization
675+ // 8 == No-Op
829676
830677 size_t M, K, N; // Matrix dimensions
831- int kTestSize = kTestSize_str == NULL ? 2 : atoi ( kTestSize_str ) ;
832- if (kTestSize == 0 ) {
678+ static constexpr int kTestSize = 2 ;
679+ if constexpr (kTestSize == 0 ) {
833680 // Tiny test
834681 M = 32 ;
835682 K = 32 ;
836683 N = 32 ;
837- } else if (kTestSize == 1 ) {
684+ } else if constexpr (kTestSize == 1 ) {
838685 // Small test
839686 M = 256 ;
840687 K = 128 ;
@@ -849,19 +696,11 @@ int main() {
849696 std::unique_ptr<float []> inputPtr = std::make_unique<float []>(M * K);
850697 std::unique_ptr<float []> weightsPtr = std::make_unique<float []>(N * K);
851698 std::unique_ptr<float []> outputPtr = std::make_unique<float []>(M * N);
852- bool transposedInput = version == 9 ;
853699
854700 initData (M, K, N, inputPtr, weightsPtr);
855- if (transposedInput) {
856- std::unique_ptr<float []> transposedWeightPtr = std::make_unique<float []>(K * N);
857- transpose (weightsPtr.get (), transposedWeightPtr.get (), N, K);
858- runTest (version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
859- } else {
860- runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
861- }
862-
701+ runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
863702
864- if (kTestSize <= 1 ) {
703+ if constexpr (kTestSize <= 1 ) {
865704 // Check result with CPU reference implementation for tiny/small tests
866705 checkCPU (M, K, N, inputPtr, weightsPtr, outputPtr);
867706 }
0 commit comments