Skip to content

Commit fda3d96

Browse files
committed
Translate formula fns after type params are instantiated
1 parent f71f919 commit fda3d96

6 files changed

Lines changed: 154 additions & 16 deletions

File tree

src/analyze.rs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ impl refine::EnumDefProvider for Rc<RefCell<EnumDefs>> {
190190

191191
pub type Env = refine::Env<Rc<RefCell<EnumDefs>>>;
192192

193+
#[derive(Debug, Clone)]
194+
struct DeferredFormulaFnDef<'tcx> {
195+
cache: Rc<RefCell<HashMap<mir_ty::GenericArgsRef<'tcx>, annot_fn::FormulaFn<'tcx>>>>,
196+
}
197+
193198
#[derive(Clone)]
194199
pub struct Analyzer<'tcx> {
195200
tcx: TyCtxt<'tcx>,
@@ -202,7 +207,7 @@ pub struct Analyzer<'tcx> {
202207
defs: HashMap<DefId, DefTy<'tcx>>,
203208

204209
/// Collection of functions with `#[thrust::formula_fn]` attribute.
205-
formula_fns: HashMap<DefId, annot_fn::FormulaFn<'tcx>>,
210+
formula_fns: HashMap<LocalDefId, DeferredFormulaFnDef<'tcx>>,
206211

207212
/// Resulting CHC system.
208213
system: Rc<RefCell<chc::System>>,
@@ -391,6 +396,30 @@ impl<'tcx> Analyzer<'tcx> {
391396
})
392397
}
393398

