Skip to content

Commit eb8a8fc

Browse files
committed
changes for ddbc_bindings
1 parent cb53741 commit eb8a8fc

1 file changed

Lines changed: 107 additions & 41 deletions

File tree

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 107 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#include <string>
1313
#include <utility> // std::forward
1414

15+
// Replace std::filesystem usage with Windows-specific headers
16+
#include <shlwapi.h>
17+
#pragma comment(lib, "shlwapi.lib")
18+
1519
#include <pybind11/chrono.h>
1620
#include <pybind11/complex.h>
1721
#include <pybind11/functional.h>
@@ -38,6 +42,11 @@ using namespace pybind11::literals;
3842
case x: \
3943
return #x
4044

45+
// Architecture-specific defines
46+
#ifndef ARCHITECTURE
47+
#define ARCHITECTURE "win64" // Default to win64 if not defined during compilation
48+
#endif
49+
4150
//-------------------------------------------------------------------------------------------------
4251
// Class definitions
4352
//-------------------------------------------------------------------------------------------------
@@ -201,6 +210,18 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr;
201210
// Diagnostic APIs
202211
SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr;
203212

213+
// Move GetModuleDirectory outside namespace to resolve ambiguity
214+
std::string GetModuleDirectory() {
215+
py::object module = py::module::import("mssql_python");
216+
py::object module_path = module.attr("__file__");
217+
std::string module_file = module_path.cast<std::string>();
218+
219+
char path[MAX_PATH];
220+
strncpy_s(path, MAX_PATH, module_file.c_str(), module_file.length());
221+
PathRemoveFileSpecA(path);
222+
return std::string(path);
223+
}
224+
204225
// Smart wrapper around SQLHANDLE
205226
class SqlHandle {
206227
public:
@@ -247,48 +268,59 @@ void ThrowStdException(const std::string& message) { throw std::runtime_error(me
247268
// Helper to load the driver
248269
// TODO: We don't need to do explicit linking using LoadLibrary. We can just use implicit
249270
// linking to load this DLL. It will simplify the code a lot.
250-
void LoadDriverOrThrowException() {
251-
HMODULE hDdbcModule;
252-
wchar_t ddbcModulePath[MAX_PATH];
253-
// Get the path to DDBC module:
254-
// GetModuleHandleExW returns a handle to current shared library (ddbc_bindings.pyd) given a
255-
// function from the library (LoadDriverOrThrowException). GetModuleFileNameW takes in the
256-
// library handle (hDdbcModule) & returns the full path to this library (ddbcModulePath)
257-
if (GetModuleHandleExW(
258-
GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
259-
(LPWSTR)&LoadDriverOrThrowException, &hDdbcModule) &&
260-
GetModuleFileNameW(hDdbcModule, ddbcModulePath, MAX_PATH)) {
261-
// Look for last occurence of '\' in the path and set it to null
262-
wchar_t* lastBackSlash = wcsrchr(ddbcModulePath, L'\\');
263-
if (lastBackSlash == nullptr) {
264-
LOG("Invalid DDBC module path - %S", ddbcModulePath);
265-
ThrowStdException("Failed to load driver");
266-
}
267-
*lastBackSlash = 0;
268-
} else {
269-
LOG("Failed to get DDBC module path. Error code - %d", GetLastError());
270-
ThrowStdException("Failed to load driver");
271+
std::wstring LoadDriverOrThrowException(const std::wstring& modulePath = L"") {
272+
std::wstring ddbcModulePath = modulePath;
273+
if (ddbcModulePath.empty()) {
274+
// Get the module path if not provided
275+
std::string path = GetModuleDirectory();
276+
ddbcModulePath = std::wstring(path.begin(), path.end());
271277
}
272278

273-
// Preload mssql-auth.dll from the same path if available
274-
// TODO: Only load mssql-auth.dll if using Entra ID Authentication modes (Active Directory modes)
275-
std::wstring authDllDir = std::wstring(ddbcModulePath) + L"\\libs\\win\\mssql-auth.dll";
276-
HMODULE hAuthModule = LoadLibraryW(authDllDir.c_str());
277-
if (hAuthModule) {
278-
LOG("Authentication library loaded successfully from - {}", authDllDir.c_str());
279+
std::wstring dllDir = ddbcModulePath;
280+
dllDir += L"\\libs\\";
281+
282+
// Convert ARCHITECTURE macro to wstring
283+
std::wstring archStr(ARCHITECTURE, ARCHITECTURE + strlen(ARCHITECTURE));
284+
285+
// Map architecture identifiers to correct subdirectory names
286+
std::wstring archDir;
287+
if (archStr == L"win64" || archStr == L"amd64" || archStr == L"x64") {
288+
archDir = L"x64";
289+
} else if (archStr == L"arm64") {
290+
archDir = L"arm64";
279291
} else {
280-
LOG("Note: Authentication library not found at - {}. This is OK if you're not using Entra ID Authentication.", authDllDir.c_str());
292+
archDir = L"x86";
281293
}
282-
283-
// Look for msodbcsql18.dll in a path relative to DDBC module
284-
std::wstring dllDir = std::wstring(ddbcModulePath) + L"\\libs\\win\\msodbcsql18.dll";
294+
dllDir += archDir;
295+
dllDir += L"\\msodbcsql18.dll";
296+
297+
// Convert wstring to string for logging
298+
std::string dllDirStr(dllDir.begin(), dllDir.end());
299+
LOG("Attempting to load driver from - {}", dllDirStr);
300+
285301
HMODULE hModule = LoadLibraryW(dllDir.c_str());
286302
if (!hModule) {
287-
LOG("LoadLibraryW failed to load driver from - %S", dllDir.c_str());
288-
ThrowStdException("Failed to load driver");
303+
// Failed to load the DLL, get the error message
304+
DWORD error = GetLastError();
305+
char* messageBuffer = nullptr;
306+
size_t size = FormatMessageA(
307+
FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
308+
NULL,
309+
error,
310+
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
311+
(LPSTR)&messageBuffer,
312+
0,
313+
NULL
314+
);
315+
std::string errorMessage = messageBuffer ? std::string(messageBuffer, size) : "Unknown error";
316+
LocalFree(messageBuffer);
317+
318+
// Log the error message
319+
LOG("Failed to load the driver with error code: {} - {}", error, errorMessage);
320+
ThrowStdException("Failed to load the ODBC driver. Please check that it is installed correctly.");
289321
}
290-
LOG("Driver loaded successfully from - {}", dllDir.c_str());
291322

323+
// If we got here, we've successfully loaded the DLL. Now get the function pointers.
292324
// Environment and handle function loading
293325
SQLAllocHandle_ptr = (SQLAllocHandleFunc)GetProcAddress(hModule, "SQLAllocHandle");
294326
SQLSetEnvAttr_ptr = (SQLSetEnvAttrFunc)GetProcAddress(hModule, "SQLSetEnvAttr");
@@ -331,16 +363,18 @@ void LoadDriverOrThrowException() {
331363
SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr &&
332364
SQLExecDirect_ptr && SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr &&
333365
SQLRowCount_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr && SQLFetch_ptr &&
334-
SQLFetchScroll_ptr && SQLGetData_ptr && SQLNumResultCols_ptr &&
335-
SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr &&
336-
SQLColAttribute_ptr && SQLEndTran_ptr && SQLFreeHandle_ptr &&
337-
SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr;
366+
SQLFetchScroll_ptr && SQLGetData_ptr && SQLNumResultCols_ptr &&
367+
SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr &&
368+
SQLColAttribute_ptr && SQLEndTran_ptr && SQLFreeHandle_ptr &&
369+
SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr;
338370

339371
if (!success) {
340-
LOG("Failed to load required function pointers from driver - %S", dllDir.c_str());
372+
LOG("Failed to load required function pointers from driver - {}", dllDirStr);
341373
ThrowStdException("Failed to load required function pointers from driver");
342374
}
343-
LOG("Sucessfully loaded function pointers from driver");
375+
LOG("Successfully loaded function pointers from driver");
376+
377+
return dllDir;
344378
}
345379

346380
const char* GetSqlCTypeAsString(const SQLSMALLINT cType) {
@@ -369,7 +403,7 @@ const char* GetSqlCTypeAsString(const SQLSMALLINT cType) {
369403
STRINGIFY_FOR_CASE(SQL_C_GUID);
370404
STRINGIFY_FOR_CASE(SQL_C_DEFAULT);
371405
default:
372-
return "Unkown";
406+
return "Unknown";
373407
}
374408
}
375409

@@ -1988,29 +2022,50 @@ SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) {
19882022
return rowCount;
19892023
}
19902024

2025+
// Architecture-specific defines
2026+
#ifndef ARCHITECTURE
2027+
#define ARCHITECTURE "win64" // Default to win64 if not defined during compilation
2028+
#endif
2029+
19912030
// Functions/data to be exposed to Python as a part of ddbc_bindings module
19922031
PYBIND11_MODULE(ddbc_bindings, m) {
19932032
m.doc() = "msodbcsql driver api bindings for Python";
2033+
2034+
// Add architecture information as module attribute
2035+
m.attr("__architecture__") = ARCHITECTURE;
2036+
2037+
// Expose architecture-specific constants
2038+
m.attr("ARCHITECTURE") = ARCHITECTURE;
2039+
2040+
// Expose the C++ functions to Python
19942041
m.def("ThrowStdException", &ThrowStdException);
2042+
2043+
// Define parameter info class
19952044
py::class_<ParamInfo>(m, "ParamInfo")
19962045
.def(py::init<>())
19972046
.def_readwrite("inputOutputType", &ParamInfo::inputOutputType)
19982047
.def_readwrite("paramCType", &ParamInfo::paramCType)
19992048
.def_readwrite("paramSQLType", &ParamInfo::paramSQLType)
20002049
.def_readwrite("columnSize", &ParamInfo::columnSize)
20012050
.def_readwrite("decimalDigits", &ParamInfo::decimalDigits);
2051+
2052+
// Define numeric data class
20022053
py::class_<NumericData>(m, "NumericData")
20032054
.def(py::init<>())
20042055
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, std::uint64_t>())
20052056
.def_readwrite("precision", &NumericData::precision)
20062057
.def_readwrite("scale", &NumericData::scale)
20072058
.def_readwrite("sign", &NumericData::sign)
20082059
.def_readwrite("val", &NumericData::val);
2060+
2061+
// Define error info class
20092062
py::class_<ErrorInfo>(m, "ErrorInfo")
20102063
.def_readwrite("sqlState", &ErrorInfo::sqlState)
20112064
.def_readwrite("ddbcErrorMsg", &ErrorInfo::ddbcErrorMsg);
2065+
20122066
py::class_<SqlHandle, SqlHandlePtr>(m, "SqlHandle")
20132067
.def("free", &SqlHandle::free);
2068+
20142069
m.def("DDBCSQLAllocHandle", [](SQLSMALLINT HandleType, SqlHandlePtr InputHandle = nullptr) {
20152070
SqlHandlePtr OutputHandle;
20162071
SQLRETURN rc = SQLAllocHandle_wrap(HandleType, InputHandle, OutputHandle);
@@ -2045,4 +2100,15 @@ PYBIND11_MODULE(ddbc_bindings, m) {
20452100
m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle");
20462101
m.def("DDBCSQLDisconnect", &SQLDisconnect_wrap, "Disconnect from a data source");
20472102
m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors");
2103+
2104+
// Add a version attribute
2105+
m.attr("__version__") = "1.0.0";
2106+
2107+
try {
2108+
// Try loading the ODBC driver when the module is imported
2109+
LoadDriverOrThrowException();
2110+
} catch (const std::exception& e) {
2111+
// Log the error but don't throw - let the error happen when functions are called
2112+
LOG("Failed to load ODBC driver during module initialization: {}", e.what());
2113+
}
20482114
}

0 commit comments

Comments
 (0)