Skip to content

Commit 33586a4

Browse files
authored
Merge pull request #2470 from jakkdl/named_threads
Add name parameter to to_thread.run_sync [SQUASH]
2 parents 55117ad + 235f841 commit 33586a4

4 files changed

Lines changed: 232 additions & 24 deletions

File tree

newsfragments/1148.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added support for naming threads created with `trio.to_thread.run_sync`, requires pthreads so is only available on POSIX platforms with glibc installed.

trio/_core/_thread_cache.py

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,69 @@
22
import traceback
33
from threading import Thread, Lock
44
import outcome
5+
import ctypes
6+
import ctypes.util
57
from itertools import count
68

9+
from typing import Callable, Optional, Tuple
10+
from functools import partial
11+
12+
13+
def _to_os_thread_name(name: str) -> bytes:
14+
# ctypes handles the trailing \00
15+
return name.encode("ascii", errors="replace")[:15]
16+
17+
18+
# used to construct the method used to set os thread name, or None, depending on platform.
19+
# called once on import
20+
def get_os_thread_name_func() -> Optional[Callable[[Optional[int], str], None]]:
21+
def namefunc(setname: Callable[[int, bytes], int], ident: Optional[int], name: str):
22+
# Thread.ident is None "if it has not been started". Unclear if that can happen
23+
# with current usage.
24+
if ident is not None: # pragma: no cover
25+
setname(ident, _to_os_thread_name(name))
26+
27+
# namefunc on mac also takes an ident, even if pthread_setname_np doesn't/can't use it
28+
# so the caller don't need to care about platform.
29+
def darwin_namefunc(
30+
setname: Callable[[bytes], int], ident: Optional[int], name: str
31+
):
32+
# I don't know if Mac can rename threads that hasn't been started, but default
33+
# to no to be on the safe side.
34+
if ident is not None: # pragma: no cover
35+
setname(_to_os_thread_name(name))
36+
37+
# find the pthread library
38+
# this will fail on windows
39+
libpthread_path = ctypes.util.find_library("pthread")
40+
if not libpthread_path:
41+
return None
42+
libpthread = ctypes.CDLL(libpthread_path)
43+
44+
# get the setname method from it
45+
# afaik this should never fail
46+
pthread_setname_np = getattr(libpthread, "pthread_setname_np", None)
47+
if pthread_setname_np is None: # pragma: no cover
48+
return None
49+
50+
# specify function prototype
51+
pthread_setname_np.restype = ctypes.c_int
52+
53+
# on mac OSX pthread_setname_np does not take a thread id,
54+
# it only lets threads name themselves, which is not a problem for us.
55+
# Just need to make sure to call it correctly
56+
if sys.platform == "darwin":
57+
pthread_setname_np.argtypes = [ctypes.c_char_p]
58+
return partial(darwin_namefunc, pthread_setname_np)
59+
60+
# otherwise assume linux parameter conventions. Should also work on *BSD
61+
pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
62+
return partial(namefunc, pthread_setname_np)
63+
64+
65+
# construct os thread name method
66+
set_os_thread_name = get_os_thread_name_func()
67+
768
# The "thread cache" is a simple unbounded thread pool, i.e., it automatically
869
# spawns as many threads as needed to handle all the requests its given. Its
970
# only purpose is to cache worker threads so that they don't have to be
@@ -44,7 +105,7 @@
44105

45106
class WorkerThread:
46107
def __init__(self, thread_cache):
47-
self._job = None
108+
self._job: Optional[Tuple[Callable, Callable, str]] = None
48109
self._thread_cache = thread_cache
49110
# This Lock is used in an unconventional way.
50111
#
@@ -54,16 +115,34 @@ def __init__(self, thread_cache):
54115
# Initially we have no job, so it starts out in locked state.
55116
self._worker_lock = Lock()
56117
self._worker_lock.acquire()
57-
thread = Thread(target=self._work, daemon=True)
58-
thread.name = f"Trio worker thread {next(name_counter)}"
59-
thread.start()
118+
self._default_name = f"Trio thread {next(name_counter)}"
119+
120+
self._thread = Thread(target=self._work, name=self._default_name, daemon=True)
121+
122+
if set_os_thread_name:
123+
set_os_thread_name(self._thread.ident, self._default_name)
124+
self._thread.start()
60125

61126
def _handle_job(self):
62127
# Handle job in a separate method to ensure user-created
63128
# objects are cleaned up in a consistent manner.
64-
fn, deliver = self._job
129+
assert self._job is not None
130+
fn, deliver, name = self._job
65131
self._job = None
132+
133+
# set name
134+
if name is not None:
135+
self._thread.name = name
136+
if set_os_thread_name:
137+
set_os_thread_name(self._thread.ident, name)
66138
result = outcome.capture(fn)
139+
140+
# reset name if it was changed
141+
if name is not None:
142+
self._thread.name = self._default_name
143+
if set_os_thread_name:
144+
set_os_thread_name(self._thread.ident, self._default_name)
145+
67146
# Tell the cache that we're available to be assigned a new
68147
# job. We do this *before* calling 'deliver', so that if
69148
# 'deliver' triggers a new job, it can be assigned to us
@@ -102,19 +181,19 @@ class ThreadCache:
102181
def __init__(self):
103182
self._idle_workers = {}
104183

