Skip to content

Commit fafc687

Browse files
TrueDoctorKeavon
andauthored
Migrate memo nodes to node macro and make implementing other persistent nodes easier (#3552)
* Add #[data] and #[serialize] attributes to node macro - Add #[data] attribute for struct fields that aren't node parameters - Data fields are initialized with Default::default() - Passed as references to the underlying function - Excluded from registry metadata (internal state) - Generic types in data fields allowed without #[implementations] - Add #[serialize] attribute for custom Node::serialize() implementation - Receives references to all data fields - Generates serialize() method in Node trait impl - Conditional derives based on data field presence - With data fields: Debug, Clone only - Without data fields: Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash * Refactor Memo and Monitor Node to use node macro * Move Complex type into type alias * Fix format * Update node-graph/nodes/gcore/src/memo.rs Co-authored-by: Keavon Chambers <keavon@keavon.com> * Update node-graph/nodes/gcore/src/memo.rs Co-authored-by: Keavon Chambers <keavon@keavon.com> --------- Co-authored-by: Keavon Chambers <keavon@keavon.com>
1 parent 8f25eb6 commit fafc687

5 files changed

Lines changed: 252 additions & 118 deletions

File tree

node-graph/node-macro/src/codegen.rs

Lines changed: 156 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,33 +37,102 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
3737
};
3838
let struct_name = format_ident!("{}Node", struct_name);
3939

40-
let struct_generics: Vec<Ident> = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
40+
// Separate data fields from regular fields
41+
let (data_fields, regular_fields): (Vec<_>, Vec<_>) = fields.iter().partition(|f| f.is_data_field);
42+
43+
// Extract function generics used by data fields
44+
let data_field_generics: Vec<_> = fn_generics
45+
.iter()
46+
.filter(|generic| {
47+
let generic_ident = match generic {
48+
syn::GenericParam::Type(type_param) => &type_param.ident,
49+
_ => return false,
50+
};
51+
52+
// Check if this generic is used in any data field type
53+
data_fields.iter().any(|field| match &field.ty {
54+
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => type_contains_ident(ty, generic_ident),
55+
_ => false,
56+
})
57+
})
58+
.cloned()
59+
.collect();
60+
61+
// Node generics for regular fields (Node0, Node1, ...)
62+
let node_generics: Vec<Ident> = regular_fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
63+
64+
// Extract just the idents from data_field_generics for struct type parameters
65+
let data_field_generic_idents: Vec<Ident> = data_field_generics
66+
.iter()
67+
.filter_map(|gp| match gp {
68+
syn::GenericParam::Type(tp) => Some(tp.ident.clone()),
69+
_ => None,
70+
})
71+
.collect();
72+
73+
// Combined struct type parameters: data field generic idents (T, U, ...) + node generics (Node0, Node1, ...)
74+
// For struct type instantiation: MemoNode<T, Node0>
75+
let struct_type_params: Vec<Ident> = data_field_generic_idents.iter().cloned().chain(node_generics.iter().cloned()).collect();
76+
77+
// Combined struct generic parameters with bounds for struct definition
78+
// struct MemoNode<T: Clone, Node0>
79+
let struct_generic_params: Vec<TokenStream2> = data_field_generics.iter().map(|gp| quote!(#gp)).chain(node_generics.iter().map(|id| quote!(#id))).collect();
4180
let input_ident = &input.pat_ident;
4281

4382
let context_features = &input.context_features;
4483

45-
let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect();
84+
// Regular field idents and names (for function parameters)
85+
let field_idents: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident).collect();
4686
let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect();
87+
let regular_field_names: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident.ident).collect();
88+
let data_field_names: Vec<_> = data_fields.iter().map(|f| &f.pat_ident.ident).collect();
4789

48-
let input_names: Vec<_> = fields
90+
// Only regular fields have input names/descriptions (for UI)
91+
let input_names: Vec<_> = regular_fields
4992
.iter()
5093
.map(|f| &f.name)
51-
.zip(field_names.iter())
94+
.zip(regular_field_names.iter())
5295
.map(|zipped| match zipped {
5396
(Some(name), _) => name.value(),
5497
(_, name) => name.to_string().to_case(Case::Title),
5598
})
5699
.collect();
57100

