Skip to content

Commit 135f11e

Browse files
authored
Shared trie (#24)
## Changes * Use shared trie for a range evaluation. * This uses more memory but saves a lot of time because the same evaluation paths ("highways") don't need to be evaluated multiple times. * To accommodate this case, `MergedWithAnotherInit` was added as a `ConclusionReason`. ## Comments * There is currently no way to find out which other init this one was merged into. This potentially can be added, but it will require a change in the file format, so I did not add it now. * A better approach might be to make this optimization invisible by automatically figuring out the values of all evaluation result components based on the information about the merge. I can work on this approach now, but I'm not sure I can finish it today (so it might have to wait until next week). ## Examples * The difference in performance can be seen with a disabled test: ```c++ TEST(PostTagSearcher, DISABLED_rangePerformance) { // With separate tries: 0.36 GB, 1133 seconds // With shared trie: 1.1 GB of RAM, 93 seconds, 12x speedup, 3x more memory use PostTagSearcher().evaluateRange(30, 0, 1000000, PostTagSearcher::EvaluationParameters()); } ```
1 parent 950557a commit 135f11e

8 files changed

Lines changed: 105 additions & 66 deletions

libPostTagSystem/CheckpointsTrie.cpp

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,61 +18,66 @@ class CheckpointsTrie::Implementation {
1818
MetadataMap metadataMap_;
1919
TrieNodes trieNodes_;
2020
Suffixes reverseSuffixes_; // suffixes are written in reverse order to optimize trie nodes extension
21+
std::vector<int> values_; // each suffix corresponds to one value, even if empty
2122

2223
public:
23-
void insert(const ChunkedState& state) {
24-
const uint16_t lastChunkSizePhaseIndex = 256 * state.lastChunkSize + state.phase;
24+
bool insert(const ChunkedState& key, const int value) {
25+
const uint16_t lastChunkSizePhaseIndex = 256 * key.lastChunkSize + key.phase;
2526

26-
const auto fixedChunkCountIt = metadataMap_.find(state.chunks.size());
27+
const auto fixedChunkCountIt = metadataMap_.find(key.chunks.size());
2728
if (fixedChunkCountIt == metadataMap_.end()) {
28-
reverseSuffixes_.push_back(std::vector<uint8_t>(state.chunks.rbegin(), state.chunks.rend()));
29+
reverseSuffixes_.push_back(std::vector<uint8_t>(key.chunks.rbegin(), key.chunks.rend()));
30+
values_.push_back(value);
2931
metadataMap_.insert(
30-
{state.chunks.size(), {{lastChunkSizePhaseIndex, fromReverseSuffixesIndex(reverseSuffixes_.size() - 1)}}});
31-
return;
32+
{key.chunks.size(), {{lastChunkSizePhaseIndex, fromReverseSuffixesIndex(reverseSuffixes_.size() - 1)}}});
33+
return true;
3234
}
3335

3436
const auto fixedMetadataIt = fixedChunkCountIt->second.find(lastChunkSizePhaseIndex);
3537
if (fixedMetadataIt == fixedChunkCountIt->second.end()) {
36-
reverseSuffixes_.push_back(std::vector<uint8_t>(state.chunks.rbegin(), state.chunks.rend()));
38+
reverseSuffixes_.push_back(std::vector<uint8_t>(key.chunks.rbegin(), key.chunks.rend()));
39+
values_.push_back(value);
3740
fixedChunkCountIt->second.insert(
3841
{lastChunkSizePhaseIndex, fromReverseSuffixesIndex(reverseSuffixes_.size() - 1)});
39-
return;
42+
return true;
4043
}
4144

42-
insertChunks(&fixedMetadataIt->second, state.chunks.begin(), state.chunks.end());
45+
return insertChunks(&fixedMetadataIt->second, key.chunks.begin(), key.chunks.end(), value);
4346
}
4447

45-
bool contains(const ChunkedState& state) const {
48+
std::optional<int> findValue(const ChunkedState& state) const {
4649
const auto fixedChunkCountIt = metadataMap_.find(state.chunks.size());
47-
if (fixedChunkCountIt == metadataMap_.end()) return false;
50+
if (fixedChunkCountIt == metadataMap_.end()) return std::nullopt;
4851
const auto fixedMetadataIt = fixedChunkCountIt->second.find(256 * state.lastChunkSize + state.phase);
49-
if (fixedMetadataIt == fixedChunkCountIt->second.end()) return false;
50-
return containsChunks(fixedMetadataIt->second, state.chunks.begin(), state.chunks.end());
52+
if (fixedMetadataIt == fixedChunkCountIt->second.end()) return std::nullopt;
53+
return findValueInChunks(fixedMetadataIt->second, state.chunks.begin(), state.chunks.end());
5154
}
5255

5356
private:
5457
using ChunksIterator = std::deque<uint8_t>::const_iterator;
55-
void insertChunks(int64_t* index, ChunksIterator chunksBegin, ChunksIterator chunksEnd) {
56-
if (chunksBegin == chunksEnd) return;
58+
bool insertChunks(int64_t* index, ChunksIterator chunksBegin, ChunksIterator chunksEnd, const int value) {
59+
if (chunksBegin == chunksEnd) return false; // it's a total match, don't insert a value
5760

5861
if (*index >= 0) {
59-
const auto nextChunkIt = trieNodes_[*index].find(*chunksBegin);
60-
if (nextChunkIt == trieNodes_[*index].end()) {
62+
const auto nextChunkIt = trieNodes_.at(*index).find(*chunksBegin);
63+
if (nextChunkIt == trieNodes_.at(*index).end()) {
6164
reverseSuffixes_.push_back(std::vector<uint8_t>(std::reverse_iterator<ChunksIterator>(chunksEnd),
6265
std::reverse_iterator<ChunksIterator>(chunksBegin) - 1));
63-
trieNodes_[*index].insert({*chunksBegin, fromReverseSuffixesIndex(reverseSuffixes_.size() - 1)});
66+
values_.push_back(value);
67+
trieNodes_.at(*index).insert({*chunksBegin, fromReverseSuffixesIndex(reverseSuffixes_.size() - 1)});
6468
} else {
65-
insertChunks(&nextChunkIt->second, chunksBegin + 1, chunksEnd);
69+
insertChunks(&nextChunkIt->second, chunksBegin + 1, chunksEnd, value);
6670
}
6771
} else {
6872
*index = pushChunk(*index);
69-
insertChunks(index, chunksBegin, chunksEnd);
73+
insertChunks(index, chunksBegin, chunksEnd, value);
7074
}
75+
return true;
7176
}
7277

7378
int64_t pushChunk(int64_t negativeIndex) {
74-
const auto firstValue = reverseSuffixes_[toReverseSuffixesIndex(negativeIndex)].back();
75-
reverseSuffixes_[toReverseSuffixesIndex(negativeIndex)].pop_back();
79+
const auto firstValue = reverseSuffixes_.at(toReverseSuffixesIndex(negativeIndex)).back();
80+
reverseSuffixes_.at(toReverseSuffixesIndex(negativeIndex)).pop_back();
7681
trieNodes_.push_back({{firstValue, negativeIndex}});
7782
return trieNodes_.size() - 1;
7883
}
@@ -81,30 +86,38 @@ class CheckpointsTrie::Implementation {
8186

8287
static inline int64_t fromReverseSuffixesIndex(int64_t positiveIndex) { return -(positiveIndex + 1); }
8388

84-
bool containsChunks(int64_t index, ChunksIterator chunksBegin, ChunksIterator chunksEnd) const {
85-
if (chunksBegin == chunksEnd) return true;
89+
std::optional<int> findValueInChunks(int64_t index, ChunksIterator chunksBegin, ChunksIterator chunksEnd) const {
90+
if (chunksBegin == chunksEnd) {
91+
return values_.at(toReverseSuffixesIndex(index));
92+
}
8693

8794
if (index >= 0) {
88-
const auto nextChunkIt = trieNodes_[index].find(*chunksBegin);
89-
if (nextChunkIt == trieNodes_[index].end()) {
90-
return false;
95+
const auto nextChunkIt = trieNodes_.at(index).find(*chunksBegin);
96+
if (nextChunkIt == trieNodes_.at(index).end()) {
97+
return std::nullopt;
9198
} else {
92-
return containsChunks(nextChunkIt->second, chunksBegin + 1, chunksEnd);
99+
return findValueInChunks(nextChunkIt->second, chunksBegin + 1, chunksEnd);
93100
}
94101
} else {
95102
const auto firstMismatch = std::mismatch(chunksBegin,
96103
chunksEnd,
97-
reverseSuffixes_[toReverseSuffixesIndex(index)].rbegin(),
98-
reverseSuffixes_[toReverseSuffixesIndex(index)].rend());
99-
return firstMismatch.first == chunksEnd;
104+
reverseSuffixes_.at(toReverseSuffixesIndex(index)).rbegin(),
105+
reverseSuffixes_.at(toReverseSuffixesIndex(index)).rend());
106+
if (firstMismatch.first == chunksEnd) {
107+
return values_.at(fromReverseSuffixesIndex(index));
108+
} else {
109+
return std::nullopt;
110+
}
100111
}
101112
}
102113
};
103114

104115
CheckpointsTrie::CheckpointsTrie() { implementation_ = std::make_shared<Implementation>(); }
105116

106-
void CheckpointsTrie::insert(const ChunkedState& state) { implementation_->insert(state); }
117+
bool CheckpointsTrie::insert(const ChunkedState& key, const int value) { return implementation_->insert(key, value); }
107118

108-
bool CheckpointsTrie::contains(const ChunkedState& state) const { return implementation_->contains(state); }
119+
std::optional<int> CheckpointsTrie::findValue(const ChunkedState& state) const {
120+
return implementation_->findValue(state);
121+
}
109122

110123
} // namespace PostTagSystem

libPostTagSystem/CheckpointsTrie.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
#define LIBPOSTTAGSYSTEM_CHECKPOINTSTRIE_HPP_
33

44
#include <memory>
5+
#include <optional>
56

67
#include "ChunkedState.hpp"
78

89
namespace PostTagSystem {
910
class CheckpointsTrie {
1011
public:
1112
CheckpointsTrie();
12-
void insert(const ChunkedState& state);
13-
bool contains(const ChunkedState& state) const;
13+
// returns false if the value already exists, in which case the old key will remain
14+
bool insert(const ChunkedState& key, int value);
15+
std::optional<int> findValue(const ChunkedState& state) const;
1416

1517
private:
1618
class Implementation;

libPostTagSystem/PostTagHistory.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class PostTagHistory::Implementation {
4444

4545
std::unordered_map<NamedRule, ChunkEvaluationTable> evaluationTables_;
4646

47+
static constexpr int explicitCheckpoint = -1;
48+
4749
public:
4850
Implementation() {}
4951

@@ -72,24 +74,25 @@ class PostTagHistory::Implementation {
7274
return std::vector<EvaluationResult>(
7375
inits.size(), {ConclusionReason::InvalidInput, {{}, std::numeric_limits<uint8_t>::max()}, 0, 0});
7476
}
75-
CheckpointsTrie explicitCheckpointsTrie;
77+
CheckpointsTrie checkpointsTrie;
7678
for (const auto& checkpoint : checkpointSpec.states) {
77-
explicitCheckpointsTrie.insert(toChunkedState(checkpoint));
79+
checkpointsTrie.insert(toChunkedState(checkpoint), explicitCheckpoint);
7880
}
7981

8082
std::vector<EvaluationResult> results;
8183
results.reserve(inits.size());
82-
for (const auto& init : inits) {
83-
auto chunkedState = toChunkedState(init);
84+
for (size_t initIndex = 0; initIndex < inits.size(); ++initIndex) {
85+
auto chunkedState = toChunkedState(inits[initIndex]);
8486
uint64_t maxIntermediateTapeLength = tapeLength(chunkedState);
8587
ConclusionReason conclusionReason;
8688
const auto eventCount = evaluate(chunkEvaluationTable,
8789
&chunkedState,
90+
initIndex,
8891
&conclusionReason,
8992
&maxIntermediateTapeLength,
9093
limits,
9194
endClock,
92-
explicitCheckpointsTrie,
95+
&checkpointsTrie,
9396
checkpointSpec.flags);
9497
results.push_back(
9598
{conclusionReason, fromChunkedStateDestructively(&chunkedState), eventCount, maxIntermediateTapeLength});
@@ -175,17 +178,17 @@ class PostTagHistory::Implementation {
175178

176179
static uint64_t evaluate(const ChunkEvaluationTable& evaluationTable,
177180
ChunkedState* state,
181+
const size_t index,
178182
ConclusionReason* conclusionReason,
179183
uint64_t* maxIntermediateTapeLength,
180184
const EvaluationLimits& limits,
181185
std::chrono::time_point<std::chrono::steady_clock> endClock,
182-
const CheckpointsTrie& explicitCheckpoints,
186+
CheckpointsTrie* checkpoints,
183187
const CheckpointSpecFlags& checkpointFlags) {
184188
if (std::chrono::steady_clock::now() > endClock) {
185189
*conclusionReason = ConclusionReason::NotEvaluated;
186190
return 0;
187191
}
188-
CheckpointsTrie automaticCheckpoints;
189192
uint64_t eventCount;
190193
constexpr int eventsPerClockCheck = 1000;
191194
for (eventCount = 0; eventCount < limits.maxEventCount && state->chunks.size() > 1;
@@ -198,16 +201,21 @@ class PostTagHistory::Implementation {
198201
*conclusionReason = ConclusionReason::MaxTapeLengthExceeded;
199202
return eventCount;
200203
}
201-
if (explicitCheckpoints.contains(*state)) {
202-
*conclusionReason = ConclusionReason::ReachedExplicitCheckpoint;
203-
return eventCount;
204-
}
205-
if (automaticCheckpoints.contains(*state)) {
206-
*conclusionReason = ConclusionReason::ReachedAutomaticCheckpoint;
207-
return eventCount;
204+
const auto foundCheckpoint = checkpoints->findValue(*state);
205+
if (foundCheckpoint.has_value()) {
206+
if (foundCheckpoint.value() == explicitCheckpoint) {
207+
*conclusionReason = ConclusionReason::ReachedExplicitCheckpoint;
208+
return eventCount;
209+
} else if (foundCheckpoint.value() == static_cast<int>(index)) {
210+
*conclusionReason = ConclusionReason::ReachedAutomaticCheckpoint;
211+
return eventCount;
212+
} else {
213+
*conclusionReason = ConclusionReason::ReachedPreviousInitCheckpoint;
214+
return eventCount;
215+
}
208216
}
209217
if (checkpointFlags.powerOfTwoEventCounts && !isPowerOfTwo(eventCount)) {
210-
automaticCheckpoints.insert(*state);
218+
checkpoints->insert(*state, static_cast<int>(index));
211219
}
212220
evaluateOnce(evaluationTable, state);
213221
*maxIntermediateTapeLength = std::max(*maxIntermediateTapeLength, tapeLength(*state));

libPostTagSystem/PostTagHistory.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class PostTagHistory {
1515
Terminated,
1616
ReachedExplicitCheckpoint,
1717
ReachedAutomaticCheckpoint,
18+
ReachedPreviousInitCheckpoint,
1819
MaxEventCountExceeded,
1920
MaxTapeLengthExceeded,
2021
TimeConstraintExceeded,

libPostTagSystem/PostTagSearcher.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ class PostTagSearcher::Implementation {
7777
result.conclusionReason = ConclusionReason::NotEvaluated;
7878
break;
7979

80+
case PostTagHistory::ConclusionReason::ReachedPreviousInitCheckpoint:
81+
result.conclusionReason = ConclusionReason::MergedWithAnotherInit;
82+
break;
83+
8084
default:
8185
result.conclusionReason = ConclusionReason::InvalidInput;
8286
}

libPostTagSystem/PostTagSearcher.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class PostTagSearcher {
2020
MaxTapeLengthExceeded,
2121
MaxEventCountExceeded,
2222
TimeConstraintExceeded,
23-
NotEvaluated
23+
NotEvaluated,
24+
MergedWithAnotherInit
2425
};
2526

2627
struct EvaluationResult {

libPostTagSystem/test/CheckpointsTrie_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ void checkStateInsertion(const std::vector<ChunkedState>& insertedStates,
1010
const std::vector<ChunkedState>& missingStates) {
1111
CheckpointsTrie trie;
1212
for (auto insertionIt = insertedStates.begin(); insertionIt != insertedStates.end(); ++insertionIt) {
13-
trie.insert(*insertionIt);
13+
trie.insert(*insertionIt, static_cast<int>(insertionIt - insertedStates.begin()));
1414
for (auto checkIt = insertedStates.begin(); checkIt != insertionIt + 1; ++checkIt) {
15-
ASSERT_TRUE(trie.contains(*checkIt));
15+
ASSERT_EQ(trie.findValue(*checkIt).value(), static_cast<int>(checkIt - insertedStates.begin()));
1616
}
1717
}
1818

1919
for (const auto& state : missingStates) {
20-
ASSERT_FALSE(trie.contains(state));
20+
ASSERT_EQ(trie.findValue(state), std::nullopt);
2121
}
2222
}
2323
} // namespace

libPostTagSystem/test/PostTagSearcher_test.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,24 @@ void compareResults(const TagState& init,
3535

3636
if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::InvalidInput) {
3737
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::InvalidInput);
38-
} else if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::Terminated) {
39-
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::Terminated);
40-
} else if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::ReachedAutomaticCheckpoint) {
41-
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::ReachedCycle);
42-
} else if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::ReachedExplicitCheckpoint) {
43-
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::ReachedKnownCheckpoint);
44-
} else if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::MaxEventCountExceeded) {
45-
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::MaxEventCountExceeded);
38+
} else if (result.conclusionReason != PostTagSearcher::ConclusionReason::MergedWithAnotherInit) {
39+
if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::Terminated) {
40+
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::Terminated);
41+
} else if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::ReachedAutomaticCheckpoint) {
42+
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::ReachedCycle);
43+
} else if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::ReachedExplicitCheckpoint) {
44+
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::ReachedKnownCheckpoint);
45+
} else if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::MaxEventCountExceeded) {
46+
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::MaxEventCountExceeded);
47+
} else if (singleResult.conclusionReason == PostTagHistory::ConclusionReason::MaxTapeLengthExceeded) {
48+
ASSERT_EQ(result.conclusionReason, PostTagSearcher::ConclusionReason::MaxTapeLengthExceeded);
49+
}
50+
ASSERT_EQ(result.finalState, singleResult.finalState);
51+
ASSERT_EQ(result.eventCount, singleResult.eventCount);
52+
ASSERT_EQ(result.maxTapeLength, singleResult.maxIntermediateTapeLength);
53+
ASSERT_EQ(result.finalTapeLength, singleResult.finalState.tape.size());
4654
}
4755

48-
ASSERT_EQ(result.finalState, singleResult.finalState);
49-
ASSERT_EQ(result.eventCount, singleResult.eventCount);
50-
ASSERT_EQ(result.maxTapeLength, singleResult.maxIntermediateTapeLength);
51-
ASSERT_EQ(result.finalTapeLength, singleResult.finalState.tape.size());
5256
ASSERT_EQ(result.initialState, init);
5357
}
5458

@@ -149,4 +153,10 @@ TEST(PostTagSearcher, smallTimeConstraint) {
149153
ASSERT_EQ(result.size(), 3);
150154
ASSERT_EQ(result[0].conclusionReason, PostTagSearcher::ConclusionReason::TimeConstraintExceeded);
151155
}
156+
157+
TEST(PostTagSearcher, DISABLED_rangePerformance) {
158+
// With separate tries: 0.36 GB, 1133 seconds
159+
// With shared trie: 1.1 GB of RAM, 93 seconds, 12x speedup, 3x more memory use
160+
PostTagSearcher().evaluateRange(30, 0, 1000000, PostTagSearcher::EvaluationParameters());
161+
}
152162
} // namespace PostTagSystem

0 commit comments

Comments
 (0)