Skip to content

Commit fda52ff

Browse files
committed
FEAT: Access Token Login
1 parent d2b82b5 commit fda52ff

4 files changed

Lines changed: 129 additions & 37 deletions

File tree

mssql_python/connection.py

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Connection:
3333
close() -> None:
3434
"""
3535

36-
def __init__(self, connection_str: str, autocommit: bool = False, **kwargs) -> None:
36+
def __init__(self, connection_str: str, autocommit: bool = False, attrs_before: dict = {}, **kwargs) -> None:
3737
"""
3838
Initialize the connection object with the specified connection string and parameters.
3939
@@ -58,8 +58,9 @@ def __init__(self, connection_str: str, autocommit: bool = False, **kwargs) -> N
5858
self.connection_str = self._construct_connection_string(
5959
connection_str, **kwargs
6060
)
61+
self._attrs_before = attrs_before
62+
self._autocommit = autocommit # Initialize _autocommit before calling _initializer
6163
self._initializer()
62-
self._autocommit = autocommit
6364
self.setautocommit(autocommit)
6465

6566
def _construct_connection_string(self, connection_str: str, **kwargs) -> str:
@@ -76,23 +77,34 @@ def _construct_connection_string(self, connection_str: str, **kwargs) -> str:
7677
"""
7778
# Add the driver attribute to the connection string
7879
conn_str = add_driver_to_connection_str(connection_str)
79-
# Add additional key-value pairs to the connection string
80-
for key, value in kwargs.items():
81-
if key.lower() == "host":
82-
key = "Server"
83-
elif key.lower() == "user":
84-
key = "Uid"
85-
elif key.lower() == "password":
86-
key = "Pwd"
87-
elif key.lower() == "database":
88-
key = "Database"
89-
elif key.lower() == "encrypt":
90-
key = "Encrypt"
91-
elif key.lower() == "trust_server_certificate":
92-
key = "TrustServerCertificate"
93-
else:
94-
continue
95-
conn_str += f"{key}={value};"
80+
81+
# Check if access token authentication is being used
82+
if "attrs_before" in kwargs:
83+
# Skip adding Uid and Pwd for access token authentication
84+
if ENABLE_LOGGING:
85+
logger.info("Using access token authentication. Skipping Uid and Pwd.")
86+
else:
87+
# Add additional key-value pairs to the connection string
88+
for key, value in kwargs.items():
89+
if key.lower() == "host":
90+
key = "Server"
91+
elif key.lower() == "user":
92+
key = "Uid"
93+
elif key.lower() == "password":
94+
key = "Pwd"
95+
elif key.lower() == "database":
96+
key = "Database"
97+
elif key.lower() == "encrypt":
98+
key = "Encrypt"
99+
elif key.lower() == "trust_server_certificate":
100+
key = "TrustServerCertificate"
101+
else:
102+
continue
103+
conn_str += f"{key}={value};"
104+
105+
if ENABLE_LOGGING:
106+
logger.info("Final connection string: %s", conn_str)
107+
96108
return conn_str
97109

98110
def _is_closed(self) -> bool:
@@ -103,7 +115,7 @@ def _is_closed(self) -> bool:
103115
bool: True if the connection is closed, False otherwise.
104116
"""
105117
return self.hdbc is None
106-
118+
107119
def _initializer(self) -> None:
108120
"""
109121
Initialize the environment and connection handles.
@@ -115,9 +127,41 @@ def _initializer(self) -> None:
115127
self._allocate_environment_handle()
116128
self._set_environment_attributes()
117129
self._allocate_connection_handle()
118-
self._set_connection_attributes()
130+
if self._attrs_before != {}:
131+
self._apply_attrs_before() # Apply pre-connection attributes
132+
if self._autocommit:
133+
self._set_connection_attributes(
134+
ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value,
135+
ddbc_sql_const.SQL_AUTOCOMMIT_ON.value,
136+
)
119137
self._connect_to_db()
120138

139+
140+
def _apply_attrs_before(self):
141+
"""
142+
Apply a dictionary of attributes to the database connection before connecting.
143+
144+
Returns:
145+
bool: True if all attributes were successfully applied, False otherwise.
146+
"""
147+
strencoding = "utf-16le"
148+
149+
if ENABLE_LOGGING:
150+
logger.info("Applying attrs_before: %s", self._attrs_before)
151+
152+
for key, value in self._attrs_before.items():
153+
if isinstance(key, int):
154+
ikey = key
155+
elif isinstance(key, str) and key.isdigit():
156+
ikey = int(key)
157+
else:
158+
raise TypeError(f"Unsupported key type: {type(key).__name__}")
159+
160+
if not self._set_connection_attributes(ikey, value):
161+
return False
162+
163+
return True
164+
121165
def _allocate_environment_handle(self):
122166
"""
123167
Allocate the environment handle.
@@ -152,18 +196,25 @@ def _allocate_connection_handle(self):
152196
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, handle, ret)
153197
self.hdbc = handle
154198

