@@ -8,6 +8,7 @@ use rustc_middle::ty::{self as mir_ty, TyCtxt};
88use crate :: analyze:: did_cache:: DefIdCache ;
99use crate :: annot:: AnnotFormula ;
1010use crate :: chc;
11+ use crate :: refine:: TypeBuilder ;
1112use crate :: rty;
1213
1314#[ derive( Debug , Clone ) ]
@@ -118,6 +119,7 @@ impl<T> FormulaOrTerm<T> {
118119 }
119120}
120121
122+ #[ derive( Clone ) ]
121123pub struct AnnotFnTranslator < ' tcx > {
122124 tcx : TyCtxt < ' tcx > ,
123125 local_def_id : LocalDefId ,
@@ -127,6 +129,7 @@ pub struct AnnotFnTranslator<'tcx> {
127129 generic_args : mir_ty:: GenericArgsRef < ' tcx > ,
128130
129131 def_ids : DefIdCache < ' tcx > ,
132+ type_builder : TypeBuilder < ' tcx > ,
130133 env : HashMap < HirId , chc:: Term < rty:: FunctionParamIdx > > ,
131134}
132135
@@ -138,13 +141,15 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
138141 let generic_args = tcx. mk_args ( & [ ] ) ;
139142 let typeck = tcx. typeck ( local_def_id) ;
140143 let def_ids = DefIdCache :: new ( tcx) ;
144+ let type_builder = TypeBuilder :: new ( tcx, def_ids. clone ( ) , local_def_id. to_def_id ( ) ) ;
141145 let mut translator = Self {
142146 tcx,
143147 local_def_id,
144148 typeck,
145149 body,
146150 generic_args,
147151 def_ids,
152+ type_builder,
148153 env : HashMap :: default ( ) ,
149154 } ;
150155 translator. build_env_from_params ( ) ;
@@ -158,6 +163,11 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
158163
159164 pub fn with_def_id_cache ( mut self , def_ids : DefIdCache < ' tcx > ) -> Self {
160165 self . def_ids = def_ids;
166+ self . type_builder = TypeBuilder :: new (
167+ self . tcx ,
168+ self . def_ids . clone ( ) ,
169+ self . local_def_id . to_def_id ( ) ,
170+ ) ;
161171 self
162172 }
163173
@@ -197,6 +207,13 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
197207 self . tcx . normalize_erasing_regions ( param_env, instantiated)
198208 }
199209
210+ fn pat_ty ( & self , pat : & ' tcx rustc_hir:: Pat < ' tcx > ) -> mir_ty:: Ty < ' tcx > {
211+ let ty = self . typeck . pat_ty ( pat) ;
212+ let instantiated = mir_ty:: EarlyBinder :: bind ( ty) . instantiate ( self . tcx , self . generic_args ) ;
213+ let param_env = mir_ty:: ParamEnv :: reveal_all ( ) ;
214+ self . tcx . normalize_erasing_regions ( param_env, instantiated)
215+ }
216+
200217 pub fn to_formula_fn ( & self ) -> FormulaFn < ' tcx > {
201218 let formula = self . to_formula ( self . body . value ) ;
202219 let params = self
@@ -224,6 +241,28 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
224241 . expect ( "expected a term" )
225242 }
226243
244+ fn variant_ctor_term (
245+ & self ,
246+ ctor_did : rustc_span:: def_id:: DefId ,
247+ result_ty : mir_ty:: Ty < ' tcx > ,
248+ field_terms : Vec < chc:: Term < rty:: FunctionParamIdx > > ,
249+ ) -> chc:: Term < rty:: FunctionParamIdx > {
250+ let variant_did = self . tcx . parent ( ctor_did) ;
251+ let adt_did = self . tcx . parent ( variant_did) ;
252+ let d_sym = crate :: refine:: datatype_symbol ( self . tcx , adt_did) ;
253+ let variant_name = self . tcx . item_name ( variant_did) ;
254+ let v_sym = chc:: DatatypeSymbol :: new ( format ! ( "{}.{}" , d_sym, variant_name) ) ;
255+ let sort_args = if let mir_ty:: TyKind :: Adt ( _, generic_args) = result_ty. kind ( ) {
256+ generic_args
257+ . types ( )
258+ . map ( |ty| self . type_builder . build ( ty) . to_sort ( ) )
259+ . collect ( )
260+ } else {
261+ panic ! ( "expected an ADT type for variant constructor" )
262+ } ;
263+ chc:: Term :: datatype_ctor ( d_sym, sort_args, v_sym, field_terms)
264+ }
265+
227266 fn to_formula_or_term (
228267 & self ,
229268 hir : & ' tcx rustc_hir:: Expr < ' tcx > ,
@@ -319,20 +358,21 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
319358 rustc_ast:: LitKind :: Bool ( b) => FormulaOrTerm :: Literal ( b) ,
320359 _ => unimplemented ! ( "unsupported literal in formula: {:?}" , lit) ,
321360 } ,
322- ExprKind :: Path ( qpath) => {
323- if let rustc_hir:: def:: Res :: Local ( hir_id) =
324- self . typeck . qpath_res ( & qpath , hir . hir_id )
325- {
326- FormulaOrTerm :: Term (
327- self . env
328- . get ( & hir_id )
329- . expect ( "unbound variable in formula" )
330- . clone ( ) ,
331- )
332- } else {
333- unimplemented ! ( "unsupported path in formula: {:?}" , qpath ) ;
361+ ExprKind :: Path ( qpath) => match self . typeck . qpath_res ( & qpath , hir . hir_id ) {
362+ rustc_hir:: def:: Res :: Local ( hir_id) => FormulaOrTerm :: Term (
363+ self . env
364+ . get ( & hir_id )
365+ . expect ( "unbound variable in formula" )
366+ . clone ( ) ,
367+ ) ,
368+ rustc_hir :: def :: Res :: Def (
369+ rustc_hir :: def :: DefKind :: Ctor ( rustc_hir :: def :: CtorOf :: Variant , _ ) ,
370+ ctor_did ,
371+ ) => {
372+ FormulaOrTerm :: Term ( self . variant_ctor_term ( ctor_did , self . expr_ty ( hir ) , vec ! [ ] ) )
334373 }
335- }
374+ _ => unimplemented ! ( "unsupported path in formula: {:?}" , qpath) ,
375+ } ,
336376 ExprKind :: Tup ( exprs) => {
337377 let terms = exprs. iter ( ) . map ( |e| self . to_term ( e) ) . collect ( ) ;
338378 FormulaOrTerm :: Term ( chc:: Term :: tuple ( terms) )
@@ -349,7 +389,40 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
349389 ExprKind :: Call ( func_expr, args) => {
350390 if let ExprKind :: Path ( qpath) = & func_expr. kind {
351391 let res = self . typeck . qpath_res ( qpath, func_expr. hir_id ) ;
352- if let rustc_hir:: def:: Res :: Def ( _, def_id) = res {
392+ if let rustc_hir:: def:: Res :: Def ( def_kind, def_id) = res {
393+ if Some ( def_id) == self . def_ids . exists ( ) {
394+ assert_eq ! ( args. len( ) , 1 , "exists takes exactly 1 argument" ) ;
395+ let ExprKind :: Closure ( closure) = args[ 0 ] . kind else {
396+ panic ! ( "exists argument must be a closure" ) ;
397+ } ;
398+ let closure_body = self . tcx . hir ( ) . body ( closure. body ) ;
399+
400+ let mut inner_translator = self . clone ( ) ;
401+ let mut vars = Vec :: new ( ) ;
402+ for param in closure_body. params {
403+ let rustc_hir:: PatKind :: Binding ( _, hir_id, ident, None ) =
404+ param. pat . kind
405+ else {
406+ panic ! (
407+ "exists closure parameter must be a simple binding: {:?}" ,
408+ param. pat
409+ ) ;
410+ } ;
411+ let param_ty = self . pat_ty ( param. pat ) ;
412+ let sort = self . type_builder . build ( param_ty) . to_sort ( ) ;
413+ let var_term = chc:: Term :: FormulaExistentialVar (
414+ sort. clone ( ) ,
415+ ident. name . to_string ( ) ,
416+ ) ;
417+ inner_translator. env . insert ( hir_id, var_term) ;
418+ vars. push ( ( ident. name . to_string ( ) , sort) ) ;
419+ }
420+ let body_formula = inner_translator. to_formula ( closure_body. value ) ;
421+ return FormulaOrTerm :: Formula ( chc:: Formula :: exists (
422+ vars,
423+ body_formula,
424+ ) ) ;
425+ }
353426 if Some ( def_id) == self . def_ids . mut_model_new ( ) {
354427 assert_eq ! ( args. len( ) , 2 , "Mut::new takes exactly 2 arguments" ) ;
355428 let t1 = self . to_term ( & args[ 0 ] ) ;
@@ -361,6 +434,17 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
361434 let t = self . to_term ( & args[ 0 ] ) ;
362435 return FormulaOrTerm :: Term ( chc:: Term :: box_ ( t) ) ;
363436 }
437+ if matches ! (
438+ def_kind,
439+ rustc_hir:: def:: DefKind :: Ctor ( rustc_hir:: def:: CtorOf :: Variant , _)
440+ ) {
441+ let field_terms = args. iter ( ) . map ( |arg| self . to_term ( arg) ) . collect ( ) ;
442+ return FormulaOrTerm :: Term ( self . variant_ctor_term (
443+ def_id,
444+ self . expr_ty ( hir) ,
445+ field_terms,
446+ ) ) ;
447+ }
364448 }
365449 }
366450 unimplemented ! ( "unsupported call in formula: {:?}" , func_expr)
0 commit comments