Skip to content

Commit 362cec2

Browse files
committed
Support enum construction in formula_fn
1 parent bec3465 commit 362cec2

3 files changed

Lines changed: 132 additions & 15 deletions

File tree

src/analyze/annot_fn.rs

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use rustc_middle::ty::{self as mir_ty, TyCtxt};
88
use crate::analyze::did_cache::DefIdCache;
99
use crate::annot::AnnotFormula;
1010
use crate::chc;
11+
use crate::refine::TypeBuilder;
1112
use crate::rty;
1213

1314
#[derive(Debug, Clone)]
@@ -127,6 +128,7 @@ pub struct AnnotFnTranslator<'tcx> {
127128
generic_args: mir_ty::GenericArgsRef<'tcx>,
128129

129130
def_ids: DefIdCache<'tcx>,
131+
type_builder: TypeBuilder<'tcx>,
130132
env: HashMap<HirId, chc::Term<rty::FunctionParamIdx>>,
131133
}
132134

@@ -138,13 +140,15 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
138140
let generic_args = tcx.mk_args(&[]);
139141
let typeck = tcx.typeck(local_def_id);
140142
let def_ids = DefIdCache::new(tcx);
143+
let type_builder = TypeBuilder::new(tcx, def_ids.clone(), local_def_id.to_def_id());
141144
let mut translator = Self {
142145
tcx,
143146
local_def_id,
144147
typeck,
145148
body,
146149
generic_args,
147150
def_ids,
151+
type_builder,
148152
env: HashMap::default(),
149153
};
150154
translator.build_env_from_params();
@@ -224,6 +228,29 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
224228
.expect("expected a term")
225229
}
226230

