Skip to content

Commit cf3e1c5

Browse files
Harmonize Datatype equality checks (#1854)
* Fix typing issues with load/store_chunk in Python * Datatype helpers: non-template variants * Unify Datatype equality semantics * replace operator==(Datatype, Datatype) by isSame
1 parent 30061c6 commit cf3e1c5

9 files changed

Lines changed: 124 additions & 80 deletions

File tree

include/openPMD/Datatype.hpp

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ template <typename T>
294294
inline constexpr Datatype determineDatatype(T &&val)
295295
{
296296
(void)val; // don't need this, it only has a name for Doxygen
297-
using T_stripped = std::remove_cv_t<std::remove_reference_t<T>>;
297+
using T_stripped =
298+
std::remove_extent_t<std::remove_cv_t<std::remove_reference_t<T>>>;
298299
if constexpr (auxiliary::IsPointer_v<T_stripped>)
299300
{
300301
return determineDatatype<auxiliary::IsPointer_t<T_stripped>>();
@@ -419,6 +420,8 @@ inline size_t toBits(Datatype d)
419420
return toBytes(d) * CHAR_BIT;
420421
}
421422

423+
constexpr bool isSigned(Datatype d);
424+
422425
/** Compare if a Datatype is a vector type
423426
*
424427
* @param d Datatype to test
@@ -595,14 +598,19 @@ inline std::tuple<bool, bool> isInteger()
595598
*/
596599
template <typename T_FP>
597600
inline bool isSameFloatingPoint(Datatype d)
601+
{
602+
return isSameFloatingPoint(d, determineDatatype<T_FP>());
603+
}
604+
605+
inline bool isSameFloatingPoint(Datatype d1, Datatype d2)
598606
{
599607
// template
600-
bool tt_is_fp = isFloatingPoint<T_FP>();
608+
bool tt_is_fp = isFloatingPoint(d1);
601609

602610
// Datatype
603-
bool dt_is_fp = isFloatingPoint(d);
611+
bool dt_is_fp = isFloatingPoint(d2);
604612

605-
if (tt_is_fp && dt_is_fp && toBits(d) == toBits(determineDatatype<T_FP>()))
613+
if (tt_is_fp && dt_is_fp && toBits(d1) == toBits(d2))
606614
return true;
607615
else
608616
return false;
@@ -617,15 +625,19 @@ inline bool isSameFloatingPoint(Datatype d)
617625
*/
618626
template <typename T_CFP>
619627
inline bool isSameComplexFloatingPoint(Datatype d)
628+
{
629+
return isSameComplexFloatingPoint(d, determineDatatype<T_CFP>());
630+
}
631+
632+
inline bool isSameComplexFloatingPoint(Datatype d1, Datatype d2)
620633
{
621634
// template
622-
bool tt_is_cfp = isComplexFloatingPoint<T_CFP>();
635+
bool tt_is_cfp = isComplexFloatingPoint(d1);
623636

624637
// Datatype
625-
bool dt_is_cfp = isComplexFloatingPoint(d);
638+
bool dt_is_cfp = isComplexFloatingPoint(d2);
626639

627-
if (tt_is_cfp && dt_is_cfp &&
628-
toBits(d) == toBits(determineDatatype<T_CFP>()))
640+
if (tt_is_cfp && dt_is_cfp && toBits(d1) == toBits(d2))
629641
return true;
630642
else
631643
return false;
@@ -640,17 +652,22 @@ inline bool isSameComplexFloatingPoint(Datatype d)
640652
*/
641653
template <typename T_Int>
642654
inline bool isSameInteger(Datatype d)
655+
{
656+
return isSameInteger(d, determineDatatype<T_Int>());
657+
}
658+
659+
inline bool isSameInteger(Datatype d1, Datatype d2)
643660
{
644661
// template
645662
bool tt_is_int, tt_is_sig;
646-
std::tie(tt_is_int, tt_is_sig) = isInteger<T_Int>();
663+
std::tie(tt_is_int, tt_is_sig) = isInteger(d1);
647664

648665
// Datatype
649666
bool dt_is_int, dt_is_sig;
650-
std::tie(dt_is_int, dt_is_sig) = isInteger(d);
667+
std::tie(dt_is_int, dt_is_sig) = isInteger(d2);
651668

652669
if (tt_is_int && dt_is_int && tt_is_sig == dt_is_sig &&
653-
toBits(d) == toBits(determineDatatype<T_Int>()))
670+
toBits(d1) == toBits(d2))
654671
return true;
655672
else
656673
return false;
@@ -691,46 +708,15 @@ constexpr bool isChar(Datatype d)
691708
template <typename T_Char>
692709
constexpr bool isSameChar(Datatype d);
693710

711+
constexpr bool isSameChar(Datatype d1, Datatype d2);
712+
694713
/** Comparison for two Datatypes
695714
*
696715
* Besides returning true for the same types, identical implementations on
697716
* some platforms, e.g. if long and long long are the same or double and
698717
* long double will also return true.
699718
*/
700-
inline bool isSame(openPMD::Datatype const d, openPMD::Datatype const e)
701-
{
702-
// exact same type
703-
if (static_cast<int>(d) == static_cast<int>(e))
704-
return true;
705-
706-
bool d_is_vec = isVector(d);
707-
bool e_is_vec = isVector(e);
708-
709-
// same int
710-
bool d_is_int, d_is_sig;
711-
std::tie(d_is_int, d_is_sig) = isInteger(d);
712-
bool e_is_int, e_is_sig;
713-
std::tie(e_is_int, e_is_sig) = isInteger(e);
714-
if (d_is_int && e_is_int && d_is_vec == e_is_vec && d_is_sig == e_is_sig &&
715-
toBits(d) == toBits(e))
716-
return true;
717-
718-
// same float
719-
bool d_is_fp = isFloatingPoint(d);
720-
bool e_is_fp = isFloatingPoint(e);
721-
722-
if (d_is_fp && e_is_fp && d_is_vec == e_is_vec && toBits(d) == toBits(e))
723-
return true;
724-
725-
// same complex floating point
726-
bool d_is_cfp = isComplexFloatingPoint(d);
727-
bool e_is_cfp = isComplexFloatingPoint(e);
728-
729-
if (d_is_cfp && e_is_cfp && d_is_vec == e_is_vec && toBits(d) == toBits(e))
730-
return true;
731-
732-
return false;
733-
}
719+
constexpr bool isSame(openPMD::Datatype d, openPMD::Datatype e);
734720

735721
/**
736722
* @brief basicDatatype Strip openPMD Datatype of std::vector, std::array et.

include/openPMD/Datatype.tpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
// comment to prevent clang-format from moving this #include up
2626
// datatype macros may be included and un-included in other headers
2727
#include "openPMD/DatatypeMacros.hpp"
28+
#include "openPMD/auxiliary/TypeTraits.hpp"
2829

2930
#include <string>
3031
#include <type_traits> // std::void_t
@@ -253,6 +254,56 @@ constexpr inline bool isSameChar(Datatype d)
253254
{
254255
return switchType<detail::IsSameChar<T_Char>>(d);
255256
}
257+
258+
namespace detail
259+
{
260+
struct IsSigned
261+
{
262+
template <typename T>
263+
static constexpr bool call()
264+
{
265+
if constexpr (auxiliary::IsVector_v<T> || auxiliary::IsArray_v<T>)
266+
{
267+
return call<typename T::value_type>();
268+
}
269+
else if constexpr (std::is_same_v<T, std::string>)
270+
{
271+
return call<char>();
272+
}
273+
else
274+
{
275+
return std::is_signed_v<T>;
276+
}
277+
}
278+
279+
static constexpr char const *errorMsg = "IsSigned";
280+
};
281+
} // namespace detail
282+
283+
constexpr inline bool isSigned(Datatype d)
284+
{
285+
return switchType<detail::IsSigned>(d);
286+
}
287+
288+
constexpr inline bool isSameChar(Datatype d, Datatype e)
289+
{
290+
return isChar(d) && isChar(e) && isSigned(d) == isSigned(e);
291+
}
292+
293+
constexpr bool isSame(openPMD::Datatype const d, openPMD::Datatype const e)
294+
{
295+
return
296+
// exact same type
297+
static_cast<int>(d) == static_cast<int>(e)
298+
// same int
299+
|| isSameInteger(d, e)
300+
// same float
301+
|| isSameFloatingPoint(d, e)
302+
// same complex floating point
303+
|| isSameComplexFloatingPoint(d, e)
304+
// same char
305+
|| isSameChar(d, e);
306+
}
256307
} // namespace openPMD
257308

258309
#include "openPMD/UndefDatatypeMacros.hpp"

include/openPMD/backend/PatchRecordComponent.hpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ template <typename T>
122122
inline void PatchRecordComponent::load(std::shared_ptr<T> data)
123123
{
124124
Datatype dtype = determineDatatype<T>();
125-
if (dtype != getDatatype())
125+
// Attention: Do NOT use operator==(), doesnt work properly on Windows!
126+
if (!isSame(dtype, getDatatype()))
126127
throw std::runtime_error(
127128
"Type conversion during particle patch loading not yet "
128129
"implemented");
@@ -160,10 +161,7 @@ template <typename T>
160161
inline void PatchRecordComponent::store(uint64_t idx, T data)
161162
{
162163
Datatype dtype = determineDatatype<T>();
163-
if (dtype != getDatatype() && !isSameInteger<T>(getDatatype()) &&
164-
!isSameFloatingPoint<T>(getDatatype()) &&
165-
!isSameComplexFloatingPoint<T>(getDatatype()) &&
166-
!isSameChar<T>(getDatatype()))
164+
if (!isSame(dtype, getDatatype()))
167165
{
168166
std::ostringstream oss;
169167
oss << "Datatypes of patch data (" << dtype << ") and dataset ("
@@ -190,10 +188,7 @@ template <typename T>
190188
inline void PatchRecordComponent::store(T data)
191189
{
192190
Datatype dtype = determineDatatype<T>();
193-
if (dtype != getDatatype() && !isSameInteger<T>(getDatatype()) &&
194-
!isSameFloatingPoint<T>(getDatatype()) &&
195-
!isSameComplexFloatingPoint<T>(getDatatype()) &&
196-
!isSameChar<T>(getDatatype()))
191+
if (!isSame(dtype, getDatatype()))
197192
{
198193
std::ostringstream oss;
199194
oss << "Datatypes of patch data (" << dtype << ") and dataset ("

src/IO/ADIOS/ADIOS2PreloadAttributes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ PreloadAdiosAttributes::getAttribute(std::string const &name) const
248248
}
249249
AttributeLocation const &location = it->second;
250250
Datatype determinedDatatype = determineDatatype<T>();
251-
if (location.dt != determinedDatatype)
251+
if (!isSame(location.dt, determinedDatatype))
252252
{
253253
std::stringstream errorMsg;
254254
errorMsg << "[ADIOS2] Wrong datatype for attribute: " << name

src/IO/JSON/JSONIOHandlerImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2332,7 +2332,7 @@ auto JSONIOHandlerImpl::verifyDataset(
23322332
}
23332333
Datatype dt = stringToDatatype(j["datatype"].get<std::string>());
23342334
VERIFY_ALWAYS(
2335-
dt == parameters.dtype,
2335+
isSame(dt, parameters.dtype),
23362336
"[JSON] Read/Write request does not fit the dataset's type");
23372337
}
23382338
catch (json::basic_json::type_error &)

src/RecordComponent.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ void RecordComponent::verifyChunk(
657657
if (empty())
658658
throw std::runtime_error(
659659
"Chunks cannot be written for an empty RecordComponent.");
660-
if (dtype != getDatatype())
660+
if (!isSame(dtype, getDatatype()))
661661
{
662662
std::ostringstream oss;
663663
oss << "Datatypes of chunk data (" << dtype
@@ -833,21 +833,19 @@ void RecordComponent::loadChunk(std::shared_ptr<T> data, Offset o, Extent e)
833833
* JSON/TOML backends as they might implicitly turn a LONG into an INT in a
834834
* constant component. The frontend needs to catch such edge cases.
835835
* Ref. `if (constant())` branch.
836+
*
837+
* Attention: Do NOT use operator==(), doesnt work properly on Windows!
836838
*/
837-
if (dtype != getDatatype() && !constant())
838-
if (!isSameInteger<T>(getDatatype()) &&
839-
!isSameFloatingPoint<T>(getDatatype()) &&
840-
!isSameComplexFloatingPoint<T>(getDatatype()) &&
841-
!isSameChar<T>(getDatatype()))
842-
{
843-
std::string const data_type_str = datatypeToString(getDatatype());
844-
std::string const requ_type_str =
845-
datatypeToString(determineDatatype<T>());
846-
std::string err_msg =
847-
"Type conversion during chunk loading not yet implemented! ";
848-
err_msg += "Data: " + data_type_str + "; Load as: " + requ_type_str;
849-
throw std::runtime_error(err_msg);
850-
}
839+
if (!isSame(dtype, getDatatype()) && !constant())
840+
{
841+
std::string const data_type_str = datatypeToString(getDatatype());
842+
std::string const requ_type_str =
843+
datatypeToString(determineDatatype<T>());
844+
std::string err_msg =
845+
"Type conversion during chunk loading not yet implemented! ";
846+
err_msg += "Data: " + data_type_str + "; Load as: " + requ_type_str;
847+
throw std::runtime_error(err_msg);
848+
}
851849

852850
uint8_t dim = getDimensionality();
853851

src/Series.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,12 +1950,11 @@ void Series::readOneIterationFileBased(std::string const &filePath)
19501950

19511951
readBase();
19521952

1953-
using DT = Datatype;
19541953
aRead.name = "iterationEncoding";
19551954
IOHandler()->enqueue(IOTask(this, aRead));
19561955
IOHandler()->flush(internal::defaultFlushParams);
19571956
IterationEncoding encoding_out;
1958-
if (*aRead.dtype == DT::STRING)
1957+
if (isSame(*aRead.dtype, Datatype::STRING))
19591958
{
19601959
std::string encoding = Attribute(Attribute::from_any, *aRead.m_resource)
19611960
.get<std::string>();
@@ -2010,7 +2009,7 @@ void Series::readOneIterationFileBased(std::string const &filePath)
20102009
aRead.name = "iterationFormat";
20112010
IOHandler()->enqueue(IOTask(this, aRead));
20122011
IOHandler()->flush(internal::defaultFlushParams);
2013-
if (*aRead.dtype == DT::STRING)
2012+
if (isSame(*aRead.dtype, Datatype::STRING))
20142013
{
20152014
setWritten(false, Attributable::EnqueueAsynchronously::No);
20162015
setIterationFormat(Attribute(Attribute::from_any, *aRead.m_resource)

src/binding/python/RecordComponent.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,14 @@ inline void store_chunk(
489489

490490
check_buffer_is_contiguous(a);
491491

492-
// dtype_from_numpy(a.dtype())
492+
if (!dtype_to_numpy(r.getDatatype()).is(a.dtype()))
493+
{
494+
std::stringstream err;
495+
err << "Attempting store from Python array of type '"
496+
<< dtype_from_numpy(a.dtype())
497+
<< "' into Record Component of type '" << r.getDatatype() << "'.";
498+
throw error::WrongAPIUsage(err.str());
499+
}
493500
switchDatasetType<StoreChunkFromPythonArray>(
494501
r.getDatatype(), r, a, offset, extent);
495502
}
@@ -770,6 +777,15 @@ inline void load_chunk(
770777

771778
check_buffer_is_contiguous(a);
772779

780+
if (!dtype_to_numpy(r.getDatatype()).is(a.dtype()))
781+
{
782+
std::stringstream err;
783+
err << "Attempting load into Python array of type '"
784+
<< dtype_from_numpy(a.dtype())
785+
<< "' from Record Component of type '" << r.getDatatype() << "'.";
786+
throw error::WrongAPIUsage(err.str());
787+
}
788+
773789
switchDatasetType<LoadChunkIntoPythonArray>(
774790
r.getDatatype(), r, a, offset, extent);
775791
}

test/python/unittest/API/APITest.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2209,33 +2209,32 @@ def testError(self):
22092209

22102210
def testCustomGeometries(self):
22112211
DS = io.Dataset
2212-
DT = io.Datatype
22132212
sample_data = np.ones([10], dtype=np.int_)
22142213

22152214
write = io.Series("../samples/custom_geometries_python.json",
22162215
io.Access.create)
22172216
E = write.iterations[0].meshes["E"]
22182217
E.set_attribute("geometry", "other:customGeometry")
22192218
E_x = E["x"]
2220-
E_x.reset_dataset(DS(DT.LONG, [10]))
2219+
E_x.reset_dataset(DS(np.dtype(np.int_), [10]))
22212220
E_x[:] = sample_data
22222221

22232222
B = write.iterations[0].meshes["B"]
22242223
B.set_geometry("customGeometry")
22252224
B_x = B["x"]
2226-
B_x.reset_dataset(DS(DT.LONG, [10]))
2225+
B_x.reset_dataset(DS(np.dtype(np.int_), [10]))
22272226
B_x[:] = sample_data
22282227

22292228
e_energyDensity = write.iterations[0].meshes["e_energyDensity"]
22302229
e_energyDensity.set_geometry("other:customGeometry")
22312230
e_energyDensity_x = e_energyDensity[io.Mesh_Record_Component.SCALAR]
2232-
e_energyDensity_x.reset_dataset(DS(DT.LONG, [10]))
2231+
e_energyDensity_x.reset_dataset(DS(np.dtype(np.int_), [10]))
22332232
e_energyDensity_x[:] = sample_data
22342233

22352234
e_chargeDensity = write.iterations[0].meshes["e_chargeDensity"]
22362235
e_chargeDensity.set_geometry(io.Geometry.other)
22372236
e_chargeDensity_x = e_chargeDensity[io.Mesh_Record_Component.SCALAR]
2238-
e_chargeDensity_x.reset_dataset(DS(DT.LONG, [10]))
2237+
e_chargeDensity_x.reset_dataset(DS(np.dtype(np.int_), [10]))
22392238
e_chargeDensity_x[:] = sample_data
22402239

22412240
self.assertTrue(write)

0 commit comments

Comments
 (0)