Skip to content

Commit 3b0f825

Browse files
committed
feat: add go to definition to lsp
1 parent 3271ae1 commit 3b0f825

6 files changed

Lines changed: 289 additions & 17 deletions

File tree

sqlmesh/lsp/context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from collections import defaultdict
2+
from pathlib import Path
3+
from sqlmesh.core.context import Context
4+
import typing as t
5+
6+
7+
class LSPContext:
8+
"""
9+
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
10+
"""
11+
12+
def __init__(self, context: Context) -> None:
13+
self.context = context
14+
map: t.Dict[str, t.List[str]] = defaultdict(list)
15+
for model in context.models.values():
16+
if model._path is not None:
17+
path = Path(model._path).resolve()
18+
map[f"file://{path.as_posix()}"].append(model.name)
19+
self.map = map

sqlmesh/lsp/main.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
22
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33

4-
from collections import defaultdict
54
import logging
65
import typing as t
76
from pathlib import Path
@@ -12,21 +11,8 @@
1211
from sqlmesh._version import __version__
1312
from sqlmesh.core.context import Context
1413
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
15-
16-
17-
class LSPContext:
18-
"""
19-
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
20-
"""
21-
22-
def __init__(self, context: Context) -> None:
23-
self.context = context
24-
map: t.Dict[str, t.List[str]] = defaultdict(list)
25-
for model in context.models.values():
26-
if model._path is not None:
27-
path = Path(model._path).resolve()
28-
map[f"file://{path.as_posix()}"].append(model.name)
29-
self.map = map
14+
from sqlmesh.lsp.context import LSPContext
15+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
3016

3117

3218
class SQLMeshLanguageServer:
@@ -144,6 +130,43 @@ def formatting(
144130
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
145131
return []
146132

133+
@self.server.feature(types.TEXT_DOCUMENT_DEFINITION)
134+
def goto_definition(
135+
ls: LanguageServer, params: types.DefinitionParams
136+
) -> t.List[types.LocationLink]:
137+
"""Jump to an object's definition."""
138+
try:
139+
self._ensure_context_for_document(params.text_document.uri)
140+
document = ls.workspace.get_document(params.text_document.uri)
141+
if self.lsp_context is None:
142+
raise RuntimeError(f"No context found for document: {document.path}")
143+
144+
references = get_model_definitions_for_a_path(
145+
self.lsp_context, params.text_document.uri
146+
)
147+
if len(references) == 0:
148+
return []
149+
150+
return [
151+
types.LocationLink(
152+
target_uri=reference.uri,
153+
target_selection_range=types.Range(
154+
start=types.Position(line=0, character=0),
155+
end=types.Position(line=0, character=0),
156+
),
157+
target_range=types.Range(
158+
start=types.Position(line=0, character=0),
159+
end=types.Position(line=0, character=0),
160+
),
161+
origin_selection_range=reference.range,
162+
)
163+
for reference in references
164+
]
165+
166+
except Exception as e:
167+
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
168+
return []
169+
147170
def _context_get_or_load(self, document_uri: str) -> LSPContext:
148171
if self.lsp_context is None:
149172
self._ensure_context_for_document(document_uri)

sqlmesh/lsp/reference.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from pathlib import Path
2+
3+
from lsprotocol.types import Range, Position
4+
import typing as t
5+
6+
from sqlmesh.core.dialect import normalize_model_name
7+
from sqlmesh.core.model.definition import SqlModel
8+
from sqlmesh.lsp.context import LSPContext
9+
from sqlglot import exp
10+
11+
from sqlmesh.utils.pydantic import PydanticModel
12+
13+
14+
class Reference(PydanticModel):
15+
range: Range
16+
uri: str
17+
18+
19+
def get_model_definitions_for_a_path(
20+
lint_context: LSPContext, document_uri: str
21+
) -> t.List[Reference]:
22+
"""
23+
Get the model references for a given path.
24+
25+
Works for models and audits.
26+
Works for targeting sql and python models.
27+
28+
Steps:
29+
- Get the parsed query
30+
- Find all table objects using find_all exp.Table
31+
- Match the string against all model names
32+
- Need to normalize it before matching
33+
- Try get_model before normalization
34+
- Match to models that the model refers to
35+
"""
36+
# Ensure the path is a sql model
37+
if not document_uri.endswith(".sql"):
38+
return []
39+
40+
# Get the model
41+
models = lint_context.map[document_uri]
42+
if models is None:
43+
return []
44+
if len(models) == 0:
45+
return []
46+
model_name = models[0]
47+
model = lint_context.context.get_model(model_or_snapshot=model_name, raise_if_missing=False)
48+
if model is None:
49+
return []
50+
if not isinstance(model, SqlModel):
51+
return []
52+
53+
# Find all possible references
54+
references = []
55+
for table in model.query.find_all(exp.Table):
56+
depends_on = model.depends_on
57+
58+
# Normalize the table reference
59+
reference_name = table.sql(dialect=model.dialect)
60+
normalized_reference_name = normalize_model_name(
61+
reference_name,
62+
default_catalog=lint_context.context.default_catalog,
63+
dialect=model.dialect,
64+
)
65+
if normalized_reference_name not in depends_on:
66+
continue
67+
68+
# Get the referenced model uri
69+
referenced_model = lint_context.context.get_model(
70+
model_or_snapshot=normalized_reference_name, raise_if_missing=False
71+
)
72+
if referenced_model is None:
73+
continue
74+
# Get the model uri
75+
referenced_model_path = referenced_model._path
76+
if referenced_model_path is None:
77+
continue
78+
# Fully qualify the path in case
79+
path = Path.resolve(Path(referenced_model_path))
80+
referenced_model_uri = f"file://{path}"
81+
read_file = open(path, "r").readlines()
82+
83+
# Extract metadata for positioning
84+
table_meta = TokenPositionDetails.from_meta(table.this.meta)
85+
table_range = _range_from_token_position_details(table_meta, read_file)
86+
start_pos = table_range.start
87+
end_pos = table_range.end
88+
89+
# If there's a catalog or database qualifier, adjust the start position
90+
catalog_or_db = table.args.get("catalog") or table.args.get("db")
91+
if catalog_or_db is not None:
92+
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
93+
catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file)
94+
start_pos = catalog_or_db_range.start
95+
96+
references.append(
97+
Reference(uri=referenced_model_uri, range=Range(start=start_pos, end=end_pos))
98+
)
99+
100+
return references
101+
102+
103+
class TokenPositionDetails(PydanticModel):
104+
"""
105+
Details about a token's position in the source code.
106+
107+
Attributes:
108+
line (int): The line that the token ends on.
109+
col (int): The column that the token ends on.
110+
start (int): The start index of the token.
111+
end (int): The ending index of the token.
112+
"""
113+
114+
line: int
115+
col: int
116+
start: int
117+
end: int
118+
119+
@staticmethod
120+
def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
121+
return TokenPositionDetails(
122+
line=meta["line"],
123+
col=meta["col"],
124+
start=meta["start"],
125+
end=meta["end"],
126+
)
127+
128+
129+
def _range_from_token_position_details(
130+
token_position_details: TokenPositionDetails, read_file: t.List[str]
131+
) -> Range:
132+
"""
133+
Convert a TokenPositionDetails object to a Range object.
134+
135+
:param token_position_details: Details about a token's position
136+
:param read_file: List of lines from the file
137+
:return: A Range object representing the token's position
138+
"""
139+
# Convert from 1-indexed to 0-indexed for line and column
140+
end_line_0 = token_position_details.line - 1
141+
end_col_0 = token_position_details.col
142+
143+
# Find the start line and column by counting backwards from the end position
144+
start_pos = token_position_details.start
145+
end_pos = token_position_details.end
146+
147+
# Initialize with the end position
148+
start_line_0 = end_line_0
149+
start_col_0 = end_col_0 - (end_pos - start_pos + 1)
150+
151+
# If start_col_0 is negative, we need to go back to previous lines
152+
while start_col_0 < 0 and start_line_0 > 0:
153+
start_line_0 -= 1
154+
start_col_0 += len(read_file[start_line_0])
155+
# Account for newline character
156+
if start_col_0 >= 0:
157+
break
158+
start_col_0 += 1 # For the newline character
159+
160+
# Ensure we don't have negative values
161+
start_col_0 = max(0, start_col_0)
162+
return Range(
163+
start=Position(line=start_line_0, character=start_col_0),
164+
end=Position(line=end_line_0, character=end_col_0),
165+
)

