Skip to content

Commit 8c85e03

Browse files
committed
Workaround Closure model via manually replacing type
1 parent 60ac68c commit 8c85e03

4 files changed

Lines changed: 77 additions & 3 deletions

File tree

src/analyze/annot.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ pub fn array_model_path() -> [Symbol; 3] {
8989
]
9090
}
9191

92+
pub fn closure_model_path() -> [Symbol; 3] {
93+
[
94+
Symbol::intern("thrust"),
95+
Symbol::intern("def"),
96+
Symbol::intern("closure_model"),
97+
]
98+
}
99+
92100
/// A [`annot::Resolver`] implementation for resolving function parameters.
93101
///
94102
/// The parameter names and their sorts needs to be configured via

src/analyze/did_cache.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct DefIds {
1818
mut_model: OnceCell<Option<DefId>>,
1919
box_model: OnceCell<Option<DefId>>,
2020
array_model: OnceCell<Option<DefId>>,
21+
closure_model: OnceCell<Option<DefId>>,
2122
}
2223

2324
/// Retrieves and caches well-known [`DefId`]s.
@@ -127,4 +128,11 @@ impl<'tcx> DefIdCache<'tcx> {
127128
.array_model
128129
.get_or_init(|| self.annotated_def(&crate::analyze::annot::array_model_path()))
129130
}
131+
132+
pub fn closure_model(&self) -> Option<DefId> {
133+
*self
134+
.def_ids
135+
.closure_model
136+
.get_or_init(|| self.annotated_def(&crate::analyze::annot::closure_model_path()))
137+
}
130138
}

src/refine/template.rs

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,50 @@ impl<'tcx> TypeBuilder<'tcx> {
110110
rty::ParamType::new(index).into()
111111
}
112112

113-
fn resolve_model_ty(&self, ty: mir_ty::Ty<'tcx>) -> mir_ty::Ty<'tcx> {
113+
/// Replaces {closure} types with thrust_models::Closure<{closure}>.
114+
///
115+
/// Ideally, we want to have `impl<F> Model for F where F: Fn` instead of this and let
116+
/// normalization do the work. However, it doesn't work:
117+
/// 1. it simply conflicts with other blanket impls
118+
/// 2. attempts to fake the impl via override_queries have failed so far
119+
fn replace_closure_model(&self, ty: mir_ty::Ty<'tcx>) -> mir_ty::Ty<'tcx> {
120+
let Some(closure_model_id) = self.def_ids.closure_model() else {
121+
return ty;
122+
};
123+
124+
struct ReplaceClosureModel<'tcx> {
125+
tcx: mir_ty::TyCtxt<'tcx>,
126+
closure_model_id: DefId,
127+
}
128+
129+
use mir_ty::fold::TypeFoldable;
130+
impl<'tcx> mir_ty::fold::TypeFolder<mir_ty::TyCtxt<'tcx>> for ReplaceClosureModel<'tcx> {
131+
fn interner(&self) -> mir_ty::TyCtxt<'tcx> {
132+
self.tcx
133+
}
134+
135+
fn fold_ty(&mut self, ty: mir_ty::Ty<'tcx>) -> mir_ty::Ty<'tcx> {
136+
use mir_ty::fold::TypeSuperFoldable;
137+
if let mir_ty::TyKind::Closure(_, args) = ty.kind() {
138+
let args = self
139+
.tcx
140+
.mk_args(&[args.as_closure().tupled_upvars_ty().into()]);
141+
let adt_def = self.tcx.adt_def(self.closure_model_id);
142+
return mir_ty::Ty::new_adt(self.tcx, adt_def, args);
143+
}
144+
ty.super_fold_with(self)
145+
}
146+
}
147+
148+
ty.fold_with(&mut ReplaceClosureModel {
149+
tcx: self.tcx,
150+
closure_model_id,
151+
})
152+
}
153+
154+
fn resolve_model_ty(&self, orig_ty: mir_ty::Ty<'tcx>) -> mir_ty::Ty<'tcx> {
155+
let ty = self.replace_closure_model(orig_ty);
156+
114157
let Some(model_ty_def_id) = self.def_ids.model_ty() else {
115158
return ty;
116159
};
@@ -151,6 +194,11 @@ impl<'tcx> TypeBuilder<'tcx> {
151194
return Some(rty::ArrayType::new(idx_ty, elem_ty).into());
152195
}
153196

197+
if Some(adt.did()) == self.def_ids.closure_model() {
198+
let tupled_upvars_ty = args.type_at(0);
199+
return Some(self.build(tupled_upvars_ty));
200+
}
201+
154202
None
155203
}
156204

@@ -210,7 +258,6 @@ impl<'tcx> TypeBuilder<'tcx> {
210258
unimplemented!("unsupported ADT: {:?}", ty);
211259
}
212260
}
213-
mir_ty::TyKind::Closure(_, args) => self.build(args.as_closure().tupled_upvars_ty()),
214261
kind => unimplemented!("unrefined_ty: {:?}", kind),
215262
}
216263
}
@@ -310,6 +357,11 @@ where
310357
return Some(rty::ArrayType::new(idx_ty, elem_ty).into());
311358
}
312359

360+
if Some(adt.did()) == self.inner.def_ids.closure_model() {
361+
let tupled_upvars_ty = args.type_at(0);
362+
return Some(self.build(tupled_upvars_ty));
363+
}
364+
313365
None
314366
}
315367

@@ -361,7 +413,6 @@ where
361413
unimplemented!("unsupported ADT: {:?}", ty);
362414
}
363415
}
364-
mir_ty::TyKind::Closure(_, args) => self.build(args.as_closure().tupled_upvars_ty()),
365416
kind => unimplemented!("ty: {:?}", kind),
366417
}
367418
}

std.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ mod thrust_models {
2121
#[thrust::def::array_model]
2222
pub struct Array<I, T>(PhantomData<I>, PhantomData<T>);
2323

24+
#[thrust::def::closure_model]
25+
pub struct Closure<T>(PhantomData<T>);
26+
2427
pub struct Vec<T>(pub Array<Int, T>, pub usize);
2528
}
2629

@@ -56,6 +59,10 @@ mod thrust_models {
5659
type Ty = ();
5760
}
5861

62+
impl<T> Model for model::Closure<T> {
63+
type Ty = model::Closure<T>;
64+
}
65+
5966
impl<T0> Model for (T0,) where T0: Model {
6067
type Ty = (<T0 as Model>::Ty,);
6168
}

0 commit comments

Comments
 (0)