Skip to content

Commit 2d78295

Browse files
committed
Enable to annotate functions with formula_fn
1 parent 183f13c commit 2d78295

4 files changed

Lines changed: 229 additions & 15 deletions

File tree

src/analyze.rs

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -472,20 +472,71 @@ impl<'tcx> Analyzer<'tcx> {
472472
self.fn_sig_with_body(def_id, body)
473473
}
474474

475+
fn extract_path_with_attr(
476+
&self,
477+
local_def_id: LocalDefId,
478+
attr_path: &[Symbol],
479+
) -> Option<DefId> {
480+
let map = self.tcx.hir();
481+
let Some(body_id) = map.maybe_body_owned_by(local_def_id) else {
482+
return None;
483+
};
484+
let body = map.body(body_id);
485+
486+
let rustc_hir::ExprKind::Block(block, _) = body.value.kind else {
487+
return None;
488+
};
489+
for stmt in block.stmts {
490+
if map
491+
.attrs(stmt.hir_id)
492+
.iter()
493+
.all(|attr| !attr.path_matches(attr_path))
494+
{
495+
continue;
496+
}
497+
let rustc_hir::StmtKind::Semi(expr) = stmt.kind else {
498+
panic!(
499+
"annotated path is expected to be a semi statement, but found: {:?}",
500+
stmt.kind
501+
);
502+
};
503+
let rustc_hir::ExprKind::Path(qpath) = expr.kind else {
504+
panic!(
505+
"annotated path is expected to be a path expression, but found: {:?}",
506+
expr.kind
507+
);
508+
};
509+
let rustc_hir::QPath::Resolved(_, path) = qpath else {
510+
panic!(
511+
"qpath is unexpectedly not resolved in annotated path: {:?}",
512+
qpath
513+
);
514+
};
515+
let rustc_hir::def::Res::Def(_, def_id) = path.res else {
516+
panic!(
517+
"annotated path is expected to refer to a def, but found: {:?}",
518+
path.res
519+
);
520+
};
521+
return Some(def_id);
522+
}
523+
None
524+
}
525+
475526
fn extract_require_annot<T>(
476527
&self,
477-
def_id: DefId,
528+
local_def_id: LocalDefId,
478529
resolver: T,
479530
self_type_name: Option<String>,
480531
) -> Option<AnnotFormula<T::Output>>
481532
where
482-
T: Resolver,
533+
T: Resolver<Output = rty::FunctionParamIdx>,
483534
{
484535
let mut require_annot = None;
485536
let parser = AnnotParser::new(&resolver, self_type_name);
486537
for attrs in self
487538
.tcx
488-
.get_attrs_by_path(def_id, &analyze::annot::requires_path())
539+
.get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::requires_path())
489540
{
490541
if require_annot.is_some() {
491542
unimplemented!();
@@ -494,23 +545,47 @@ impl<'tcx> Analyzer<'tcx> {
494545
let require = parser.parse_formula(ts).unwrap();
495546
require_annot = Some(require);
496547
}
548+
549+
if let Some(formula_def_id) =
550+
self.extract_path_with_attr(local_def_id, &analyze::annot::requires_path())
551+
{
552+
if require_annot.is_some() {
553+
unimplemented!();
554+
}
555+
let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else {
556+
panic!(
557+
"require annotation {:?} is not a formula function",
558+
formula_def_id
559+
);
560+
};
561+
let Some(annot) = formula_fn.to_require_annot(self.fn_sig(local_def_id.to_def_id()))
562+
else {
563+
panic!(
564+
"require annotation {:?} has incompatible signature with {:?}",
565+
formula_def_id, local_def_id
566+
);
567+
};
568+
require_annot = Some(annot);
569+
}
570+
497571
require_annot
498572
}
499573

500574
fn extract_ensure_annot<T>(
501575
&self,
502-
def_id: DefId,
576+
local_def_id: LocalDefId,
503577
resolver: T,
504578
self_type_name: Option<String>,
505579
) -> Option<AnnotFormula<T::Output>>
506580
where
507-
T: Resolver,
581+
T: Resolver<Output = rty::RefinedTypeVar<rty::FunctionParamIdx>>,
508582
{
509583
let mut ensure_annot = None;
584+
510585
let parser = AnnotParser::new(&resolver, self_type_name);
511586
for attrs in self
512587
.tcx
513-
.get_attrs_by_path(def_id, &analyze::annot::ensures_path())
588+
.get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::ensures_path())
514589
{
515590
if ensure_annot.is_some() {
516591
unimplemented!();
@@ -519,6 +594,29 @@ impl<'tcx> Analyzer<'tcx> {
519594
let ensure = parser.parse_formula(ts).unwrap();
520595
ensure_annot = Some(ensure);
521596
}
597+
598+
if let Some(formula_def_id) =
599+
self.extract_path_with_attr(local_def_id, &analyze::annot::ensures_path())
600+
{
601+
if ensure_annot.is_some() {
602+
unimplemented!();
603+
}
604+
let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else {
605+
panic!(
606+
"ensure annotation {:?} is not a formula function",
607+
formula_def_id
608+
);
609+
};
610+
let Some(annot) = formula_fn.to_ensure_annot(self.fn_sig(local_def_id.to_def_id()))
611+
else {
612+
panic!(
613+
"ensure annotation {:?} has incompatible signature with {:?}",
614+
formula_def_id, local_def_id
615+
);
616+
};
617+
ensure_annot = Some(annot);
618+
}
619+
522620
ensure_annot
523621
}
524622

src/analyze/local_def.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,22 +214,30 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
214214

215215
// TODO: unify this logic with extraction functions above
216216
pub fn is_fully_annotated(&self) -> bool {
217-
let has_require = self
217+
let has_requires = self
218218
.tcx
219219
.get_attrs_by_path(
220220
self.local_def_id.to_def_id(),
221221
&analyze::annot::requires_path(),
222222
)
223223
.next()
224-
.is_some();
225-
let has_ensure = self
224+
.is_some()
225+
|| self
226+
.ctx
227+
.extract_path_with_attr(self.local_def_id, &analyze::annot::requires_path())
228+
.is_some();
229+
let has_ensures = self
226230
.tcx
227231
.get_attrs_by_path(
228232
self.local_def_id.to_def_id(),
229233
&analyze::annot::ensures_path(),
230234
)
231235
.next()
232-
.is_some();
236+
.is_some()
237+
|| self
238+
.ctx
239+
.extract_path_with_attr(self.local_def_id, &analyze::annot::ensures_path())
240+
.is_some();
233241
let annotated_params: Vec<_> = self
234242
.tcx
235243
.get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::param_path())
@@ -250,7 +258,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
250258
.iter()
251259
.all(|ident| annotated_params.contains(ident));
252260
self.is_annotated_as_callable()
253-
|| (has_require && has_ensure)
261+
|| (has_requires && has_ensures)
254262
|| (all_params_annotated && has_ret)
255263
}
256264

@@ -289,26 +297,26 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
289297
let self_type_name = self.impl_type().map(|ty| ty.to_string());
290298

291299
let mut require_annot = self.ctx.extract_require_annot(
292-
self.local_def_id.to_def_id(),
300+
self.local_def_id,
293301
&param_resolver,
294302
self_type_name.clone(),
295303
);
296304

297305
let mut ensure_annot = self.ctx.extract_ensure_annot(
298-
self.local_def_id.to_def_id(),
306+
self.local_def_id,
299307
&result_param_resolver,
300308
self_type_name.clone(),
301309
);
302310

303311
if let Some(trait_item_id) = self.trait_item_id() {
304312
tracing::info!("trait item fonud: {:?}", trait_item_id);
305313
let trait_require_annot = self.ctx.extract_require_annot(
306-
trait_item_id.into(),
314+
trait_item_id,
307315
&param_resolver,
308316
self_type_name.clone(),
309317
);
310318
let trait_ensure_annot = self.ctx.extract_ensure_annot(
311-
trait_item_id.into(),
319+
trait_item_id,
312320
&result_param_resolver,
313321
self_type_name.clone(),
314322
);

tests/ui/fail/annot_formula_fn.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//@error-in-other-file: Unsat
2+
3+
#[thrust::formula_fn]
4+
fn _thrust_requires_rand_except(x: i64) -> bool {
5+
true
6+
}
7+
8+
#[thrust::formula_fn]
9+
fn _thrust_ensures_rand_except(result: i64, x: i64) -> bool {
10+
result != x
11+
}
12+
13+
fn rand_except(x: i64) -> i64 {
14+
#[thrust::requires]
15+
_thrust_requires_rand_except;
16+
#[thrust::ensures]
17+
_thrust_ensures_rand_except;
18+
19+
if x == 0 {
20+
1
21+
} else {
22+
0
23+
}
24+
}
25+
26+
#[thrust::formula_fn]
27+
fn _thrust_requires_f(x: i64) -> bool {
28+
true
29+
}
30+
31+
#[thrust::formula_fn]
32+
fn _thrust_ensures_f(result: i64, x: i64) -> bool {
33+
(result == 1) || (result == -1) && result == 0
34+
}
35+
36+
fn f(x: i64) -> i64 {
37+
#[thrust::requires]
38+
_thrust_requires_f;
39+
#[thrust::ensures]
40+
_thrust_ensures_f;
41+
42+
let y = rand_except(x);
43+
if y > x {
44+
1
45+
} else if y < x {
46+
-1
47+
} else {
48+
0
49+
}
50+
}
51+
52+
fn main() {
53+
assert!(rand_except(1) == 0);
54+
}

tests/ui/pass/annot_formula_fn.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//@check-pass
2+
3+
#[thrust::formula_fn]
4+
fn _thrust_requires_rand_except(x: i64) -> bool {
5+
true
6+
}
7+
8+
#[thrust::formula_fn]
9+
fn _thrust_ensures_rand_except(result: i64, x: i64) -> bool {
10+
result != x
11+
}
12+
13+
fn rand_except(x: i64) -> i64 {
14+
#[thrust::requires]
15+
_thrust_requires_rand_except;
16+
#[thrust::ensures]
17+
_thrust_ensures_rand_except;
18+
19+
if x == 0 {
20+
1
21+
} else {
22+
0
23+
}
24+
}
25+
26+
#[thrust::formula_fn]
27+
fn _thrust_requires_f(x: i64) -> bool {
28+
true
29+
}
30+
31+
#[thrust::formula_fn]
32+
fn _thrust_ensures_f(result: i64, x: i64) -> bool {
33+
(result == 1) || (result == -1)
34+
}
35+
36+
fn f(x: i64) -> i64 {
37+
#[thrust::requires]
38+
_thrust_requires_f;
39+
#[thrust::ensures]
40+
_thrust_ensures_f;
41+
42+
let y = rand_except(x);
43+
if y > x {
44+
1
45+
} else if y < x {
46+
-1
47+
} else {
48+
0
49+
}
50+
}
51+
52+
fn main() {
53+
assert!(rand_except(1) != 1);
54+
}

0 commit comments

Comments
 (0)