3636--- Built-in are test runners for unittest, pytest and django.
3737--- The key is the test runner name, the value a function to generate the
3838--- module name to run and its arguments. See |dap-python.TestRunner|
39- --- @type table<string,TestRunner>
39+ --- @type table<string , TestRunner>
4040M .test_runners = {}
4141
42- local function prune_nil (items )
43- return vim .tbl_filter (function (x ) return x end , items )
44- end
4542
4643local is_windows = function ()
4744 return vim .fn .has (" win32" ) == 1
@@ -146,29 +143,51 @@ local function get_module_path()
146143 end
147144end
148145
146+
147+ --- @return string[]
148+ local function flatten (...)
149+ local argc = select (" #" , ... )
150+ local result = {}
151+ for i = 1 , argc do
152+ local arg = select (i , ... )
153+ if type (arg ) == " table" then
154+ vim .list_extend (result , arg )
155+ else
156+ table.insert (result , arg )
157+ end
158+ end
159+ return result
160+ end
161+
162+
149163--- @private
150- function M .test_runners .unittest (classname , methodname )
151- local path = get_module_path ()
152- local test_path = table.concat (prune_nil ({path , classname , methodname }), ' .' )
164+ --- @param classnames string[] | string
165+ --- @param methodname string ?
166+ function M .test_runners .unittest (classnames , methodname )
167+ local test_path = table.concat (flatten (get_module_path (), classnames , methodname ), ' .' )
153168 local args = {' -v' , test_path }
154169 return ' unittest' , args
155170end
156171
157172
158173--- @private
159- function M .test_runners .pytest (classname , methodname )
174+ --- @param classnames string[] | string
175+ --- @param methodname string ?
176+ function M .test_runners .pytest (classnames , methodname )
160177 local path = vim .fn .expand (' %:p' )
161- local test_path = table.concat (prune_nil ({path , classname , methodname }), ' ::' )
178+ local test_path = table.concat (flatten ({path , classnames , methodname }), ' ::' )
162179 -- -s "allow output to stdout of test"
163180 local args = {' -s' , test_path }
164181 return ' pytest' , args
165182end
166183
167184
168185--- @private
169- function M .test_runners .django (classname , methodname )
186+ --- @param classnames string[] | string
187+ --- @param methodname string ?
188+ function M .test_runners .django (classnames , methodname )
170189 local path = get_module_path ()
171- local test_path = table.concat (prune_nil ({path , classname , methodname }), ' .' )
190+ local test_path = table.concat (flatten ({path , classnames , methodname }), ' .' )
172191 local args = {' test' , test_path }
173192 return ' django' , args
174193end
@@ -258,50 +277,10 @@ function M.setup(adapter_python_path, opts)
258277end
259278
260279
261- local function get_nodes (query_text , predicate )
262- local end_row = api .nvim_win_get_cursor (0 )[1 ]
263- local ft = api .nvim_buf_get_option (0 , ' filetype' )
264- assert (ft == ' python' , ' test_method of dap-python only works for python files, not ' .. ft )
265- local query = (vim .treesitter .query .parse
266- and vim .treesitter .query .parse (ft , query_text )
267- or vim .treesitter .parse_query (ft , query_text )
268- )
269- assert (query , ' Could not parse treesitter query. Cannot find test' )
270- local parser = vim .treesitter .get_parser (0 )
271- local root = (parser :parse ()[1 ]):root ()
272- local nodes = {}
273- for _ , node in query :iter_captures (root , 0 , 0 , end_row ) do
274- if predicate (node ) then
275- table.insert (nodes , node )
276- end
277- end
278- return nodes
279- end
280-
281-
282- local function get_function_nodes ()
283- local query_text = [[
284- (function_definition
285- name: (identifier) @name) @definition.function
286- ]]
287- return get_nodes (query_text , function (node )
288- return node :type () == ' identifier'
289- end )
290- end
291-
292-
293- local function get_class_nodes ()
294- local query_text = [[
295- (class_definition
296- name: (identifier) @name) @definition.class
297- ]]
298- return get_nodes (query_text , function (node )
299- return node :type () == ' identifier'
300- end )
301- end
302-
303-
304280local function get_node_text (node )
281+ if vim .treesitter .get_node_text then
282+ return vim .treesitter .get_node_text (node , 0 )
283+ end
305284 local row1 , col1 , row2 , col2 = node :range ()
306285 if row1 == row2 then
307286 row2 = row2 + 1
@@ -314,24 +293,90 @@ local function get_node_text(node)
314293end
315294
316295
317- local function get_parent_classname (node )
318- local parent = node :parent ()
319- while parent do
320- local type = parent :type ()
321- if type == ' class_definition' then
322- for child in parent :iter_children () do
323- if child :type () == ' identifier' then
324- return get_node_text (child )
325- end
296+ --- Reverse list inline
297+ --- @param list any[]
298+ local function reverse (list )
299+ local len = # list
300+ for i = 1 , math.floor (len * 0.5 ) do
301+ local opposite = len - i + 1
302+ list [i ], list [opposite ] = list [opposite ], list [i ]
303+ end
304+ end
305+
306+
307+ --- @param source string | integer
308+ --- @param subject " function" | " class"
309+ --- @param end_row integer ? defaults to cursor
310+ --- @return TSNode[]
311+ function M ._get_nodes (source , subject , end_row )
312+ end_row = end_row or api .nvim_win_get_cursor (0 )[1 ]
313+ local query_text = [[
314+ (function_definition
315+ name: (identifier) @function
316+ )
317+
318+ (class_definition
319+ name: (identifier) @class
320+ )
321+ ]]
322+ local lang = " python"
323+ local query = (vim .treesitter .query .parse
324+ and vim .treesitter .query .parse (lang , query_text )
325+ or vim .treesitter .parse_query (lang , query_text )
326+ )
327+ local parser = (
328+ type (source ) == " number"
329+ and vim .treesitter .get_parser (source , lang )
330+ or vim .treesitter .get_string_parser (source --[[ @as string]] , lang )
331+ )
332+ local trees = parser :parse ()
333+ local root = trees [1 ]:root ()
334+ local nodes = {}
335+ for id , node in query :iter_captures (root , source , 0 , end_row ) do
336+ local capture = query .captures [id ]
337+ if capture == subject then
338+ table.insert (nodes , node )
339+ end
340+ end
341+ if not next (nodes ) then
342+ return nodes
343+ end
344+ if subject == " function" then
345+ local result = nodes [# nodes ]
346+ local parent = result
347+ while parent ~= nil do
348+ if parent :type () == " function_definition" then
349+ local ident = parent :child (1 )
350+ assert (ident :type () == " identifier" )
351+ result = ident
326352 end
353+ parent = parent :parent ()
327354 end
328- parent = parent :parent ()
355+ return { result }
356+ elseif subject == " class" then
357+ local last = nodes [# nodes ]
358+ local parent = last
359+ local results = {}
360+ while parent ~= nil do
361+ if parent :type () == " class_definition" then
362+ local ident = parent :child (1 )
363+ assert (ident :type () == " identifier" )
364+ table.insert (results , ident )
365+ end
366+ parent = parent :parent ()
367+ end
368+ reverse (results )
369+ return results
370+ else
371+ error (" Expected subject 'function' or 'class', not: " .. subject )
329372 end
330373end
331374
332375
376+ --- @param classnames string[]
377+ --- @param methodname string ?
333378--- @param opts DebugOpts
334- local function trigger_test (classname , methodname , opts )
379+ local function trigger_test (classnames , methodname , opts )
335380 local test_runner = opts .test_runner or (M .test_runner or default_runner )
336381 if type (test_runner ) == " function" then
337382 test_runner = test_runner ()
@@ -342,9 +387,11 @@ local function trigger_test(classname, methodname, opts)
342387 return
343388 end
344389 assert (type (runner ) == " function" , " Test runner must be a function" )
345- local module , args = runner (classname , methodname )
390+ -- for BWC with custom runners which expect a string instead of a list of strings
391+ local classes = # classnames == 1 and classnames [1 ] or classnames
392+ local module , args = runner (classes , methodname )
346393 local config = {
347- name = table.concat (prune_nil ({ classname , methodname } ), ' .' ),
394+ name = table.concat (flatten ( classnames , methodname ), ' .' ),
348395 type = ' python' ,
349396 request = ' launch' ,
350397 module = module ,
@@ -355,49 +402,51 @@ local function trigger_test(classname, methodname, opts)
355402end
356403
357404
358- local function closest_above_cursor (nodes )
359- local result
360- for _ , node in pairs (nodes ) do
361- if not result then
362- result = node
363- else
364- local node_row1 , _ , _ , _ = node :range ()
365- local result_row1 , _ , _ , _ = result :range ()
366- if node_row1 > result_row1 then
367- result = node
368- end
369- end
370- end
371- return result
372- end
373-
374-
375405--- Run test class above cursor
376406--- @param opts ? DebugOpts See | dap-python.DebugOpts |
377407function M .test_class (opts )
378408 opts = vim .tbl_extend (' keep' , opts or {}, default_test_opts )
379- local class_node = closest_above_cursor ( get_class_nodes () )
380- if not class_node then
381- print (' No suitable test class found' )
409+ local candidates = M . _get_nodes ( 0 , " class " )
410+ if not candidates then
411+ print (' No test class found near cursor ' )
382412 return
383413 end
384- local class = get_node_text (class_node )
385- trigger_test (class , nil , opts )
414+ local names = vim .tbl_map (get_node_text , candidates )
415+ trigger_test (names , nil , opts )
416+ end
417+
418+
419+ --- @param node TSNode
420+ --- @result TSNode[]
421+ local function get_parent_classes (node )
422+ local parent = node :parent ()
423+ local result = {}
424+ while parent ~= nil do
425+ if parent :type () == " class_definition" then
426+ local ident = parent :child (1 )
427+ assert (ident and ident :type () == " identifier" )
428+ table.insert (result , ident )
429+ end
430+ parent = parent :parent ()
431+ end
432+ reverse (result )
433+ return result
386434end
387435
388436
389437--- Run the test method above cursor
390438--- @param opts ? DebugOpts See | dap-python.DebugOpts |
391439function M .test_method (opts )
392440 opts = vim .tbl_extend (' keep' , opts or {}, default_test_opts )
393- local function_node = closest_above_cursor ( get_function_nodes () )
394- if not function_node then
395- print (' No suitable test method found' )
441+ local functions = M . _get_nodes ( 0 , " function " )
442+ if not functions then
443+ print (' No test method found near cursor ' )
396444 return
397445 end
398- local class = get_parent_classname (function_node )
399- local function_name = get_node_text (function_node )
400- trigger_test (class , function_name , opts )
446+ local fn = functions [1 ]
447+ local parent_classes = get_parent_classes (fn )
448+ local classnames = vim .tbl_map (get_node_text , parent_classes )
449+ trigger_test (classnames , get_node_text (fn ), opts )
401450end
402451
403452
@@ -414,6 +463,7 @@ local function remove_indent(lines)
414463 end
415464 end
416465 if offset > 1 then
466+ assert (offset )
417467 return vim .tbl_map (function (x ) return string.sub (x , offset ) end , lines )
418468 else
419469 return lines
479529--- @field pythonPath string | nil Path to python interpreter. Uses interpreter from ` VIRTUAL_ENV` environment variable or ` adapter_python_path` by default
480530
481531
482- --- @alias TestRunner fun ( classname : string , methodname : string ): string , string[]
532+ --- @alias TestRunner fun ( classname : string | string[] , methodname : string ? ):string , string[]
483533
484534--- @alias DebugpyConsole " internalConsole" | " integratedTerminal" | " externalTerminal" | nil
485535
0 commit comments