From 2adf9e6ed48b7288540a877d86d134e0224a2943 Mon Sep 17 00:00:00 2001 From: Ben <9087625+benfdking@users.noreply.github.com> Date: Wed, 11 Jun 2025 10:42:19 +0100 Subject: [PATCH 1/3] feat(lsp): include model descriptions in completions --- sqlmesh/lsp/completions.py | 62 ++++++---- sqlmesh/lsp/custom.py | 9 +- sqlmesh/lsp/main.py | 114 +++++++++++++----- vscode/extension/src/completion/completion.ts | 15 ++- vscode/extension/src/lsp/custom.ts | 13 +- 5 files changed, 158 insertions(+), 55 deletions(-) diff --git a/sqlmesh/lsp/completions.py b/sqlmesh/lsp/completions.py index 7e3781a550..ad7e0c2a81 100644 --- a/sqlmesh/lsp/completions.py +++ b/sqlmesh/lsp/completions.py @@ -4,6 +4,8 @@ from sqlmesh import macro import typing as t from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget +from sqlmesh.lsp.custom import ModelCompletion +from sqlmesh.lsp.description import generate_markdown_description from sqlmesh.lsp.uri import URI @@ -21,7 +23,9 @@ def get_sql_completions( # Get keywords from file content if provided file_keywords = set() if content: - file_keywords = extract_keywords_from_content(content, get_dialect(context, file_uri)) + file_keywords = extract_keywords_from_content( + content, get_dialect(context, file_uri) + ) # Combine keywords - SQL keywords first, then file keywords all_keywords = list(sql_keywords) + list(file_keywords - sql_keywords) @@ -33,7 +37,9 @@ def get_sql_completions( ) -def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]: +def get_models( + context: t.Optional[LSPContext], file_uri: t.Optional[URI] +) -> t.List[ModelCompletion]: """ Return a list of models for a given file. @@ -41,23 +47,23 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t. If there is a context, return a list of all models bar the ones the file itself defines. """ if context is None: - return set() + return [] - all_models = set() - # Extract model names from ModelInfo objects - for file_info in context.map.values(): - if isinstance(file_info, ModelTarget): - all_models.update(file_info.names) + current_path = file_uri.to_path() if file_uri is not None else None - # Remove models from the current file - path = file_uri.to_path() if file_uri is not None else None - if path is not None and path in context.map: - file_info = context.map[path] - if isinstance(file_info, ModelTarget): - for model in file_info.names: - all_models.discard(model) + completions: t.List[ModelCompletion] = [] + for model in context.context.models.values(): + if current_path is not None and model._path == current_path: + continue + description = None + try: + description = generate_markdown_description(model) + except Exception: + description = getattr(model, "description", None) - return all_models + completions.append(ModelCompletion(name=model.name, description=description)) + + return completions def get_macros( @@ -79,7 +85,9 @@ def get_macros( return [MacroCompletion(name=name, description=doc) for name, doc in macros.items()] -def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]: +def get_keywords( + context: t.Optional[LSPContext], file_uri: t.Optional[URI] +) -> t.Set[str]: """ Return a list of sql keywords for a given file. If no context is provided, return ANSI SQL keywords. @@ -90,7 +98,11 @@ def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> If both a context and a file_uri are provided, returns the keywords for the dialect of the model that the file belongs to. """ - if file_uri is not None and context is not None and file_uri.to_path() in context.map: + if ( + file_uri is not None + and context is not None + and file_uri.to_path() in context.map + ): file_info = context.map[file_uri.to_path()] # Handle ModelInfo objects @@ -133,11 +145,17 @@ def get_keywords_from_tokenizer(dialect: t.Optional[str] = None) -> t.Set[str]: return expanded_keywords -def get_dialect(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Optional[str]: +def get_dialect( + context: t.Optional[LSPContext], file_uri: t.Optional[URI] +) -> t.Optional[str]: """ Get the dialect for a given file. """ - if file_uri is not None and context is not None and file_uri.to_path() in context.map: + if ( + file_uri is not None + and context is not None + and file_uri.to_path() in context.map + ): file_info = context.map[file_uri.to_path()] # Handle ModelInfo objects @@ -158,7 +176,9 @@ def get_dialect(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t return None -def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None) -> t.Set[str]: +def extract_keywords_from_content( + content: str, dialect: t.Optional[str] = None +) -> t.Set[str]: """ Extract identifiers from SQL content using the tokenizer. diff --git a/sqlmesh/lsp/custom.py b/sqlmesh/lsp/custom.py index 5c8123b7a0..5d611729d9 100644 --- a/sqlmesh/lsp/custom.py +++ b/sqlmesh/lsp/custom.py @@ -30,12 +30,19 @@ class MacroCompletion(PydanticModel): description: t.Optional[str] = None +class ModelCompletion(PydanticModel): + """Information about a model for autocompletion.""" + + name: str + description: t.Optional[str] = None + + class AllModelsResponse(CustomMethodResponseBaseClass): """ Response to get all the models that are in the current project. """ - models: t.List[str] + models: t.List[ModelCompletion] keywords: t.List[str] macros: t.List[MacroCompletion] diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index b028ea7766..a1a3d212f1 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -115,7 +115,9 @@ def _create_lsp_context(self, paths: t.List[Path]) -> t.Optional[LSPContext]: return None # All the custom LSP methods are registered here and prefixed with _custom - def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse: + def _custom_all_models( + self, ls: LanguageServer, params: AllModelsRequest + ) -> AllModelsResponse: uri = URI(params.textDocument.uri) # Get the document content content = None @@ -147,7 +149,9 @@ def _custom_all_models_for_render( self._ensure_context_in_folder(current_path) if self.lsp_context is None: raise RuntimeError("No context found") - return AllModelsForRenderResponse(models=self.lsp_context.list_of_models_for_rendering()) + return AllModelsForRenderResponse( + models=self.lsp_context.list_of_models_for_rendering() + ) def _custom_format_project( self, ls: LanguageServer, params: FormatProjectRequest @@ -169,7 +173,9 @@ def _custom_format_project( def _custom_api( self, ls: LanguageServer, request: ApiRequest - ) -> t.Union[ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage]: + ) -> t.Union[ + ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage + ]: ls.log_trace(f"API request: {request}") if self.lsp_context is None: current_path = Path.cwd() @@ -191,7 +197,9 @@ def _custom_api( # /api/lineage/{model} model_name = urllib.parse.unquote(path_parts[2]) lineage = model_lineage(model_name, self.lsp_context.context) - non_set_lineage = {k: v for k, v in lineage.items() if v is not None} + non_set_lineage = { + k: v for k, v in lineage.items() if v is not None + } return ApiResponseGetLineage(data=non_set_lineage) if len(path_parts) == 4: @@ -200,7 +208,9 @@ def _custom_api( column = urllib.parse.unquote(path_parts[3]) models_only = False if hasattr(request, "params"): - models_only = bool(getattr(request.params, "models_only", False)) + models_only = bool( + getattr(request.params, "models_only", False) + ) column_lineage_response = column_lineage( model_name, column, models_only, self.lsp_context.context ) @@ -226,7 +236,9 @@ def _register_features(self) -> None: for name, method in self._supported_custom_methods.items(): def create_function_call(method_func: t.Callable) -> t.Callable: - def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]: + def function_call( + ls: LanguageServer, params: t.Any + ) -> t.Dict[str, t.Any]: try: response = method_func(ls, params) except Exception as e: @@ -243,7 +255,9 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None: try: # Check if the client supports pull diagnostics if params.capabilities and params.capabilities.text_document: - diagnostics = getattr(params.capabilities.text_document, "diagnostic", None) + diagnostics = getattr( + params.capabilities.text_document, "diagnostic", None + ) if diagnostics: self.client_supports_pull_diagnostics = True ls.log_trace("Client supports pull diagnostics") @@ -256,7 +270,8 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None: if params.workspace_folders: # Store all workspace folders for later use self.workspace_folders = [ - Path(self._uri_to_path(folder.uri)) for folder in params.workspace_folders + Path(self._uri_to_path(folder.uri)) + for folder in params.workspace_folders ] # Try to find a SQLMesh config file in any workspace folder (only at the root level) @@ -274,7 +289,9 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None: ) @self.server.feature(types.TEXT_DOCUMENT_DID_OPEN) - def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None: + def did_open( + ls: LanguageServer, params: types.DidOpenTextDocumentParams + ) -> None: uri = URI(params.text_document.uri) context = self._context_get_or_load(uri) models = context.map[uri.to_path()] @@ -292,7 +309,9 @@ def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> Non ) @self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE) - def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None: + def did_change( + ls: LanguageServer, params: types.DidChangeTextDocumentParams + ) -> None: uri = URI(params.text_document.uri) context = self._context_get_or_load(uri) models = context.map[uri.to_path()] @@ -310,7 +329,9 @@ def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> ) @self.server.feature(types.TEXT_DOCUMENT_DID_SAVE) - def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None: + def did_save( + ls: LanguageServer, params: types.DidSaveTextDocumentParams + ) -> None: uri = URI(params.text_document.uri) # Reload the entire context and create a new LSPContext @@ -344,7 +365,9 @@ def formatting( document = ls.workspace.get_text_document(params.text_document.uri) before = document.source if self.lsp_context is None: - raise RuntimeError(f"No context found for document: {document.path}") + raise RuntimeError( + f"No context found for document: {document.path}" + ) target = next( ( @@ -371,7 +394,9 @@ def formatting( start=types.Position(line=0, character=0), end=types.Position( line=len(document.lines), - character=len(document.lines[-1]) if document.lines else 0, + character=len(document.lines[-1]) + if document.lines + else 0, ), ), new_text=after, @@ -382,20 +407,27 @@ def formatting( return [] @self.server.feature(types.TEXT_DOCUMENT_HOVER) - def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hover]: + def hover( + ls: LanguageServer, params: types.HoverParams + ) -> t.Optional[types.Hover]: """Provide hover information for an object.""" try: uri = URI(params.text_document.uri) self._ensure_context_for_document(uri) document = ls.workspace.get_text_document(params.text_document.uri) if self.lsp_context is None: - raise RuntimeError(f"No context found for document: {document.path}") + raise RuntimeError( + f"No context found for document: {document.path}" + ) references = get_references(self.lsp_context, uri, params.position) if not references: return None reference = references[0] - if isinstance(reference, LSPCteReference) or not reference.markdown_description: + if ( + isinstance(reference, LSPCteReference) + or not reference.markdown_description + ): return None return types.Hover( contents=types.MarkupContent( @@ -440,7 +472,9 @@ def goto_definition( self._ensure_context_for_document(uri) document = ls.workspace.get_text_document(params.text_document.uri) if self.lsp_context is None: - raise RuntimeError(f"No context found for document: {document.path}") + raise RuntimeError( + f"No context found for document: {document.path}" + ) references = get_references(self.lsp_context, uri, params.position) location_links = [] @@ -484,7 +518,9 @@ def goto_definition( ) return location_links except Exception as e: - ls.show_message(f"Error getting references: {e}", types.MessageType.Error) + ls.show_message( + f"Error getting references: {e}", types.MessageType.Error + ) return [] @self.server.feature(types.TEXT_DOCUMENT_REFERENCES) @@ -497,16 +533,25 @@ def find_references( self._ensure_context_for_document(uri) document = ls.workspace.get_text_document(params.text_document.uri) if self.lsp_context is None: - raise RuntimeError(f"No context found for document: {document.path}") + raise RuntimeError( + f"No context found for document: {document.path}" + ) - all_references = get_all_references(self.lsp_context, uri, params.position) + all_references = get_all_references( + self.lsp_context, uri, params.position + ) # Convert references to Location objects - locations = [types.Location(uri=ref.uri, range=ref.range) for ref in all_references] + locations = [ + types.Location(uri=ref.uri, range=ref.range) + for ref in all_references + ] return locations if locations else None except Exception as e: - ls.show_message(f"Error getting locations: {e}", types.MessageType.Error) + ls.show_message( + f"Error getting locations: {e}", types.MessageType.Error + ) return None @self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC) @@ -519,7 +564,10 @@ def diagnostic( diagnostics, result_id = self._get_diagnostics_for_uri(uri) # Check if client provided a previous result ID - if hasattr(params, "previous_result_id") and params.previous_result_id == result_id: + if ( + hasattr(params, "previous_result_id") + and params.previous_result_id == result_id + ): # Return unchanged report if diagnostics haven't changed return types.RelatedUnchangedDocumentDiagnosticReport( kind=types.DocumentDiagnosticReportKind.Unchanged, @@ -568,7 +616,10 @@ def workspace_diagnostic( # Check if we have a previous result ID for this file previous_result_id = None - if hasattr(params, "previous_result_ids") and params.previous_result_ids: + if ( + hasattr(params, "previous_result_ids") + and params.previous_result_ids + ): for prev in params.previous_result_ids: if prev.uri == uri.value: previous_result_id = prev.value @@ -604,7 +655,9 @@ def workspace_diagnostic( @self.server.feature( types.TEXT_DOCUMENT_COMPLETION, - types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros + types.CompletionOptions( + trigger_characters=["@"] + ), # advertise "@" for macros ) def completion( ls: LanguageServer, params: types.CompletionParams @@ -630,9 +683,10 @@ def completion( for model in completion_response.models: completion_items.append( types.CompletionItem( - label=model, + label=model.name, kind=types.CompletionItemKind.Reference, detail="SQLMesh Model", + documentation=model.description, ) ) # Add macro completions @@ -675,7 +729,9 @@ def completion( get_sql_completions(None, URI(params.text_document.uri)) return None - def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic], int]: + def _get_diagnostics_for_uri( + self, uri: URI + ) -> t.Tuple[t.List[types.Diagnostic], int]: """Get diagnostics for a specific URI, returning (diagnostics, result_id). Since we no longer track version numbers, we always return 0 as the result_id. @@ -810,7 +866,9 @@ def _diagnostics_to_lsp_diagnostics( """ lsp_diagnostics = {} for diagnostic in diagnostics: - lsp_diagnostic = SQLMeshLanguageServer._diagnostic_to_lsp_diagnostic(diagnostic) + lsp_diagnostic = SQLMeshLanguageServer._diagnostic_to_lsp_diagnostic( + diagnostic + ) if lsp_diagnostic is not None: # Create a unique key combining message and range diagnostic_key = ( diff --git a/vscode/extension/src/completion/completion.ts b/vscode/extension/src/completion/completion.ts index 8e8a101c50..004817f583 100644 --- a/vscode/extension/src/completion/completion.ts +++ b/vscode/extension/src/completion/completion.ts @@ -19,10 +19,17 @@ export const completionProvider = ( if (isErr(result)) { return [] } - const modelCompletions = result.value.models.map( - model => - new vscode.CompletionItem(model, vscode.CompletionItemKind.Reference), - ) + const modelCompletions = result.value.models.map(model => { + const item = new vscode.CompletionItem( + model.name, + vscode.CompletionItemKind.Reference, + ) + item.detail = 'SQLMesh Model' + if (model.description) { + item.documentation = new vscode.MarkdownString(model.description) + } + return item + }) const keywordCompletions = result.value.keywords.map( keyword => new vscode.CompletionItem(keyword, vscode.CompletionItemKind.Keyword), diff --git a/vscode/extension/src/lsp/custom.ts b/vscode/extension/src/lsp/custom.ts index be11419b79..71f0c2eae3 100644 --- a/vscode/extension/src/lsp/custom.ts +++ b/vscode/extension/src/lsp/custom.ts @@ -25,6 +25,16 @@ export interface RenderModelEntry { rendered_query: string } +export interface ModelCompletion { + name: string + description: string | null | undefined +} + +export interface MacroCompletion { + name: string + description: string | null | undefined +} + // @eslint-disable-next-line @typescript-eslint/consistent-type-definition export type CustomLSPMethods = | AllModelsMethod @@ -41,8 +51,9 @@ interface AllModelsRequest { } interface AllModelsResponse { - models: string[] + models: ModelCompletion[] keywords: string[] + macros: MacroCompletion[] } export interface AbstractAPICallRequest { From 695d1eee6c55b5e9fb4535d9b3a5f93281e1eb92 Mon Sep 17 00:00:00 2001 From: Ben <9087625+benfdking@users.noreply.github.com> Date: Wed, 11 Jun 2025 13:12:43 +0100 Subject: [PATCH 2/3] feat(lsp): include model completion metadata --- sqlmesh/lsp/completions.py | 11 ++++++++--- sqlmesh/lsp/custom.py | 8 ++++---- sqlmesh/lsp/main.py | 2 +- tests/lsp/test_completions.py | 14 ++++++++++++++ vscode/extension/src/completion/completion.ts | 15 ++++----------- vscode/extension/src/lsp/custom.ts | 13 +------------ 6 files changed, 32 insertions(+), 31 deletions(-) diff --git a/sqlmesh/lsp/completions.py b/sqlmesh/lsp/completions.py index ad7e0c2a81..70cc3bcfe8 100644 --- a/sqlmesh/lsp/completions.py +++ b/sqlmesh/lsp/completions.py @@ -1,10 +1,13 @@ from functools import lru_cache from sqlglot import Dialect, Tokenizer -from sqlmesh.lsp.custom import AllModelsResponse, MacroCompletion +from sqlmesh.lsp.custom import ( + AllModelsResponse, + MacroCompletion, + ModelCompletion, +) from sqlmesh import macro import typing as t from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget -from sqlmesh.lsp.custom import ModelCompletion from sqlmesh.lsp.description import generate_markdown_description from sqlmesh.lsp.uri import URI @@ -30,8 +33,10 @@ def get_sql_completions( # Combine keywords - SQL keywords first, then file keywords all_keywords = list(sql_keywords) + list(file_keywords - sql_keywords) + models = list(get_models(context, file_uri)) return AllModelsResponse( - models=list(get_models(context, file_uri)), + models=[m.name for m in models], + model_completions=models, keywords=all_keywords, macros=list(get_macros(context, file_uri)), ) diff --git a/sqlmesh/lsp/custom.py b/sqlmesh/lsp/custom.py index 5d611729d9..618b4a44bc 100644 --- a/sqlmesh/lsp/custom.py +++ b/sqlmesh/lsp/custom.py @@ -38,11 +38,11 @@ class ModelCompletion(PydanticModel): class AllModelsResponse(CustomMethodResponseBaseClass): - """ - Response to get all the models that are in the current project. - """ + """Response to get all models that are in the current project.""" - models: t.List[ModelCompletion] + #: Deprecated: use ``model_completions`` instead + models: t.List[str] + model_completions: t.List[ModelCompletion] keywords: t.List[str] macros: t.List[MacroCompletion] diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index a1a3d212f1..b715a81b9c 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -680,7 +680,7 @@ def completion( completion_items = [] # Add model completions - for model in completion_response.models: + for model in completion_response.model_completions: completion_items.append( types.CompletionItem( label=model.name, diff --git a/tests/lsp/test_completions.py b/tests/lsp/test_completions.py index 7e193d77d6..e0772c1a96 100644 --- a/tests/lsp/test_completions.py +++ b/tests/lsp/test_completions.py @@ -41,6 +41,20 @@ def test_get_macros(): assert add_one_macro.description +def test_model_completions_include_descriptions(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + completions = LSPContext.get_completions(lsp_context, None) + + model_entry = next( + (m for m in completions.model_completions if m.name == "sushi.customers"), + None, + ) + assert model_entry is not None + assert model_entry.description + + def test_get_sql_completions_with_context_no_file_uri(): context = Context(paths=["examples/sushi"]) lsp_context = LSPContext(context) diff --git a/vscode/extension/src/completion/completion.ts b/vscode/extension/src/completion/completion.ts index 004817f583..8e8a101c50 100644 --- a/vscode/extension/src/completion/completion.ts +++ b/vscode/extension/src/completion/completion.ts @@ -19,17 +19,10 @@ export const completionProvider = ( if (isErr(result)) { return [] } - const modelCompletions = result.value.models.map(model => { - const item = new vscode.CompletionItem( - model.name, - vscode.CompletionItemKind.Reference, - ) - item.detail = 'SQLMesh Model' - if (model.description) { - item.documentation = new vscode.MarkdownString(model.description) - } - return item - }) + const modelCompletions = result.value.models.map( + model => + new vscode.CompletionItem(model, vscode.CompletionItemKind.Reference), + ) const keywordCompletions = result.value.keywords.map( keyword => new vscode.CompletionItem(keyword, vscode.CompletionItemKind.Keyword), diff --git a/vscode/extension/src/lsp/custom.ts b/vscode/extension/src/lsp/custom.ts index 71f0c2eae3..be11419b79 100644 --- a/vscode/extension/src/lsp/custom.ts +++ b/vscode/extension/src/lsp/custom.ts @@ -25,16 +25,6 @@ export interface RenderModelEntry { rendered_query: string } -export interface ModelCompletion { - name: string - description: string | null | undefined -} - -export interface MacroCompletion { - name: string - description: string | null | undefined -} - // @eslint-disable-next-line @typescript-eslint/consistent-type-definition export type CustomLSPMethods = | AllModelsMethod @@ -51,9 +41,8 @@ interface AllModelsRequest { } interface AllModelsResponse { - models: ModelCompletion[] + models: string[] keywords: string[] - macros: MacroCompletion[] } export interface AbstractAPICallRequest { From 264e398df06c1577e40a80d451a8dcaab088c2a0 Mon Sep 17 00:00:00 2001 From: Ben King <9087625+benfdking@users.noreply.github.com> Date: Wed, 11 Jun 2025 13:45:39 +0100 Subject: [PATCH 3/3] running style --- examples/sushi/models/latest_order.sql | 1 - sqlmesh/lsp/completions.py | 28 ++---- sqlmesh/lsp/main.py | 118 +++++++------------------ 3 files changed, 39 insertions(+), 108 deletions(-) diff --git a/examples/sushi/models/latest_order.sql b/examples/sushi/models/latest_order.sql index 31293f19f9..4523537505 100644 --- a/examples/sushi/models/latest_order.sql +++ b/examples/sushi/models/latest_order.sql @@ -12,4 +12,3 @@ MODEL ( SELECT id, customer_id, start_ts, end_ts, event_date FROM sushi.orders ORDER BY event_date DESC LIMIT 1 - diff --git a/sqlmesh/lsp/completions.py b/sqlmesh/lsp/completions.py index 70cc3bcfe8..0026260481 100644 --- a/sqlmesh/lsp/completions.py +++ b/sqlmesh/lsp/completions.py @@ -26,9 +26,7 @@ def get_sql_completions( # Get keywords from file content if provided file_keywords = set() if content: - file_keywords = extract_keywords_from_content( - content, get_dialect(context, file_uri) - ) + file_keywords = extract_keywords_from_content(content, get_dialect(context, file_uri)) # Combine keywords - SQL keywords first, then file keywords all_keywords = list(sql_keywords) + list(file_keywords - sql_keywords) @@ -90,9 +88,7 @@ def get_macros( return [MacroCompletion(name=name, description=doc) for name, doc in macros.items()] -def get_keywords( - context: t.Optional[LSPContext], file_uri: t.Optional[URI] -) -> t.Set[str]: +def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]: """ Return a list of sql keywords for a given file. If no context is provided, return ANSI SQL keywords. @@ -103,11 +99,7 @@ def get_keywords( If both a context and a file_uri are provided, returns the keywords for the dialect of the model that the file belongs to. """ - if ( - file_uri is not None - and context is not None - and file_uri.to_path() in context.map - ): + if file_uri is not None and context is not None and file_uri.to_path() in context.map: file_info = context.map[file_uri.to_path()] # Handle ModelInfo objects @@ -150,17 +142,11 @@ def get_keywords_from_tokenizer(dialect: t.Optional[str] = None) -> t.Set[str]: return expanded_keywords -def get_dialect( - context: t.Optional[LSPContext], file_uri: t.Optional[URI] -) -> t.Optional[str]: +def get_dialect(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Optional[str]: """ Get the dialect for a given file. """ - if ( - file_uri is not None - and context is not None - and file_uri.to_path() in context.map - ): + if file_uri is not None and context is not None and file_uri.to_path() in context.map: file_info = context.map[file_uri.to_path()] # Handle ModelInfo objects @@ -181,9 +167,7 @@ def get_dialect( return None -def extract_keywords_from_content( - content: str, dialect: t.Optional[str] = None -) -> t.Set[str]: +def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None) -> t.Set[str]: """ Extract identifiers from SQL content using the tokenizer. diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index b715a81b9c..75ac9b70a2 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -115,9 +115,7 @@ def _create_lsp_context(self, paths: t.List[Path]) -> t.Optional[LSPContext]: return None # All the custom LSP methods are registered here and prefixed with _custom - def _custom_all_models( - self, ls: LanguageServer, params: AllModelsRequest - ) -> AllModelsResponse: + def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse: uri = URI(params.textDocument.uri) # Get the document content content = None @@ -149,9 +147,7 @@ def _custom_all_models_for_render( self._ensure_context_in_folder(current_path) if self.lsp_context is None: raise RuntimeError("No context found") - return AllModelsForRenderResponse( - models=self.lsp_context.list_of_models_for_rendering() - ) + return AllModelsForRenderResponse(models=self.lsp_context.list_of_models_for_rendering()) def _custom_format_project( self, ls: LanguageServer, params: FormatProjectRequest @@ -173,9 +169,7 @@ def _custom_format_project( def _custom_api( self, ls: LanguageServer, request: ApiRequest - ) -> t.Union[ - ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage - ]: + ) -> t.Union[ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage]: ls.log_trace(f"API request: {request}") if self.lsp_context is None: current_path = Path.cwd() @@ -197,9 +191,7 @@ def _custom_api( # /api/lineage/{model} model_name = urllib.parse.unquote(path_parts[2]) lineage = model_lineage(model_name, self.lsp_context.context) - non_set_lineage = { - k: v for k, v in lineage.items() if v is not None - } + non_set_lineage = {k: v for k, v in lineage.items() if v is not None} return ApiResponseGetLineage(data=non_set_lineage) if len(path_parts) == 4: @@ -208,9 +200,7 @@ def _custom_api( column = urllib.parse.unquote(path_parts[3]) models_only = False if hasattr(request, "params"): - models_only = bool( - getattr(request.params, "models_only", False) - ) + models_only = bool(getattr(request.params, "models_only", False)) column_lineage_response = column_lineage( model_name, column, models_only, self.lsp_context.context ) @@ -236,9 +226,7 @@ def _register_features(self) -> None: for name, method in self._supported_custom_methods.items(): def create_function_call(method_func: t.Callable) -> t.Callable: - def function_call( - ls: LanguageServer, params: t.Any - ) -> t.Dict[str, t.Any]: + def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]: try: response = method_func(ls, params) except Exception as e: @@ -255,9 +243,7 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None: try: # Check if the client supports pull diagnostics if params.capabilities and params.capabilities.text_document: - diagnostics = getattr( - params.capabilities.text_document, "diagnostic", None - ) + diagnostics = getattr(params.capabilities.text_document, "diagnostic", None) if diagnostics: self.client_supports_pull_diagnostics = True ls.log_trace("Client supports pull diagnostics") @@ -270,8 +256,7 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None: if params.workspace_folders: # Store all workspace folders for later use self.workspace_folders = [ - Path(self._uri_to_path(folder.uri)) - for folder in params.workspace_folders + Path(self._uri_to_path(folder.uri)) for folder in params.workspace_folders ] # Try to find a SQLMesh config file in any workspace folder (only at the root level) @@ -289,9 +274,7 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None: ) @self.server.feature(types.TEXT_DOCUMENT_DID_OPEN) - def did_open( - ls: LanguageServer, params: types.DidOpenTextDocumentParams - ) -> None: + def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None: uri = URI(params.text_document.uri) context = self._context_get_or_load(uri) models = context.map[uri.to_path()] @@ -309,9 +292,7 @@ def did_open( ) @self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE) - def did_change( - ls: LanguageServer, params: types.DidChangeTextDocumentParams - ) -> None: + def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None: uri = URI(params.text_document.uri) context = self._context_get_or_load(uri) models = context.map[uri.to_path()] @@ -329,9 +310,7 @@ def did_change( ) @self.server.feature(types.TEXT_DOCUMENT_DID_SAVE) - def did_save( - ls: LanguageServer, params: types.DidSaveTextDocumentParams - ) -> None: + def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None: uri = URI(params.text_document.uri) # Reload the entire context and create a new LSPContext @@ -365,9 +344,7 @@ def formatting( document = ls.workspace.get_text_document(params.text_document.uri) before = document.source if self.lsp_context is None: - raise RuntimeError( - f"No context found for document: {document.path}" - ) + raise RuntimeError(f"No context found for document: {document.path}") target = next( ( @@ -394,9 +371,7 @@ def formatting( start=types.Position(line=0, character=0), end=types.Position( line=len(document.lines), - character=len(document.lines[-1]) - if document.lines - else 0, + character=len(document.lines[-1]) if document.lines else 0, ), ), new_text=after, @@ -407,27 +382,20 @@ def formatting( return [] @self.server.feature(types.TEXT_DOCUMENT_HOVER) - def hover( - ls: LanguageServer, params: types.HoverParams - ) -> t.Optional[types.Hover]: + def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hover]: """Provide hover information for an object.""" try: uri = URI(params.text_document.uri) self._ensure_context_for_document(uri) document = ls.workspace.get_text_document(params.text_document.uri) if self.lsp_context is None: - raise RuntimeError( - f"No context found for document: {document.path}" - ) + raise RuntimeError(f"No context found for document: {document.path}") references = get_references(self.lsp_context, uri, params.position) if not references: return None reference = references[0] - if ( - isinstance(reference, LSPCteReference) - or not reference.markdown_description - ): + if isinstance(reference, LSPCteReference) or not reference.markdown_description: return None return types.Hover( contents=types.MarkupContent( @@ -472,9 +440,7 @@ def goto_definition( self._ensure_context_for_document(uri) document = ls.workspace.get_text_document(params.text_document.uri) if self.lsp_context is None: - raise RuntimeError( - f"No context found for document: {document.path}" - ) + raise RuntimeError(f"No context found for document: {document.path}") references = get_references(self.lsp_context, uri, params.position) location_links = [] @@ -518,9 +484,7 @@ def goto_definition( ) return location_links except Exception as e: - ls.show_message( - f"Error getting references: {e}", types.MessageType.Error - ) + ls.show_message(f"Error getting references: {e}", types.MessageType.Error) return [] @self.server.feature(types.TEXT_DOCUMENT_REFERENCES) @@ -533,25 +497,16 @@ def find_references( self._ensure_context_for_document(uri) document = ls.workspace.get_text_document(params.text_document.uri) if self.lsp_context is None: - raise RuntimeError( - f"No context found for document: {document.path}" - ) + raise RuntimeError(f"No context found for document: {document.path}") - all_references = get_all_references( - self.lsp_context, uri, params.position - ) + all_references = get_all_references(self.lsp_context, uri, params.position) # Convert references to Location objects - locations = [ - types.Location(uri=ref.uri, range=ref.range) - for ref in all_references - ] + locations = [types.Location(uri=ref.uri, range=ref.range) for ref in all_references] return locations if locations else None except Exception as e: - ls.show_message( - f"Error getting locations: {e}", types.MessageType.Error - ) + ls.show_message(f"Error getting locations: {e}", types.MessageType.Error) return None @self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC) @@ -564,10 +519,7 @@ def diagnostic( diagnostics, result_id = self._get_diagnostics_for_uri(uri) # Check if client provided a previous result ID - if ( - hasattr(params, "previous_result_id") - and params.previous_result_id == result_id - ): + if hasattr(params, "previous_result_id") and params.previous_result_id == result_id: # Return unchanged report if diagnostics haven't changed return types.RelatedUnchangedDocumentDiagnosticReport( kind=types.DocumentDiagnosticReportKind.Unchanged, @@ -616,10 +568,7 @@ def workspace_diagnostic( # Check if we have a previous result ID for this file previous_result_id = None - if ( - hasattr(params, "previous_result_ids") - and params.previous_result_ids - ): + if hasattr(params, "previous_result_ids") and params.previous_result_ids: for prev in params.previous_result_ids: if prev.uri == uri.value: previous_result_id = prev.value @@ -655,9 +604,7 @@ def workspace_diagnostic( @self.server.feature( types.TEXT_DOCUMENT_COMPLETION, - types.CompletionOptions( - trigger_characters=["@"] - ), # advertise "@" for macros + types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros ) def completion( ls: LanguageServer, params: types.CompletionParams @@ -686,7 +633,12 @@ def completion( label=model.name, kind=types.CompletionItemKind.Reference, detail="SQLMesh Model", - documentation=model.description, + documentation=types.MarkupContent( + kind=types.MarkupKind.Markdown, + value=model.description or "No description available", + ) + if model.description + else None, ) ) # Add macro completions @@ -729,9 +681,7 @@ def completion( get_sql_completions(None, URI(params.text_document.uri)) return None - def _get_diagnostics_for_uri( - self, uri: URI - ) -> t.Tuple[t.List[types.Diagnostic], int]: + def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic], int]: """Get diagnostics for a specific URI, returning (diagnostics, result_id). Since we no longer track version numbers, we always return 0 as the result_id. @@ -866,9 +816,7 @@ def _diagnostics_to_lsp_diagnostics( """ lsp_diagnostics = {} for diagnostic in diagnostics: - lsp_diagnostic = SQLMeshLanguageServer._diagnostic_to_lsp_diagnostic( - diagnostic - ) + lsp_diagnostic = SQLMeshLanguageServer._diagnostic_to_lsp_diagnostic(diagnostic) if lsp_diagnostic is not None: # Create a unique key combining message and range diagnostic_key = (