Skip to content

Commit e03d927

Browse files
add mock executor; fix loader; adapt unit tests
1 parent 6b8c2e5 commit e03d927

5 files changed

Lines changed: 91 additions & 41 deletions

File tree

sqlmesh/core/loader.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from __future__ import annotations
22

33
import abc
4+
import concurrent.futures
45
import glob
56
import itertools
67
import linecache
78
import multiprocessing as mp
89
import os
910
import re
1011
import typing as t
12+
import concurrent
1113
from collections import Counter, defaultdict
1214
from dataclasses import dataclass
1315
from pathlib import Path
14-
from concurrent.futures import ProcessPoolExecutor, as_completed
1516

1617
from sqlglot.errors import SqlglotError
1718
from sqlglot import exp
@@ -310,11 +311,18 @@ def _load_external_models(
310311
# external models with no explicit gateway defined form the base set
311312
for model in external_models:
312313
if model.gateway is None:
314+
<<<<<<< HEAD
313315
if model.fqn in models:
314316
self._raise_failed_to_load_model_error(
315317
path, f"Duplicate external model name: '{model.name}'."
316318
)
317319
models[model.fqn] = model
320+
=======
321+
try:
322+
models[model.fqn] = model
323+
except Exception as ex:
324+
raise ConfigError(f"Failed to add model: {model.fqn}\n\n{ex}")
325+
>>>>>>> 80487e4f (add mock executor; fix loader; adapt unit tests)
318326

319327
# however, if there is a gateway defined, gateway-specific models take precedence
320328
if gateway:
@@ -473,20 +481,15 @@ def _load_models(
473481
audits into a Dict and creates the dag
474482
"""
475483
cache = SqlMeshLoader._Cache(self, self.config_path)
476-
import time
477484

478-
now = time.time()
479485
sql_models = self._load_sql_models(macros, jinja_macros, audits, signals, cache, gateway)
480-
print("sql models", time.time() - now)
481-
now = time.time()
482486
external_models = self._load_external_models(audits, cache, gateway)
483-
print("external models", time.time() - now)
484487
python_models = self._load_python_models(macros, jinja_macros, audits, signals)
485488

486489
all_model_names = list(sql_models) + list(external_models) + list(python_models)
487490
duplicates = [name for name, count in Counter(all_model_names).items() if count > 1]
488491
if duplicates:
489-
raise ValueError(f"Duplicate model name(s) found: {', '.join(duplicates)}.")
492+
raise ConfigError(f"Duplicate model name(s) found: {', '.join(duplicates)}.")
490493

491494
return UniqueKeyDict("models", **sql_models, **external_models, **python_models)
492495

@@ -501,8 +504,7 @@ def _load_sql_models(
501504
) -> UniqueKeyDict[str, Model]:
502505
"""Loads the sql models into a Dict"""
503506
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
504-
505-
paths = set()
507+
paths: t.Set[Path] = set()
506508

507509
for path in self._glob_paths(
508510
self.config_path / c.MODELS,
@@ -517,14 +519,11 @@ def _load_sql_models(
517519

518520
for path in paths.copy():
519521
cached_models = cache.get(path)
520-
521522
if cached_models:
522523
paths.remove(path)
523-
524524
for model in cached_models:
525-
models[model.fqn] = model
526-
527-
error = False
525+
if model.enabled:
526+
models[model.fqn] = model
528527

529528
if paths:
530529
defaults = dict(
@@ -545,31 +544,31 @@ def _load_sql_models(
545544
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
546545
)
547546

548-
with ProcessPoolExecutor(
547+
errors: t.List[str] = []
548+
with concurrent.futures.ProcessPoolExecutor(
549549
mp_context=mp.get_context("fork"),
550550
initializer=_init_model_defaults,
551551
initargs=(self.config, gateway, defaults, cache),
552552
max_workers=c.MAX_FORK_WORKERS,
553553
) as pool:
554-
for fut in as_completed(pool.submit(load_sql_models, path) for path in paths):
554+
futures_to_paths = {pool.submit(load_sql_models, path): path for path in paths}
555+
for fut, path in futures_to_paths.items():
555556
try:
556-
path, loaded = fut.result()
557-
557+
_, loaded = fut.result()
558558
if loaded:
559559
for model in loaded:
560-
model._path = path
561-
models[model.fqn] = model
560+
if model.enabled:
561+
model._path = path
562+
models[model.fqn] = model
562563
else:
563564
for model in cache.get(path):
564-
models[model.fqn] = model
565+
if model.enabled:
566+
models[model.fqn] = model
565567
except Exception as ex:
566-
self._console.log_error(
567-
f"Failed to load model definition at '{path}'.\n{ex}"
568-
)
569-
error = True
568+
errors.append(f"Failed to load model definition at '{path}'.\n\n{ex}")
570569

571-
if error:
572-
raise ConfigError("Failed to load models")
570+
if errors:
571+
raise ConfigError(f"Failed to load models\n\n{'\n'.join(errors)}")
573572

574573
return models
575574

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,3 +506,15 @@ def _make_function(table_name: str, random_id: str) -> exp.Table:
506506
return temp_table
507507

508508
return _make_function
509+
510+
511+
@pytest.fixture(autouse=True)
512+
def patch_process_pool_executor(mocker: MockerFixture, request):
513+
"""Patch ProcessPoolExecutor with MockProcessPoolExecutor in all tests except test_forking.py."""
514+
# Skip mocking for test_forking.py
515+
if request.node.fspath.basename == "test_forking.py":
516+
return
517+
518+
from tests.mock_executor import MockProcessPoolExecutor
519+
520+
mocker.patch("concurrent.futures.ProcessPoolExecutor", MockProcessPoolExecutor)

tests/core/test_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ def test_duplicate_model_names_different_kind(tmp_path: Path, sample_models):
8585
path_3.write_text(model_3["contents"])
8686

8787
with pytest.raises(
88-
ValueError, match=r'Duplicate model name\(s\) found: "memory"."test_schema"."test_model".'
88+
ConfigError, match=r'Duplicate model name\(s\) found: "memory"."test_schema"."test_model".'
8989
):
9090
Context(paths=tmp_path, config=config)
9191

9292

9393
@pytest.mark.parametrize("sample_models", ["sql", "external"], indirect=True)
9494
def test_duplicate_model_names_same_kind(tmp_path: Path, sample_models):
95-
"""Test same (SQL and external) models with duplicate model names raises ValueError."""
95+
"""Test same (SQL and external) models with duplicate model names raises ConfigError."""
9696

9797
def duplicate_model_path(fpath):
9898
return Path(fpath).parent / ("duplicate" + Path(fpath).suffix)

tests/core/test_model.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,26 +2941,33 @@ def test_model_cache(tmp_path: Path, mocker: MockerFixture):
29412941
expressions = d.parse(
29422942
"""
29432943
MODEL (
2944-
name db.seed,
2944+
name db.model_sql,
29452945
);
29462946
SELECT 1, ds;
29472947
"""
29482948
)
29492949

29502950
model = load_sql_based_model([e for e in expressions if e])
29512951

2952-
loader = mocker.Mock(return_value=[model])
2953-
2954-
assert cache.get_or_load("test_model", "test_entry_a", loader=loader)[0].dict() == model.dict()
2955-
assert cache.get_or_load("test_model", "test_entry_a", loader=loader)[0].dict() == model.dict()
2952+
assert cache.put([model], "test_model", "test_entry_a")
2953+
assert cache.get("test_model", "test_entry_a")[0].dict() == model.dict()
29562954

2957-
assert cache.get_or_load("test_model", "test_entry_b", loader=loader)[0].dict() == model.dict()
2958-
assert cache.get_or_load("test_model", "test_entry_b", loader=loader)[0].dict() == model.dict()
2955+
expressions = d.parse(
2956+
"""
2957+
MODEL (
2958+
name db.model_seed,
2959+
kind SEED (
2960+
path '../seeds/waiter_names.csv',
2961+
),
2962+
);
2963+
"""
2964+
)
29592965

2960-
assert cache.get_or_load("test_model", "test_entry_a", loader=loader)[0].dict() == model.dict()
2961-
assert cache.get_or_load("test_model", "test_entry_a", loader=loader)[0].dict() == model.dict()
2966+
seed_model = load_sql_based_model(
2967+
expressions, path=Path("./examples/sushi/models/test_model.sql")
2968+
)
29622969

2963-
assert loader.call_count == 2
2970+
assert not cache.put([seed_model], "test_model", "test_entry_b")
29642971

29652972

29662973
@pytest.mark.slow
@@ -2983,7 +2990,7 @@ def test_model_cache_gateway(tmp_path: Path, mocker: MockerFixture):
29832990
assert patched_cache_put.call_count == 0
29842991

29852992
Context(paths=tmp_path, config=config, gateway="secondary")
2986-
assert patched_cache_put.call_count == 4
2993+
assert patched_cache_put.call_count == 2
29872994

29882995

29892996
@pytest.mark.slow
@@ -3001,7 +3008,7 @@ def test_model_cache_default_catalog(tmp_path: Path, mocker: MockerFixture):
30013008
PropertyMock(return_value=None),
30023009
):
30033010
Context(paths=tmp_path)
3004-
assert patched_cache_put.call_count == 4
3011+
assert patched_cache_put.call_count == 2
30053012

30063013

30073014
def test_model_ctas_query():

tests/mock_executor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from concurrent.futures import Future
2+
3+
4+
class MockProcessPoolExecutor:
5+
"""A mock implementation of ProcessPoolExecutor for use in tests.
6+
7+
This executor runs functions synchronously in the same process, avoiding the issues
8+
with forking in test environments.
9+
"""
10+
11+
def __init__(self, max_workers=None, mp_context=None, initializer=None, initargs=()):
12+
if initializer is not None:
13+
try:
14+
initializer(*initargs)
15+
except BaseException as ex:
16+
raise RuntimeError(f"Exception in initializer: {ex}")
17+
18+
def __enter__(self):
19+
return self
20+
21+
def __exit__(self, *args):
22+
return True
23+
24+
def submit(self, fn, *args, **kwargs):
25+
"""Execute the function synchronously and return a Future with the result."""
26+
future = Future()
27+
try:
28+
result = fn(*args, **kwargs)
29+
future.set_result(result)
30+
except Exception as e:
31+
future.set_exception(e)
32+
return future

0 commit comments

Comments
 (0)