105-
def start_thread_soon(self, fn, deliver):
184+
def start_thread_soon(self, fn, deliver, name: Optional[str] = None):
106185
try:
107186
worker, _ = self._idle_workers.popitem()
108187
except KeyError:
109188
worker = WorkerThread(self)
110-
worker._job = (fn, deliver)
189+
worker._job = (fn, deliver, name)
111190
worker._worker_lock.release()
112191

113192

114193
THREAD_CACHE = ThreadCache()
115194

116195

117-
def start_thread_soon(fn, deliver):
196+
def start_thread_soon(fn, deliver, name: Optional[str] = None):
118197
"""Runs ``deliver(outcome.capture(fn))`` in a worker thread.
119198
120199
Generally ``fn`` does some blocking work, and ``deliver`` delivers the
@@ -174,4 +253,4 @@ def start_thread_soon(fn, deliver):
174253
limit how many threads they're using then it's polite to respect that.
175254
176255
"""
177-
THREAD_CACHE.start_thread_soon(fn, deliver)
256+
THREAD_CACHE.start_thread_soon(fn, deliver, name)

trio/_threads.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
import contextvars
2-
import threading
3-
import queue as stdlib_queue
42
import functools
3+
import inspect
4+
import queue as stdlib_queue
5+
import threading
56
from itertools import count
7+
from typing import Optional
68

79
import attr
8-
import inspect
910
import outcome
1011
from sniffio import current_async_library_cvar
1112

1213
import trio
1314

14-
from ._sync import CapacityLimiter
1515
from ._core import (
16-
enable_ki_protection,
17-
disable_ki_protection,
1816
RunVar,
1917
TrioToken,
18+
disable_ki_protection,
19+
enable_ki_protection,
2020
start_thread_soon,
2121
)
22+
from ._sync import CapacityLimiter
2223
from ._util import coroutine_or_error
2324

2425
# Global due to Threading API, thread local storage for trio token
@@ -57,7 +58,9 @@ class ThreadPlaceholder:
5758

5859

