Skip to content

Commit 3b9c327

Browse files
committed
Handle nested classes and nested methods
1 parent 3dffa58 commit 3b9c327

2 files changed

Lines changed: 222 additions & 96 deletions

File tree

lua/dap-python.lua

Lines changed: 146 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@ end
3636
--- Built-in are test runners for unittest, pytest and django.
3737
--- The key is the test runner name, the value a function to generate the
3838
--- module name to run and its arguments. See |dap-python.TestRunner|
39-
---@type table<string,TestRunner>
39+
---@type table<string, TestRunner>
4040
M.test_runners = {}
4141

42-
local function prune_nil(items)
43-
return vim.tbl_filter(function(x) return x end, items)
44-
end
4542

4643
local is_windows = function()
4744
return vim.fn.has("win32") == 1
@@ -146,29 +143,51 @@ local function get_module_path()
146143
end
147144
end
148145

146+
147+
---@return string[]
148+
local function flatten(...)
149+
local argc = select("#", ...)
150+
local result = {}
151+
for i = 1, argc do
152+
local arg = select(i, ...)
153+
if type(arg) == "table" then
154+
vim.list_extend(result, arg)
155+
else
156+
table.insert(result, arg)
157+
end
158+
end
159+
return result
160+
end
161+
162+
149163
---@private
150-
function M.test_runners.unittest(classname, methodname)
151-
local path = get_module_path()
152-
local test_path = table.concat(prune_nil({path, classname, methodname}), '.')
164+
---@param classnames string[]|string
165+
---@param methodname string?
166+
function M.test_runners.unittest(classnames, methodname)
167+
local test_path = table.concat(flatten(get_module_path(), classnames, methodname), '.')
153168
local args = {'-v', test_path}
154169
return 'unittest', args
155170
end
156171

157172

158173
---@private
159-
function M.test_runners.pytest(classname, methodname)
174+
---@param classnames string[]|string
175+
---@param methodname string?
176+
function M.test_runners.pytest(classnames, methodname)
160177
local path = vim.fn.expand('%:p')
161-
local test_path = table.concat(prune_nil({path, classname, methodname}), '::')
178+
local test_path = table.concat(flatten({path, classnames, methodname}), '::')
162179
-- -s "allow output to stdout of test"
163180
local args = {'-s', test_path}
164181
return 'pytest', args
165182
end
166183

167184

168185
---@private
169-
function M.test_runners.django(classname, methodname)
186+
---@param classnames string[]|string
187+
---@param methodname string?
188+
function M.test_runners.django(classnames, methodname)
170189
local path = get_module_path()
171-
local test_path = table.concat(prune_nil({path, classname, methodname}), '.')
190+
local test_path = table.concat(flatten({path, classnames, methodname}), '.')
172191
local args = {'test', test_path}
173192
return 'django', args
174193
end
@@ -258,50 +277,10 @@ function M.setup(adapter_python_path, opts)
258277
end
259278

260279

261-
local function get_nodes(query_text, predicate)
262-
local end_row = api.nvim_win_get_cursor(0)[1]
263-
local ft = api.nvim_buf_get_option(0, 'filetype')
264-
assert(ft == 'python', 'test_method of dap-python only works for python files, not ' .. ft)
265-
local query = (vim.treesitter.query.parse
266-
and vim.treesitter.query.parse(ft, query_text)
267-
or vim.treesitter.parse_query(ft, query_text)
268-
)
269-
assert(query, 'Could not parse treesitter query. Cannot find test')
270-
local parser = vim.treesitter.get_parser(0)
271-
local root = (parser:parse()[1]):root()
272-
local nodes = {}
273-
for _, node in query:iter_captures(root, 0, 0, end_row) do
274-
if predicate(node) then
275-
table.insert(nodes, node)
276-
end
277-
end
278-
return nodes
279-
end
280-
281-
282-
local function get_function_nodes()
283-
local query_text = [[
284-
(function_definition
285-
name: (identifier) @name) @definition.function
286-
]]
287-
return get_nodes(query_text, function(node)
288-
return node:type() == 'identifier'
289-
end)
290-
end
291-
292-
293-
local function get_class_nodes()
294-
local query_text = [[
295-
(class_definition
296-
name: (identifier) @name) @definition.class
297-
]]
298-
return get_nodes(query_text, function(node)
299-
return node:type() == 'identifier'
300-
end)
301-
end
302-
303-
304280
local function get_node_text(node)
281+
if vim.treesitter.get_node_text then
282+
return vim.treesitter.get_node_text(node, 0)
283+
end
305284
local row1, col1, row2, col2 = node:range()
306285
if row1 == row2 then
307286
row2 = row2 + 1
@@ -314,24 +293,90 @@ local function get_node_text(node)
314293
end
315294

