Skip to content

Commit b131cd2

Browse files
Add visit-like pattern for RecordComponent (#1544)
* Fixes for the datatype macros * Introduce std::variant for dataset, attribute and non-vector types Also use them to erase some repetition * Add switchDatasetType Also refactor the switchType functions to use the datatype macros * Main implementation: Add variant-based loadChunk API * Testing, examples
1 parent 2e89f87 commit b131cd2

13 files changed

Lines changed: 321 additions & 219 deletions

examples/10_streaming_read.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using namespace openPMD;
1111
int main()
1212
{
1313
#if openPMD_HAVE_ADIOS2
14-
using position_t = double;
1514
auto backends = openPMD::getFileExtensions();
1615
if (std::find(backends.begin(), backends.end(), "sst") == backends.end())
1716
{
@@ -40,15 +39,15 @@ int main()
4039
std::cout << "Current iteration: " << iteration.iterationIndex
4140
<< std::endl;
4241
Record electronPositions = iteration.particles["e"]["position"];
43-
std::array<std::shared_ptr<position_t>, 3> loadedChunks;
42+
std::array<RecordComponent::shared_ptr_dataset_types, 3> loadedChunks;
4443
std::array<Extent, 3> extents;
4544
std::array<std::string, 3> const dimensions{{"x", "y", "z"}};
4645

4746
for (size_t i = 0; i < 3; ++i)
4847
{
4948
std::string const &dim = dimensions[i];
5049
RecordComponent rc = electronPositions[dim];
51-
loadedChunks[i] = rc.loadChunk<position_t>(
50+
loadedChunks[i] = rc.loadChunkVariant(
5251
Offset(rc.getDimensionality(), 0), rc.getExtent());
5352
extents[i] = rc.getExtent();
5453
}
@@ -64,10 +63,14 @@ int main()
6463
Extent const &extent = extents[i];
6564
std::cout << "\ndim: " << dim << "\n" << std::endl;
6665
auto chunk = loadedChunks[i];
67-
for (size_t j = 0; j < extent[0]; ++j)
68-
{
69-
std::cout << chunk.get()[j] << ", ";
70-
}
66+
std::visit(
67+
[&extent](auto &shared_ptr) {
68+
for (size_t j = 0; j < extent[0]; ++j)
69+
{
70+
std::cout << shared_ptr.get()[j] << ", ";
71+
}
72+
},
73+
chunk);
7174
std::cout << "\n----------\n" << std::endl;
7275
}
7376
}

examples/2_read_serial.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,25 @@ int main()
7373

7474
Offset chunk_offset = {1, 1, 1};
7575
Extent chunk_extent = {2, 2, 1};
76-
auto chunk_data = E_x.loadChunk<double>(chunk_offset, chunk_extent);
76+
// Loading without explicit datatype here
77+
auto chunk_data = E_x.loadChunkVariant(chunk_offset, chunk_extent);
7778
cout << "Queued the loading of a single chunk from disk, "
7879
"ready to execute\n";
7980
series.flush();
8081
cout << "Chunk has been read from disk\n"
8182
<< "Read chunk contains:\n";
82-
for (size_t row = 0; row < chunk_extent[0]; ++row)
83-
{
84-
for (size_t col = 0; col < chunk_extent[1]; ++col)
85-
cout << "\t" << '(' << row + chunk_offset[0] << '|'
86-
<< col + chunk_offset[1] << '|' << 1 << ")\t"
87-
<< chunk_data.get()[row * chunk_extent[1] + col];
88-
cout << '\n';
89-
}
83+
std::visit(
84+
[&chunk_offset, &chunk_extent](auto &shared_ptr) {
85+
for (size_t row = 0; row < chunk_extent[0]; ++row)
86+
{
87+
for (size_t col = 0; col < chunk_extent[1]; ++col)
88+
cout << "\t" << '(' << row + chunk_offset[0] << '|'
89+
<< col + chunk_offset[1] << '|' << 1 << ")\t"
90+
<< shared_ptr.get()[row * chunk_extent[1] + col];
91+
cout << '\n';
92+
}
93+
},
94+
chunk_data);
9095

9196
auto all_data = E_x.loadChunk<double>();
9297

examples/4_read_parallel.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ int main(int argc, char *argv[])
4949
Offset chunk_offset = {static_cast<long unsigned int>(mpi_rank) + 1, 1, 1};
5050
Extent chunk_extent = {2, 2, 1};
5151

52-
auto chunk_data = E_x.loadChunk<double>(chunk_offset, chunk_extent);
52+
// If you know the datatype, use `loadChunk<double>(...)` instead.
53+
auto chunk_data = E_x.loadChunkVariant(chunk_offset, chunk_extent);
5354

5455
if (0 == mpi_rank)
5556
cout << "Queued the loading of a single chunk per MPI rank from "
@@ -72,9 +73,20 @@ int main(int argc, char *argv[])
7273
for (size_t row = 0; row < chunk_extent[0]; ++row)
7374
{
7475
for (size_t col = 0; col < chunk_extent[1]; ++col)
76+
{
7577
cout << "\t" << '(' << row + chunk_offset[0] << '|'
76-
<< col + chunk_offset[1] << '|' << 1 << ")\t"
77-
<< chunk_data.get()[row * chunk_extent[1] + col];
78+
<< col + chunk_offset[1] << '|' << 1 << ")\t";
79+
/*
80+
* For hot loops, the std::visit(...) call should be moved
81+
* further up.
82+
*/
83+
std::visit(
84+
[row, col, &chunk_extent](auto &shared_ptr) {
85+
cout << shared_ptr
86+
.get()[row * chunk_extent[1] + col];
87+
},
88+
chunk_data);
89+
}
7890
cout << std::endl;
7991
}
8092
}

include/openPMD/Datatype.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
#include "openPMD/auxiliary/TypeTraits.hpp"
2424
#include "openPMD/auxiliary/UniquePtr.hpp"
2525

26+
// comment to prevent clang-format from moving this #include up
27+
// datatype macros may be included and un-included in other headers
28+
#include "openPMD/DatatypeMacros.hpp"
29+
2630
#include <array>
2731
#include <climits>
2832
#include <complex>
@@ -35,6 +39,7 @@
3539
#include <tuple>
3640
#include <type_traits>
3741
#include <utility> // std::declval
42+
#include <variant>
3843
#include <vector>
3944

4045
namespace openPMD
@@ -94,6 +99,33 @@ enum class Datatype : int
9499
*/
95100
std::vector<Datatype> openPMD_Datatypes();
96101

102+
namespace detail
103+
{
104+
struct bottom
105+
{};
106+
107+
// std::variant, but ignore first template parameter
108+
// little trick to avoid trailing commas in the macro expansions below
109+
template <typename Arg, typename... Args>
110+
using variant_tail_t = std::variant<Args...>;
111+
} // namespace detail
112+
113+
#define OPENPMD_ENUMERATE_TYPES(type) , type
114+
115+
using dataset_types =
116+
detail::variant_tail_t<detail::bottom OPENPMD_FOREACH_DATASET_DATATYPE(
117+
OPENPMD_ENUMERATE_TYPES)>;
118+
119+
using non_vector_types =
120+
detail::variant_tail_t<detail::bottom OPENPMD_FOREACH_NONVECTOR_DATATYPE(
121+
OPENPMD_ENUMERATE_TYPES)>;
122+
123+
using attribute_types =
124+
detail::variant_tail_t<detail::bottom OPENPMD_FOREACH_DATATYPE(
125+
OPENPMD_ENUMERATE_TYPES)>;
126+
127+
#undef OPENPMD_ENUMERATE_TYPES
128+
97129
/** @brief Fundamental equivalence check for two given types T and U.
98130
*
99131
* This checks whether the fundamental datatype (i.e. that of a single value
@@ -782,6 +814,25 @@ template <typename Action, typename... Args>
782814
constexpr auto switchNonVectorType(Datatype dt, Args &&...args)
783815
-> decltype(Action::template call<char>(std::forward<Args>(args)...));
784816

817+
/**
818+
* Generalizes switching over an openPMD datatype.
819+
*
820+
* Will call the function template found at Action::call< T >(), instantiating T
821+
* with the C++ internal datatype corresponding to the openPMD datatype.
822+
* Specializes only on those types that can occur in a dataset.
823+
*
824+
* @tparam ReturnType The function template's return type.
825+
* @tparam Action The struct containing the function template.
826+
* @tparam Args The function template's argument types.
827+
* @param dt The openPMD datatype.
828+
* @param args The function template's arguments.
829+
* @return Passes on the result of invoking the function template with the given
830+
* arguments and with the template parameter specified by dt.
831+
*/
832+
template <typename Action, typename... Args>
833+
constexpr auto switchDatasetType(Datatype dt, Args &&...args)
834+
-> decltype(Action::template call<char>(std::forward<Args>(args)...));
835+
785836
} // namespace openPMD
786837

787838
#if !defined(_MSC_VER)
@@ -811,3 +862,4 @@ inline bool operator!=(openPMD::Datatype d, openPMD::Datatype e)
811862
#endif
812863

813864
#include "openPMD/Datatype.tpp"
865+
#include "openPMD/UndefDatatypeMacros.hpp"

0 commit comments

Comments
 (0)