58-
let input_descriptions: Vec<_> = fields.iter().map(|f| &f.description).collect();
101+
let input_descriptions: Vec<_> = regular_fields.iter().map(|f| &f.description).collect();
102+
103+
// Generate struct fields: data fields (concrete types) + regular fields (generic types)
104+
let data_field_defs = data_fields.iter().map(|field| {
105+
let name = &field.pat_ident.ident;
106+
let ty = match &field.ty {
107+
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty,
108+
_ => unreachable!("Data fields must be Regular types, not Node types"),
109+
};
110+
quote! { pub(super) #name: #ty }
111+
});
59112

60-
let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| {
113+
let regular_field_defs = regular_field_names.iter().zip(node_generics.iter()).map(|(name, r#gen)| {
61114
quote! { pub(super) #name: #r#gen }
62115
});
63116

117+
let struct_fields = data_field_defs.chain(regular_field_defs);
118+
64119
let mut future_idents = Vec::new();
65120

66-
let field_types: Vec<_> = fields
121+
// Data fields get passed as references to the underlying function
122+
let data_field_idents: Vec<_> = data_fields.iter().map(|f| &f.pat_ident).collect();
123+
let data_field_types: Vec<_> = data_fields
124+
.iter()
125+
.map(|field| match &field.ty {
126+
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => {
127+
let ty = ty.clone();
128+
quote!(&#ty)
129+
}
130+
_ => unreachable!("Data fields must be Regular types, not Node types"),
131+
})
132+
.collect();
133+
134+
// Regular fields have types passed to the function
135+
let field_types: Vec<_> = regular_fields
67136
.iter()
68137
.map(|field| match &field.ty {
69138
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(),
@@ -74,7 +143,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
74143
})
75144
.collect();
76145

77-
let widget_override: Vec<_> = fields
146+
// Only regular fields have UI metadata (data fields are internal state)
147+
let widget_override: Vec<_> = regular_fields
78148
.iter()
79149
.map(|field| match &field.widget_override {
80150
ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None),
@@ -84,7 +154,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
84154
})
85155
.collect();
86156

87-
let value_sources: Vec<_> = fields
157+
let value_sources: Vec<_> = regular_fields
88158
.iter()
89159
.map(|field| match &field.ty {
90160
ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source {
@@ -104,7 +174,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
104174
})
105175
.collect();
106176

107-
let default_types: Vec<_> = fields
177+
let default_types: Vec<_> = regular_fields
108178
.iter()
109179
.map(|field| match &field.ty {
110180
ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() {
@@ -115,7 +185,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
115185
})
116186
.collect();
117187

118-
let number_min_values: Vec<_> = fields
188+
let number_min_values: Vec<_> = regular_fields
119189
.iter()
120190
.map(|field| match &field.ty {
121191
ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) {
@@ -126,7 +196,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
126196
_ => quote!(None),
127197
})
128198
.collect();
129-
let number_max_values: Vec<_> = fields
199+
let number_max_values: Vec<_> = regular_fields
130200
.iter()
131201
.map(|field| match &field.ty {
132202
ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) {
@@ -137,7 +207,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
137207
_ => quote!(None),
138208
})
139209
.collect();
140-
let number_mode_range_values: Vec<_> = fields
210+
let number_mode_range_values: Vec<_> = regular_fields
141211
.iter()
142212
.map(|field| match &field.ty {
143213
ParsedFieldType::Regular(RegularParsedField {
@@ -147,23 +217,24 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
147217
_ => quote!(None),
148218
})
149219
.collect();
150-
let number_display_decimal_places: Vec<_> = fields
220+
let number_display_decimal_places: Vec<_> = regular_fields
151221
.iter()
152222
.map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i))))
153223
.collect();
154-
let number_step: Vec<_> = fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
224+
let number_step: Vec<_> = regular_fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
155225

156-
let unit_suffix: Vec<_> = fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
226+
let unit_suffix: Vec<_> = regular_fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
157227

158-
let exposed: Vec<_> = fields
228+
let exposed: Vec<_> = regular_fields
159229
.iter()
160230
.map(|field| match &field.ty {
161231
ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed),
162232
_ => quote!(true),
163233
})
164234
.collect();
165235

166-
let eval_args = fields.iter().map(|field| {
236+
// Only eval regular fields (data fields are accessed directly as self.field_name)
237+
let eval_args = regular_fields.iter().map(|field| {
167238
let name = &field.pat_ident.ident;
168239
match &field.ty {
169240
ParsedFieldType::Regular { .. } => {
@@ -175,7 +246,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
175246
}
176247
});
177248

178-
let min_max_args = fields.iter().map(|field| match &field.ty {
249+
// Only regular fields can have min/max constraints
250+
let min_max_args = regular_fields.iter().map(|field| match &field.ty {
179251
ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => {
180252
let name = &field.pat_ident.ident;
181253
let mut tokens = quote!();
@@ -208,7 +280,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
208280
let mut clauses = Vec::new();
209281
let mut clampable_clauses = Vec::new();
210282

211-
for (field, name) in fields.iter().zip(struct_generics.iter()) {
283+
for (field, name) in regular_fields.iter().zip(node_generics.iter()) {
212284
clauses.push(match (&field.ty, *is_async) {
213285
(
214286
ParsedFieldType::Regular(RegularParsedField {
@@ -259,13 +331,42 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
259331
);
260332
struct_where_clause.predicates.extend(extra_where);
261333

262-
let new_args = struct_generics.iter().zip(field_names.iter()).map(|(r#gen, name)| {
334+
// Only regular fields are parameters to new()
335+
let new_args = node_generics.iter().zip(regular_field_names.iter()).map(|(r#gen, name)| {
263336
quote! { #name: #r#gen }
264337
});
265338

339+
// Initialize data fields with Default, regular fields with parameters
340+
let data_inits = data_field_names.iter().map(|name| {
341+
quote! { #name: Default::default() }
342+
});
343+
let regular_inits = regular_field_names.iter().map(|name| {
344+
quote! { #name }
345+
});
346+
let all_field_inits = data_inits.chain(regular_inits);
347+
266348
let async_keyword = is_async.then(|| quote!(async));
267349
let await_keyword = is_async.then(|| quote!(.await));
268350

351+
// Data fields may not implement Copy, PartialEq, etc., so only derive Debug and Clone
352+
let struct_derives = if data_fields.is_empty() {
353+
quote!(#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)])
354+
} else {
355+
quote!(#[derive(Debug, Clone)])
356+
};
357+
358+
// Generate serialize method if serialize attribute is specified
359+
let serialize_impl = if let Some(serialize_fn) = &parsed.attributes.serialize {
360+
let data_field_refs = data_field_names.iter().map(|name| quote!(&self.#name));
361+
quote! {
362+
fn serialize(&self) -> Option<std::sync::Arc<dyn std::any::Any + Send + Sync>> {
363+
#serialize_fn(#(#data_field_refs),*)
364+
}
365+
}
366+
} else {
367+
quote!()
368+
};
369+
269370
let eval_impl = quote! {
270371
type Output = #core_types::registry::DynFuture<'n, #output_type>;
271372
#[inline]
@@ -275,9 +376,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
275376

276377
#(#eval_args)*
277378
#(#min_max_args)*
278-
self::#fn_name(__input #(, #field_names)*) #await_keyword
379+
self::#fn_name(__input #(, &self.#data_field_names)* #(, #regular_field_names)*) #await_keyword
279380
})
280381
}
382+
383+
#serialize_impl
281384
};
282385

283386
let identifier = format_ident!("{}_proto_ident", fn_name);
@@ -302,11 +405,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
302405
/// Underlying implementation for [#struct_name]
303406
#[inline]
304407
#[allow(clippy::too_many_arguments)]
305-
#vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #field_idents: #field_types)*) -> #output_type #where_clause #body
408+
#vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #data_field_idents: #data_field_types)* #(, #field_idents: #field_types)*) -> #output_type #where_clause #body
306409

307410
#cfg
308411
#[automatically_derived]
309-
impl<'n, #(#fn_generics,)* #(#struct_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_generics,)*>
412+
impl<'n, #(#fn_generics,)* #(#node_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_type_params,)*>
310413
#struct_where_clause
311414
{
312415
#eval_impl
@@ -340,18 +443,18 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
340443

341444
static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData;
342445

343-
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
344-
pub struct #struct_name<#(#struct_generics,)*> {
446+
#struct_derives
447+
pub struct #struct_name<#(#struct_generic_params,)*> {
345448
#(#struct_fields,)*
346449
}
347450

348451
#[automatically_derived]
349-
impl<'n, #(#struct_generics,)*> #struct_name<#(#struct_generics,)*>
452+
impl<'n, #(#struct_generic_params,)*> #struct_name<#(#struct_type_params,)*>
350453
{
351454
#[allow(clippy::too_many_arguments)]
352455
pub fn new(#(#new_args,)*) -> Self {
353456
Self {
354-
#(#field_names,)*
457+
#(#all_field_inits,)*
355458
}
356459
}
357460
}
@@ -493,8 +596,10 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
493596

494597
let mut constructors = Vec::new();
495598
let unit = parse_quote!(gcore::Context);
496-
let parameter_types: Vec<_> = parsed
497-
.fields
599+
600+
let regular_fields: Vec<_> = parsed.fields.iter().filter(|f| !f.is_data_field).collect();
601+
602+
let parameter_types: Vec<_> = regular_fields
498603
.iter()
499604
.map(|field| {
500605
match &field.ty {
@@ -535,7 +640,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
535640
let field_name = field_names[j];
536641
let (input_type, output_type) = &types[i.min(types.len() - 1)];
537642

538-
let node = matches!(parsed.fields[j].ty, ParsedFieldType::Node { .. });
643+
let node = matches!(regular_fields[j].ty, ParsedFieldType::Node { .. });
539644

540645
let downcast_node = quote!(
541646
let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone());
@@ -712,3 +817,23 @@ impl FilterUsedGenerics {
712817
self.used(&*modified).cloned().collect()
713818
}
714819
}
820+
821+
/// Check if a type contains a reference to a specific identifier (e.g., a generic type parameter)
822+
fn type_contains_ident(ty: &Type, ident: &Ident) -> bool {
823+
struct IdentChecker<'a> {
824+
target: &'a Ident,
825+
found: bool,
826+
}
827+
828+
impl<'a, 'ast> syn::visit::Visit<'ast> for IdentChecker<'a> {
829+
fn visit_ident(&mut self, i: &'ast Ident) {
830+
if i == self.target {
831+
self.found = true;
832+
}
833+
}
834+
}
835+
836+
let mut checker = IdentChecker { target: ident, found: false };
837+
syn::visit::visit_type(&mut checker, ty);
838+
checker.found
839+
}

0 commit comments

Comments
 (0)