Skip to content

Commit 9296a77

Browse files
tausbnCopilot
andcommitted
yeast: Integrate yeast with shared tree-sitter extractor
extract() gains a rules parameter. When empty, uses tree-sitter native traversal (no behavior change). When non-empty, runs yeast desugaring and extracts via traverse_yeast. Adds AstNode trait abstracting over tree_sitter::Node and yeast::Node, with minimal changes to existing Visitor methods (Node -> &N in 6 signatures). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent f0a79ff commit 9296a77

7 files changed

Lines changed: 178 additions & 18 deletions

File tree

ql/extractor/src/extractor.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,28 @@ pub fn run(options: Options) -> std::io::Result<()> {
2929
prefix: "ql",
3030
ts_language: tree_sitter_ql::LANGUAGE.into(),
3131
node_types: tree_sitter_ql::NODE_TYPES,
32+
desugar: None,
3233
file_globs: vec!["*.ql".into(), "*.qll".into()],
3334
},
3435
simple::LanguageSpec {
3536
prefix: "dbscheme",
3637
ts_language: tree_sitter_ql_dbscheme::LANGUAGE.into(),
3738
node_types: tree_sitter_ql_dbscheme::NODE_TYPES,
39+
desugar: None,
3840
file_globs: vec!["*.dbscheme".into()],
3941
},
4042
simple::LanguageSpec {
4143
prefix: "json",
4244
ts_language: tree_sitter_json::LANGUAGE.into(),
4345
node_types: tree_sitter_json::NODE_TYPES,
46+
desugar: None,
4447
file_globs: vec!["*.json".into(), "*.jsonl".into(), "*.jsonc".into()],
4548
},
4649
simple::LanguageSpec {
4750
prefix: "blame",
4851
ts_language: tree_sitter_blame::LANGUAGE.into(),
4952
node_types: tree_sitter_blame::NODE_TYPES,
53+
desugar: None,
5054
file_globs: vec!["*.blame".into()],
5155
},
5256
],

ruby/extractor/src/extractor.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ pub fn run(options: Options) -> std::io::Result<()> {
123123
&path,
124124
&source,
125125
&[],
126+
None,
126127
);
127128

