Skip to content

Commit 09e65e1

Browse files
committed
Make ThreadPool a context manager to prevent memory leaks
1 parent c4f65a5 commit 09e65e1

3 files changed

Lines changed: 104 additions & 104 deletions

File tree

src/Test/TestNoparallel.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -149,21 +149,19 @@ def raiseException():
149149

150150
def testMultithreadMix(self, queue_spawn):
151151
obj1 = ExampleClass()
152-
thread_pool = ThreadPool.ThreadPool(10)
153-
154-
s = time.time()
155-
t1 = queue_spawn(obj1.countBlocking, 5)
156-
time.sleep(0.01)
157-
t2 = thread_pool.spawn(obj1.countBlocking, 5)
158-
time.sleep(0.01)
159-
t3 = thread_pool.spawn(obj1.countBlocking, 5)
160-
time.sleep(0.3)
161-
t4 = gevent.spawn(obj1.countBlocking, 5)
162-
threads = [t1, t2, t3, t4]
163-
for thread in threads:
164-
assert thread.get() == "counted:5"
165-
166-
time_taken = time.time() - s
167-
assert obj1.counted == 5
168-
assert 0.5 < time_taken < 0.7
169-
thread_pool.kill()
152+
with ThreadPool.ThreadPool(10) as thread_pool:
153+
s = time.time()
154+
t1 = queue_spawn(obj1.countBlocking, 5)
155+
time.sleep(0.01)
156+
t2 = thread_pool.spawn(obj1.countBlocking, 5)
157+
time.sleep(0.01)
158+
t3 = thread_pool.spawn(obj1.countBlocking, 5)
159+
time.sleep(0.3)
160+
t4 = gevent.spawn(obj1.countBlocking, 5)
161+
threads = [t1, t2, t3, t4]
162+
for thread in threads:
163+
assert thread.get() == "counted:5"
164+
165+
time_taken = time.time() - s
166+
assert obj1.counted == 5
167+
assert 0.5 < time_taken < 0.7

src/Test/TestThreadPool.py

Lines changed: 82 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,29 @@
99

1010
class TestThreadPool:
1111
def testExecutionOrder(self):
12-
pool = ThreadPool.ThreadPool(4)
13-
14-
events = []
15-
16-
@pool.wrap
17-
def blocker():
18-
events.append("S")
19-
out = 0
20-
for i in range(10000000):
21-
if i == 3000000:
22-
events.append("M")
23-
out += 1
24-
events.append("D")
25-
return out
26-
27-
threads = []
28-
for i in range(3):
29-
threads.append(gevent.spawn(blocker))
30-
gevent.joinall(threads)
31-
32-
assert events == ["S"] * 3 + ["M"] * 3 + ["D"] * 3
33-
34-
res = blocker()
35-
assert res == 10000000
36-
pool.kill()
12+
with ThreadPool.ThreadPool(4) as pool:
13+
events = []
14+
15+
@pool.wrap
16+
def blocker():
17+
events.append("S")
18+
out = 0
19+
for i in range(10000000):
20+
if i == 3000000:
21+
events.append("M")
22+
out += 1
23+
events.append("D")
24+
return out
25+
26+
threads = []
27+
for i in range(3):
28+
threads.append(gevent.spawn(blocker))
29+
gevent.joinall(threads)
30+
31+
assert events == ["S"] * 3 + ["M"] * 3 + ["D"] * 3
32+
33+
res = blocker()
34+
assert res == 10000000
3735

3836
def testLockBlockingSameThread(self):
3937
lock = ThreadPool.Lock()
@@ -60,89 +58,88 @@ def locker():
6058
time.sleep(0.5)
6159
lock.release()
6260

63-
pool = ThreadPool.ThreadPool(10)
64-
threads = [
65-
pool.spawn(locker),
66-
pool.spawn(locker),
67-
gevent.spawn(locker),
68-
pool.spawn(locker)
69-
]
70-
time.sleep(0.1)
61+
with ThreadPool.ThreadPool(10) as pool:
62+
threads = [
63+
pool.spawn(locker),
64+
pool.spawn(locker),
65+
gevent.spawn(locker),
66+
pool.spawn(locker)
67+
]
68+
time.sleep(0.1)
7169

72-
s = time.time()
70+
s = time.time()
7371

74-
lock.acquire(True, 5.0)
72+
lock.acquire(True, 5.0)
7573

76-
unlock_taken = time.time() - s
74+
unlock_taken = time.time() - s
7775

78-
assert 1.8 < unlock_taken < 2.2
76+
assert 1.8 < unlock_taken < 2.2
7977

80-
gevent.joinall(threads)
78+
gevent.joinall(threads)
8179

8280
def testMainLoopCallerThreadId(self):
8381
main_thread_id = threading.current_thread().ident
84-
pool = ThreadPool.ThreadPool(5)
85-
86-
def getThreadId(*args, **kwargs):
87-
return threading.current_thread().ident
82+
with ThreadPool.ThreadPool(5) as pool:
83+
def getThreadId(*args, **kwargs):
84+
return threading.current_thread().ident
8885