155-
def _set_connection_attributes(self):
199+
def _set_connection_attributes(self, ikey: int, ivalue: any) -> None:
156200
"""
157201
Set the connection attributes before connecting.
202+
203+
Args:
204+
ikey (int): The attribute key to set.
205+
ivalue (Any): The value to set for the attribute. Can be bytes, bytearray, int, or unicode.
206+
vallen (int): The length of the value.
207+
208+
Raises:
209+
DatabaseError: If there is an error while setting the connection attribute.
158210
"""
159-
if self.autocommit:
160-
ret = ddbc_bindings.DDBCSQLSetConnectAttr(
161-
self.hdbc, # Using the wrapper class
162-
ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value,
163-
ddbc_sql_const.SQL_AUTOCOMMIT_ON.value,
164-
0
165-
)
166-
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret)
211+
212+
ret = ddbc_bindings.DDBCSQLSetConnectAttr(
213+
self.hdbc, # Connection handle
214+
ikey, # Attribute
215+
ivalue, # Value
216+
)
217+
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret)
167218

168219
def _connect_to_db(self) -> None:
169220
"""
@@ -224,7 +275,6 @@ def autocommit(self, value: bool) -> None:
224275
if value
225276
else ddbc_sql_const.SQL_AUTOCOMMIT_OFF.value
226277
), # Value
227-
0, # String length
228278
)
229279
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret)
230280
self._autocommit = value

mssql_python/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,5 @@ class ConstantsDDBC(Enum):
116116
SQL_C_WCHAR = -8
117117
SQL_NULLABLE = 1
118118
SQL_MAX_NUMERIC_LEN = 16
119+
SQL_IS_POINTER = -4
120+

mssql_python/db_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mssql_python.connection import Connection
77

88

9-
def connect(connection_str: str, autocommit: bool = True, **kwargs) -> Connection:
9+
def connect(connection_str: str, autocommit: bool = True, attrs_before: dict = {}, **kwargs) -> Connection:
1010
"""
1111
Constructor for creating a connection to the database.
1212
@@ -34,5 +34,5 @@ def connect(connection_str: str, autocommit: bool = True, **kwargs) -> Connectio
3434
be used to perform database operations such as executing queries, committing
3535
transactions, and closing the connection.
3636
"""
37-
conn = Connection(connection_str, autocommit=autocommit, **kwargs)
37+
conn = Connection(connection_str, autocommit=autocommit, attrs_before=attrs_before, **kwargs)
3838
return conn

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,18 +692,58 @@ SQLRETURN SQLSetEnvAttr_wrap(SqlHandlePtr EnvHandle, SQLINTEGER Attribute, intpt
692692
}
693693

694694
// Wrap SQLSetConnectAttr
695-
SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, intptr_t ValuePtr,
696-
SQLINTEGER StringLength) {
695+
SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute,
696+
py::object ValuePtr) {
697697
LOG("Set SQL Connection Attribute");
698698
if (!SQLSetConnectAttr_ptr) {
699699
LoadDriverOrThrowException();
700700
}
701701

702-
// TODO: Does ValuePtr need to be converted from Python to C++ object?
703-
SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, reinterpret_cast<SQLPOINTER>(ValuePtr), StringLength);
702+
// Print the type of ValuePtr and attribute value - helpful for debugging
703+
LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast<std::string>(), Attribute);
704+
705+
SQLPOINTER value = 0;
706+
SQLINTEGER length = 0;
707+
708+
if (py::isinstance<py::int_>(ValuePtr)) {
709+
// Handle integer values
710+
int intValue = ValuePtr.cast<int>();
711+
value = reinterpret_cast<SQLPOINTER>(intValue);
712+
length = SQL_IS_INTEGER; // Integer values don't require a length
713+
} else if (py::isinstance<py::str>(ValuePtr)) {
714+
// Handle Unicode string values
715+
static std::wstring unicodeValueBuffer;
716+
unicodeValueBuffer = ValuePtr.cast<std::wstring>();
717+
value = const_cast<SQLWCHAR*>(unicodeValueBuffer.c_str());
718+
length = SQL_NTS; // Indicates null-terminated string
719+
} else if (py::isinstance<py::bytes>(ValuePtr) || py::isinstance<py::bytearray>(ValuePtr)) {
720+
// Handle byte or bytearray values (like access tokens)
721+
// Store in static buffer to ensure memory remains valid during connection
722+
static std::vector<std::string> bytesBuffers;
723+
bytesBuffers.push_back(ValuePtr.cast<std::string>());
724+
value = const_cast<char*>(bytesBuffers.back().c_str());
725+
length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token)
726+
} else if (py::isinstance<py::list>(ValuePtr) || py::isinstance<py::tuple>(ValuePtr)) {
727+
// Handle list or tuple values
728+
LOG("ValuePtr is a sequence (list or tuple)");
729+
for (py::handle item : ValuePtr) {
730+
LOG("Processing item in sequence");
731+
SQLRETURN ret = SQLSetConnectAttr_wrap(ConnectionHandle, Attribute, py::reinterpret_borrow<py::object>(item));
732+
if (!SQL_SUCCEEDED(ret)) {
733+
LOG("Failed to set attribute for item in sequence");
734+
return ret;
735+
}
736+
}
737+
} else {
738+
LOG("Unsupported ValuePtr type");
739+
return SQL_ERROR;
740+
}
741+
742+
SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, value, length);
704743
if (!SQL_SUCCEEDED(ret)) {
705744
LOG("Failed to set Connection attribute");
706745
}
746+
LOG("Set Connection attribute successfully");
707747
return ret;
708748
}
709749

0 commit comments

Comments
 (0)