1010#include < popcntintrin.h>
1111#include < cmath>
1212#include < string.h>
13+
14+ inline __m512h floor_ph_512 (__m512h val_ph)
15+ {
16+ __m512h ret_ph;
17+ #ifndef USE_ATAN2_APPROX
18+ ret_ph = _mm512_floor_ph (val_ph); // svml instruction.
19+ #else
20+ ret_ph = _mm512_cvtepi16_ph (_mm512_cvt_roundph_epi16 (val_ph, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC));
21+ #endif
22+ return ret_ph;
23+ }
24+
25+ inline __m128h floor_ph_128 (__m128h val_ph)
26+ {
27+ __m128h ret_ph;
28+ ret_ph = _mm_cvtepi16_ph (_mm512_castph512_ph128 (_mm512_cvt_roundph_epi16 (_mm512_castph128_ph512 (val_ph), _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)));
29+ return ret_ph;
30+ }
31+
1332inline void load3x3_ph (_Float16 *img, unsigned int width, unsigned int height, unsigned int stride, __m128h *out_8neighbors_ph, __m128h *out_center_ph)
1433{
1534 int index = (height - 1 ) * stride + (width - 1 );
@@ -291,10 +310,7 @@ void CTCountOfBitsChangedSegment_AVX512FP16_16f(_Float16 *LRImage, _Float16 *HRI
291310 __m512h val_ph = _mm512_add_ph ( _mm512_mul_ph ( weight_ph, center_LR_ph),
292311 _mm512_mul_ph (weight2_ph, center_HR_ph));
293312 val_ph = _mm512_add_ph ( val_ph, _mm512_set1_ph (0.5 ));
294-
295- #ifndef USE_ATAN2_APPROX
296- val_ph = _mm512_floor_ph (val_ph); // svml instruction. perhaps rounding would be better?
297- #endif
313+ val_ph = floor_ph_512 (val_ph);
298314
299315 // convert (float)val to (epu8/16)val
300316 __m512i val_epu16 = _mm512_cvtph_epu16 (val_ph), val_epu8, perm_epu;
@@ -420,14 +436,10 @@ void GetHashValue_AVX512FP16_16h_8Elements(_Float16 GTWG[3][32], int passIdx, in
420436 _mm_add_ph ( _mm_add_ph (sqrtL1_ph, sqrtL2_ph), _mm_set1_ph (near_zero) ) );
421437 __m128h strength_ph = L1_ph;
422438
423- __m128i angleIdx_epi16 = _mm_cvtph_epi16 ( _mm_mul_ph (angle_ph, _mm_set1_ph (gQAngle )));
424-
439+ __m128i angleIdx_epi16 = floor_ph_128 ( _mm_mul_ph (angle_ph, _mm_set1_ph (gQAngle )));
425440 __m128i quantAngle_lessone_epi16 = _mm_sub_epi16 (_mm_set1_epi16 (gQuantizationAngle ), one_epi16);
426- angleIdx_epi16 = _mm_mask_blend_epi16 ( _mm_cmp_epi16_mask ( angleIdx_epi16, quantAngle_lessone_epi16, _MM_CMPINT_GT),
427- _mm_mask_blend_epi16 (_mm_cmp_epi16_mask ( angleIdx_epi16, zero_epi16, _MM_CMPINT_LT),
428- angleIdx_epi16,
429- zero_epi16),
430- quantAngle_lessone_epi16);
441+ angleIdx_epi16 = _mm_min_epi16 ( _mm_sub_epi16 (_mm_set1_epi16 (gQuantizationAngle ), _mm_set1_epi16 (1 )),
442+ _mm_max_epi16 (angleIdx_epi16, zero_epi16));
431443
432444 // AFAIK, today QStr & QCoh are vectors of size 2. I think searchsorted can return an index of 0,1, or 2
433445 _Float16 *gQStr_data , *gQCoh_data ;
@@ -438,17 +450,14 @@ void GetHashValue_AVX512FP16_16h_8Elements(_Float16 GTWG[3][32], int passIdx, in
438450 __m128h gQCoh1_ph = _mm_set1_ph (gQCoh_data [0 ]);
439451 __m128h gQCoh2_ph = _mm_set1_ph (gQCoh_data [1 ]);
440452
441-
442- __m128i strengthIdx_epi16 = _mm_mask_blend_epi16 (_mm_cmp_ph_mask (gQStr1_ph , strength_ph, _MM_CMPINT_LE),
443- zero_epi16,
444- _mm_mask_blend_epi16 (_mm_cmp_ph_mask (gQStr2_ph , strength_ph, _MM_CMPINT_LE),
445- two_epi16,
446- one_epi16));
447- __m128i coherenceIdx_epi16 = _mm_mask_blend_epi16 (_mm_cmp_ph_mask (gQCoh1_ph , coherence_ph, _MM_CMPINT_LE),
448- zero_epi16,
449- _mm_mask_blend_epi16 (_mm_cmp_ph_mask (gQCoh2_ph , coherence_ph, _MM_CMPINT_LE),
450- two_epi16,
451- one_epi16));
453+ __m128i strengthIdx_epi16 =
454+ _mm_add_epi16 (
455+ _mm_mask_blend_epi16 (_mm_cmp_ph_mask (gQStr1_ph , strength_ph, _MM_CMPINT_LE),zero_epi16, one_epi16),
456+ _mm_mask_blend_epi16 (_mm_cmp_ph_mask (gQStr2_ph , strength_ph, _MM_CMPINT_LE),zero_epi16, one_epi16));
457+ __m128i coherenceIdx_epi16 =
458+ _mm_add_epi16 (
459+ _mm_mask_blend_epi16 (_mm_cmp_ph_mask (gQCoh1_ph , coherence_ph, _MM_CMPINT_LE),zero_epi16, one_epi16),
460+ _mm_mask_blend_epi16 (_mm_cmp_ph_mask (gQCoh2_ph , coherence_ph, _MM_CMPINT_LE),zero_epi16, one_epi16));
452461
453462 const __m128i gQuantizationCoherence_epi16 = _mm_set1_epi16 (gQuantizationCoherence );
454463 __m128i idx_epi16 = _mm_mullo_epi16 (gQuantizationCoherence_epi16 ,
@@ -498,9 +507,9 @@ void GetHashValue_AVX512FP16_16h_32Elements(_Float16 GTWG[3][32], int passIdx, i
498507 const int cmp_le = _CMP_LE_OQ;
499508 const int cmp_gt = _CMP_GT_OQ;
500509
501- __m512h m_a_ph = _mm512_load_ph ( & GTWG[0 ]);
502- __m512h m_b_ph = _mm512_load_ph ( & GTWG[1 ]);
503- __m512h m_d_ph = _mm512_load_ph ( & GTWG[2 ]);
510+ __m512h m_a_ph = _mm512_load_ph ( GTWG[0 ]);
511+ __m512h m_b_ph = _mm512_load_ph ( GTWG[1 ]);
512+ __m512h m_d_ph = _mm512_load_ph ( GTWG[2 ]);
504513
505514 __m512h T_ph = _mm512_add_ph (m_a_ph, m_d_ph);
506515 __m512h D_ph = _mm512_sub_ph ( _mm512_mul_ph ( m_a_ph, m_d_ph),
@@ -535,14 +544,11 @@ void GetHashValue_AVX512FP16_16h_32Elements(_Float16 GTWG[3][32], int passIdx, i
535544 _mm512_add_ph ( _mm512_add_ph (sqrtL1_ph, sqrtL2_ph), _mm512_set1_ph (near_zero) ) );
536545 __m512h strength_ph = L1_ph;
537546
538- __m512i angleIdx_epi16 = _mm512_cvtph_epi16 ( _mm512_floor_ph ( _mm512_mul_ph (angle_ph, _mm512_set1_ph (gQAngle ) )));
547+ __m512i angleIdx_epi16 = floor_ph_512 ( _mm512_mul_ph (angle_ph, _mm512_set1_ph (gQAngle )));
539548
540549 __m512i quantAngle_lessone_epi16 = _mm512_sub_epi16 (_mm512_set1_epi16 (gQuantizationAngle ), one_epi16);
541- angleIdx_epi16 = _mm512_mask_blend_epi16 ( _mm512_cmp_epi16_mask ( angleIdx_epi16, quantAngle_lessone_epi16, _MM_CMPINT_GT),
542- _mm512_mask_blend_epi16 (_mm512_cmp_epi16_mask ( angleIdx_epi16, zero_epi16, _MM_CMPINT_LT),
543- angleIdx_epi16,
544- zero_epi16),
545- quantAngle_lessone_epi16);
550+ angleIdx_epi16 = _mm512_min_epi16 (_mm512_sub_epi16 (_mm512_set1_epi16 (gQuantizationAngle ),_mm512_set1_epi16 (1 )),
551+ _mm512_max_epi16 (angleIdx_epi16, zero_epi16));
546552
547553 // AFAIK, today QStr & QCoh are vectors of size 2. I think searchsorted can return an index of 0,1, or 2
548554 _Float16 *gQStr_data , *gQCoh_data ;
@@ -553,17 +559,14 @@ void GetHashValue_AVX512FP16_16h_32Elements(_Float16 GTWG[3][32], int passIdx, i
553559 __m512h gQCoh1_ph = _mm512_set1_ph (gQCoh_data [0 ]);
554560 __m512h gQCoh2_ph = _mm512_set1_ph (gQCoh_data [1 ]);
555561
556-
557- __m512i strengthIdx_epi16 = _mm512_mask_blend_epi16 (_mm512_cmp_ph_mask (gQStr1_ph , strength_ph, _MM_CMPINT_LE),
558- zero_epi16,
559- _mm512_mask_blend_epi16 (_mm512_cmp_ph_mask (gQStr2_ph , strength_ph, _MM_CMPINT_LE),
560- two_epi16,
561- one_epi16));
562- __m512i coherenceIdx_epi16 = _mm512_mask_blend_epi16 (_mm512_cmp_ph_mask (gQCoh1_ph , coherence_ph, _MM_CMPINT_LE),
563- zero_epi16,
564- _mm512_mask_blend_epi16 (_mm512_cmp_ph_mask (gQCoh2_ph , coherence_ph, _MM_CMPINT_LE),
565- two_epi16,
566- one_epi16));
562+ __m512i strengthIdx_epi16 =
563+ _mm512_add_epi16 (
564+ _mm512_mask_blend_epi16 (_mm512_cmp_ph_mask (gQStr1_ph , strength_ph, _MM_CMPINT_LE),zero_epi16, one_epi16),
565+ _mm512_mask_blend_epi16 (_mm512_cmp_ph_mask (gQStr2_ph , strength_ph, _MM_CMPINT_LE),zero_epi16, one_epi16));
566+ __m512i coherenceIdx_epi16 =
567+ _mm512_add_epi16 (
568+ _mm512_mask_blend_epi16 (_mm512_cmp_ph_mask (gQCoh1_ph , coherence_ph, _MM_CMPINT_LE),zero_epi16, one_epi16),
569+ _mm512_mask_blend_epi16 (_mm512_cmp_ph_mask (gQCoh2_ph , coherence_ph, _MM_CMPINT_LE),zero_epi16, one_epi16));
567570
568571 const __m512i gQuantizationCoherence_epi16 = _mm512_set1_epi16 (gQuantizationCoherence );
569572 __m512i idx_epi16 = _mm512_mullo_epi16 (gQuantizationCoherence_epi16 ,
0 commit comments