@@ -110,7 +110,50 @@ impl<'tcx> TypeBuilder<'tcx> {
110110 rty:: ParamType :: new ( index) . into ( )
111111 }
112112
113- fn resolve_model_ty ( & self , ty : mir_ty:: Ty < ' tcx > ) -> mir_ty:: Ty < ' tcx > {
113+ /// Replaces {closure} types with thrust_models::Closure<{closure}>.
114+ ///
115+ /// Ideally, we want to have `impl<F> Model for F where F: Fn` instead of this and let
116+ /// normalization do the work. However, it doesn't work:
117+ /// 1. it simply conflicts with other blanket impls
118+ /// 2. attempts to fake the impl via override_queries have failed so far
119+ fn replace_closure_model ( & self , ty : mir_ty:: Ty < ' tcx > ) -> mir_ty:: Ty < ' tcx > {
120+ let Some ( closure_model_id) = self . def_ids . closure_model ( ) else {
121+ return ty;
122+ } ;
123+
124+ struct ReplaceClosureModel < ' tcx > {
125+ tcx : mir_ty:: TyCtxt < ' tcx > ,
126+ closure_model_id : DefId ,
127+ }
128+
129+ use mir_ty:: fold:: TypeFoldable ;
130+ impl < ' tcx > mir_ty:: fold:: TypeFolder < mir_ty:: TyCtxt < ' tcx > > for ReplaceClosureModel < ' tcx > {
131+ fn interner ( & self ) -> mir_ty:: TyCtxt < ' tcx > {
132+ self . tcx
133+ }
134+
135+ fn fold_ty ( & mut self , ty : mir_ty:: Ty < ' tcx > ) -> mir_ty:: Ty < ' tcx > {
136+ use mir_ty:: fold:: TypeSuperFoldable ;
137+ if let mir_ty:: TyKind :: Closure ( _, args) = ty. kind ( ) {
138+ let args = self
139+ . tcx
140+ . mk_args ( & [ args. as_closure ( ) . tupled_upvars_ty ( ) . into ( ) ] ) ;
141+ let adt_def = self . tcx . adt_def ( self . closure_model_id ) ;
142+ return mir_ty:: Ty :: new_adt ( self . tcx , adt_def, args) ;
143+ }
144+ ty. super_fold_with ( self )
145+ }
146+ }
147+
148+ ty. fold_with ( & mut ReplaceClosureModel {
149+ tcx : self . tcx ,
150+ closure_model_id,
151+ } )
152+ }
153+
154+ fn resolve_model_ty ( & self , orig_ty : mir_ty:: Ty < ' tcx > ) -> mir_ty:: Ty < ' tcx > {
155+ let ty = self . replace_closure_model ( orig_ty) ;
156+
114157 let Some ( model_ty_def_id) = self . def_ids . model_ty ( ) else {
115158 return ty;
116159 } ;
@@ -151,6 +194,11 @@ impl<'tcx> TypeBuilder<'tcx> {
151194 return Some ( rty:: ArrayType :: new ( idx_ty, elem_ty) . into ( ) ) ;
152195 }
153196
197+ if Some ( adt. did ( ) ) == self . def_ids . closure_model ( ) {
198+ let tupled_upvars_ty = args. type_at ( 0 ) ;
199+ return Some ( self . build ( tupled_upvars_ty) ) ;
200+ }
201+
154202 None
155203 }
156204
@@ -210,7 +258,6 @@ impl<'tcx> TypeBuilder<'tcx> {
210258 unimplemented ! ( "unsupported ADT: {:?}" , ty) ;
211259 }
212260 }
213- mir_ty:: TyKind :: Closure ( _, args) => self . build ( args. as_closure ( ) . tupled_upvars_ty ( ) ) ,
214261 kind => unimplemented ! ( "unrefined_ty: {:?}" , kind) ,
215262 }
216263 }
@@ -310,6 +357,11 @@ where
310357 return Some ( rty:: ArrayType :: new ( idx_ty, elem_ty) . into ( ) ) ;
311358 }
312359
360+ if Some ( adt. did ( ) ) == self . inner . def_ids . closure_model ( ) {
361+ let tupled_upvars_ty = args. type_at ( 0 ) ;
362+ return Some ( self . build ( tupled_upvars_ty) ) ;
363+ }
364+
313365 None
314366 }
315367
@@ -361,7 +413,6 @@ where
361413 unimplemented ! ( "unsupported ADT: {:?}" , ty) ;
362414 }
363415 }
364- mir_ty:: TyKind :: Closure ( _, args) => self . build ( args. as_closure ( ) . tupled_upvars_ty ( ) ) ,
365416 kind => unimplemented ! ( "ty: {:?}" , kind) ,
366417 }
367418 }
0 commit comments