|
18 | 18 | # However, if you have executed another commercial license agreement |
19 | 19 | # with Crate these terms will supersede the license and you may use the |
20 | 20 | # software solely pursuant to the terms of the relevant commercial agreement. |
| 21 | +import re |
21 | 22 | import typing as t |
22 | 23 | import warnings |
23 | 24 | from datetime import datetime, timedelta, timezone |
24 | 25 |
|
25 | 26 | from .converter import Converter, DataType |
26 | 27 | from .exceptions import ProgrammingError |
27 | 28 |
|
| 29 | +_NAMED_PARAM_RE = re.compile(r"%\((\w+)\)s") |
| 30 | + |
| 31 | + |
| 32 | +def _convert_named_to_positional( |
| 33 | + sql: str, params: t.Dict[str, t.Any] |
| 34 | +) -> t.Tuple[str, t.List[t.Any]]: |
| 35 | + """Convert pyformat-style named parameters to positional qmark parameters. |
| 36 | +
|
| 37 | + Converts ``%(name)s`` placeholders to ``?`` and returns an ordered list |
| 38 | + of corresponding values extracted from ``params``. |
| 39 | +
|
| 40 | + The same name may appear multiple times; each occurrence appends the |
| 41 | + value to the positional list independently. |
| 42 | +
|
| 43 | + Raises ``ProgrammingError`` if a placeholder name is absent from ``params``. |
| 44 | + Extra keys in ``params`` are silently ignored. |
| 45 | +
|
| 46 | + Example:: |
| 47 | +
|
| 48 | + sql = "SELECT * FROM t WHERE a = %(a)s AND b = %(b)s" |
| 49 | + params = {"a": 1, "b": 2} |
| 50 | + # returns: ("SELECT * FROM t WHERE a = ? AND b = ?", [1, 2]) |
| 51 | + """ |
| 52 | + positional: t.List[t.Any] = [] |
| 53 | + |
| 54 | + def _replace(match: "re.Match[str]") -> str: |
| 55 | + name = match.group(1) |
| 56 | + if name not in params: |
| 57 | + raise ProgrammingError( |
| 58 | + f"Named parameter '{name}' not found in the parameters dict" |
| 59 | + ) |
| 60 | + positional.append(params[name]) |
| 61 | + return "?" |
| 62 | + |
| 63 | + converted_sql = _NAMED_PARAM_RE.sub(_replace, sql) |
| 64 | + return converted_sql, positional |
| 65 | + |
28 | 66 |
|
29 | 67 | class Cursor: |
30 | 68 | """ |
@@ -54,6 +92,9 @@ def execute(self, sql, parameters=None, bulk_parameters=None): |
54 | 92 | if self._closed: |
55 | 93 | raise ProgrammingError("Cursor closed") |
56 | 94 |
|
| 95 | + if isinstance(parameters, dict): |
| 96 | + sql, parameters = _convert_named_to_positional(sql, parameters) |
| 97 | + |
57 | 98 | self._result = self.connection.client.sql( |
58 | 99 | sql, parameters, bulk_parameters |
59 | 100 | ) |
|
0 commit comments