|
1 | 1 | import requests |
2 | 2 | import time |
3 | 3 | import copy |
| 4 | +import tempfile |
4 | 5 | from pagination import Pagination |
5 | 6 | from safe_logger import SafeLogger |
6 | 7 | from loop_detector import LoopDetector |
@@ -184,14 +185,35 @@ def request(self, method, url, can_raise_exeption=True, **kwargs): |
184 | 185 | def request_with_redirect_retry(self, method, url, **kwargs): |
185 | 186 | # In case of redirection to another domain, the authorization header is not kept |
186 | 187 | # If redirect_auth_header is true, another attempt is made with initial headers to the redirected url |
187 | | - response = self.session.request(method, url, **kwargs) |
| 188 | + response = self.request_with_cert(method, url, **kwargs) |
188 | 189 | if self.redirect_auth_header and not response.url.startswith(url): |
189 | 190 | redirection_kwargs = copy.deepcopy(kwargs) |
190 | 191 | redirection_kwargs.pop("params", None) # params are contained in the redirected url |
191 | 192 | logger.warning("Redirection ! Accessing endpoint {} with initial authorization headers".format(response.url)) |
192 | | - response = self.session.request(method, response.url, **redirection_kwargs) |
| 193 | + response = self.request_with_cert(method, response.url, **redirection_kwargs) |
193 | 194 | return response |
194 | 195 |
|
| 196 | + def request_with_cert(self, method, url, **kwargs): |
| 197 | + cert = kwargs.get("cert", None) |
| 198 | + if cert and len(cert) == 2: |
| 199 | + if cert[0].startswith("-----BEGIN CERTIFICATE") and cert[1].startswith("-----BEGIN PRIVATE KEY"): |
| 200 | + logger.info("mTLS certificate and key are strings") |
| 201 | + response = None |
| 202 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".crt") as tmp_certificate: |
| 203 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".key") as tmp_key: |
| 204 | + tmp_certificate.write( |
| 205 | + normalize_key(cert[0]) |
| 206 | + ) |
| 207 | + tmp_certificate.seek(0) |
| 208 | + tmp_key.write( |
| 209 | + normalize_key(cert[1]) |
| 210 | + ) |
| 211 | + tmp_key.seek(0) |
| 212 | + kwargs["cert"] = (tmp_certificate.name, tmp_key.name) |
| 213 | + response = self.session.request(method, url, **kwargs) |
| 214 | + return response |
| 215 | + return self.session.request(method, url, **kwargs) |
| 216 | + |
195 | 217 | def paginated_api_call(self, can_raise_exeption=True): |
196 | 218 | if self.pagination.params_must_be_blanked: |
197 | 219 | self.requests_kwargs["params"] = {} |
@@ -278,3 +300,17 @@ def get_headers(response): |
278 | 300 | if isinstance(response, requests.Response): |
279 | 301 | return response.headers |
280 | 302 | return None |
| 303 | + |
| 304 | + |
| 305 | +def normalize_key(key): |
| 306 | + tempo_text = str(key) |
| 307 | + tempo_text = tempo_text.replace("BEGIN CERTIFICATE", "BEGINCERTIFICATE") |
| 308 | + tempo_text = tempo_text.replace("END CERTIFICATE", "ENDCERTIFICATE") |
| 309 | + tempo_text = tempo_text.replace("-----BEGIN PRIVATE KEY-----", "-----BEGINPRIVATEKEY-----") |
| 310 | + tempo_text = tempo_text.replace("-----END PRIVATE KEY-----", "-----ENDPRIVATEKEY-----") |
| 311 | + tempo_text = tempo_text.replace(" ", "\n") |
| 312 | + tempo_text = tempo_text.replace("BEGINCERTIFICATE", "BEGIN CERTIFICATE") |
| 313 | + tempo_text = tempo_text.replace("ENDCERTIFICATE", "END CERTIFICATE") |
| 314 | + tempo_text = tempo_text.replace("-----BEGINPRIVATEKEY-----", "-----BEGIN PRIVATE KEY-----") |
| 315 | + tempo_text = tempo_text.replace("-----ENDPRIVATEKEY-----", "-----END PRIVATE KEY-----") |
| 316 | + return tempo_text |
0 commit comments