Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 34 additions & 137 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ static std::string buildCompilerArgs(const MatrixParams &Params,
SS << " -DELEM_TYPE=uint";
break;
}
if (Params.EmulateTest)
SS << " -DEMULATE_TEST";
if (Params.Enable16Bit)
SS << " -enable-16bit-types";
if (ExtraDefines)
Expand Down Expand Up @@ -282,12 +280,6 @@ static VariantCompType makeExpected(ComponentType CompType, int32_t M,
}
}

static void logCompiledButSkipping() {
hlsl_test::LogCommentFmt(
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
}

class DxilConf_SM610_LinAlg {
public:
BEGIN_TEST_CLASS(DxilConf_SM610_LinAlg)
Expand Down Expand Up @@ -316,36 +308,16 @@ class DxilConf_SM610_LinAlg {
TEST_METHOD(ElementAccess_Wave_16x16_F16);

private:
D3D_SHADER_MODEL createDevice();

CComPtr<ID3D12Device> D3DDevice;
dxc::SpecificDllLoader DxcSupport;
bool VerboseLogging = false;
bool EmulateTest = false;
bool Initialized = false;
bool CompileOnly = false;
std::optional<D3D12SDKSelector> D3D12SDK;

WEX::TestExecution::SetVerifyOutput VerifyOutput{
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures};
};

/// Attempts to create a device. If shaders are being emulated then a SM6.8
/// device is attempted. Otherwise a SM6.10 device is attempted
D3D_SHADER_MODEL DxilConf_SM610_LinAlg::createDevice() {
if (EmulateTest) {
if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_8, false))
return D3D_SHADER_MODEL_6_8;

return D3D_SHADER_MODEL_NONE;
}

if (D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false))
return D3D_SHADER_MODEL_6_10;

return D3D_SHADER_MODEL_NONE;
}

bool DxilConf_SM610_LinAlg::setupClass() {
if (!Initialized) {
Initialized = true;
Expand All @@ -354,28 +326,18 @@ bool DxilConf_SM610_LinAlg::setupClass() {
D3D12SDK = D3D12SDKSelector();
WEX::TestExecution::RuntimeParameters::TryGetValue(L"VerboseLogging",
VerboseLogging);
WEX::TestExecution::RuntimeParameters::TryGetValue(L"EmulateTest",
EmulateTest);
D3D_SHADER_MODEL SupportedSM = createDevice();

if (EmulateTest) {
hlsl_test::LogWarningFmt(L"EmulateTest flag set. Tests are NOT REAL");
if (SupportedSM != D3D_SHADER_MODEL_6_8) {
hlsl_test::LogErrorFmt(
L"Device creation failed. Expected a driver supporting SM6.8");
return false;
}
}

if (!D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false)) {
#ifdef _HLK_CONF
if (SupportedSM != D3D_SHADER_MODEL_6_10) {
hlsl_test::LogErrorFmt(
L"Device creation failed. Expected a driver supporting SM6.10");
#else
hlsl_test::LogWarningFmt(
L"Device creation failed. Expected a driver supporting SM6.10");
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
#endif
return false;
}
#endif

CompileOnly = SupportedSM == D3D_SHADER_MODEL_NONE;
}

return true;
Expand All @@ -387,27 +349,24 @@ bool DxilConf_SM610_LinAlg::setupMethod() {
if (D3DDevice && D3DDevice->GetDeviceRemovedReason() == S_OK)
return true;

// Device is expected to be null. No point in recreating it
if (CompileOnly)
return true;

hlsl_test::LogCommentFmt(L"Device was lost!");
D3DDevice.Release();

hlsl_test::LogCommentFmt(L"Recreating device");

// !CompileOnly implies we expect it to succeeded
return createDevice() != D3D_SHADER_MODEL_NONE;
return D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false);
}

static const char LoadStoreShader[] = R"(
RWByteAddressBuffer Input : register(u0);
RWByteAddressBuffer Output : register(u1);

#ifndef EMULATE_TEST
[WaveSize(4, 64)]
[numthreads(NUMTHREADS, 1, 1)]
void main() {
void main(uint threadID : SV_GroupIndex) {
if (WaveReadLaneFirst(threadID) != 0)
return;

__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
Mat;
Expand All @@ -416,45 +375,26 @@ static const char LoadStoreShader[] = R"(
__builtin_LinAlg_MatrixStoreToDescriptor(
Mat, Output, OFFSET, STRIDE, LAYOUT, 128);
}
#else
[numthreads(NUMTHREADS, 1, 1)]
void main() {
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
}
}
#endif
)";

static void runLoadStoreRoundtrip(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose,
bool CompileOnly) {
const MatrixParams &Params, bool Verbose) {
const size_t NumElements = Params.totalElements();
const size_t BufferSize = Params.totalBytes();

std::string Target = "cs_6_10";
if (Params.EmulateTest)
Target = "cs_6_8";

// TODO: these should be varied by test to ensure full coverage
std::stringstream ExtraDefs;
ExtraDefs << " -DOFFSET=" << 0;

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

// Always verify the shader compiles.
compileShader(DxcSupport, LoadStoreShader, Target.c_str(), Args, Verbose);

if (CompileOnly) {
logCompiledButSkipping();
return;
}
compileShader(DxcSupport, LoadStoreShader, "cs_6_10", Args, Verbose);

auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1);