399+
pub fn formula_fn_with_args(
400+
&self,
401+
local_def_id: LocalDefId,
402+
generic_args: mir_ty::GenericArgsRef<'tcx>,
403+
) -> Option<annot_fn::FormulaFn<'tcx>> {
404+
let deferred_formula_fn = self.formula_fns.get(&local_def_id)?;
405+
406+
let deferred_formula_fn_cache = Rc::clone(&deferred_formula_fn.cache);
407+
if let Some(formula_fn) = deferred_formula_fn_cache.borrow().get(&generic_args) {
408+
return Some(formula_fn.clone());
409+
}
410+
411+
let translator = annot_fn::AnnotFnTranslator::new(self.tcx, local_def_id)
412+
.with_generic_args(generic_args)
413+
.with_def_id_cache(self.def_ids());
414+
let formula_fn = translator.to_formula_fn();
415+
deferred_formula_fn_cache
416+
.borrow_mut()
417+
.insert(generic_args, formula_fn.clone());
418+
419+
tracing::info!(?local_def_id, formula_fn = %formula_fn.display(), ?generic_args, "formula_fn_with_args");
420+
Some(formula_fn)
421+
}
422+
394423
pub fn def_ty_with_args(
395424
&mut self,
396425
def_id: DefId,
@@ -443,9 +472,14 @@ impl<'tcx> Analyzer<'tcx> {
443472
Some(expected)
444473
}
445474

446-
pub fn register_formula_fn(&mut self, def_id: DefId, formula_fn: annot_fn::FormulaFn<'tcx>) {
447-
tracing::info!(def_id = ?def_id, formula_fn = %formula_fn.display(), "register_formula_fn");
448-
self.formula_fns.insert(def_id, formula_fn);
475+
pub fn register_formula_fn(&mut self, local_def_id: LocalDefId) {
476+
tracing::info!(?local_def_id, "register_formula_fn");
477+
self.formula_fns.insert(
478+
local_def_id,
479+
DeferredFormulaFnDef {
480+
cache: Rc::new(RefCell::new(HashMap::new())),
481+
},
482+
);
449483
}
450484

451485
pub fn register_basic_block_ty(
@@ -585,11 +619,13 @@ impl<'tcx> Analyzer<'tcx> {
585619
None
586620
}
587621

622+
// TODO: reduce number of args
588623
fn extract_require_annot<T>(
589624
&self,
590625
local_def_id: LocalDefId,
591626
resolver: T,
592627
self_type_name: Option<String>,
628+
generic_args: mir_ty::GenericArgsRef<'tcx>,
593629
) -> Option<AnnotFormula<T::Output>>
594630
where
595631
T: Resolver<Output = rty::FunctionParamIdx>,
@@ -611,10 +647,16 @@ impl<'tcx> Analyzer<'tcx> {
611647
if let Some(formula_def_id) =
612648
self.extract_path_with_attr(local_def_id, &analyze::annot::requires_path_path())
613649
{
650+
let Some(formula_def_id) = formula_def_id.as_local() else {
651+
panic!(
652+
"require annotation with path is expected to refer to a local def, but found: {:?}",
653+
formula_def_id
654+
);
655+
};
614656
if require_annot.is_some() {
615657
unimplemented!();
616658
}
617-
let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else {
659+
let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else {
618660
panic!(
619661
"require annotation {:?} is not a formula function",
620662
formula_def_id
@@ -626,11 +668,13 @@ impl<'tcx> Analyzer<'tcx> {
626668
require_annot
627669
}
628670

671+
// TODO: reduce number of args
629672
fn extract_ensure_annot<T>(
630673
&self,
631674
local_def_id: LocalDefId,
632675
resolver: T,
633676
self_type_name: Option<String>,
677+
generic_args: mir_ty::GenericArgsRef<'tcx>,
634678
) -> Option<AnnotFormula<T::Output>>
635679
where
636680
T: Resolver<Output = rty::RefinedTypeVar<rty::FunctionParamIdx>>,
@@ -653,10 +697,16 @@ impl<'tcx> Analyzer<'tcx> {
653697
if let Some(formula_def_id) =
654698
self.extract_path_with_attr(local_def_id, &analyze::annot::ensures_path_path())
655699
{
700+
let Some(formula_def_id) = formula_def_id.as_local() else {
701+
panic!(
702+
"require annotation with path is expected to refer to a local def, but found: {:?}",
703+
formula_def_id
704+
);
705+
};
656706
if ensure_annot.is_some() {
657707
unimplemented!();
658708
}
659-
let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else {
709+
let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else {
660710
panic!(
661711
"ensure annotation {:?} is not a formula function",
662712
formula_def_id

src/analyze/annot_fn.rs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,30 +124,43 @@ pub struct AnnotFnTranslator<'tcx> {
124124

125125
typeck: &'tcx mir_ty::TypeckResults<'tcx>,
126126
body: &'tcx rustc_hir::Body<'tcx>,
127+
generic_args: mir_ty::GenericArgsRef<'tcx>,
127128

128129
def_ids: DefIdCache<'tcx>,
129130
env: HashMap<HirId, chc::Term<rty::FunctionParamIdx>>,
130131
}
131132

132133
impl<'tcx> AnnotFnTranslator<'tcx> {
133-
pub fn new(tcx: TyCtxt<'tcx>, def_ids: DefIdCache<'tcx>, local_def_id: LocalDefId) -> Self {
134+
pub fn new(tcx: TyCtxt<'tcx>, local_def_id: LocalDefId) -> Self {
134135
let map = tcx.hir();
135136
let body_id = map.body_owned_by(local_def_id);
136137
let body = map.body(body_id);
137-
138+
let generic_args = tcx.mk_args(&[]);
138139
let typeck = tcx.typeck(local_def_id);
140+
let def_ids = DefIdCache::new(tcx);
139141
let mut translator = Self {
140142
tcx,
141143
local_def_id,
142144
typeck,
143145
body,
146+
generic_args,
144147
def_ids,
145148
env: HashMap::default(),
146149
};
147150
translator.build_env_from_params();
148151
translator
149152
}
150153

154+
pub fn with_generic_args(mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> Self {
155+
self.generic_args = generic_args;
156+
self
157+
}
158+
159+
pub fn with_def_id_cache(mut self, def_ids: DefIdCache<'tcx>) -> Self {
160+
self.def_ids = def_ids;
161+
self
162+
}
163+
151164
fn build_env_from_params(&mut self) {
152165
for (idx, param) in self.body.params.iter().enumerate() {
153166
let param_idx = rty::FunctionParamIdx::from(idx);
@@ -177,12 +190,19 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
177190
}
178191
}
179192

193+
fn expr_ty(&self, expr: &'tcx rustc_hir::Expr<'tcx>) -> mir_ty::Ty<'tcx> {
194+
let ty = self.typeck.expr_ty(expr);
195+
let instantiated = mir_ty::EarlyBinder::bind(ty).instantiate(self.tcx, self.generic_args);
196+
let param_env = mir_ty::ParamEnv::reveal_all();
197+
self.tcx.normalize_erasing_regions(param_env, instantiated)
198+
}
199+
180200
pub fn to_formula_fn(&self) -> FormulaFn<'tcx> {
181201
let formula = self.to_formula(self.body.value);
182202
let params = self
183203
.tcx
184204
.fn_sig(self.local_def_id.to_def_id())
185-
.instantiate_identity()
205+
.instantiate(self.tcx, self.generic_args)
186206
.skip_binder()
187207
.inputs()
188208
.to_vec();
@@ -260,7 +280,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
260280
FormulaOrTerm::Term(operand.neg())
261281
}
262282
rustc_hir::UnOp::Not => {
263-
let operand_ty = self.typeck.expr_ty(operand);
283+
let operand_ty = self.expr_ty(operand);
264284
match operand_ty.ty_adt_def() {
265285
Some(adt) if Some(adt.did()) == self.def_ids.mut_model() => {
266286
let operand = self.to_term(operand);
@@ -273,7 +293,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
273293
}
274294
}
275295
rustc_hir::UnOp::Deref => {
276-
let operand_ty = self.typeck.expr_ty(operand);
296+
let operand_ty = self.expr_ty(operand);
277297
let adt = operand_ty
278298
.ty_adt_def()
279299
.expect("deref operand must be a model type");

src/analyze/crate_.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_hir::def_id::CRATE_DEF_ID;
66
use rustc_middle::ty::{self as mir_ty, TyCtxt};
77
use rustc_span::def_id::LocalDefId;
88

9-
use crate::analyze::{self, annot_fn::AnnotFnTranslator};
9+
use crate::analyze;
1010
use crate::chc;
1111
use crate::rty::{self, ClauseBuilderExt as _};
1212

@@ -95,10 +95,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
9595
}
9696

9797
if analyzer.is_annotated_as_formula_fn() {
98-
let formula_fn =
99-
AnnotFnTranslator::new(self.tcx, self.ctx.def_ids(), local_def_id).to_formula_fn();
100-
self.ctx
101-
.register_formula_fn(local_def_id.to_def_id(), formula_fn);
98+
self.ctx.register_formula_fn(local_def_id);
10299
self.skip_analysis.insert(local_def_id);
103100
return;
104101
}

src/analyze/local_def.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ pub struct Analyzer<'tcx, 'ctx> {
4848
local_def_id: LocalDefId,
4949

5050
body: Body<'tcx>,
51+
/// to substitute HIR types during translation in [`crate::analyze::annot_fn`]
52+
generic_args: mir_ty::GenericArgsRef<'tcx>,
5153
drop_points: HashMap<BasicBlock, analyze::basic_block::DropPoints>,
5254
type_builder: TypeBuilder<'tcx>,
5355
}
@@ -310,12 +312,14 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
310312
self.local_def_id,
311313
&param_resolver,
312314
self_type_name.clone(),
315+
self.generic_args,
313316
);
314317

315318
let mut ensure_annot = self.ctx.extract_ensure_annot(
316319
self.local_def_id,
317320
&result_param_resolver,
318321
self_type_name.clone(),
322+
self.generic_args,
319323
);
320324

321325
if let Some(trait_item_id) = self.trait_item_id() {
@@ -324,11 +328,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
324328
trait_item_id,
325329
&param_resolver,
326330
self_type_name.clone(),
331+
self.generic_args,
327332
);
328333
let trait_ensure_annot = self.ctx.extract_ensure_annot(
329334
trait_item_id,
330335
&result_param_resolver,
331336
self_type_name.clone(),
337+
self.generic_args,
332338
);
333339

334340
assert!(require_annot.is_none() || trait_require_annot.is_none());
@@ -851,17 +857,20 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
851857
let body = tcx.optimized_mir(local_def_id.to_def_id()).clone();
852858
let drop_points = Default::default();
853859
let type_builder = TypeBuilder::new(tcx, ctx.def_ids(), local_def_id.to_def_id());
860+
let generic_args = tcx.mk_args(&[]);
854861
Self {
855862
ctx,
856863
tcx,
857864
local_def_id,
858865
body,
866+
generic_args,
859867
drop_points,
860868
type_builder,
861869
}
862870
}
863871

864872
pub fn generic_args(&mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> &mut Self {
873+
self.generic_args = generic_args;
865874
self.body =
866875
mir_ty::EarlyBinder::bind(self.body.clone()).instantiate(self.tcx, generic_args);
867876
self
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//@error-in-other-file: Unsat
2+
3+
#[allow(unused_variables)]
4+
#[thrust::formula_fn]
5+
fn _thrust_requires_swap(x: thrust_models::model::Mut<i64>, y: i64) -> bool {
6+
true
7+
}
8+
9+
#[allow(unused_variables)]
10+
#[thrust::formula_fn]
11+
fn _thrust_ensures_swap(result: (), x: thrust_models::model::Mut<i64>, y: thrust_models::model::Mut<i64>) -> bool {
12+
*x == *y && *y == *x
13+
}
14+
15+
#[allow(path_statements)]
16+
fn swap<T>(x: &mut T, y: &mut T) {
17+
#[thrust::requires_path]
18+
_thrust_requires_swap;
19+
20+
#[thrust::ensures_path]
21+
_thrust_ensures_swap;
22+
23+
std::mem::swap(x, y)
24+
}
25+
26+
fn main() {
27+
let mut a = 1;
28+
let mut b = 2;
29+
swap(&mut a, &mut b);
30+
assert!(a == 2 && b == 1);
31+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//@check-pass
2+
3+
#[allow(unused_variables)]
4+
#[thrust::formula_fn]
5+
fn _thrust_requires_swap(x: thrust_models::model::Mut<i64>, y: i64) -> bool {
6+
true
7+
}
8+
9+
#[allow(unused_variables)]
10+
#[thrust::formula_fn]
11+
fn _thrust_ensures_swap(result: (), x: thrust_models::model::Mut<i64>, y: thrust_models::model::Mut<i64>) -> bool {
12+
!x == *y && !y == *x
13+
}
14+
15+
#[allow(path_statements)]
16+
fn swap<T>(x: &mut T, y: &mut T) {
17+
#[thrust::requires_path]
18+
_thrust_requires_swap;
19+
20+
#[thrust::ensures_path]
21+
_thrust_ensures_swap;
22+
23+
std::mem::swap(x, y)
24+
}
25+
26+
fn main() {
27+
let mut a = 1;
28+
let mut b = 2;
29+
swap(&mut a, &mut b);
30+
assert!(a == 2 && b == 1);
31+
}

0 commit comments

Comments
 (0)