@@ -204,11 +204,75 @@ kat_is_kind(KeysArrayType kat, char kind) {
204204 case KAT_STRING :
205205 return kind == 'S' ;
206206
207+ case KAT_DTY :
208+ case KAT_DTM :
209+ case KAT_DTW :
210+ case KAT_DTD :
211+ case KAT_DTh :
212+ case KAT_DTm :
213+ case KAT_DTs :
214+ case KAT_DTms :
215+ case KAT_DTus :
216+ case KAT_DTns :
217+ case KAT_DTps :
218+ case KAT_DTfs :
219+ case KAT_DTas :
220+ return kind == 'M' ;
221+
207222 default :
208223 return 0 ;
209224 }
210225}
211226
227+ // Given a KAT, determine if it matches a NumPy dt64 unit.
228+ int
229+ kat_is_datetime_unit (KeysArrayType kat , NPY_DATETIMEUNIT unit ) {
230+ switch (kat ) {
231+ case KAT_DTY :
232+ if (unit == NPY_FR_Y ) {return 1 ;}
233+ break ;
234+ case KAT_DTM :
235+ if (unit == NPY_FR_M ) {return 1 ;}
236+ break ;
237+ case KAT_DTW :
238+ if (unit == NPY_FR_W ) {return 1 ;}
239+ break ;
240+ case KAT_DTD :
241+ if (unit == NPY_FR_D ) {return 1 ;}
242+ break ;
243+ case KAT_DTh :
244+ if (unit == NPY_FR_h ) {return 1 ;}
245+ break ;
246+ case KAT_DTm :
247+ if (unit == NPY_FR_m ) {return 1 ;}
248+ break ;
249+ case KAT_DTs :
250+ if (unit == NPY_FR_s ) {return 1 ;}
251+ break ;
252+ case KAT_DTms :
253+ if (unit == NPY_FR_ms ) {return 1 ;}
254+ break ;
255+ case KAT_DTus :
256+ if (unit == NPY_FR_us ) {return 1 ;}
257+ break ;
258+ case KAT_DTns :
259+ if (unit == NPY_FR_ns ) {return 1 ;}
260+ break ;
261+ case KAT_DTps :
262+ if (unit == NPY_FR_ps ) {return 1 ;}
263+ break ;
264+ case KAT_DTfs :
265+ if (unit == NPY_FR_fs ) {return 1 ;}
266+ break ;
267+ case KAT_DTas :
268+ if (unit == NPY_FR_as ) {return 1 ;}
269+ break ;
270+ default : // non dt64 KATs
271+ return 0 ;
272+ }
273+ return 0 ;
274+ }
275+
212276typedef struct FAMObject {
213277 PyObject_VAR_HEAD
214278 Py_ssize_t table_size ;
@@ -1059,53 +1123,13 @@ static Py_ssize_t
10591123lookup_datetime (FAMObject * self , PyObject * key ) {
10601124 npy_int64 v = 0 ; // int64
10611125 if (PyArray_IsScalar (key , Datetime )) {
1062- NPY_DATETIMEUNIT key_unit = dt_unit_from_scalar ((PyDatetimeScalarObject * )key );
10631126 v = (npy_int64 )PyArrayScalar_VAL (key , Datetime );
10641127 // if we observe a NAT, we skip unit checks
10651128 if (v != NPY_DATETIME_NAT ) {
1066- // DEBUG_MSG_OBJ("scalar unit", PyLong_FromLongLong(key_unit));
1067- switch (self -> keys_array_type ) {
1068- case KAT_DTY :
1069- if (key_unit != NPY_FR_Y ) {return -1 ;}
1070- break ;
1071- case KAT_DTM :
1072- if (key_unit != NPY_FR_M ) {return -1 ;}
1073- break ;
1074- case KAT_DTW :
1075- if (key_unit != NPY_FR_W ) {return -1 ;}
1076- break ;
1077- case KAT_DTD :
1078- if (key_unit != NPY_FR_D ) {return -1 ;}
1079- break ;
1080- case KAT_DTh :
1081- if (key_unit != NPY_FR_h ) {return -1 ;}
1082- break ;
1083- case KAT_DTm :
1084- if (key_unit != NPY_FR_m ) {return -1 ;}
1085- break ;
1086- case KAT_DTs :
1087- if (key_unit != NPY_FR_s ) {return -1 ;}
1088- break ;
1089- case KAT_DTms :
1090- if (key_unit != NPY_FR_ms ) {return -1 ;}
1091- break ;
1092- case KAT_DTus :
1093- if (key_unit != NPY_FR_us ) {return -1 ;}
1094- break ;
1095- case KAT_DTns :
1096- if (key_unit != NPY_FR_ns ) {return -1 ;}
1097- break ;
1098- case KAT_DTps :
1099- if (key_unit != NPY_FR_ps ) {return -1 ;}
1100- break ;
1101- case KAT_DTfs :
1102- if (key_unit != NPY_FR_fs ) {return -1 ;}
1103- break ;
1104- case KAT_DTas :
1105- if (key_unit != NPY_FR_as ) {return -1 ;}
1106- break ;
1107- default :
1108- return -1 ;
1129+ NPY_DATETIMEUNIT key_unit = dt_unit_from_scalar (
1130+ (PyDatetimeScalarObject * )key );
1131+ if (!kat_is_datetime_unit (self -> keys_array_type , key_unit )) {
1132+ return -1 ;
11091133 }
11101134 }
11111135 // DEBUG_MSG_OBJ("dt64 value", PyLong_FromLongLong(v));
@@ -1856,7 +1880,7 @@ fam_get_all(FAMObject *self, PyObject *key) {
18561880 return NULL ;
18571881 }
18581882
1859- // construct array to be returned
1883+ // construct array to be returned; this is a little expensive if we do not yet know if we can use it
18601884 npy_intp dims [] = {key_size };
18611885 array = PyArray_EMPTY (1 , dims , NPY_INT64 , 0 );
18621886 if (array == NULL ) {
@@ -1926,6 +1950,14 @@ fam_get_all(FAMObject *self, PyObject *key) {
19261950 case NPY_STRING :
19271951 GET_ALL_FLEXIBLE (char , char_get_end_p , lookup_hash_string , string_to_hash , PyBytes_FromStringAndSize );
19281952 break ;
1953+ case NPY_DATETIME :
1954+ NPY_DATETIMEUNIT key_unit = dt_unit_from_array (key_array );
1955+ if (!kat_is_datetime_unit (self -> keys_array_type , key_unit )) {
1956+ Py_DECREF (array );
1957+ return NULL ;
1958+ }
1959+ GET_ALL_SCALARS (npy_int64 , npy_int64 , KAT_INT64 , lookup_hash_int , int_to_hash , PyLong_FromLongLong , );
1960+ break ;
19291961 }
19301962 }
19311963 else {
0 commit comments