tests/lsp/test_context.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext
4+
5+
@pytest.mark.fast
6+
def test_lsp_context():
7+
context = Context(paths=["examples/sushi"])
8+
lsp_context = LSPContext(context)
9+
10+
assert lsp_context is not None
11+
assert lsp_context.context is not None
12+
assert lsp_context.map is not None
13+
14+
# find one model in the map
15+
active_customers_key = next(
16+
key for key in lsp_context.map.keys() if key.endswith("models/active_customers.sql")
17+
)
18+
assert lsp_context.map[active_customers_key] == ["sushi.active_customers"]

tests/lsp/test_reference.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext
4+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
5+
6+
@pytest.mark.fast
7+
def test_reference() -> None:
8+
context = Context(paths=["examples/sushi"])
9+
lsp_context = LSPContext(context)
10+
11+
active_customers_uri = next(
12+
uri for uri, models in lsp_context.map.items() if "sushi.active_customers" in models
13+
)
14+
sushi_customers_uri = next(
15+
uri for uri, models in lsp_context.map.items() if "sushi.customers" in models
16+
)
17+
18+
references = get_model_definitions_for_a_path(lsp_context, active_customers_uri)
19+
20+
assert len(references) == 1
21+
assert references[0].uri == sushi_customers_uri
22+
23+
# Check that the reference in the correct range is sushi.customers
24+
path = active_customers_uri.removeprefix("file://")
25+
read_file = open(path, "r").readlines()
26+
# Get the string range in the read file
27+
reference_range = references[0].range
28+
start_line = reference_range.start.line
29+
end_line = reference_range.end.line
30+
start_character = reference_range.start.character
31+
end_character = reference_range.end.character
32+
# Get the string from the file
33+
34+
# If the reference spans multiple lines, handle it accordingly
35+
if start_line == end_line:
36+
# Reference is on a single line
37+
line_content = read_file[start_line]
38+
referenced_text = line_content[start_character:end_character]
39+
else:
40+
# Reference spans multiple lines
41+
referenced_text = read_file[start_line][
42+
start_character:
43+
] # First line from start_character to end
44+
for line_num in range(start_line + 1, end_line): # Middle lines (if any)
45+
referenced_text += read_file[line_num]
46+
referenced_text += read_file[end_line][:end_character] # Last line up to end_character
47+
assert referenced_text == "sushi.customers"

vscode/extension/src/lsp/lsp.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export class LSPClient implements Disposable {
2727

2828
const sqlmesh = await sqlmesh_lsp_exec()
2929
if (isErr(sqlmesh)) {
30-
traceError(`Failed to get sqlmesh_lsp_exec, ${sqlmesh.error.type}`)
30+
traceError(`Failed to get sqlmesh_lsp_exec, ${JSON.stringify(sqlmesh.error)}`)
3131
return sqlmesh
3232
}
3333
const workspaceFolders = getWorkspaceFolders()

0 commit comments

Comments
 (0)