89-
t = pool.spawn(getThreadId)
90-
assert t.get() != main_thread_id
86+
t = pool.spawn(getThreadId)
87+
assert t.get() != main_thread_id
9188

92-
t = pool.spawn(lambda: ThreadPool.main_loop.call(getThreadId))
93-
assert t.get() == main_thread_id
89+
t = pool.spawn(lambda: ThreadPool.main_loop.call(getThreadId))
90+
assert t.get() == main_thread_id
9491

9592
def testMainLoopCallerGeventSpawn(self):
9693
main_thread_id = threading.current_thread().ident
97-
pool = ThreadPool.ThreadPool(5)
98-
def waiter():
99-
time.sleep(1)
100-
return threading.current_thread().ident
94+
with ThreadPool.ThreadPool(5) as pool:
95+
def waiter():
96+
time.sleep(1)
97+
return threading.current_thread().ident
10198

102-
def geventSpawner():
103-
event = ThreadPool.main_loop.call(gevent.spawn, waiter)
99+
def geventSpawner():
100+
event = ThreadPool.main_loop.call(gevent.spawn, waiter)
104101

105-
with pytest.raises(Exception) as greenlet_err:
106-
event.get()
107-
assert str(greenlet_err.value) == "cannot switch to a different thread"
102+
with pytest.raises(Exception) as greenlet_err:
103+
event.get()
104+
assert str(greenlet_err.value) == "cannot switch to a different thread"
108105

109-
waiter_thread_id = ThreadPool.main_loop.call(event.get)
110-
return waiter_thread_id
106+
waiter_thread_id = ThreadPool.main_loop.call(event.get)
107+
return waiter_thread_id
111108

112-
s = time.time()
113-
waiter_thread_id = pool.apply(geventSpawner)
114-
assert main_thread_id == waiter_thread_id
115-
time_taken = time.time() - s
116-
assert 0.9 < time_taken < 1.2
109+
s = time.time()
110+
waiter_thread_id = pool.apply(geventSpawner)
111+
assert main_thread_id == waiter_thread_id
112+
time_taken = time.time() - s
113+
assert 0.9 < time_taken < 1.2
117114

118115
def testEvent(self):
119-
pool = ThreadPool.ThreadPool(5)
120-
event = ThreadPool.Event()
116+
with ThreadPool.ThreadPool(5) as pool:
117+
event = ThreadPool.Event()
121118

122-
def setter():
123-
time.sleep(1)
124-
event.set("done!")
119+
def setter():
120+
time.sleep(1)
121+
event.set("done!")
125122

126-
def getter():
127-
return event.get()
123+
def getter():
124+
return event.get()
128125

129-
pool.spawn(setter)
130-
t_gevent = gevent.spawn(getter)
131-
t_pool = pool.spawn(getter)
132-
s = time.time()
133-
assert event.get() == "done!"
134-
time_taken = time.time() - s
135-
gevent.joinall([t_gevent, t_pool])
126+
pool.spawn(setter)
127+
t_gevent = gevent.spawn(getter)
128+
t_pool = pool.spawn(getter)
129+
s = time.time()
130+
assert event.get() == "done!"
131+
time_taken = time.time() - s
132+
gevent.joinall([t_gevent, t_pool])
136133

137-
assert t_gevent.get() == "done!"
138-
assert t_pool.get() == "done!"
134+
assert t_gevent.get() == "done!"
135+
assert t_pool.get() == "done!"
139136

140-
assert 0.9 < time_taken < 1.2
137+
assert 0.9 < time_taken < 1.2
141138

142-
with pytest.raises(Exception) as err:
143-
event.set("another result")
139+
with pytest.raises(Exception) as err:
140+
event.set("another result")
144141

145-
assert "Event already has value" in str(err.value)
142+
assert "Event already has value" in str(err.value)
146143

147144
def testMemoryLeak(self):
148145
import gc
@@ -153,10 +150,9 @@ def worker():
153150
return "ok"
154151

155152
def poolTest():
156-
pool = ThreadPool.ThreadPool(5)
157-
for i in range(20):
158-
pool.spawn(worker)
159-
pool.kill()
153+
with ThreadPool.ThreadPool(5) as pool:
154+
for i in range(20):
155+
pool.spawn(worker)
160156

161157
for i in range(5):
162158
poolTest()

src/util/ThreadPool.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ def kill(self):
5555
del self.pool
5656
self.pool = None
5757

58+
def __enter__(self):
59+
return self
60+
61+
def __exit__(self, *args):
62+
self.kill()
63+
5864

5965
lock_pool = gevent.threadpool.ThreadPool(50)
6066
main_thread_id = threading.current_thread().ident

0 commit comments

Comments
 (0)