@@ -49,60 +49,68 @@ PartialValidity IncompleteSolver::computeValidity(const Query &query) {
4949/* **/
5050
5151StagedSolverImpl::StagedSolverImpl (std::unique_ptr<IncompleteSolver> primary,
52- std::unique_ptr<Solver> secondary)
53- : primary(std::move(primary)), secondary(std::move(secondary)) {}
52+ std::unique_ptr<Solver> secondary,
53+ QueryPredicate predicate)
54+ : primary(std::move(primary)), secondary(std::move(secondary)),
55+ predicate(predicate) {}
5456
5557bool StagedSolverImpl::computeTruth (const Query &query, bool &isValid) {
56- PartialValidity trueResult = primary->computeTruth (query);
58+ if (predicate (query)) {
59+ PartialValidity trueResult = primary->computeTruth (query);
5760
58- if (trueResult != PValidity::None) {
59- isValid = (trueResult == PValidity::MustBeTrue);
60- return true ;
61+ if (trueResult != PValidity::None) {
62+ isValid = (trueResult == PValidity::MustBeTrue);
63+ return true ;
64+ }
6165 }
6266
6367 return secondary->impl ->computeTruth (query, isValid);
6468}
6569
6670bool StagedSolverImpl::computeValidity (const Query &query,
6771 PartialValidity &result) {
68- bool tmp;
69-
70- switch (primary->computeValidity (query)) {
71- case PValidity::MustBeTrue:
72- result = PValidity::MustBeTrue;
73- break ;
74- case PValidity::MustBeFalse:
75- result = PValidity::MustBeFalse;
76- break ;
77- case PValidity::TrueOrFalse:
78- result = PValidity::TrueOrFalse;
79- break ;
80- case PValidity::MayBeTrue:
81- if (secondary->impl ->computeTruth (query, tmp)) {
82-
83- result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse;
84- } else {
85- result = PValidity::MayBeTrue;
86- }
87- break ;
88- case PValidity::MayBeFalse:
89- if (secondary->impl ->computeTruth (query.negateExpr (), tmp)) {
90- result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse;
91- } else {
92- result = PValidity::MayBeFalse;
72+ if (predicate (query)) {
73+ bool tmp;
74+
75+ switch (primary->computeValidity (query)) {
76+ case PValidity::MustBeTrue:
77+ result = PValidity::MustBeTrue;
78+ break ;
79+ case PValidity::MustBeFalse:
80+ result = PValidity::MustBeFalse;
81+ break ;
82+ case PValidity::TrueOrFalse:
83+ result = PValidity::TrueOrFalse;
84+ break ;
85+ case PValidity::MayBeTrue:
86+ if (secondary->impl ->computeTruth (query, tmp)) {
87+
88+ result = tmp ? PValidity::MustBeTrue : PValidity::TrueOrFalse;
89+ } else {
90+ result = PValidity::MayBeTrue;
91+ }
92+ break ;
93+ case PValidity::MayBeFalse:
94+ if (secondary->impl ->computeTruth (query.negateExpr (), tmp)) {
95+ result = tmp ? PValidity::MustBeFalse : PValidity::TrueOrFalse;
96+ } else {
97+ result = PValidity::MayBeFalse;
98+ }
99+ break ;
100+ default :
101+ if (!secondary->impl ->computeValidity (query, result))
102+ return false ;
103+ break ;
93104 }
94- break ;
95- default :
96- if (!secondary->impl ->computeValidity (query, result))
97- return false ;
98- break ;
105+ } else {
106+ return secondary->impl ->computeValidity (query, result);
99107 }
100108
101109 return true ;
102110}
103111
104112bool StagedSolverImpl::computeValue (const Query &query, ref<Expr> &result) {
105- if (primary->computeValue (query, result))
113+ if (predicate (query) && primary->computeValue (query, result))
106114 return true ;
107115
108116 return secondary->impl ->computeValue (query, result);
@@ -111,25 +119,28 @@ bool StagedSolverImpl::computeValue(const Query &query, ref<Expr> &result) {
111119bool StagedSolverImpl::computeInitialValues (
112120 const Query &query, const std::vector<const Array *> &objects,
113121 std::vector<SparseStorage<unsigned char >> &values, bool &hasSolution) {
114- if (primary->computeInitialValues (query, objects, values, hasSolution))
122+ if (predicate (query) &&
123+ primary->computeInitialValues (query, objects, values, hasSolution))
115124 return true ;
116125
117126 return secondary->impl ->computeInitialValues (query, objects, values,
118127 hasSolution);
119128}
120129
121130bool StagedSolverImpl::check (const Query &query, ref<SolverResponse> &result) {
122- std::vector<const Array *> objects;
123- findSymbolicObjects (query, objects);
124- std::vector<SparseStorage<unsigned char >> values;
125-
126- bool hasSolution;
127-
128- bool primaryResult =
129- primary->computeInitialValues (query, objects, values, hasSolution);
130- if (primaryResult && hasSolution) {
131- result = new InvalidResponse (objects, values);
132- return true ;
131+ if (predicate (query)) {
132+ std::vector<const Array *> objects;
133+ findSymbolicObjects (query, objects);
134+ std::vector<SparseStorage<unsigned char >> values;
135+
136+ bool hasSolution;
137+
138+ bool primaryResult =
139+ primary->computeInitialValues (query, objects, values, hasSolution);
140+ if (primaryResult && hasSolution) {
141+ result = new InvalidResponse (objects, values);
142+ return true ;
143+ }
133144 }
134145
135146 return secondary->impl ->check (query, result);
@@ -138,6 +149,14 @@ bool StagedSolverImpl::check(const Query &query, ref<SolverResponse> &result) {
138149bool StagedSolverImpl::computeValidityCore (const Query &query,
139150 ValidityCore &validityCore,
140151 bool &isValid) {
152+ if (predicate (query)) {
153+ PartialValidity trueResult = primary->computeTruth (query);
154+
155+ if (trueResult == PValidity::MayBeFalse) {
156+ isValid = false ;
157+ return true ;
158+ }
159+ }
141160 return secondary->impl ->computeValidityCore (query, validityCore, isValid);
142161}
143162
0 commit comments