|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +"""Tests for bulkcopy auth field cleanup in cursor.py. |
| 5 | +
|
| 6 | +When cursor.bulkcopy() acquires an Azure AD token, it must strip stale |
| 7 | +authentication/user_name/password keys from the pycore_context dict before |
| 8 | +passing it to mssql_py_core. The Rust validator rejects access_token |
| 9 | +combined with those fields (ODBC parity). |
| 10 | +""" |
| 11 | + |
| 12 | +import pytest |
| 13 | +import secrets |
| 14 | +from unittest.mock import MagicMock, patch, PropertyMock |
| 15 | + |
| 16 | +SAMPLE_TOKEN = secrets.token_hex(44) |
| 17 | + |
| 18 | + |
| 19 | +def _make_cursor(connection_str, auth_type): |
| 20 | + """Build a mock Cursor with just enough wiring for bulkcopy's auth path.""" |
| 21 | + from mssql_python.cursor import Cursor |
| 22 | + |
| 23 | + mock_conn = MagicMock() |
| 24 | + mock_conn.connection_str = connection_str |
| 25 | + mock_conn._auth_type = auth_type |
| 26 | + mock_conn._is_connected = True |
| 27 | + |
| 28 | + cursor = Cursor.__new__(Cursor) |
| 29 | + cursor._connection = mock_conn |
| 30 | + cursor.closed = False |
| 31 | + cursor.hstmt = None |
| 32 | + return cursor |
| 33 | + |
| 34 | + |
| 35 | +class TestBulkcopyAuthCleanup: |
| 36 | + """Verify cursor.bulkcopy strips stale auth fields after token acquisition.""" |
| 37 | + |
| 38 | + @patch("mssql_python.cursor.get_settings") |
| 39 | + @patch("mssql_python.cursor.logger") |
| 40 | + def test_token_replaces_auth_fields(self, mock_logger, mock_settings): |
| 41 | + """access_token present ⇒ authentication, user_name, password removed.""" |
| 42 | + mock_settings.return_value = MagicMock(logging=False) |
| 43 | + mock_logger.is_debug_enabled = False |
| 44 | + |
| 45 | + cursor = _make_cursor( |
| 46 | + "Server=tcp:test.database.windows.net;Database=testdb;" |
| 47 | + "Authentication=ActiveDirectoryDefault;UID=user@test.com;PWD=secret", |
| 48 | + "activedirectorydefault", |
| 49 | + ) |
| 50 | + |
| 51 | + captured_context = {} |
| 52 | + |
| 53 | + mock_pycore_cursor = MagicMock() |
| 54 | + mock_pycore_cursor.bulkcopy.return_value = { |
| 55 | + "rows_copied": 1, |
| 56 | + "batch_count": 1, |
| 57 | + "elapsed_time": 0.1, |
| 58 | + } |
| 59 | + mock_pycore_conn = MagicMock() |
| 60 | + mock_pycore_conn.cursor.return_value = mock_pycore_cursor |
| 61 | + |
| 62 | + def capture_context(ctx, **kwargs): |
| 63 | + captured_context.update(ctx) |
| 64 | + return mock_pycore_conn |
| 65 | + |
| 66 | + mock_pycore_module = MagicMock() |
| 67 | + mock_pycore_module.PyCoreConnection = capture_context |
| 68 | + |
| 69 | + with ( |
| 70 | + patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}), |
| 71 | + patch("mssql_python.auth.AADAuth.get_raw_token", return_value=SAMPLE_TOKEN), |
| 72 | + ): |
| 73 | + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) |
| 74 | + |
| 75 | + assert captured_context.get("access_token") == SAMPLE_TOKEN |
| 76 | + assert "authentication" not in captured_context |
| 77 | + assert "user_name" not in captured_context |
| 78 | + assert "password" not in captured_context |
| 79 | + |
| 80 | + @patch("mssql_python.cursor.get_settings") |
| 81 | + @patch("mssql_python.cursor.logger") |
| 82 | + def test_no_auth_type_leaves_fields_intact(self, mock_logger, mock_settings): |
| 83 | + """No _auth_type ⇒ credentials pass through unchanged (SQL auth path).""" |
| 84 | + mock_settings.return_value = MagicMock(logging=False) |
| 85 | + mock_logger.is_debug_enabled = False |
| 86 | + |
| 87 | + cursor = _make_cursor( |
| 88 | + "Server=tcp:test.database.windows.net;Database=testdb;" |
| 89 | + "UID=sa;PWD=password123", |
| 90 | + None, # no AD auth |
| 91 | + ) |
| 92 | + |
| 93 | + captured_context = {} |
| 94 | + |
| 95 | + mock_pycore_cursor = MagicMock() |
| 96 | + mock_pycore_cursor.bulkcopy.return_value = { |
| 97 | + "rows_copied": 1, |
| 98 | + "batch_count": 1, |
| 99 | + "elapsed_time": 0.1, |
| 100 | + } |
| 101 | + mock_pycore_conn = MagicMock() |
| 102 | + mock_pycore_conn.cursor.return_value = mock_pycore_cursor |
| 103 | + |
| 104 | + def capture_context(ctx, **kwargs): |
| 105 | + captured_context.update(ctx) |
| 106 | + return mock_pycore_conn |
| 107 | + |
| 108 | + mock_pycore_module = MagicMock() |
| 109 | + mock_pycore_module.PyCoreConnection = capture_context |
| 110 | + |
| 111 | + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): |
| 112 | + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) |
| 113 | + |
| 114 | + assert "access_token" not in captured_context |
| 115 | + assert captured_context.get("user_name") == "sa" |
| 116 | + assert captured_context.get("password") == "password123" |
0 commit comments