Skip to content

Commit 82a7840

Browse files
coord-eclaude
andcommitted
Change thrust-macros expansion for extern_spec_fn functions
When the annotated function already has #[thrust::extern_spec_fn], inject requires_path/ensures_path into the original function body instead of generating a separate _thrust_extern_spec_ wrapper. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent d793249 commit 82a7840

1 file changed

Lines changed: 68 additions & 1 deletion

File tree

thrust-macros/src/lib.rs

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,19 @@ pub fn requires(attr: TokenStream, item: TokenStream) -> TokenStream {
5858
.map_or(true, |s| s.ident != "requires")
5959
});
6060

61-
transform(func, req_expr, ens_expr).unwrap_or_else(|e| e.to_compile_error().into())
61+
let is_extern_spec = func.attrs.iter().any(|a| {
62+
a.path()
63+
.segments
64+
.last()
65+
.map_or(false, |s| s.ident == "extern_spec_fn")
66+
});
67+
68+
if is_extern_spec {
69+
transform_extern_spec(func, req_expr, ens_expr)
70+
.unwrap_or_else(|e| e.to_compile_error().into())
71+
} else {
72+
transform(func, req_expr, ens_expr).unwrap_or_else(|e| e.to_compile_error().into())
73+
}
6274
}
6375

6476
/// Pass-through no-op. When `#[ensures]` is consumed by `#[requires]`, this is never
@@ -69,6 +81,61 @@ pub fn ensures(_attr: TokenStream, item: TokenStream) -> TokenStream {
6981
item
7082
}
7183

84+
fn transform_extern_spec(
85+
mut func: ItemFn,
86+
req_expr: TokenStream2,
87+
ens_expr: TokenStream2,
88+
) -> syn::Result<TokenStream> {
89+
let name = &func.sig.ident;
90+
let requires_name = format_ident!("_thrust_requires_{}", name);
91+
let ensures_name = format_ident!("_thrust_ensures_{}", name);
92+
93+
let generics = &func.sig.generics;
94+
95+
let def_generics = generic_params_tokens(generics);
96+
let turbofish = generic_turbofish(generics);
97+
let model_preds = model_where_predicates(generics);
98+
let extended_where = extended_where_clause(generics, &model_preds);
99+
100+
// Parameters with types replaced by <T as thrust_models::Model>::Ty
101+
let model_ty_params = fn_params_with_model_ty(&func.sig.inputs);
102+
103+
// Return type for ensures helper (wrapped in ::Ty)
104+
let ret_model_ty: Type = match &func.sig.output {
105+
ReturnType::Default => syn::parse_quote!(<() as thrust_models::Model>::Ty),
106+
ReturnType::Type(_, ty) => syn::parse_quote!(<#ty as thrust_models::Model>::Ty),
107+
};
108+
109+
// Prepend requires_path/ensures_path statements to the original block
110+
let orig_stmts = func.block.stmts.clone();
111+
func.block = syn::parse_quote!({
112+
#[thrust::requires_path]
113+
#requires_name #turbofish;
114+
#[thrust::ensures_path]
115+
#ensures_name #turbofish;
116+
#(#orig_stmts)*
117+
});
118+
119+
let output = quote! {
120+
#[allow(unused_variables)]
121+
#[thrust::formula_fn]
122+
fn #requires_name #def_generics(#model_ty_params) -> bool #extended_where {
123+
#req_expr
124+
}
125+
126+
#[allow(unused_variables)]
127+
#[thrust::formula_fn]
128+
fn #ensures_name #def_generics(result: #ret_model_ty, #model_ty_params) -> bool #extended_where {
129+
#ens_expr
130+
}
131+
132+
#[allow(path_statements)]
133+
#func
134+
};
135+
136+
Ok(output.into())
137+
}
138+
72139
fn transform(
73140
func: ItemFn,
74141
req_expr: TokenStream2,

0 commit comments

Comments
 (0)