Skip to content

Commit 9bce62e

Browse files
authored
Merge pull request #64 from coord-e/coord-e/formula-fns
Support enum/existentials in formula_fn
2 parents bec3465 + 889b84e commit 9bce62e

8 files changed

Lines changed: 264 additions & 14 deletions

File tree

src/analyze/annot.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ pub fn box_model_new_path() -> [Symbol; 3] {
125125
]
126126
}
127127

128+
pub fn exists_path() -> [Symbol; 3] {
129+
[
130+
Symbol::intern("thrust"),
131+
Symbol::intern("def"),
132+
Symbol::intern("exists"),
133+
]
134+
}
135+
128136
/// A [`annot::Resolver`] implementation for resolving function parameters.
129137
///
130138
/// The parameter names and their sorts needs to be configured via

src/analyze/annot_fn.rs

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use rustc_middle::ty::{self as mir_ty, TyCtxt};
88
use crate::analyze::did_cache::DefIdCache;
99
use crate::annot::AnnotFormula;
1010
use crate::chc;
11+
use crate::refine::TypeBuilder;
1112
use crate::rty;
1213

1314
#[derive(Debug, Clone)]
@@ -118,6 +119,7 @@ impl<T> FormulaOrTerm<T> {
118119
}
119120
}
120121

122+
#[derive(Clone)]
121123
pub struct AnnotFnTranslator<'tcx> {
122124
tcx: TyCtxt<'tcx>,
123125
local_def_id: LocalDefId,
@@ -127,6 +129,7 @@ pub struct AnnotFnTranslator<'tcx> {
127129
generic_args: mir_ty::GenericArgsRef<'tcx>,
128130

129131
def_ids: DefIdCache<'tcx>,
132+
type_builder: TypeBuilder<'tcx>,
130133
env: HashMap<HirId, chc::Term<rty::FunctionParamIdx>>,
131134
}
132135

@@ -138,13 +141,15 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
138141
let generic_args = tcx.mk_args(&[]);
139142
let typeck = tcx.typeck(local_def_id);
140143
let def_ids = DefIdCache::new(tcx);
144+
let type_builder = TypeBuilder::new(tcx, def_ids.clone(), local_def_id.to_def_id());
141145
let mut translator = Self {
142146
tcx,
143147
local_def_id,
144148
typeck,
145149
body,
146150
generic_args,
147151
def_ids,
152+
type_builder,
148153
env: HashMap::default(),
149154
};
150155
translator.build_env_from_params();
@@ -158,6 +163,11 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
158163

159164
pub fn with_def_id_cache(mut self, def_ids: DefIdCache<'tcx>) -> Self {
160165
self.def_ids = def_ids;
166+
self.type_builder = TypeBuilder::new(
167+
self.tcx,
168+
self.def_ids.clone(),
169+
self.local_def_id.to_def_id(),
170+
);
161171
self
162172
}
163173

@@ -197,6 +207,13 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
197207
self.tcx.normalize_erasing_regions(param_env, instantiated)
198208
}
199209

210+
fn pat_ty(&self, pat: &'tcx rustc_hir::Pat<'tcx>) -> mir_ty::Ty<'tcx> {
211+
let ty = self.typeck.pat_ty(pat);
212+
let instantiated = mir_ty::EarlyBinder::bind(ty).instantiate(self.tcx, self.generic_args);
213+
let param_env = mir_ty::ParamEnv::reveal_all();
214+
self.tcx.normalize_erasing_regions(param_env, instantiated)
215+
}
216+
200217
pub fn to_formula_fn(&self) -> FormulaFn<'tcx> {
201218
let formula = self.to_formula(self.body.value);
202219
let params = self
@@ -224,6 +241,28 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
224241
.expect("expected a term")
225242
}
226243