316295

317-
local function get_parent_classname(node)
318-
local parent = node:parent()
319-
while parent do
320-
local type = parent:type()
321-
if type == 'class_definition' then
322-
for child in parent:iter_children() do
323-
if child:type() == 'identifier' then
324-
return get_node_text(child)
325-
end
296+
--- Reverse list inline
297+
---@param list any[]
298+
local function reverse(list)
299+
local len = #list
300+
for i = 1, math.floor(len * 0.5) do
301+
local opposite = len - i + 1
302+
list[i], list[opposite] = list[opposite], list[i]
303+
end
304+
end
305+
306+
307+
---@param source string|integer
308+
---@param subject "function"|"class"
309+
---@param end_row integer? defaults to cursor
310+
---@return TSNode[]
311+
function M._get_nodes(source, subject, end_row)
312+
end_row = end_row or api.nvim_win_get_cursor(0)[1]
313+
local query_text = [[
314+
(function_definition
315+
name: (identifier) @function
316+
)
317+
318+
(class_definition
319+
name: (identifier) @class
320+
)
321+
]]
322+
local lang = "python"
323+
local query = (vim.treesitter.query.parse
324+
and vim.treesitter.query.parse(lang, query_text)
325+
or vim.treesitter.parse_query(lang, query_text)
326+
)
327+
local parser = (
328+
type(source) == "number"
329+
and vim.treesitter.get_parser(source, lang)
330+
or vim.treesitter.get_string_parser(source --[[@as string]], lang)
331+
)
332+
local trees = parser:parse()
333+
local root = trees[1]:root()
334+
local nodes = {}
335+
for id, node in query:iter_captures(root, source, 0, end_row) do
336+
local capture = query.captures[id]
337+
if capture == subject then
338+
table.insert(nodes, node)
339+
end
340+
end
341+
if not next(nodes) then
342+
return nodes
343+
end
344+
if subject == "function" then
345+
local result = nodes[#nodes]
346+
local parent = result
347+
while parent ~= nil do
348+
if parent:type() == "function_definition" then
349+
local ident = parent:child(1)
350+
assert(ident:type() == "identifier")
351+
result = ident
326352
end
353+
parent = parent:parent()
327354
end
328-
parent = parent:parent()
355+
return { result }
356+
elseif subject == "class" then
357+
local last = nodes[#nodes]
358+
local parent = last
359+
local results = {}
360+
while parent ~= nil do
361+
if parent:type() == "class_definition" then
362+
local ident = parent:child(1)
363+
assert(ident:type() == "identifier")
364+
table.insert(results, ident)
365+
end
366+
parent = parent:parent()
367+
end
368+
reverse(results)
369+
return results
370+
else
371+
error("Expected subject 'function' or 'class', not: " .. subject)
329372
end
330373
end
331374

332375

376+
---@param classnames string[]
377+
---@param methodname string?
333378
---@param opts DebugOpts
334-
local function trigger_test(classname, methodname, opts)
379+
local function trigger_test(classnames, methodname, opts)
335380
local test_runner = opts.test_runner or (M.test_runner or default_runner)
336381
if type(test_runner) == "function" then
337382
test_runner = test_runner()
@@ -342,9 +387,11 @@ local function trigger_test(classname, methodname, opts)
342387
return
343388
end
344389
assert(type(runner) == "function", "Test runner must be a function")
345-
local module, args = runner(classname, methodname)
390+
-- for BWC with custom runners which expect a string instead of a list of strings
391+
local classes = #classnames == 1 and classnames[1] or classnames
392+
local module, args = runner(classes, methodname)
346393
local config = {
347-
name = table.concat(prune_nil({classname, methodname}), '.'),
394+
name = table.concat(flatten(classnames, methodname), '.'),
348395
type = 'python',
349396
request = 'launch',
350397
module = module,
@@ -355,49 +402,51 @@ local function trigger_test(classname, methodname, opts)
355402
end
356403

357404

358-
local function closest_above_cursor(nodes)
359-
local result
360-
for _, node in pairs(nodes) do
361-
if not result then
362-
result = node
363-
else
364-
local node_row1, _, _, _ = node:range()
365-
local result_row1, _, _, _ = result:range()
366-
if node_row1 > result_row1 then
367-
result = node
368-
end
369-
end
370-
end
371-
return result
372-
end
373-
374-
375405
--- Run test class above cursor
376406
---@param opts? DebugOpts See |dap-python.DebugOpts|
377407
function M.test_class(opts)
378408
opts = vim.tbl_extend('keep', opts or {}, default_test_opts)
379-
local class_node = closest_above_cursor(get_class_nodes())
380-
if not class_node then
381-
print('No suitable test class found')
409+
local candidates = M._get_nodes(0, "class")
410+
if not candidates then
411+
print('No test class found near cursor')
382412
return
383413
end
384-
local class = get_node_text(class_node)
385-
trigger_test(class, nil, opts)
414+
local names = vim.tbl_map(get_node_text, candidates)
415+
trigger_test(names, nil, opts)
416+
end
417+
418+
419+
---@param node TSNode
420+
---@result TSNode[]
421+
local function get_parent_classes(node)
422+
local parent = node:parent()
423+
local result = {}
424+
while parent ~= nil do
425+
if parent:type() == "class_definition" then
426+
local ident = parent:child(1)
427+
assert(ident and ident:type() == "identifier")
428+
table.insert(result, ident)
429+
end
430+
parent = parent:parent()
431+
end
432+
reverse(result)
433+
return result
386434
end
387435

388436

389437
--- Run the test method above cursor
390438
---@param opts? DebugOpts See |dap-python.DebugOpts|
391439
function M.test_method(opts)
392440
opts = vim.tbl_extend('keep', opts or {}, default_test_opts)
393-
local function_node = closest_above_cursor(get_function_nodes())
394-
if not function_node then
395-
print('No suitable test method found')
441+
local functions = M._get_nodes(0, "function")
442+
if not functions then
443+
print('No test method found near cursor')
396444
return
397445
end
398-
local class = get_parent_classname(function_node)
399-
local function_name = get_node_text(function_node)
400-
trigger_test(class, function_name, opts)
446+
local fn = functions[1]
447+
local parent_classes = get_parent_classes(fn)
448+
local classnames = vim.tbl_map(get_node_text, parent_classes)
449+
trigger_test(classnames, get_node_text(fn), opts)
401450
end
402451

403452

@@ -414,6 +463,7 @@ local function remove_indent(lines)
414463
end
415464
end
416465
if offset > 1 then
466+
assert(offset)
417467
return vim.tbl_map(function(x) return string.sub(x, offset) end, lines)
418468
else
419469
return lines
@@ -479,7 +529,7 @@ end
479529
---@field pythonPath string|nil Path to python interpreter. Uses interpreter from `VIRTUAL_ENV` environment variable or `adapter_python_path` by default
480530

481531

482-
---@alias TestRunner fun(classname: string, methodname: string):string, string[]
532+
---@alias TestRunner fun(classname: string|string[], methodname: string?):string, string[]
483533

484534
---@alias DebugpyConsole "internalConsole"|"integratedTerminal"|"externalTerminal"|nil
485535

0 commit comments

Comments
 (0)