diff --git a/mypy/build.py b/mypy/build.py index 96ba59dd1095..23396c76a8b8 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -29,6 +29,7 @@ from collections.abc import Callable, Iterator, Mapping, Sequence, Set as AbstractSet from heapq import heappop, heappush from textwrap import dedent +from threading import Thread from typing import ( TYPE_CHECKING, Any, @@ -371,6 +372,7 @@ def default_flush_errors( extra_plugins = extra_plugins or [] workers = [] + connect_threads = [] if options.num_workers > 0: # TODO: switch to something more efficient than pickle (also in the daemon). pickled_options = pickle.dumps(options.snapshot()) @@ -383,10 +385,17 @@ def default_flush_errors( buf = WriteBuffer() sources_message.write(buf) sources_data = buf.getvalue() + + def connect(wc: WorkerClient, data: bytes) -> None: + # Start loading sources in each worker as soon as it is up. + wc.connect() + wc.conn.write_bytes(data) + + # We don't wait for workers to be ready until they are actually needed. for worker in workers: - # Start loading graph in each worker as soon as it is up. - worker.connect() - worker.conn.write_bytes(sources_data) + thread = Thread(target=connect, args=(worker, sources_data)) + thread.start() + connect_threads.append(thread) try: result = build_inner( @@ -399,6 +408,7 @@ def default_flush_errors( stderr, extra_plugins, workers, + connect_threads, ) result.errors = messages return result @@ -412,6 +422,10 @@ def default_flush_errors( e.messages = messages raise finally: + # In case of an early crash it is better to wait for workers to become ready, and + # shut them down cleanly. Otherwise, they will linger until connection timeout. + for thread in connect_threads: + thread.join() for worker in workers: try: send(worker.conn, SccRequestMessage(scc_id=None, import_errors={}, mod_data={})) @@ -431,6 +445,7 @@ def build_inner( stderr: TextIO, extra_plugins: Sequence[Plugin], workers: list[WorkerClient], + connect_threads: list[Thread], ) -> BuildResult: if platform.python_implementation() == "CPython": # Run gc less frequently, as otherwise we can spend a large fraction of @@ -486,7 +501,7 @@ def build_inner( reset_global_state() try: - graph = dispatch(sources, manager, stdout) + graph = dispatch(sources, manager, stdout, connect_threads) if not options.fine_grained_incremental: type_state.reset_all_subtype_caches() if options.timing_stats is not None: @@ -496,9 +511,7 @@ def build_inner( warn_unused_configs(options, flush_errors) return BuildResult(manager, graph) finally: - t0 = time.time() - manager.metastore.commit() - manager.add_stats(cache_commit_time=time.time() - t0) + manager.commit() manager.log( "Build finished in %.3f seconds with %d modules, and %d errors" % ( @@ -1119,6 +1132,11 @@ def report_file( if self.reports is not None and self.source_set.is_source(file): self.reports.file(file, self.modules, type_map, options) + def commit(self) -> None: + t0 = time.time() + self.metastore.commit() + self.add_stats(cache_commit_time=time.time() - t0) + def verbosity(self) -> int: return self.options.verbosity @@ -1156,6 +1174,24 @@ def add_stats(self, **kwds: Any) -> None: def stats_summary(self) -> Mapping[str, object]: return self.stats + def broadcast(self, message: bytes) -> None: + """Broadcast same message to all workers in parallel.""" + t0 = time.time() + threads = [] + for worker in self.workers: + thread = Thread(target=worker.conn.write_bytes, args=(message,)) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + self.add_stats(broadcast_time=time.time() - t0) + + def wait_ack(self) -> None: + """Wait for an ack from all workers.""" + for worker in self.workers: + buf = receive(worker.conn) + assert read_tag(buf) == ACK_MESSAGE + def submit(self, graph: Graph, sccs: list[SCC]) -> None: """Submit a stale SCC for processing in current process or parallel workers.""" if self.workers: @@ -1176,6 +1212,7 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None for mod_id in scc.mod_ids if (path := graph[mod_id].xpath) in self.errors.recorded } + t0 = time.time() send( self.workers[idx].conn, SccRequestMessage( @@ -1193,6 +1230,7 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None }, ), ) + self.add_stats(scc_send_time=time.time() - t0) def wait_for_done( self, graph: Graph @@ -1221,7 +1259,10 @@ def wait_for_done_workers( done_sccs = [] results = {} - for idx in ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT): + t0 = time.time() + ready = ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT) + t1 = time.time() + for idx in ready: buf = receive(self.workers[idx].conn) assert read_tag(buf) == SCC_RESPONSE_MESSAGE data = SccResponseMessage.read(buf) @@ -1232,6 +1273,7 @@ def wait_for_done_workers( assert data.result is not None results.update(data.result) done_sccs.append(self.scc_by_id[scc_id]) + self.add_stats(scc_wait_time=t1 - t0, scc_receive_time=time.time() - t1) self.submit_to_workers(graph) # advance after some workers are free. return ( done_sccs, @@ -3685,7 +3727,12 @@ def log_configuration(manager: BuildManager, sources: list[BuildSource]) -> None # The driver -def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO) -> Graph: +def dispatch( + sources: list[BuildSource], + manager: BuildManager, + stdout: TextIO, + connect_threads: list[Thread], +) -> Graph: log_configuration(manager, sources) t0 = time.time() @@ -3742,7 +3789,7 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO) dump_graph(graph, stdout) return graph - # Fine grained dependencies that didn't have an associated module in the build + # Fine-grained dependencies that didn't have an associated module in the build # are serialized separately, so we read them after we load the graph. # We need to read them both for running in daemon mode and if we are generating # a fine-grained cache (so that we can properly update them incrementally). @@ -3755,25 +3802,28 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO) if fg_deps_meta is not None: manager.fg_deps_meta = fg_deps_meta elif manager.stats.get("fresh_metas", 0) > 0: - # Clear the stats so we don't infinite loop because of positive fresh_metas + # Clear the stats, so we don't infinite loop because of positive fresh_metas manager.stats.clear() # There were some cache files read, but no fine-grained dependencies loaded. manager.log("Error reading fine-grained dependencies cache -- aborting cache load") manager.cache_enabled = False manager.log("Falling back to full run -- reloading graph...") - return dispatch(sources, manager, stdout) + return dispatch(sources, manager, stdout, connect_threads) # If we are loading a fine-grained incremental mode cache, we # don't want to do a real incremental reprocess of the # graph---we'll handle it all later. if not manager.use_fine_grained_cache(): + # Wait for workers since they may be needed at this point. + for thread in connect_threads: + thread.join() process_graph(graph, manager) # Update plugins snapshot. write_plugins_snapshot(manager) manager.old_plugins_snapshot = manager.plugins_snapshot if manager.options.cache_fine_grained or manager.options.fine_grained_incremental: - # If we are running a daemon or are going to write cache for further fine grained use, - # then we need to collect fine grained protocol dependencies. + # If we are running a daemon or are going to write cache for further fine-grained use, + # then we need to collect fine-grained protocol dependencies. # Since these are a global property of the program, they are calculated after we # processed the whole graph. type_state.add_all_protocol_deps(manager.fg_deps) @@ -4166,10 +4216,8 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: buf = WriteBuffer() graph_message.write(buf) graph_data = buf.getvalue() - for worker in manager.workers: - buf = receive(worker.conn) - assert read_tag(buf) == ACK_MESSAGE - worker.conn.write_bytes(graph_data) + manager.wait_ack() + manager.broadcast(graph_data) sccs = sorted_components(graph) manager.log( @@ -4187,13 +4235,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: buf = WriteBuffer() sccs_message.write(buf) sccs_data = buf.getvalue() - for worker in manager.workers: - buf = receive(worker.conn) - assert read_tag(buf) == ACK_MESSAGE - worker.conn.write_bytes(sccs_data) - for worker in manager.workers: - buf = receive(worker.conn) - assert read_tag(buf) == ACK_MESSAGE + manager.wait_ack() + manager.broadcast(sccs_data) + manager.wait_ack() manager.free_workers = set(range(manager.options.num_workers)) diff --git a/mypy/build_worker/worker.py b/mypy/build_worker/worker.py index 66cfec6f6a36..b08662e453a0 100644 --- a/mypy/build_worker/worker.py +++ b/mypy/build_worker/worker.py @@ -45,10 +45,10 @@ process_stale_scc, ) from mypy.cache import Tag, read_int_opt -from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT +from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT, WORKER_IDLE_TIMEOUT from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error from mypy.fscache import FileSystemCache -from mypy.ipc import IPCException, IPCServer, receive, send +from mypy.ipc import IPCException, IPCServer, ready_to_read, receive, send from mypy.modulefinder import BuildSource, BuildSourceSet, compute_search_paths from mypy.nodes import FileRawData from mypy.options import Options @@ -170,9 +170,13 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: # Notify coordinator we are ready to start processing SCCs. send(server, AckMessage()) while True: + t0 = time.time() + ready_to_read([server], WORKER_IDLE_TIMEOUT) + t1 = time.time() buf = receive(server) assert read_tag(buf) == SCC_REQUEST_MESSAGE scc_message = SccRequestMessage.read(buf) + manager.add_stats(scc_wait_time=t1 - t0, scc_receive_time=time.time() - t1) scc_id = scc_message.scc_id if scc_id is None: manager.dump_stats() @@ -193,11 +197,13 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: gc.enable() result = process_stale_scc(graph, scc, manager, from_cache=graph_data.from_cache) # We must commit after each SCC, otherwise we break --sqlite-cache. - manager.metastore.commit() + manager.commit() except CompileError as blocker: send(server, SccResponseMessage(scc_id=scc_id, blocker=blocker)) else: + t1 = time.time() send(server, SccResponseMessage(scc_id=scc_id, result=result)) + manager.add_stats(scc_send_time=time.time() - t1) manager.add_stats(total_process_stale_time=time.time() - t0, stale_sccs_processed=1) diff --git a/mypy/defaults.py b/mypy/defaults.py index a39a577d03ac..749879861fbf 100644 --- a/mypy/defaults.py +++ b/mypy/defaults.py @@ -48,4 +48,5 @@ WORKER_START_INTERVAL: Final = 0.01 WORKER_START_TIMEOUT: Final = 3 WORKER_CONNECTION_TIMEOUT: Final = 10 +WORKER_IDLE_TIMEOUT: Final = 600 WORKER_DONE_TIMEOUT: Final = 600 diff --git a/mypy/ipc.py b/mypy/ipc.py index 29710cd57a91..2911d5f77db2 100644 --- a/mypy/ipc.py +++ b/mypy/ipc.py @@ -13,7 +13,7 @@ import sys import tempfile from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Sequence from select import select from types import TracebackType from typing import Final @@ -38,7 +38,11 @@ _IPCHandle = socket.socket # Size of the message packed as !L, i.e. 4 bytes in network order (big-endian). -HEADER_SIZE = 4 +HEADER_SIZE: Final = 4 + +# This is Linux default socket buffer size (for 64 bit), so we will not +# introduce an additional obstacle when exchanging a large IPC message. +MAX_READ: Final = 212992 # TODO: we should make sure consistent exceptions are raised on different platforms. @@ -80,10 +84,10 @@ def frame_from_buffer(self) -> bytes | None: self.message_size = None return bytes(bdata) - def read(self, size: int = 100000) -> str: + def read(self, size: int = MAX_READ) -> str: return self.read_bytes(size).decode("utf-8") - def read_bytes(self, size: int = 100000) -> bytes: + def read_bytes(self, size: int = MAX_READ) -> bytes: """Read bytes from an IPC connection until we have a full frame.""" if sys.platform == "win32": while True: @@ -215,6 +219,10 @@ def __init__(self, name: str, timeout: float | None) -> None: ) else: self.connection = socket.socket(socket.AF_UNIX) + # This is already default on Linux, we set same buffer size + # for macOS vs Linux consistency to simplify reasoning. + self.connection.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, MAX_READ) + self.connection.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, MAX_READ) self.connection.settimeout(timeout) self.connection.connect(name) @@ -291,6 +299,10 @@ def __enter__(self) -> IPCServer: else: try: self.connection, _ = self.sock.accept() + # This is already default on Linux, we set same buffer size + # for macOS vs Linux consistency to simplify reasoning. + self.connection.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, MAX_READ) + self.connection.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, MAX_READ) except TimeoutError as e: raise IPCException("The socket timed out") from e return self @@ -361,7 +373,7 @@ def read_status(status_file: str) -> dict[str, object]: return data -def ready_to_read(conns: list[IPCClient], timeout: float | None = None) -> list[int]: +def ready_to_read(conns: Sequence[IPCBase], timeout: float | None = None) -> list[int]: """Wait until some connections are readable. Return index of each readable connection in the original list. diff --git a/mypy/metastore.py b/mypy/metastore.py index 3d32ba29ae10..1a2a7b335e72 100644 --- a/mypy/metastore.py +++ b/mypy/metastore.py @@ -157,7 +157,7 @@ def close(self) -> None: def connect_db(db_file: str) -> sqlite3.Connection: import sqlite3.dbapi2 - db = sqlite3.dbapi2.connect(db_file) + db = sqlite3.dbapi2.connect(db_file, check_same_thread=False) # This is a bit unfortunate (as we may get corrupt cache after e.g. Ctrl + C), # but without this flag, commits are *very* slow, especially when using HDDs, # see https://www.sqlite.org/faq.html#q19 for details.