244+
fn variant_ctor_term(
245+
&self,
246+
ctor_did: rustc_span::def_id::DefId,
247+
result_ty: mir_ty::Ty<'tcx>,
248+
field_terms: Vec<chc::Term<rty::FunctionParamIdx>>,
249+
) -> chc::Term<rty::FunctionParamIdx> {
250+
let variant_did = self.tcx.parent(ctor_did);
251+
let adt_did = self.tcx.parent(variant_did);
252+
let d_sym = crate::refine::datatype_symbol(self.tcx, adt_did);
253+
let variant_name = self.tcx.item_name(variant_did);
254+
let v_sym = chc::DatatypeSymbol::new(format!("{}.{}", d_sym, variant_name));
255+
let sort_args = if let mir_ty::TyKind::Adt(_, generic_args) = result_ty.kind() {
256+
generic_args
257+
.types()
258+
.map(|ty| self.type_builder.build(ty).to_sort())
259+
.collect()
260+
} else {
261+
panic!("expected an ADT type for variant constructor")
262+
};
263+
chc::Term::datatype_ctor(d_sym, sort_args, v_sym, field_terms)
264+
}
265+
227266
fn to_formula_or_term(
228267
&self,
229268
hir: &'tcx rustc_hir::Expr<'tcx>,
@@ -319,20 +358,21 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
319358
rustc_ast::LitKind::Bool(b) => FormulaOrTerm::Literal(b),
320359
_ => unimplemented!("unsupported literal in formula: {:?}", lit),
321360
},
322-
ExprKind::Path(qpath) => {
323-
if let rustc_hir::def::Res::Local(hir_id) =
324-
self.typeck.qpath_res(&qpath, hir.hir_id)
325-
{
326-
FormulaOrTerm::Term(
327-
self.env
328-
.get(&hir_id)
329-
.expect("unbound variable in formula")
330-
.clone(),
331-
)
332-
} else {
333-
unimplemented!("unsupported path in formula: {:?}", qpath);
361+
ExprKind::Path(qpath) => match self.typeck.qpath_res(&qpath, hir.hir_id) {
362+
rustc_hir::def::Res::Local(hir_id) => FormulaOrTerm::Term(
363+
self.env
364+
.get(&hir_id)
365+
.expect("unbound variable in formula")
366+
.clone(),
367+
),
368+
rustc_hir::def::Res::Def(
369+
rustc_hir::def::DefKind::Ctor(rustc_hir::def::CtorOf::Variant, _),
370+
ctor_did,
371+
) => {
372+
FormulaOrTerm::Term(self.variant_ctor_term(ctor_did, self.expr_ty(hir), vec![]))
334373
}
335-
}
374+
_ => unimplemented!("unsupported path in formula: {:?}", qpath),
375+
},
336376
ExprKind::Tup(exprs) => {
337377
let terms = exprs.iter().map(|e| self.to_term(e)).collect();
338378
FormulaOrTerm::Term(chc::Term::tuple(terms))
@@ -349,7 +389,40 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
349389
ExprKind::Call(func_expr, args) => {
350390
if let ExprKind::Path(qpath) = &func_expr.kind {
351391
let res = self.typeck.qpath_res(qpath, func_expr.hir_id);
352-
if let rustc_hir::def::Res::Def(_, def_id) = res {
392+
if let rustc_hir::def::Res::Def(def_kind, def_id) = res {
393+
if Some(def_id) == self.def_ids.exists() {
394+
assert_eq!(args.len(), 1, "exists takes exactly 1 argument");
395+
let ExprKind::Closure(closure) = args[0].kind else {
396+
panic!("exists argument must be a closure");
397+
};
398+
let closure_body = self.tcx.hir().body(closure.body);
399+
400+
let mut inner_translator = self.clone();
401+
let mut vars = Vec::new();
402+
for param in closure_body.params {
403+
let rustc_hir::PatKind::Binding(_, hir_id, ident, None) =
404+
param.pat.kind
405+
else {
406+
panic!(
407+
"exists closure parameter must be a simple binding: {:?}",
408+
param.pat
409+
);
410+
};
411+
let param_ty = self.pat_ty(param.pat);
412+
let sort = self.type_builder.build(param_ty).to_sort();
413+
let var_term = chc::Term::FormulaExistentialVar(
414+
sort.clone(),
415+
ident.name.to_string(),
416+
);
417+
inner_translator.env.insert(hir_id, var_term);
418+
vars.push((ident.name.to_string(), sort));
419+
}
420+
let body_formula = inner_translator.to_formula(closure_body.value);
421+
return FormulaOrTerm::Formula(chc::Formula::exists(
422+
vars,
423+
body_formula,
424+
));
425+
}
353426
if Some(def_id) == self.def_ids.mut_model_new() {
354427
assert_eq!(args.len(), 2, "Mut::new takes exactly 2 arguments");
355428
let t1 = self.to_term(&args[0]);
@@ -361,6 +434,17 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
361434
let t = self.to_term(&args[0]);
362435
return FormulaOrTerm::Term(chc::Term::box_(t));
363436
}
437+
if matches!(
438+
def_kind,
439+
rustc_hir::def::DefKind::Ctor(rustc_hir::def::CtorOf::Variant, _)
440+
) {
441+
let field_terms = args.iter().map(|arg| self.to_term(arg)).collect();
442+
return FormulaOrTerm::Term(self.variant_ctor_term(
443+
def_id,
444+
self.expr_ty(hir),
445+
field_terms,
446+
));
447+
}
364448
}
365449
}
366450
unimplemented!("unsupported call in formula: {:?}", func_expr)

src/analyze/did_cache.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ struct DefIds {
2222

2323
mut_model_new: OnceCell<Option<DefId>>,
2424
box_model_new: OnceCell<Option<DefId>>,
25+
26+
exists: OnceCell<Option<DefId>>,
2527
}
2628

2729
/// Retrieves and caches well-known [`DefId`]s.
@@ -160,4 +162,11 @@ impl<'tcx> DefIdCache<'tcx> {
160162
.box_model_new
161163
.get_or_init(|| self.annotated_def(&crate::analyze::annot::box_model_new_path()))
162164
}
165+
166+
pub fn exists(&self) -> Option<DefId> {
167+
*self
168+
.def_ids
169+
.exists
170+
.get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path()))
171+
}
163172
}

std.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,13 @@ mod thrust_models {
234234
impl<T, E> Model for Result<T, E> where T: Model, E: Model {
235235
type Ty = Result<<T as Model>::Ty, <E as Model>::Ty>;
236236
}
237+
238+
#[allow(dead_code)]
239+
#[thrust::def::exists]
240+
#[thrust::ignored]
241+
pub fn exists<T>(_x: T) -> bool {
242+
unimplemented!()
243+
}
237244
}
238245

239246
#[thrust::extern_spec_fn]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//@error-in-other-file: Unsat
2+
3+
#[derive(PartialEq)]
4+
pub enum X {
5+
A(i64),
6+
B(bool),
7+
}
8+
9+
impl thrust_models::Model for X {
10+
type Ty = X;
11+
}
12+
13+
#[thrust::formula_fn]
14+
fn _thrust_requires_test(x: X) -> bool {
15+
x == X::A(1)
16+
}
17+
18+
#[thrust::formula_fn]
19+
fn _thrust_ensures_test(_result: (), _x: X) -> bool {
20+
true
21+
}
22+
23+
#[allow(path_statements)]
24+
fn test(x: X) {
25+
#[thrust::requires_path]
26+
_thrust_requires_test;
27+
28+
#[thrust::ensures_path]
29+
_thrust_ensures_test;
30+
31+
if let X::A(i) = x {
32+
assert!(i == 2);
33+
} else {
34+
loop {}
35+
}
36+
}
37+
38+
fn main() {
39+
test(X::A(1));
40+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper
4+
5+
#[thrust::trusted]
6+
#[thrust::callable]
7+
fn rand() -> i32 { unimplemented!() }
8+
9+
#[thrust::formula_fn]
10+
fn _thrust_requires_f() -> bool {
11+
true
12+
}
13+
14+
#[thrust::formula_fn]
15+
fn _thrust_ensures_f(result: i32) -> bool {
16+
thrust_models::exists(|x: i32| result == 2 * x)
17+
}
18+
19+
#[allow(path_statements)]
20+
fn f() -> i32 {
21+
#[thrust::requires_path]
22+
_thrust_requires_f;
23+
24+
#[thrust::ensures_path]
25+
_thrust_ensures_f;
26+
27+
let x = rand();
28+
x + x + x
29+
}
30+
31+
fn main() {}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//@check-pass
2+
3+
#[derive(PartialEq)]
4+
pub enum X {
5+
A(i64),
6+
B(bool),
7+
}
8+
9+
impl thrust_models::Model for X {
10+
type Ty = X;
11+
}
12+
13+
#[thrust::formula_fn]
14+
fn _thrust_requires_test(x: X) -> bool {
15+
x == X::A(1)
16+
}
17+
18+
#[thrust::formula_fn]
19+
fn _thrust_ensures_test(_result: (), _x: X) -> bool {
20+
true
21+
}
22+
23+
#[allow(path_statements)]
24+
fn test(x: X) {
25+
#[thrust::requires_path]
26+
_thrust_requires_test;
27+
28+
#[thrust::ensures_path]
29+
_thrust_ensures_test;
30+
31+
if let X::A(i) = x {
32+
assert!(i == 1);
33+
} else {
34+
loop {}
35+
}
36+
}
37+
38+
fn main() {
39+
test(X::A(1));
40+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//@check-pass
2+
//@compile-flags: -C debug-assertions=off
3+
//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper
4+
5+
#[thrust::trusted]
6+
#[thrust::callable]
7+
fn rand() -> i32 { unimplemented!() }
8+
9+
#[thrust::formula_fn]
10+
fn _thrust_requires_f() -> bool {
11+
true
12+
}
13+
14+
#[thrust::formula_fn]
15+
fn _thrust_ensures_f(result: i32) -> bool {
16+
thrust_models::exists(|x: i32| result == 2 * x)
17+
}
18+
19+
#[allow(path_statements)]
20+
fn f() -> i32 {
21+
#[thrust::requires_path]
22+
_thrust_requires_f;
23+
24+
#[thrust::ensures_path]
25+
_thrust_ensures_f;
26+
27+
let x = rand();
28+
x + x
29+
}
30+
31+
fn main() {}

0 commit comments

Comments
 (0)