Skip to content

Commit 0081b75

Browse files
committed
address review comments -
- make fork safety complete in the client - add shutdown mechanism to the integration - better test coverage - better docs on usage
1 parent c519ef9 commit 0081b75

7 files changed

Lines changed: 572 additions & 172 deletions

File tree

examples/celery_integration.py

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,35 @@
22
Celery integration example for the PostHog Python SDK.
33
44
Demonstrates how to use ``PosthogCeleryIntegration`` with:
5-
- producer-side instrumentation (publishing events and context propagation)
6-
- worker-side instrumentation via ``worker_process_init`` (prefork-safe)
5+
- producer-side and worker-side instrumentation (publishing events and context propagation)
76
- context propagation (distinct ID, session ID, tags) from producer to worker
87
- task lifecycle events (published, started, success, failure, retry)
98
- exception capture from failed tasks
109
- ``task_filter`` customization hook
1110
1211
Setup:
13-
1. Update POSTHOG_PROJECT_API_KEY and POSTHOG_HOST here with your credentials
14-
(environment variables won't work as it's better if Celery forks worker into
15-
separate process for the example to prove context propagation)
12+
1. Set ``POSTHOG_PROJECT_API_KEY`` and ``POSTHOG_HOST`` in your environment
1613
2. Install dependencies: pip install posthog celery redis
1714
3. Start Redis: redis-server
1815
4. Start the worker: celery -A examples.celery_integration worker --loglevel=info
1916
5. Run the producer: python -m examples.celery_integration
2017
"""
2118

19+
import os
2220
import time
2321
from typing import Any, Optional
2422

2523
from celery import Celery
2624
from celery.signals import worker_process_init, worker_process_shutdown
2725

2826
import posthog
29-
from posthog.client import Client
3027
from posthog.integrations.celery import PosthogCeleryIntegration
3128

3229

3330
# --- Configuration ---
3431

35-
POSTHOG_PROJECT_API_KEY = "phc_..."
36-
POSTHOG_HOST = "http://localhost:8000"
32+
POSTHOG_PROJECT_API_KEY = os.getenv("POSTHOG_PROJECT_API_KEY", "phc_...")
33+
POSTHOG_HOST = os.getenv("POSTHOG_HOST", "http://localhost:8000")
3734

3835
app = Celery(
3936
"examples.celery_integration",
@@ -43,11 +40,11 @@
4340

4441
# --- Integration wiring ---
4542

46-
def create_client() -> Client:
47-
return Client(
48-
project_api_key=POSTHOG_PROJECT_API_KEY,
49-
host=POSTHOG_HOST
50-
)
43+
def configure_posthog() -> None:
44+
posthog.api_key = POSTHOG_PROJECT_API_KEY
45+
posthog.host = POSTHOG_HOST
46+
posthog.enable_local_evaluation = False # to not require personal_api_key for this example
47+
posthog.setup()
5148

5249

5350
def task_filter(task_name: Optional[str], task_properties: dict[str, Any]) -> bool:
@@ -56,40 +53,42 @@ def task_filter(task_name: Optional[str], task_properties: dict[str, Any]) -> bo
5653
return True
5754

5855

59-
def create_integration(client: Client) -> PosthogCeleryIntegration:
56+
def create_integration() -> PosthogCeleryIntegration:
6057
return PosthogCeleryIntegration(
61-
client=client,
6258
capture_exceptions=True,
6359
capture_task_lifecycle_events=True,
6460
propagate_context=True,
6561
task_filter=task_filter,
6662
)
6763

68-
69-
# Worker process setup.
70-
# Celery's default prefork pool runs tasks in child processes, so initialize
71-
# PostHog per child using worker_process_init.
64+
configure_posthog()
65+
integration = create_integration()
66+
integration.instrument()
7267

7368

69+
# --- Worker process setup ---
70+
# Celery's default prefork pool runs tasks in child processes. This example
71+
# runs on a single host, so the inherited PostHog client and Celery
72+
# integration are fork-safe and do not need to be recreated in each child.
73+
# If workers run across multiple hosts, configure PostHog and instrument a
74+
# worker-local integration in worker_process_init.
7475
@worker_process_init.connect
7576
def on_worker_process_init(**kwargs) -> None:
76-
worker_posthog_client = create_client()
77-
worker_integration = create_integration(worker_posthog_client)
78-
worker_integration.instrument()
79-
80-
app._posthog_client = worker_posthog_client
81-
app._posthog_integration = worker_integration
77+
# global integration
78+
79+
# configure_posthog()
80+
# integration = create_integration()
81+
# integration.instrument()
82+
return
8283

8384

85+
# Use this signal to shutdown the integration and PostHog client
86+
# Calling shutdown() is important to flush any pending events
8487
@worker_process_shutdown.connect
8588
def on_worker_process_shutdown(**kwargs) -> None:
86-
worker_integration = getattr(app, "_posthog_integration", None)
87-
if worker_integration:
88-
worker_integration.uninstrument()
89+
integration.shutdown()
90+
posthog.shutdown()
8991

90-
worker_posthog_client = getattr(app, "_posthog_client", None)
91-
if worker_posthog_client:
92-
worker_posthog_client.shutdown()
9392

9493
# --- Example tasks ---
9594

@@ -98,8 +97,8 @@ def health_check() -> dict[str, str]:
9897
return {"status": "ok"}
9998

10099

101-
@app.task(bind=True, max_retries=3)
102-
def process_order(self, order_id: str) -> dict:
100+
@app.task(max_retries=3)
101+
def process_order(order_id: str) -> dict:
103102
"""A task that processes an order successfully."""
104103

105104
# simulate work
@@ -108,7 +107,7 @@ def process_order(self, order_id: str) -> dict:
108107
# Custom event inside the task - context tags propagated from the
109108
# producer (e.g. "source", "release") should appear on this event
110109
# and this should be attributed to the correct distinct ID and session.
111-
app._posthog_client.capture(
110+
posthog.capture(
112111
"celery example order processed",
113112
properties={"order_id": order_id, "amount": 99.99},
114113
)
@@ -136,17 +135,13 @@ def failing_task() -> None:
136135
# --- Producer code ---
137136

138137
if __name__ == "__main__":
139-
posthog_client = create_client()
140-
integration = create_integration(posthog_client)
141-
integration.instrument()
142-
143138
print("PostHog Celery Integration Example")
144139
print("=" * 40)
145140
print()
146141

147142
# Set up PostHog context before dispatching tasks.
148143
# The integration propagates this context to workers via task headers.
149-
with posthog.new_context(fresh=True, client=posthog_client):
144+
with posthog.new_context(fresh=True):
150145
posthog.identify_context("user-123")
151146
posthog.set_context_session("session-user-123-abc")
152147
posthog.tag("source", "celery_integration_example_script")
@@ -186,6 +181,5 @@ def failing_task() -> None:
186181
print("Tasks dispatched. Check your Celery worker logs and PostHog for events.")
187182
print()
188183

189-
posthog_client.flush()
190-
integration.uninstrument()
191-
posthog_client.shutdown()
184+
integration.shutdown()
185+
posthog.shutdown()

posthog/client.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
flags,
5858
get,
5959
remote_config,
60+
reset_sessions,
6061
)
6162
from posthog.types import (
6263
FeatureFlag,
@@ -245,6 +246,7 @@ def __init__(
245246
)
246247
self.poller = None
247248
self.distinct_ids_feature_flags_reported = SizeLimitedDict(MAX_DICT_SIZE, set)
249+
self.flag_fallback_cache_url = flag_fallback_cache_url
248250
self.flag_cache = self._initialize_flag_cache(flag_fallback_cache_url)
249251
self.flag_definition_version = 0
250252
self._flags_etag: Optional[str] = None
@@ -1098,42 +1100,56 @@ def _reinit_after_fork_weak(weak_self):
10981100
self._reinit_after_fork()
10991101

11001102
def _reinit_after_fork(self):
1101-
"""Reinitialize queue and consumer threads in a forked child process.
1103+
"""Reinitialize fork-unsafe client state in a forked child process.
11021104
11031105
Registered via os.register_at_fork(after_in_child=...) so it runs
11041106
exactly once in each child, before any user code, covering all code
11051107
paths (capture, flush, join, etc.).
11061108
11071109
Python threads do not survive fork() and queue.Queue internal locks
1108-
may be in an inconsistent state, so both are replaced.
1109-
Inherited queue items are intentionally discarded as they'll be
1110-
handled by the parent process's consumers.
1110+
may be in an inconsistent state, so the event queue, consumer threads
1111+
and other state are replaced. Inherited queue items are not retained
1112+
as they'll be handled by the parent process's consumers.
11111113
"""
1112-
if self.consumers is None:
1113-
return
1114+
if self.consumers:
1115+
self.queue = queue.Queue(self._max_queue_size)
1116+
1117+
new_consumers = []
1118+
for old in self.consumers:
1119+
consumer = Consumer(
1120+
self.queue,
1121+
old.api_key,
1122+
flush_at=old.flush_at,
1123+
host=old.host,
1124+
on_error=old.on_error,
1125+
flush_interval=old.flush_interval,
1126+
gzip=old.gzip,
1127+
retries=old.retries,
1128+
timeout=old.timeout,
1129+
historical_migration=old.historical_migration,
1130+
)
1131+
new_consumers.append(consumer)
1132+
1133+
if self.send:
1134+
consumer.start()
1135+
1136+
self.consumers = new_consumers
11141137

