Skip to content

Commit 353cf1b

Browse files
committed
review comments
1 parent de176fe commit 353cf1b

4 files changed

Lines changed: 60 additions & 13 deletions

File tree

posthog/client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import sys
55
import warnings
6+
import weakref
67
from datetime import datetime, timedelta
78
from typing import Any, Dict, Optional, Union
89
from uuid import uuid4
@@ -334,7 +335,8 @@ def __init__(
334335
consumer.start()
335336

336337
if hasattr(os, "register_at_fork"):
337-
os.register_at_fork(after_in_child=self._reinit_after_fork)
338+
weak_self = weakref.ref(self)
339+
os.register_at_fork(after_in_child=lambda: Client._reinit_after_fork_weak(weak_self))
338340

339341
def new_context(self, fresh=False, capture_exceptions=True):
340342
"""
@@ -1084,6 +1086,17 @@ def capture_exception(
10841086
except Exception as e:
10851087
self.log.exception(f"Failed to capture exception: {e}")
10861088

1089+
@staticmethod
1090+
def _reinit_after_fork_weak(weak_self):
1091+
"""
1092+
Reinitialize the client after a fork.
1093+
Garbage collected if the client is deleted.
1094+
"""
1095+
self = weak_self()
1096+
if self is None:
1097+
return
1098+
self._reinit_after_fork()
1099+
10871100
def _reinit_after_fork(self):
10881101
"""Reinitialize queue and consumer threads in a forked child process.
10891102

posthog/integrations/celery.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def _on_after_task_publish(self, *args, **kwargs):
207207
logger.exception("Failed to capture Celery after_task_publish lifecycle event")
208208

209209
def _on_task_prerun(self, *args, **kwargs):
210+
context_manager = None
210211
try:
211212
task_id = kwargs.get("task_id")
212213
if not task_id:
@@ -222,17 +223,16 @@ def _on_task_prerun(self, *args, **kwargs):
222223
)
223224
task_name = task_properties.get("celery_task_name")
224225

225-
context_manager = contexts.new_context(
226-
fresh=True, # to prevent context bleed across tasks
227-
capture_exceptions=False, # Celery catches task exceptions internally and
228-
# delivers them via task_failure signal, so they
229-
# never propagate through the context manager.
230-
# We capture them in _on_task_failure.
231-
client=self.client,
232-
)
233-
context_manager.__enter__()
234-
235226
if request is not None:
227+
context_manager = contexts.new_context(
228+
fresh=True, # to prevent context bleed across tasks
229+
capture_exceptions=False, # Celery catches task exceptions internally and
230+
# delivers them via task_failure signal, so they
231+
# never propagate through the context manager.
232+
# We capture them in _on_task_failure.
233+
client=self.client,
234+
)
235+
context_manager.__enter__()
236236
request._posthog_ctx = context_manager
237237
request._posthog_start = time.monotonic()
238238

@@ -246,6 +246,11 @@ def _on_task_prerun(self, *args, **kwargs):
246246
self._capture_event("celery task started", properties=task_properties)
247247
except Exception:
248248
logger.exception("Failed to process Celery task_prerun")
249+
if context_manager is not None:
250+
try:
251+
context_manager.__exit__(None, None, None)
252+
except Exception:
253+
pass
249254

250255
def _on_task_success(self, *args, **kwargs):
251256
self._handle_task_end("success", **kwargs)

posthog/test/integrations/test_celery_integration.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,26 @@ def test_extract_headers_supports_request_dict_shape(self):
463463

464464
self.assertEqual(headers, {CONTEXT_DISTINCT_ID_HEADER: "user-1"})
465465

466+
def test_prerun_exits_context_on_failure_after_entry(self):
467+
mock_client = Mock()
468+
integration = PosthogCeleryIntegration(client=mock_client)
469+
470+
request = SimpleNamespace(
471+
headers={},
472+
delivery_info={},
473+
hostname="worker-1",
474+
retries=0,
475+
)
476+
task = SimpleNamespace(name="app.tasks.boom", request=request)
477+
478+
ctx_before = contexts._get_current_context()
479+
480+
with patch.object(integration, "_apply_propagated_identity", side_effect=RuntimeError("boom")):
481+
integration._on_task_prerun(sender=task, task_id="task-leak")
482+
483+
ctx_after = contexts._get_current_context()
484+
self.assertIs(ctx_after, ctx_before)
485+
466486
def test_extract_propagated_tags_invalid_json_returns_empty_dict(self):
467487
integration = PosthogCeleryIntegration()
468488
request = SimpleNamespace(headers={CONTEXT_TAGS_HEADER: "{bad json"})

posthog/test/test_client.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2735,8 +2735,17 @@ def test_registers_at_fork_hook(self, mock_register_at_fork):
27352735

27362736
mock_register_at_fork.assert_called_once()
27372737
after_in_child = mock_register_at_fork.call_args.kwargs["after_in_child"]
2738-
self.assertEqual(after_in_child.__self__, client)
2739-
self.assertEqual(after_in_child.__name__, "_reinit_after_fork")
2738+
2739+
with mock.patch.object(client, "_reinit_after_fork") as mock_reinit:
2740+
after_in_child()
2741+
mock_reinit.assert_called_once()
2742+
2743+
@mock.patch("posthog.client.os.register_at_fork")
2744+
def test_register_at_fork_noop_after_client_gc(self, mock_register_at_fork):
2745+
client = Client(FAKE_TEST_API_KEY, on_error=self.set_fail)
2746+
after_in_child = mock_register_at_fork.call_args.kwargs["after_in_child"]
2747+
del client
2748+
after_in_child()
27402749

27412750
@parameterized.expand([(True, 1), (False, 0)])
27422751
def test_reinit_after_fork_replaces_queue_and_consumers(self, send, expected_starts):

0 commit comments

Comments
 (0)