231+
fn variant_ctor_term(
232+
&self,
233+
ctor_did: rustc_span::def_id::DefId,
234+
result_expr: &'tcx rustc_hir::Expr<'tcx>,
235+
field_terms: Vec<chc::Term<rty::FunctionParamIdx>>,
236+
) -> chc::Term<rty::FunctionParamIdx> {
237+
let variant_did = self.tcx.parent(ctor_did);
238+
let adt_did = self.tcx.parent(variant_did);
239+
let d_sym = crate::refine::datatype_symbol(self.tcx, adt_did);
240+
let variant_name = self.tcx.item_name(variant_did);
241+
let v_sym = chc::DatatypeSymbol::new(format!("{}.{}", d_sym, variant_name));
242+
let result_ty = self.expr_ty(result_expr);
243+
let sort_args = if let mir_ty::TyKind::Adt(_, generic_args) = result_ty.kind() {
244+
generic_args
245+
.types()
246+
.map(|ty| self.type_builder.build(ty).to_sort())
247+
.collect()
248+
} else {
249+
vec![]
250+
};
251+
chc::Term::datatype_ctor(d_sym, sort_args, v_sym, field_terms)
252+
}
253+
227254
fn to_formula_or_term(
228255
&self,
229256
hir: &'tcx rustc_hir::Expr<'tcx>,
@@ -319,20 +346,19 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
319346
rustc_ast::LitKind::Bool(b) => FormulaOrTerm::Literal(b),
320347
_ => unimplemented!("unsupported literal in formula: {:?}", lit),
321348
},
322-
ExprKind::Path(qpath) => {
323-
if let rustc_hir::def::Res::Local(hir_id) =
324-
self.typeck.qpath_res(&qpath, hir.hir_id)
325-
{
326-
FormulaOrTerm::Term(
327-
self.env
328-
.get(&hir_id)
329-
.expect("unbound variable in formula")
330-
.clone(),
331-
)
332-
} else {
333-
unimplemented!("unsupported path in formula: {:?}", qpath);
334-
}
335-
}
349+
ExprKind::Path(qpath) => match self.typeck.qpath_res(&qpath, hir.hir_id) {
350+
rustc_hir::def::Res::Local(hir_id) => FormulaOrTerm::Term(
351+
self.env
352+
.get(&hir_id)
353+
.expect("unbound variable in formula")
354+
.clone(),
355+
),
356+
rustc_hir::def::Res::Def(
357+
rustc_hir::def::DefKind::Ctor(rustc_hir::def::CtorOf::Variant, _),
358+
ctor_did,
359+
) => FormulaOrTerm::Term(self.variant_ctor_term(ctor_did, hir, vec![])),
360+
_ => unimplemented!("unsupported path in formula: {:?}", qpath),
361+
},
336362
ExprKind::Tup(exprs) => {
337363
let terms = exprs.iter().map(|e| self.to_term(e)).collect();
338364
FormulaOrTerm::Term(chc::Term::tuple(terms))
@@ -349,7 +375,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
349375
ExprKind::Call(func_expr, args) => {
350376
if let ExprKind::Path(qpath) = &func_expr.kind {
351377
let res = self.typeck.qpath_res(qpath, func_expr.hir_id);
352-
if let rustc_hir::def::Res::Def(_, def_id) = res {
378+
if let rustc_hir::def::Res::Def(def_kind, def_id) = res {
353379
if Some(def_id) == self.def_ids.mut_model_new() {
354380
assert_eq!(args.len(), 2, "Mut::new takes exactly 2 arguments");
355381
let t1 = self.to_term(&args[0]);
@@ -361,6 +387,17 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
361387
let t = self.to_term(&args[0]);
362388
return FormulaOrTerm::Term(chc::Term::box_(t));
363389
}
390+
if matches!(
391+
def_kind,
392+
rustc_hir::def::DefKind::Ctor(rustc_hir::def::CtorOf::Variant, _)
393+
) {
394+
let field_terms = args.iter().map(|arg| self.to_term(arg)).collect();
395+
return FormulaOrTerm::Term(self.variant_ctor_term(
396+
def_id,
397+
hir,
398+
field_terms,
399+
));
400+
}
364401
}
365402
}
366403
unimplemented!("unsupported call in formula: {:?}", func_expr)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//@error-in-other-file: Unsat
2+
3+
#[derive(PartialEq)]
4+
pub enum X {
5+
A(i64),
6+
B(bool),
7+
}
8+
9+
impl thrust_models::Model for X {
10+
type Ty = X;
11+
}
12+
13+
#[thrust::formula_fn]
14+
fn _thrust_requires_test(x: X) -> bool {
15+
x == X::A(1)
16+
}
17+
18+
#[thrust::formula_fn]
19+
fn _thrust_ensures_test(_result: (), _x: X) -> bool {
20+
true
21+
}
22+
23+
#[allow(path_statements)]
24+
fn test(x: X) {
25+
#[thrust::requires_path]
26+
_thrust_requires_test;
27+
28+
#[thrust::ensures_path]
29+
_thrust_ensures_test;
30+
31+
if let X::A(i) = x {
32+
assert!(i == 2);
33+
} else {
34+
loop {}
35+
}
36+
}
37+
38+
fn main() {
39+
test(X::A(1));
40+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//@check-pass
2+
3+
#[derive(PartialEq)]
4+
pub enum X {
5+
A(i64),
6+
B(bool),
7+
}
8+
9+
impl thrust_models::Model for X {
10+
type Ty = X;
11+
}
12+
13+
#[thrust::formula_fn]
14+
fn _thrust_requires_test(x: X) -> bool {
15+
x == X::A(1)
16+
}
17+
18+
#[thrust::formula_fn]
19+
fn _thrust_ensures_test(_result: (), _x: X) -> bool {
20+
true
21+
}
22+
23+
#[allow(path_statements)]
24+
fn test(x: X) {
25+
#[thrust::requires_path]
26+
_thrust_requires_test;
27+
28+
#[thrust::ensures_path]
29+
_thrust_ensures_test;
30+
31+
if let X::A(i) = x {
32+
assert!(i == 1);
33+
} else {
34+
loop {}
35+
}
36+
}
37+
38+
fn main() {
39+
test(X::A(1));
40+
}

0 commit comments

Comments
 (0)