5960
@enable_ki_protection
60-
async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None):
61+
async def to_thread_run_sync(
62+
sync_fn, *args, thread_name: Optional[str] = None, cancellable=False, limiter=None
63+
):
6164
"""Convert a blocking operation into an async operation using a thread.
6265
6366
These two lines are equivalent::
@@ -79,6 +82,12 @@ async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None):
7982
arguments, use :func:`functools.partial`.
8083
cancellable (bool): Whether to allow cancellation of this operation. See
8184
discussion below.
85+
thread_name (str): Optional string to set the name of the thread.
86+
Will always set `threading.Thread.name`, but only set the os name
87+
if pthread.h is available (i.e. most POSIX installations).
88+
pthread names are limited to 15 characters, and can be read from
89+
``/proc/<PID>/task/<SPID>/comm`` or with ``ps -eT``, among others.
90+
Defaults to ``{sync_fn.__name__|None} from {trio.lowlevel.current_task().name}``.
8291
limiter (None, or CapacityLimiter-like object):
8392
An object used to limit the number of simultaneous threads. Most
8493
commonly this will be a `~trio.CapacityLimiter`, but it could be
@@ -166,6 +175,9 @@ def do_release_then_return_result():
166175

167176
current_trio_token = trio.lowlevel.current_trio_token()
168177

178+
if thread_name is None:
179+
thread_name = f"{getattr(sync_fn, '__name__', None)} from {trio.lowlevel.current_task().name}"
180+
169181
def worker_fn():
170182
current_async_library_cvar.set(None)
171183
TOKEN_LOCAL.token = current_trio_token
@@ -198,7 +210,9 @@ def deliver_worker_fn_result(result):
198210

199211
await limiter.acquire_on_behalf_of(placeholder)
200212
try:
201-
start_thread_soon(contextvars_aware_worker_fn, deliver_worker_fn_result)
213+
start_thread_soon(
214+
contextvars_aware_worker_fn, deliver_worker_fn_result, thread_name
215+
)
202216
except:
203217
limiter.release_on_behalf_of(placeholder)
204218
raise

trio/tests/test_threads.py

Lines changed: 121 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
import contextvars
2-
import threading
32
import queue as stdlib_queue
3+
import re
4+
import sys
5+
import threading
46
import time
57
import weakref
8+
from functools import partial
9+
from typing import Callable, Optional
610

711
import pytest
812
from sniffio import current_async_library_cvar
13+
914
from trio._core import TrioToken, current_trio_token
1015

11-
from .. import _core
12-
from .. import Event, CapacityLimiter, sleep
13-
from ..testing import wait_all_tasks_blocked
16+
from .. import CapacityLimiter, Event, _core, sleep
17+
from .._core.tests.test_ki import ki_self
1418
from .._core.tests.tutil import buggy_pypy_asyncgens
1519
from .._threads import (
16-
to_thread_run_sync,
1720
current_default_thread_limiter,
1821
from_thread_run,
1922
from_thread_run_sync,
23+
to_thread_run_sync,
2024
)
21-
22-
from .._core.tests.test_ki import ki_self
25+
from ..testing import wait_all_tasks_blocked
2326

2427

2528
async def test_do_in_trio_thread():
@@ -164,6 +167,117 @@ async def main():
164167
assert record == ["sleeping", "cancelled"]
165168

166169

170+
async def test_named_thread():
171+
ending = " from trio.tests.test_threads.test_named_thread"
172+
173+
def inner(name="inner" + ending) -> threading.Thread:
174+
assert threading.current_thread().name == name
175+
return threading.current_thread()
176+
177+
def f(name: str) -> Callable[[None], threading.Thread]:
178+
return partial(inner, name)
179+
180+
# test defaults
181+
await to_thread_run_sync(inner)
182+
await to_thread_run_sync(inner, thread_name=None)
183+
184+
# functools.partial doesn't have __name__, so defaults to None
185+
await to_thread_run_sync(f("None" + ending))
186+
187+
# test that you can set a custom name, and that it's reset afterwards
188+
async def test_thread_name(name: str):
189+
thread = await to_thread_run_sync(f(name), thread_name=name)
190+
assert re.match("Trio thread [0-9]*", thread.name)
191+
192+
await test_thread_name("")
193+
await test_thread_name("fobiedoo")
194+
await test_thread_name("name_longer_than_15_characters")
195+
196+
await test_thread_name("💙")
197+
198+
199+
def _get_thread_name(ident: Optional[int] = None) -> Optional[str]:
200+
import ctypes
201+
import ctypes.util
202+
203+
libpthread_path = ctypes.util.find_library("pthread")
204+
if not libpthread_path:
205+
print(f"no pthread on {sys.platform})")
206+
return None
207+
libpthread = ctypes.CDLL(libpthread_path)
208+
209+
pthread_getname_np = getattr(libpthread, "pthread_getname_np", None)
210+
211+
# this should never fail on any platforms afaik
212+
assert pthread_getname_np
213+
214+
# thankfully getname signature doesn't differ between platforms
215+
pthread_getname_np.argtypes = [
216+
ctypes.c_void_p,
217+
ctypes.c_char_p,
218+
ctypes.c_size_t,
219+
]
220+
pthread_getname_np.restype = ctypes.c_int
221+
222+
name_buffer = ctypes.create_string_buffer(b"", size=16)
223+
if ident is None:
224+
ident = threading.get_ident()
225+
assert pthread_getname_np(ident, name_buffer, 16) == 0
226+
try:
227+
return name_buffer.value.decode()
228+
except UnicodeDecodeError as e: # pragma: no cover
229+
# used for debugging when testing via CI
230+
pytest.fail(f"value: {name_buffer.value!r}, exception: {e}")
231+
232+
233+
# test os thread naming
234+
# this depends on pthread being available, which is the case on 99.9% of linux machines
235+
# and most mac machines. So unless the platform is linux it will just skip
236+
# in case it fails to fetch the os thread name.
237+
async def test_named_thread_os():
238+
def inner(name) -> threading.Thread:
239+
os_thread_name = _get_thread_name()
240+
if os_thread_name is None and sys.platform != "linux":
241+
pytest.skip(f"no pthread OS support on {sys.platform}")
242+
else:
243+
assert os_thread_name == name[:15]
244+
245+
return threading.current_thread()
246+
247+
def f(name: str) -> Callable[[None], threading.Thread]:
248+
return partial(inner, name)
249+
250+
# test defaults
251+
default = "None from trio.tests.test_threads.test_named_thread"
252+
await to_thread_run_sync(f(default))
253+
await to_thread_run_sync(f(default), thread_name=None)
254+
255+
# test that you can set a custom name, and that it's reset afterwards
256+
async def test_thread_name(name: str, expected: Optional[str] = None):
257+
if expected is None:
258+
expected = name
259+
thread = await to_thread_run_sync(f(expected), thread_name=name)
260+
261+
os_thread_name = _get_thread_name(thread.ident)
262+
assert os_thread_name is not None, "should skip earlier if this is the case"
263+
assert re.match("Trio thread [0-9]*", os_thread_name)
264+
265+
await test_thread_name("")
266+
await test_thread_name("fobiedoo")
267+
await test_thread_name("name_longer_than_15_characters")
268+
269+
await test_thread_name("💙", expected="?")
270+
271+
272+
async def test_has_pthread_setname_np():
273+
from trio._core._thread_cache import get_os_thread_name_func
274+
275+
k = get_os_thread_name_func()
276+
if k is None:
277+
assert sys.platform != "linux"
278+
pytest.skip(f"no pthread_setname_np on {sys.platform}")
279+
280+
167281
async def test_run_in_worker_thread():
168282
trio_thread = threading.current_thread()
169283

0 commit comments

Comments
 (0)