Skip to content

Commit 60ea20f

Browse files
committed
[SM6.10][HLK] Fix GetElement test, add tranpose to helper
1 parent e9d74b2 commit 60ea20f

1 file changed

Lines changed: 46 additions & 36 deletions

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -228,37 +228,47 @@ static bool fillInputBuffer(LPCSTR Name, std::vector<BYTE> &Data,
228228
return false;
229229
}
230230

231-
static VariantCompType makeExpected(ComponentType CompType, size_t NumElements,
232-
float StartingVal, bool Increment) {
233-
switch (CompType) {
234-
case ComponentType::F32: {
235-
std::vector<float> Floats(NumElements);
236-
for (size_t I = 0; I < NumElements; I++)
237-
Floats[I] = StartingVal + static_cast<float>(Increment ? I : 0);
238-
return Floats;
239-
}
240-
case ComponentType::I32: {
241-
DXASSERT(StartingVal < static_cast<float>(INT_MAX),
242-
"Value too large to cast to int32_t");
243-
std::vector<int32_t> Ints(NumElements);
244-
for (size_t I = 0; I < NumElements; I++)
245-
Ints[I] = static_cast<int32_t>(StartingVal) +
246-
static_cast<int32_t>(Increment ? I : 0);
247-
return Ints;
248-
}
249-
case ComponentType::F16: {
250-
std::vector<HLSLHalf_t> Halfs(NumElements);
251-
for (size_t I = 0; I < NumElements; I++) {
252-
// Downcasting is safe here since HLSLHalf_t will clamp if F is too large.
253-
float F = StartingVal + static_cast<float>(Increment ? I : 0);
254-
Halfs[I] = HLSLHalf_t(F);
231+
static VariantCompType makeExpected(ComponentType CompType, int32_t M, int32_t N,
232+
float StartingVal, bool Increment = true, bool Transpose = false) {
233+
int32_t NumElements = M * N;
234+
std::vector<float> Floats(NumElements);
235+
std::vector<int32_t> Ints(NumElements);
236+
std::vector<HLSLHalf_t> Halfs(NumElements);
237+
238+
for (int32_t I = 0; I < M; ++I) {
239+
for (int32_t J = 0; J < M; ++J) {
240+
int32_t Value = I * M + J;
241+
int32_t Idx = Transpose ? J * N + I : Value;
242+
switch (CompType) {
243+
case ComponentType::F32:
244+
Floats[Idx] = StartingVal + static_cast<float>(Increment ? Value : 0);
245+
break;
246+
case ComponentType::I32:
247+
DXASSERT(StartingVal < static_cast<float>(INT_MAX),
248+
"Value too large to cast to int32_t");
249+
Ints[Idx] = static_cast<int32_t>(StartingVal) + (Increment ? Value : 0);
250+
break;
251+
case ComponentType::F16: {
252+
// Downcasting is safe here since HLSLHalf_t will clamp if F is too large.
253+
float F = StartingVal + static_cast<float>(Increment ? Value : 0);
254+
Halfs[Idx] = HLSLHalf_t(F);
255+
break;
256+
}
257+
}
255258
}
256-
return Halfs;
257-
}
258259
}
259260

260-
DXASSERT(false, "Unable to fill unexpected ComponentType");
261-
return std::vector<float>();
261+
switch (CompType) {
262+
case ComponentType::F32:
263+
return Floats;
264+
case ComponentType::I32:
265+
return Ints;
266+
case ComponentType::F16:
267+
return Halfs;
268+
default:
269+
DXASSERT(false, "Unable to fill unexpected ComponentType");
270+
return Floats;
271+
}
262272
}
263273

264274
static void logCompiledButSkipping() {
@@ -429,7 +439,7 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
429439
return;
430440
}
431441

432-
auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
442+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1);
433443

434444
// Construct the ShaderOp: two UAV buffers, load from one, store to other.
435445
auto Op = createComputeOp(LoadStoreShader, Target.c_str(), "UAV(u0), UAV(u1)",
@@ -517,7 +527,7 @@ static void runSplatStore(ID3D12Device *Device,
517527
return;
518528
}
519529

520-
auto Expected = makeExpected(Params.CompType, NumElements, FillValue, false);
530+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, FillValue, false);
521531

522532
auto Op = createComputeOp(SplatStoreShader, Target.c_str(), "UAV(u0)",
523533
Args.c_str());
@@ -553,11 +563,13 @@ static const char ElementAccessShader[] = R"(
553563
RWByteAddressBuffer Output : register(u1);
554564
555565
// flatten the 2D index into a 1D index then scale by element size
566+
// Always store row-major and work it out in the test runner
556567
uint coordToByteOffset(uint2 coord) {
557-
return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE;
568+
return (coord.y * N_DIM + coord.x) * ELEM_SIZE;
558569
}
559570
560571
#ifndef EMULATE_TEST
572+
[WaveSize(4, 64)]
561573
[numthreads(NUMTHREADS, 1, 1)]
562574
void main(uint threadIndex : SV_GroupIndex) {
563575
__builtin_LinAlgMatrix
@@ -605,8 +617,7 @@ static void runElementAccess(ID3D12Device *Device,
605617
const size_t NumThreads = Params.NumThreads;
606618
const size_t InputBufSize = Params.totalBytes();
607619
const size_t ElementSize = elementSize(Params.CompType);
608-
const size_t MajorDim =
609-
Params.Layout == LinalgMatrixLayout::RowMajor ? Params.M : Params.N;
620+
610621
// Output: ElementSize bytes per element
611622
// 1 element for each mat idx
612623
// 1 uint for each thread's length
@@ -618,7 +629,6 @@ static void runElementAccess(ID3D12Device *Device,
618629
Target = "cs_6_8";
619630

620631
std::stringstream ExtraDefs;
621-
ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
622632
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
623633

624634
compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose);
@@ -628,7 +638,7 @@ static void runElementAccess(ID3D12Device *Device,
628638
return;
629639
}
630640

631-
auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
641+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1);
632642

633643
auto Op = createComputeOp(ElementAccessShader, Target.c_str(),
634644
"UAV(u0), UAV(u1)", Args.c_str());
@@ -674,7 +684,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
674684
Params.Use = MatrixUse::Accumulator;
675685
Params.Scope = MatrixScope::Wave;
676686
Params.Layout = LinalgMatrixLayout::RowMajor;
677-
Params.NumThreads = 4;
687+
Params.NumThreads = 64;
678688
Params.Enable16Bit = true;
679689
Params.EmulateTest = EmulateTest;
680690
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);

0 commit comments

Comments
 (0)