Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 58 additions & 21 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from collections.abc import Callable, Iterator, Mapping, Sequence, Set as AbstractSet
from heapq import heappop, heappush
from textwrap import dedent
from threading import Thread
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -371,6 +372,7 @@ def default_flush_errors(
extra_plugins = extra_plugins or []

workers = []
connect_threads = []
if options.num_workers > 0:
# TODO: switch to something more efficient than pickle (also in the daemon).
pickled_options = pickle.dumps(options.snapshot())
Expand All @@ -383,10 +385,17 @@ def default_flush_errors(
buf = WriteBuffer()
sources_message.write(buf)
sources_data = buf.getvalue()

def connect(wc: WorkerClient, data: bytes) -> None:
# Start loading sources in each worker as soon as it is up.
wc.connect()
wc.conn.write_bytes(data)

# We don't wait for workers to be ready until they are actually needed.
for worker in workers:
# Start loading graph in each worker as soon as it is up.
worker.connect()
worker.conn.write_bytes(sources_data)
thread = Thread(target=connect, args=(worker, sources_data))
thread.start()
connect_threads.append(thread)

try:
result = build_inner(
Expand All @@ -399,6 +408,7 @@ def default_flush_errors(
stderr,
extra_plugins,
workers,
connect_threads,
)
result.errors = messages
return result
Expand All @@ -412,6 +422,10 @@ def default_flush_errors(
e.messages = messages
raise
finally:
# In case of an early crash it is better to wait for workers to become ready, and
# shut them down cleanly. Otherwise, they will linger until connection timeout.
for thread in connect_threads:
thread.join()
for worker in workers:
try:
send(worker.conn, SccRequestMessage(scc_id=None, import_errors={}, mod_data={}))
Expand All @@ -431,6 +445,7 @@ def build_inner(
stderr: TextIO,
extra_plugins: Sequence[Plugin],
workers: list[WorkerClient],
connect_threads: list[Thread],
) -> BuildResult:
if platform.python_implementation() == "CPython":
# Run gc less frequently, as otherwise we can spend a large fraction of
Expand Down Expand Up @@ -486,7 +501,7 @@ def build_inner(

reset_global_state()
try:
graph = dispatch(sources, manager, stdout)
graph = dispatch(sources, manager, stdout, connect_threads)
if not options.fine_grained_incremental:
type_state.reset_all_subtype_caches()
if options.timing_stats is not None:
Expand Down Expand Up @@ -1156,6 +1171,24 @@ def add_stats(self, **kwds: Any) -> None:
def stats_summary(self) -> Mapping[str, object]:
return self.stats

def broadcast(self, message: bytes) -> None:
"""Broadcast same message to all workers in parallel."""
t0 = time.time()
threads = []
for worker in self.workers:
thread = Thread(target=worker.conn.write_bytes, args=(message,))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
self.add_stats(broadcast_time=time.time() - t0)

def wait_ack(self) -> None:
"""Wait for an ack from all workers."""
for worker in self.workers:
buf = receive(worker.conn)
assert read_tag(buf) == ACK_MESSAGE

def submit(self, graph: Graph, sccs: list[SCC]) -> None:
"""Submit a stale SCC for processing in current process or parallel workers."""
if self.workers:
Expand All @@ -1176,6 +1209,7 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None
for mod_id in scc.mod_ids
if (path := graph[mod_id].xpath) in self.errors.recorded
}
t0 = time.time()
send(
self.workers[idx].conn,
SccRequestMessage(
Expand All @@ -1193,6 +1227,7 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None
},
),
)
self.add_stats(scc_send_time=time.time() - t0)

def wait_for_done(
self, graph: Graph
Expand Down Expand Up @@ -3685,7 +3720,12 @@ def log_configuration(manager: BuildManager, sources: list[BuildSource]) -> None
# The driver


def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO) -> Graph:
def dispatch(
sources: list[BuildSource],
manager: BuildManager,
stdout: TextIO,
connect_threads: list[Thread],
) -> Graph:
log_configuration(manager, sources)

t0 = time.time()
Expand Down Expand Up @@ -3742,7 +3782,7 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO)
dump_graph(graph, stdout)
return graph

# Fine grained dependencies that didn't have an associated module in the build
# Fine-grained dependencies that didn't have an associated module in the build
# are serialized separately, so we read them after we load the graph.
# We need to read them both for running in daemon mode and if we are generating
# a fine-grained cache (so that we can properly update them incrementally).
Expand All @@ -3755,25 +3795,28 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO)
if fg_deps_meta is not None:
manager.fg_deps_meta = fg_deps_meta
elif manager.stats.get("fresh_metas", 0) > 0:
# Clear the stats so we don't infinite loop because of positive fresh_metas
# Clear the stats, so we don't infinite loop because of positive fresh_metas
manager.stats.clear()
# There were some cache files read, but no fine-grained dependencies loaded.
manager.log("Error reading fine-grained dependencies cache -- aborting cache load")
manager.cache_enabled = False
manager.log("Falling back to full run -- reloading graph...")
return dispatch(sources, manager, stdout)
return dispatch(sources, manager, stdout, connect_threads)

# If we are loading a fine-grained incremental mode cache, we
# don't want to do a real incremental reprocess of the
# graph---we'll handle it all later.
if not manager.use_fine_grained_cache():
# Wait for workers since they may be needed at this point.
for thread in connect_threads:
thread.join()
process_graph(graph, manager)
# Update plugins snapshot.
write_plugins_snapshot(manager)
manager.old_plugins_snapshot = manager.plugins_snapshot
if manager.options.cache_fine_grained or manager.options.fine_grained_incremental:
# If we are running a daemon or are going to write cache for further fine grained use,
# then we need to collect fine grained protocol dependencies.
# If we are running a daemon or are going to write cache for further fine-grained use,
# then we need to collect fine-grained protocol dependencies.
# Since these are a global property of the program, they are calculated after we
# processed the whole graph.
type_state.add_all_protocol_deps(manager.fg_deps)
Expand Down Expand Up @@ -4166,10 +4209,8 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
buf = WriteBuffer()
graph_message.write(buf)
graph_data = buf.getvalue()
for worker in manager.workers:
buf = receive(worker.conn)
assert read_tag(buf) == ACK_MESSAGE
worker.conn.write_bytes(graph_data)
manager.wait_ack()
manager.broadcast(graph_data)

sccs = sorted_components(graph)
manager.log(
Expand All @@ -4187,13 +4228,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
buf = WriteBuffer()
sccs_message.write(buf)
sccs_data = buf.getvalue()
for worker in manager.workers:
buf = receive(worker.conn)
assert read_tag(buf) == ACK_MESSAGE
worker.conn.write_bytes(sccs_data)
for worker in manager.workers:
buf = receive(worker.conn)
assert read_tag(buf) == ACK_MESSAGE
manager.wait_ack()
manager.broadcast(sccs_data)
manager.wait_ack()

manager.free_workers = set(range(manager.options.num_workers))

Expand Down
2 changes: 2 additions & 0 deletions mypy/build_worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
except CompileError as blocker:
send(server, SccResponseMessage(scc_id=scc_id, blocker=blocker))
else:
t1 = time.time()
send(server, SccResponseMessage(scc_id=scc_id, result=result))
manager.add_stats(scc_send_time=time.time() - t1)
manager.add_stats(total_process_stale_time=time.time() - t0, stale_sccs_processed=1)


Expand Down
10 changes: 7 additions & 3 deletions mypy/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
_IPCHandle = socket.socket

# Size of the message packed as !L, i.e. 4 bytes in network order (big-endian).
HEADER_SIZE = 4
HEADER_SIZE: Final = 4

# This is Linux default socket buffer size (for 64 bit), so we will not
# introduce an additional obstacle when exchanging a large IPC message.
MAX_READ: Final = 212992


# TODO: we should make sure consistent exceptions are raised on different platforms.
Expand Down Expand Up @@ -80,10 +84,10 @@ def frame_from_buffer(self) -> bytes | None:
self.message_size = None
return bytes(bdata)

def read(self, size: int = 100000) -> str:
def read(self, size: int = MAX_READ) -> str:
return self.read_bytes(size).decode("utf-8")

def read_bytes(self, size: int = 100000) -> bytes:
def read_bytes(self, size: int = MAX_READ) -> bytes:
"""Read bytes from an IPC connection until we have a full frame."""
if sys.platform == "win32":
while True:
Expand Down
2 changes: 1 addition & 1 deletion mypy/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def close(self) -> None:
def connect_db(db_file: str) -> sqlite3.Connection:
import sqlite3.dbapi2

db = sqlite3.dbapi2.connect(db_file)
db = sqlite3.dbapi2.connect(db_file, check_same_thread=False)
# This is a bit unfortunate (as we may get corrupt cache after e.g. Ctrl + C),
# but without this flag, commits are *very* slow, especially when using HDDs,
# see https://www.sqlite.org/faq.html#q19 for details.
Expand Down
Loading