@@ -190,6 +190,11 @@ impl refine::EnumDefProvider for Rc<RefCell<EnumDefs>> {
190190
191191pub type Env = refine:: Env < Rc < RefCell < EnumDefs > > > ;
192192
193+ #[ derive( Debug , Clone ) ]
194+ struct DeferredFormulaFnDef < ' tcx > {
195+ cache : Rc < RefCell < HashMap < mir_ty:: GenericArgsRef < ' tcx > , annot_fn:: FormulaFn < ' tcx > > > > ,
196+ }
197+
193198#[ derive( Clone ) ]
194199pub struct Analyzer < ' tcx > {
195200 tcx : TyCtxt < ' tcx > ,
@@ -202,7 +207,7 @@ pub struct Analyzer<'tcx> {
202207 defs : HashMap < DefId , DefTy < ' tcx > > ,
203208
204209 /// Collection of functions with `#[thrust::formula_fn]` attribute.
205- formula_fns : HashMap < DefId , annot_fn :: FormulaFn < ' tcx > > ,
210+ formula_fns : HashMap < LocalDefId , DeferredFormulaFnDef < ' tcx > > ,
206211
207212 /// Resulting CHC system.
208213 system : Rc < RefCell < chc:: System > > ,
@@ -391,6 +396,30 @@ impl<'tcx> Analyzer<'tcx> {
391396 } )
392397 }
393398
399+ pub fn formula_fn_with_args (
400+ & self ,
401+ local_def_id : LocalDefId ,
402+ generic_args : mir_ty:: GenericArgsRef < ' tcx > ,
403+ ) -> Option < annot_fn:: FormulaFn < ' tcx > > {
404+ let deferred_formula_fn = self . formula_fns . get ( & local_def_id) ?;
405+
406+ let deferred_formula_fn_cache = Rc :: clone ( & deferred_formula_fn. cache ) ;
407+ if let Some ( formula_fn) = deferred_formula_fn_cache. borrow ( ) . get ( & generic_args) {
408+ return Some ( formula_fn. clone ( ) ) ;
409+ }
410+
411+ let translator = annot_fn:: AnnotFnTranslator :: new ( self . tcx , local_def_id)
412+ . with_generic_args ( generic_args)
413+ . with_def_id_cache ( self . def_ids ( ) ) ;
414+ let formula_fn = translator. to_formula_fn ( ) ;
415+ deferred_formula_fn_cache
416+ . borrow_mut ( )
417+ . insert ( generic_args, formula_fn. clone ( ) ) ;
418+
419+ tracing:: info!( ?local_def_id, formula_fn = %formula_fn. display( ) , ?generic_args, "formula_fn_with_args" ) ;
420+ Some ( formula_fn)
421+ }
422+
394423 pub fn def_ty_with_args (
395424 & mut self ,
396425 def_id : DefId ,
@@ -443,9 +472,14 @@ impl<'tcx> Analyzer<'tcx> {
443472 Some ( expected)
444473 }
445474
446- pub fn register_formula_fn ( & mut self , def_id : DefId , formula_fn : annot_fn:: FormulaFn < ' tcx > ) {
447- tracing:: info!( def_id = ?def_id, formula_fn = %formula_fn. display( ) , "register_formula_fn" ) ;
448- self . formula_fns . insert ( def_id, formula_fn) ;
475+ pub fn register_formula_fn ( & mut self , local_def_id : LocalDefId ) {
476+ tracing:: info!( ?local_def_id, "register_formula_fn" ) ;
477+ self . formula_fns . insert (
478+ local_def_id,
479+ DeferredFormulaFnDef {
480+ cache : Rc :: new ( RefCell :: new ( HashMap :: new ( ) ) ) ,
481+ } ,
482+ ) ;
449483 }
450484
451485 pub fn register_basic_block_ty (
@@ -585,11 +619,13 @@ impl<'tcx> Analyzer<'tcx> {
585619 None
586620 }
587621
622+ // TODO: reduce number of args
588623 fn extract_require_annot < T > (
589624 & self ,
590625 local_def_id : LocalDefId ,
591626 resolver : T ,
592627 self_type_name : Option < String > ,
628+ generic_args : mir_ty:: GenericArgsRef < ' tcx > ,
593629 ) -> Option < AnnotFormula < T :: Output > >
594630 where
595631 T : Resolver < Output = rty:: FunctionParamIdx > ,
@@ -611,10 +647,16 @@ impl<'tcx> Analyzer<'tcx> {
611647 if let Some ( formula_def_id) =
612648 self . extract_path_with_attr ( local_def_id, & analyze:: annot:: requires_path_path ( ) )
613649 {
650+ let Some ( formula_def_id) = formula_def_id. as_local ( ) else {
651+ panic ! (
652+ "require annotation with path is expected to refer to a local def, but found: {:?}" ,
653+ formula_def_id
654+ ) ;
655+ } ;
614656 if require_annot. is_some ( ) {
615657 unimplemented ! ( ) ;
616658 }
617- let Some ( formula_fn) = self . formula_fns . get ( & formula_def_id) else {
659+ let Some ( formula_fn) = self . formula_fn_with_args ( formula_def_id, generic_args ) else {
618660 panic ! (
619661 "require annotation {:?} is not a formula function" ,
620662 formula_def_id
@@ -626,11 +668,13 @@ impl<'tcx> Analyzer<'tcx> {
626668 require_annot
627669 }
628670
671+ // TODO: reduce number of args
629672 fn extract_ensure_annot < T > (
630673 & self ,
631674 local_def_id : LocalDefId ,
632675 resolver : T ,
633676 self_type_name : Option < String > ,
677+ generic_args : mir_ty:: GenericArgsRef < ' tcx > ,
634678 ) -> Option < AnnotFormula < T :: Output > >
635679 where
636680 T : Resolver < Output = rty:: RefinedTypeVar < rty:: FunctionParamIdx > > ,
@@ -653,10 +697,16 @@ impl<'tcx> Analyzer<'tcx> {
653697 if let Some ( formula_def_id) =
654698 self . extract_path_with_attr ( local_def_id, & analyze:: annot:: ensures_path_path ( ) )
655699 {
700+ let Some ( formula_def_id) = formula_def_id. as_local ( ) else {
701+ panic ! (
702+ "require annotation with path is expected to refer to a local def, but found: {:?}" ,
703+ formula_def_id
704+ ) ;
705+ } ;
656706 if ensure_annot. is_some ( ) {
657707 unimplemented ! ( ) ;
658708 }
659- let Some ( formula_fn) = self . formula_fns . get ( & formula_def_id) else {
709+ let Some ( formula_fn) = self . formula_fn_with_args ( formula_def_id, generic_args ) else {
660710 panic ! (
661711 "ensure annotation {:?} is not a formula function" ,
662712 formula_def_id
0 commit comments