1212#include < string>
1313#include < utility> // std::forward
1414
15+ // Replace std::filesystem usage with Windows-specific headers
16+ #include < shlwapi.h>
17+ #pragma comment(lib, "shlwapi.lib")
18+
1519#include < pybind11/chrono.h>
1620#include < pybind11/complex.h>
1721#include < pybind11/functional.h>
@@ -38,6 +42,11 @@ using namespace pybind11::literals;
3842 case x: \
3943 return #x
4044
45+ // Architecture-specific defines
46+ #ifndef ARCHITECTURE
47+ #define ARCHITECTURE " win64" // Default to win64 if not defined during compilation
48+ #endif
49+
4150// -------------------------------------------------------------------------------------------------
4251// Class definitions
4352// -------------------------------------------------------------------------------------------------
@@ -201,6 +210,18 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr;
201210// Diagnostic APIs
202211SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr ;
203212
213+ // Move GetModuleDirectory outside namespace to resolve ambiguity
214+ std::string GetModuleDirectory () {
215+ py::object module = py::module::import (" mssql_python" );
216+ py::object module_path = module .attr (" __file__" );
217+ std::string module_file = module_path.cast <std::string>();
218+
219+ char path[MAX_PATH];
220+ strncpy_s (path, MAX_PATH, module_file.c_str (), module_file.length ());
221+ PathRemoveFileSpecA (path);
222+ return std::string (path);
223+ }
224+
204225// Smart wrapper around SQLHANDLE
205226class SqlHandle {
206227public:
@@ -247,48 +268,59 @@ void ThrowStdException(const std::string& message) { throw std::runtime_error(me
247268// Helper to load the driver
248269// TODO: We don't need to do explicit linking using LoadLibrary. We can just use implicit
249270// linking to load this DLL. It will simplify the code a lot.
250- void LoadDriverOrThrowException () {
251- HMODULE hDdbcModule;
252- wchar_t ddbcModulePath[MAX_PATH];
253- // Get the path to DDBC module:
254- // GetModuleHandleExW returns a handle to current shared library (ddbc_bindings.pyd) given a
255- // function from the library (LoadDriverOrThrowException). GetModuleFileNameW takes in the
256- // library handle (hDdbcModule) & returns the full path to this library (ddbcModulePath)
257- if (GetModuleHandleExW (
258- GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
259- (LPWSTR)&LoadDriverOrThrowException, &hDdbcModule) &&
260- GetModuleFileNameW (hDdbcModule, ddbcModulePath, MAX_PATH)) {
261- // Look for last occurence of '\' in the path and set it to null
262- wchar_t * lastBackSlash = wcsrchr (ddbcModulePath, L' \\ ' );
263- if (lastBackSlash == nullptr ) {
264- LOG (" Invalid DDBC module path - %S" , ddbcModulePath);
265- ThrowStdException (" Failed to load driver" );
266- }
267- *lastBackSlash = 0 ;
268- } else {
269- LOG (" Failed to get DDBC module path. Error code - %d" , GetLastError ());
270- ThrowStdException (" Failed to load driver" );
271+ std::wstring LoadDriverOrThrowException (const std::wstring& modulePath = L" " ) {
272+ std::wstring ddbcModulePath = modulePath;
273+ if (ddbcModulePath.empty ()) {
274+ // Get the module path if not provided
275+ std::string path = GetModuleDirectory ();
276+ ddbcModulePath = std::wstring (path.begin (), path.end ());
271277 }
272278
273- // Preload mssql-auth.dll from the same path if available
274- // TODO: Only load mssql-auth.dll if using Entra ID Authentication modes (Active Directory modes)
275- std::wstring authDllDir = std::wstring (ddbcModulePath) + L" \\ libs\\ win\\ mssql-auth.dll" ;
276- HMODULE hAuthModule = LoadLibraryW (authDllDir.c_str ());
277- if (hAuthModule) {
278- LOG (" Authentication library loaded successfully from - {}" , authDllDir.c_str ());
279+ std::wstring dllDir = ddbcModulePath;
280+ dllDir += L" \\ libs\\ " ;
281+
282+ // Convert ARCHITECTURE macro to wstring
283+ std::wstring archStr (ARCHITECTURE, ARCHITECTURE + strlen (ARCHITECTURE));
284+
285+ // Map architecture identifiers to correct subdirectory names
286+ std::wstring archDir;
287+ if (archStr == L" win64" || archStr == L" amd64" || archStr == L" x64" ) {
288+ archDir = L" x64" ;
289+ } else if (archStr == L" arm64" ) {
290+ archDir = L" arm64" ;
279291 } else {
280- LOG ( " Note: Authentication library not found at - {}. This is OK if you're not using Entra ID Authentication. " , authDllDir. c_str ()) ;
292+ archDir = L" x86 " ;
281293 }
282-
283- // Look for msodbcsql18.dll in a path relative to DDBC module
284- std::wstring dllDir = std::wstring (ddbcModulePath) + L" \\ libs\\ win\\ msodbcsql18.dll" ;
294+ dllDir += archDir;
295+ dllDir += L" \\ msodbcsql18.dll" ;
296+
297+ // Convert wstring to string for logging
298+ std::string dllDirStr (dllDir.begin (), dllDir.end ());
299+ LOG (" Attempting to load driver from - {}" , dllDirStr);
300+
285301 HMODULE hModule = LoadLibraryW (dllDir.c_str ());
286302 if (!hModule) {
287- LOG (" LoadLibraryW failed to load driver from - %S" , dllDir.c_str ());
288- ThrowStdException (" Failed to load driver" );
303+ // Failed to load the DLL, get the error message
304+ DWORD error = GetLastError ();
305+ char * messageBuffer = nullptr ;
306+ size_t size = FormatMessageA (
307+ FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
308+ NULL ,
309+ error,
310+ MAKELANGID (LANG_NEUTRAL, SUBLANG_DEFAULT),
311+ (LPSTR)&messageBuffer,
312+ 0 ,
313+ NULL
314+ );
315+ std::string errorMessage = messageBuffer ? std::string (messageBuffer, size) : " Unknown error" ;
316+ LocalFree (messageBuffer);
317+
318+ // Log the error message
319+ LOG (" Failed to load the driver with error code: {} - {}" , error, errorMessage);
320+ ThrowStdException (" Failed to load the ODBC driver. Please check that it is installed correctly." );
289321 }
290- LOG (" Driver loaded successfully from - {}" , dllDir.c_str ());
291322
323+ // If we got here, we've successfully loaded the DLL. Now get the function pointers.
292324 // Environment and handle function loading
293325 SQLAllocHandle_ptr = (SQLAllocHandleFunc)GetProcAddress (hModule, " SQLAllocHandle" );
294326 SQLSetEnvAttr_ptr = (SQLSetEnvAttrFunc)GetProcAddress (hModule, " SQLSetEnvAttr" );
@@ -331,16 +363,18 @@ void LoadDriverOrThrowException() {
331363 SQLSetStmtAttr_ptr && SQLGetConnectAttr_ptr && SQLDriverConnect_ptr &&
332364 SQLExecDirect_ptr && SQLPrepare_ptr && SQLBindParameter_ptr && SQLExecute_ptr &&
333365 SQLRowCount_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr && SQLFetch_ptr &&
334- SQLFetchScroll_ptr && SQLGetData_ptr && SQLNumResultCols_ptr &&
335- SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr &&
336- SQLColAttribute_ptr && SQLEndTran_ptr && SQLFreeHandle_ptr &&
337- SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr;
366+ SQLFetchScroll_ptr && SQLGetData_ptr && SQLNumResultCols_ptr &&
367+ SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr &&
368+ SQLColAttribute_ptr && SQLEndTran_ptr && SQLFreeHandle_ptr &&
369+ SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr;
338370
339371 if (!success) {
340- LOG (" Failed to load required function pointers from driver - %S " , dllDir. c_str () );
372+ LOG (" Failed to load required function pointers from driver - {} " , dllDirStr );
341373 ThrowStdException (" Failed to load required function pointers from driver" );
342374 }
343- LOG (" Sucessfully loaded function pointers from driver" );
375+ LOG (" Successfully loaded function pointers from driver" );
376+
377+ return dllDir;
344378}
345379
346380const char * GetSqlCTypeAsString (const SQLSMALLINT cType) {
@@ -369,7 +403,7 @@ const char* GetSqlCTypeAsString(const SQLSMALLINT cType) {
369403 STRINGIFY_FOR_CASE (SQL_C_GUID);
370404 STRINGIFY_FOR_CASE (SQL_C_DEFAULT);
371405 default :
372- return " Unkown " ;
406+ return " Unknown " ;
373407 }
374408}
375409
@@ -1988,29 +2022,50 @@ SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) {
19882022 return rowCount;
19892023}
19902024
2025+ // Architecture-specific defines
2026+ #ifndef ARCHITECTURE
2027+ #define ARCHITECTURE " win64" // Default to win64 if not defined during compilation
2028+ #endif
2029+
19912030// Functions/data to be exposed to Python as a part of ddbc_bindings module
19922031PYBIND11_MODULE (ddbc_bindings, m) {
19932032 m.doc () = " msodbcsql driver api bindings for Python" ;
2033+
2034+ // Add architecture information as module attribute
2035+ m.attr (" __architecture__" ) = ARCHITECTURE;
2036+
2037+ // Expose architecture-specific constants
2038+ m.attr (" ARCHITECTURE" ) = ARCHITECTURE;
2039+
2040+ // Expose the C++ functions to Python
19942041 m.def (" ThrowStdException" , &ThrowStdException);
2042+
2043+ // Define parameter info class
19952044 py::class_<ParamInfo>(m, " ParamInfo" )
19962045 .def (py::init<>())
19972046 .def_readwrite (" inputOutputType" , &ParamInfo::inputOutputType)
19982047 .def_readwrite (" paramCType" , &ParamInfo::paramCType)
19992048 .def_readwrite (" paramSQLType" , &ParamInfo::paramSQLType)
20002049 .def_readwrite (" columnSize" , &ParamInfo::columnSize)
20012050 .def_readwrite (" decimalDigits" , &ParamInfo::decimalDigits);
2051+
2052+ // Define numeric data class
20022053 py::class_<NumericData>(m, " NumericData" )
20032054 .def (py::init<>())
20042055 .def (py::init<SQLCHAR, SQLSCHAR, SQLCHAR, std::uint64_t >())
20052056 .def_readwrite (" precision" , &NumericData::precision)
20062057 .def_readwrite (" scale" , &NumericData::scale)
20072058 .def_readwrite (" sign" , &NumericData::sign)
20082059 .def_readwrite (" val" , &NumericData::val);
2060+
2061+ // Define error info class
20092062 py::class_<ErrorInfo>(m, " ErrorInfo" )
20102063 .def_readwrite (" sqlState" , &ErrorInfo::sqlState)
20112064 .def_readwrite (" ddbcErrorMsg" , &ErrorInfo::ddbcErrorMsg);
2065+
20122066 py::class_<SqlHandle, SqlHandlePtr>(m, " SqlHandle" )
20132067 .def (" free" , &SqlHandle::free);
2068+
20142069 m.def (" DDBCSQLAllocHandle" , [](SQLSMALLINT HandleType, SqlHandlePtr InputHandle = nullptr ) {
20152070 SqlHandlePtr OutputHandle;
20162071 SQLRETURN rc = SQLAllocHandle_wrap (HandleType, InputHandle, OutputHandle);
@@ -2045,4 +2100,15 @@ PYBIND11_MODULE(ddbc_bindings, m) {
20452100 m.def (" DDBCSQLFreeHandle" , &SQLFreeHandle_wrap, " Free a handle" );
20462101 m.def (" DDBCSQLDisconnect" , &SQLDisconnect_wrap, " Disconnect from a data source" );
20472102 m.def (" DDBCSQLCheckError" , &SQLCheckError_Wrap, " Check for driver errors" );
2103+
2104+ // Add a version attribute
2105+ m.attr (" __version__" ) = " 1.0.0" ;
2106+
2107+ try {
2108+ // Try loading the ODBC driver when the module is imported
2109+ LoadDriverOrThrowException ();
2110+ } catch (const std::exception& e) {
2111+ // Log the error but don't throw - let the error happen when functions are called
2112+ LOG (" Failed to load ODBC driver during module initialization: {}" , e.what ());
2113+ }
20482114}
0 commit comments