1115-
self.queue = queue.Queue(self._max_queue_size)
1116-
1117-
new_consumers = []
1118-
for old in self.consumers:
1119-
consumer = Consumer(
1120-
self.queue,
1121-
old.api_key,
1122-
flush_at=old.flush_at,
1123-
host=old.host,
1124-
on_error=old.on_error,
1125-
flush_interval=old.flush_interval,
1126-
gzip=old.gzip,
1127-
retries=old.retries,
1128-
timeout=old.timeout,
1129-
historical_migration=old.historical_migration,
1138+
if self.enable_local_evaluation:
1139+
self.poller = Poller(
1140+
interval=timedelta(seconds=self.poll_interval),
1141+
execute=self._load_feature_flags,
11301142
)
1131-
new_consumers.append(consumer)
1143+
self.poller.start()
1144+
else:
1145+
self.poller = None
11321146

1133-
if self.send:
1134-
consumer.start()
1147+
# If using Redis cache, we must reinitialize to get a fresh connection (fork-safe).
1148+
# If using Memory cache, we keep it as-is to benefit from the inherited warm cache.
1149+
if isinstance(self.flag_cache, RedisFlagCache):
1150+
self.flag_cache = self._initialize_flag_cache(self.flag_fallback_cache_url)
11351151

1136-
self.consumers = new_consumers
1152+
reset_sessions()
11371153

11381154
def _enqueue(self, msg, disable_geoip):
11391155
# type: (...) -> Optional[str]

posthog/integrations/celery.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@
2828
integration = PosthogCeleryIntegration(client=posthog)
2929
integration.instrument()
3030
31-
Both the producer process and each worker process must initialize the
32-
PostHog client and instrument the integration because the worker needs
33-
to bind to Celery signals, and the PostHog client may use background threads
34-
to send captured events (depending on ``sync_mode``). Celery provides a signal
35-
called ``worker_process_init`` that can be used to accomplish this.
31+
# ... publish tasks or run workers ...
32+
33+
integration.shutdown()
34+
posthog.shutdown()
3635
3736
See ``examples/celery_integration.py`` for a complete working example.
3837
@@ -64,6 +63,7 @@
6463
- **retry**: ``celery_reason``
6564
"""
6665

66+
import atexit
6767
import json
6868
import logging
6969
import time
@@ -113,16 +113,28 @@ def __init__(
113113
self.task_filter = task_filter
114114

115115
self._instrumented = False
116+
self._shut_down = False
116117
self._signals: Optional[Any] = None
117118
self._celery_version: Optional[str] = None
118119

119120
def instrument(self) -> None:
121+
"""Connect Celery signal handlers to capture task events and exceptions.
122+
Call this after initializing the PostHog client and this integration.
123+
124+
If Celery runs on a single host, reinstrumenting in worker children is
125+
not strictly necessary because the PostHog client and this integration
126+
are fork-safe. If Celery workers run across multiple hosts, each worker
127+
process must initialize PostHog, this integration, and call
128+
``instrument()``. Celery provides ``worker_process_init`` signal to help
129+
with this.
130+
"""
120131
if self._instrumented:
121132
return
122133

123134
from celery import signals
124135
from celery import __version__ as celery_version
125136

137+
self._shut_down = False
126138
self._signals = signals
127139
self._celery_version = celery_version
128140

@@ -133,9 +145,12 @@ def instrument(self) -> None:
133145
signals.before_task_publish.connect(self._on_before_task_publish, weak=False)
134146
signals.after_task_publish.connect(self._on_after_task_publish, weak=False)
135147

148+
signals.worker_process_shutdown.connect(self._on_worker_process_shutdown, weak=False)
149+
atexit.register(self.shutdown)
150+
136151
self._instrumented = True
137152

138-
def uninstrument(self) -> None:
153+
def _disconnect_signals(self) -> None:
139154
if not self._instrumented or not self._signals:
140155
return
141156

@@ -146,9 +161,48 @@ def uninstrument(self) -> None:
146161
self._signals.before_task_publish.disconnect(self._on_before_task_publish)
147162
self._signals.after_task_publish.disconnect(self._on_after_task_publish)
148163

164+
self._signals.worker_process_shutdown.disconnect(self._on_worker_process_shutdown)
165+
149166
self._signals = None
150167
self._instrumented = False
151168

169+
def uninstrument(self) -> None:
170+
"""Disconnect Celery signal handlers and unregister exit cleanup.
171+
172+
Do not use directly, call `shutdown()` instead.
173+
"""
174+
self._disconnect_signals()
175+
atexit.unregister(self.shutdown)
176+
177+
def shutdown(self) -> None:
178+
"""Disconnect all signal handlers registered by ``instrument()``, flush all pending events
179+
and cleanly shutdown the integration.
180+
181+
``shutdown()`` is also registered on ``worker_process_shutdown`` and ``atexit``, but
182+
there is no guarantee those will always be called, so we strongly recommend calling
183+
it manually when the integration is no longer needed to avoid data loss.
184+
"""
185+
if self._shut_down:
186+
return
187+
188+
try:
189+
self._disconnect_signals()
190+
191+
if self.client:
192+
self.client.flush()
193+
else:
194+
import posthog
195+
196+
posthog.flush()
197+
198+
self.uninstrument()
199+
self._shut_down = True
200+
except Exception:
201+
logger.exception("Failed to shut down PostHog Celery integration")
202+
203+
def _on_worker_process_shutdown(self, *args, **kwargs) -> None:
204+
self.shutdown()
205+
152206
def _on_before_task_publish(self, *args, **kwargs):
153207
try:
154208
if not self.propagate_context:
@@ -298,7 +352,7 @@ def _handle_task_end(
298352
if self.capture_task_lifecycle_events and self._should_track(task_name, task_properties):
299353
self._capture_event(f"celery task {state}", properties=task_properties)
300354
except Exception:
301-
logger.exception("Failed to process Celery %s", state)
355+
logger.exception("Failed to process Celery %s state", state)
302356
finally:
303357
ctx = getattr(request, "_posthog_ctx", None)
304358
if ctx is not None:

0 commit comments

Comments
 (0)