// Construct the ShaderOp: two UAV buffers, load from one, store to other.
auto Op = createComputeOp(LoadStoreShader, Target.c_str(), "UAV(u0), UAV(u1)",
auto Op = createComputeOp(LoadStoreShader, "cs_6_10", "UAV(u0), UAV(u1)",
Args.c_str());
addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname");
addUAVBuffer(Op.get(), "Output", BufferSize, true);
Expand Down Expand Up @@ -487,64 +427,46 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 64;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging,
CompileOnly);
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging);
}

static const char SplatStoreShader[] = R"(
RWByteAddressBuffer Output : register(u0);

#ifndef EMULATE_TEST
[WaveSize(4, 64)]
[numthreads(NUMTHREADS, 1, 1)]
void main() {
void main(uint threadID : SV_GroupIndex) {
if (WaveReadLaneFirst(threadID) != 0)
return;

__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
Mat;
__builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
__builtin_LinAlg_MatrixStoreToDescriptor(
Mat, Output, 0, STRIDE, LAYOUT, 128);
}
#else
[numthreads(NUMTHREADS, 1, 1)]
void main() {
ELEM_TYPE fill = FILL_VALUE;
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, fill);
}
}
#endif
)";

static void runSplatStore(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, float FillValue,
bool Verbose, bool CompileOnly) {
bool Verbose) {
const size_t NumElements = Params.totalElements();
const size_t BufferSize = Params.totalBytes();
std::string Target = "cs_6_10";
if (Params.EmulateTest)
Target = "cs_6_8";

std::stringstream ExtraDefs;
ExtraDefs << "-DFILL_VALUE=" << FillValue;

std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

// Always verify the shader compiles.
compileShader(DxcSupport, SplatStoreShader, Target.c_str(), Args, Verbose);

if (CompileOnly) {
logCompiledButSkipping();
return;
}
compileShader(DxcSupport, SplatStoreShader, "cs_6_10", Args, Verbose);

auto Expected =
makeExpected(Params.CompType, Params.M, Params.N, FillValue, false);

auto Op = createComputeOp(SplatStoreShader, Target.c_str(), "UAV(u0)",
Args.c_str());
auto Op =
createComputeOp(SplatStoreShader, "cs_6_10", "UAV(u0)", Args.c_str());
addUAVBuffer(Op.get(), "Output", BufferSize, true);
addRootUAV(Op.get(), 0, "Output");

Expand All @@ -567,9 +489,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 64;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging,
CompileOnly);
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging);
}

static const char ElementAccessShader[] = R"(
Expand All @@ -582,10 +502,12 @@ static const char ElementAccessShader[] = R"(
return (coord.y * N_DIM + coord.x) * ELEM_SIZE;
}

#ifndef EMULATE_TEST
[WaveSize(4, 64)]
[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadIndex : SV_GroupIndex) {
void main(uint threadID : SV_GroupIndex) {
if (WaveReadLaneFirst(threadID) != 0)
return;

__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
Mat;
Expand All @@ -603,30 +525,15 @@ static const char ElementAccessShader[] = R"(

// Save the matrix length that this thread saw. The length is written
// to the output right after the matrix, offset by the thread index
uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadID * sizeof(uint));
uint Len = __builtin_LinAlg_MatrixLength(Mat);
Output.Store<uint>(LenIdx, Len);
}
#else
[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadIndex : SV_GroupIndex) {
uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
Output.Store<uint>(LenIdx, M_DIM * N_DIM / NUMTHREADS);

if (threadIndex != 0)
return;

for (uint I = 0; I < M_DIM*N_DIM; ++I) {
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, Input.Load<ELEM_TYPE>(I*ELEM_SIZE));
}
}
#endif
)";

static void runElementAccess(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose,
bool CompileOnly) {
const MatrixParams &Params, bool Verbose) {
const size_t NumElements = Params.totalElements();
const size_t NumThreads = Params.NumThreads;
const size_t InputBufSize = Params.totalBytes();
Expand All @@ -638,24 +545,15 @@ static void runElementAccess(ID3D12Device *Device,
const size_t OutputBufSize =
NumElements * ElementSize + NumThreads * sizeof(uint32_t);

std::string Target = "cs_6_10";
if (Params.EmulateTest)
Target = "cs_6_8";

std::stringstream ExtraDefs;
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose);

if (CompileOnly) {
logCompiledButSkipping();
return;
}
compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose);

auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1);

auto Op = createComputeOp(ElementAccessShader, Target.c_str(),
"UAV(u0), UAV(u1)", Args.c_str());
auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)",
Args.c_str());
addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname");
addUAVBuffer(Op.get(), "Output", OutputBufSize, true);
addRootUAV(Op.get(), 0, "Input");
Expand Down Expand Up @@ -700,8 +598,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 64;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
}

} // namespace LinAlg
Loading