128129
let (ranges, line_breaks) = scan_erb(
@@ -211,6 +212,7 @@ pub fn run(options: Options) -> std::io::Result<()> {
211212
&path,
212213
&source,
213214
&code_ranges,
215+
None,
214216
);
215217
std::fs::create_dir_all(src_archive_file.parent().unwrap())?;
216218
if needs_conversion {

shared/tree-sitter-extractor/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ serde_json = "1.0"
2020
chrono = { version = "0.4.42", features = ["serde"] }
2121
num_cpus = "1.17.0"
2222
zstd = "0.13.3"
23+
yeast = { path = "../yeast" }
2324

2425
[dev-dependencies]
2526
tree-sitter-ql = "0.23.1"

shared/tree-sitter-extractor/src/extractor/mod.rs

Lines changed: 134 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,82 @@ use tree_sitter::{Language, Node, Parser, Range, Tree};
1818

1919
pub mod simple;
2020

21+
/// Trait abstracting over tree-sitter and yeast node types for extraction.
22+
trait AstNode {
23+
fn kind(&self) -> &str;
24+
fn is_named(&self) -> bool;
25+
fn is_missing(&self) -> bool;
26+
fn is_error(&self) -> bool;
27+
fn is_extra(&self) -> bool;
28+
fn start_position(&self) -> tree_sitter::Point;
29+
fn end_position(&self) -> tree_sitter::Point;
30+
fn byte_range(&self) -> std::ops::Range<usize>;
31+
fn end_byte(&self) -> usize {
32+
self.byte_range().end
33+
}
34+
/// For yeast nodes with synthetic content, return it. Otherwise None.
35+
fn opt_string_content(&self) -> Option<String> {
36+
None
37+
}
38+
}
39+
40+
impl<'a> AstNode for Node<'a> {
41+
fn kind(&self) -> &str {
42+
Node::kind(self)
43+
}
44+
fn is_named(&self) -> bool {
45+
Node::is_named(self)
46+
}
47+
fn is_missing(&self) -> bool {
48+
Node::is_missing(self)
49+
}
50+
fn is_error(&self) -> bool {
51+
Node::is_error(self)
52+
}
53+
fn is_extra(&self) -> bool {
54+
Node::is_extra(self)
55+
}
56+
fn start_position(&self) -> tree_sitter::Point {
57+
Node::start_position(self)
58+
}
59+
fn end_position(&self) -> tree_sitter::Point {
60+
Node::end_position(self)
61+
}
62+
fn byte_range(&self) -> std::ops::Range<usize> {
63+
Node::byte_range(self)
64+
}
65+
}
66+
67+
impl AstNode for yeast::Node {
68+
fn kind(&self) -> &str {
69+
yeast::Node::kind(self)
70+
}
71+
fn is_named(&self) -> bool {
72+
yeast::Node::is_named(self)
73+
}
74+
fn is_missing(&self) -> bool {
75+
yeast::Node::is_missing(self)
76+
}
77+
fn is_error(&self) -> bool {
78+
yeast::Node::is_error(self)
79+
}
80+
fn is_extra(&self) -> bool {
81+
yeast::Node::is_extra(self)
82+
}
83+
fn start_position(&self) -> tree_sitter::Point {
84+
yeast::Node::start_position(self)
85+
}
86+
fn end_position(&self) -> tree_sitter::Point {
87+
yeast::Node::end_position(self)
88+
}
89+
fn byte_range(&self) -> std::ops::Range<usize> {
90+
yeast::Node::byte_range(self)
91+
}
92+
fn opt_string_content(&self) -> Option<String> {
93+
yeast::Node::opt_string_content(self)
94+
}
95+
}
96+
2197
/// Sets the tracing level based on the environment variables
2298
/// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order),
2399
/// falling back to `warn` if neither is set.
@@ -204,6 +280,11 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr
204280
}
205281

206282
/// Extracts the source file at `path`, which is assumed to be canonicalized.
283+
/// When `yeast_runner` is `Some`, the parsed tree is first transformed
284+
/// through the supplied yeast `Runner` before TRAP extraction. Building the
285+
/// `Runner` (which parses YAML and constructs the schema) is the caller's
286+
/// responsibility, allowing it to be done once and shared across files.
287+
#[allow(clippy::too_many_arguments)]
207288
pub fn extract(
208289
language: &Language,
209290
language_prefix: &str,
@@ -214,6 +295,7 @@ pub fn extract(
214295
path: &Path,
215296
source: &[u8],
216297
ranges: &[Range],
298+
yeast_runner: Option<&yeast::Runner<'_>>,
217299
) {
218300
let path_str = file_paths::normalize_and_transform_path(path, transformer);
219301
let span = tracing::span!(
@@ -236,13 +318,20 @@ pub fn extract(
236318
source,
237319
diagnostics_writer,
238320
trap_writer,
239-
// TODO: should we handle path strings that are not valid UTF8 better?
240321
&path_str,
241322
file_label,
242323
language_prefix,
243324
schema,
244325
);
245-
traverse(&tree, &mut visitor);
326+
327+
if let Some(yeast_runner) = yeast_runner {
328+
let ast = yeast_runner
329+
.run_from_tree(&tree)
330+
.unwrap_or_else(|e| panic!("Desugaring failed for {path_str}: {e}"));
331+
traverse_yeast(&ast, &mut visitor);
332+
} else {
333+
traverse(&tree, &mut visitor);
334+
}
246335

247336
parser.reset();
248337
}
@@ -329,11 +418,11 @@ impl<'a> Visitor<'a> {
329418
);
330419
}
331420

332-
fn record_parse_error_for_node(
421+
fn record_parse_error_for_node<N: AstNode>(
333422
&mut self,
334423
message: &str,
335424
args: &[diagnostics::MessageArg],
336-
node: Node,
425+
node: &N,
337426
status_page: bool,
338427
) {
339428
let loc = location_for(self, self.file_label, node);
@@ -357,7 +446,7 @@ impl<'a> Visitor<'a> {
357446
self.record_parse_error(loc_label, &mesg);
358447
}
359448

360-
fn enter_node(&mut self, node: Node) -> bool {
449+
fn enter_node<N: AstNode>(&mut self, node: &N) -> bool {
361450
if node.is_missing() {
362451
self.record_parse_error_for_node(
363452
"A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis.",
@@ -383,7 +472,7 @@ impl<'a> Visitor<'a> {
383472
true
384473
}
385474

386-
fn leave_node(&mut self, field_name: Option<&'static str>, node: Node) {
475+
fn leave_node<N: AstNode>(&mut self, field_name: Option<&'static str>, node: &N) {
387476
if node.is_error() || node.is_missing() {
388477
return;
389478
}
@@ -434,7 +523,7 @@ impl<'a> Visitor<'a> {
434523
fields,
435524
name: table_name,
436525
} => {
437-
if let Some(args) = self.complex_node(&node, fields, &child_nodes, id) {
526+
if let Some(args) = self.complex_node(node, fields, &child_nodes, id) {
438527
self.trap_writer.add_tuple(
439528
&self.ast_node_location_table_name,
440529
vec![trap::Arg::Label(id), trap::Arg::Label(loc_label)],
@@ -495,9 +584,9 @@ impl<'a> Visitor<'a> {
495584
}
496585
}
497586

498-
fn complex_node(
587+
fn complex_node<N: AstNode>(
499588
&mut self,
500-
node: &Node,
589+
node: &N,
501590
fields: &[Field],
502591
child_nodes: &[ChildNode],
503592
parent_id: trap::Label,
@@ -529,7 +618,7 @@ impl<'a> Visitor<'a> {
529618
diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)),
530619
diagnostics::MessageArg::Code(&format!("{:?}", field.type_info)),
531620
],
532-
*node,
621+
node,
533622
false,
534623
);
535624
}
@@ -541,7 +630,7 @@ impl<'a> Visitor<'a> {
541630
diagnostics::MessageArg::Code(child_node.field_name.unwrap_or("child")),
542631
diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)),
543632
],
544-
*node,
633+
node,
545634
false,
546635
);
547636
}
@@ -566,7 +655,7 @@ impl<'a> Visitor<'a> {
566655
node.kind(),
567656
column_name
568657
);
569-
self.record_parse_error_for_node(&error_message, &[], *node, false);
658+
self.record_parse_error_for_node(&error_message, &[], node, false);
570659
}
571660
}
572661
Storage::Table {
@@ -582,7 +671,7 @@ impl<'a> Visitor<'a> {
582671
diagnostics::MessageArg::Code(node.kind()),
583672
diagnostics::MessageArg::Code(table_name),
584673
],
585-
*node,
674+
node,
586675
false,
587676
);
588677
break;
@@ -639,15 +728,21 @@ impl<'a> Visitor<'a> {
639728
}
640729

641730
// Emit a slice of a source file as an Arg.
642-
fn sliced_source_arg(source: &[u8], n: Node) -> trap::Arg {
643-
let range = n.byte_range();
644-
trap::Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned())
731+
fn sliced_source_arg<N: AstNode>(source: &[u8], n: &N) -> trap::Arg {
732+
trap::Arg::String(n.opt_string_content().unwrap_or_else(|| {
733+
let range = n.byte_range();
734+
String::from_utf8_lossy(&source[range.start..range.end]).into_owned()
735+
}))
645736
}
646737

647738
// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
648739
// The first is the location and label definition, and the second is the
649740
// 'Located' entry.
650-
fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap::Location {
741+
fn location_for<N: AstNode>(
742+
visitor: &mut Visitor,
743+
file_label: trap::Label,
744+
n: &N,
745+
) -> trap::Location {
651746
// Tree-sitter row, column values are 0-based while CodeQL starts
652747
// counting at 1. In addition Tree-sitter's row and column for the
653748
// end position are exclusive while CodeQL's end positions are inclusive.
@@ -715,6 +810,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap
715810

716811
fn traverse(tree: &Tree, visitor: &mut Visitor) {
717812
let cursor = &mut tree.walk();
813+
visitor.enter_node(&cursor.node());
814+
let mut recurse = true;
815+
loop {
816+
if recurse && cursor.goto_first_child() {
817+
recurse = visitor.enter_node(&cursor.node());
818+
} else {
819+
visitor.leave_node(cursor.field_name(), &cursor.node());
820+
821+
if cursor.goto_next_sibling() {
822+
recurse = visitor.enter_node(&cursor.node());
823+
} else if cursor.goto_parent() {
824+
recurse = false;
825+
} else {
826+
break;
827+
}
828+
}
829+
}
830+
}
831+
832+
fn traverse_yeast(tree: &yeast::Ast, visitor: &mut Visitor) {
833+
use yeast::Cursor;
834+
let mut cursor = tree.walk();
718835
visitor.enter_node(cursor.node());
719836
let mut recurse = true;
720837
loop {

shared/tree-sitter-extractor/src/extractor/simple.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@ use std::path::{Path, PathBuf};
77

88
use crate::diagnostics;
99
use crate::node_types;
10+
use yeast;
1011

1112
pub struct LanguageSpec {
1213
pub prefix: &'static str,
1314
pub ts_language: tree_sitter::Language,
1415
pub node_types: &'static str,
16+
/// Optional yeast desugaring configuration. When set, the parsed
17+
/// tree is rewritten through yeast before TRAP extraction. The
18+
/// config's `output_node_types_yaml` (if set) provides the schema
19+
/// used both at runtime (for the rewriter) and for TRAP validation.
20+
pub desugar: Option<yeast::DesugaringConfig>,
1521
pub file_globs: Vec<String>,
1622
}
1723

@@ -85,9 +91,35 @@ impl Extractor {
8591
.collect();
8692

8793
let mut schemas = vec![];
94+
let mut yeast_runners: Vec<Option<yeast::Runner>> = vec![];
8895
for lang in &self.languages {
89-
let schema = node_types::read_node_types_str(lang.prefix, lang.node_types)?;
96+
let effective_node_types: String =
97+
match lang.desugar.as_ref().and_then(|c| c.output_node_types_yaml) {
98+
Some(yaml) => yeast::node_types_yaml::convert(yaml).map_err(|e| {
99+
std::io::Error::other(format!(
100+
"Failed to convert YAML node-types to JSON for {}: {e}",
101+
lang.prefix
102+
))
103+
})?,
104+
None => lang.node_types.to_string(),
105+
};
106+
let schema = node_types::read_node_types_str(lang.prefix, &effective_node_types)?;
90107
schemas.push(schema);
108+
109+
// Build the yeast runner once per language so the YAML schema
110+
// isn't re-parsed for every file.
111+
let yeast_runner = lang
112+
.desugar
113+
.as_ref()
114+
.map(|config| yeast::Runner::from_config(lang.ts_language.clone(), config))
115+
.transpose()
116+
.map_err(|e| {
117+
std::io::Error::other(format!(
118+
"Failed to build desugaring runner for {}: {e}",
119+
lang.prefix
120+
))
121+
})?;
122+
yeast_runners.push(yeast_runner);
91123
}
92124

93125
// Construct a single globset containing all language globs,
@@ -162,6 +194,7 @@ impl Extractor {
162194
&path,
163195
&source,
164196
&[],
197+
yeast_runners[i].as_ref(),
165198
);
166199
std::fs::create_dir_all(src_archive_file.parent().unwrap())?;
167200
std::fs::copy(&path, &src_archive_file)?;

shared/tree-sitter-extractor/tests/integration_test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ fn simple_extractor() {
1313
prefix: "ql",
1414
ts_language: tree_sitter_ql::LANGUAGE.into(),
1515
node_types: tree_sitter_ql::NODE_TYPES,
16+
desugar: None,
1617
file_globs: vec!["*.qll".into()],
1718
};
1819

0 commit comments

Comments
 (0)