Skip to content

Commit 2488483

Browse files
cabirdmexiaoxial
authored andcommitted
added floor operation to fp16 path. bug fix in GetHashFP16 funcs.
1 parent 6a5a941 commit 2488483

1 file changed

Lines changed: 45 additions & 42 deletions

File tree

Library/Raisr_AVX512FP16.cpp

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,25 @@
1010
#include <popcntintrin.h>
1111
#include <cmath>
1212
#include<string.h>
13+
14+
inline __m512h floor_ph_512(__m512h val_ph)
15+
{
16+
__m512h ret_ph;
17+
#ifndef USE_ATAN2_APPROX
18+
ret_ph = _mm512_floor_ph(val_ph); // svml instruction.
19+
#else
20+
ret_ph = _mm512_cvtepi16_ph(_mm512_cvt_roundph_epi16(val_ph, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC));
21+
#endif
22+
return ret_ph;
23+
}
24+
25+
inline __m128h floor_ph_128(__m128h val_ph)
26+
{
27+
__m128h ret_ph;
28+
ret_ph = _mm_cvtepi16_ph(_mm512_castph512_ph128(_mm512_cvt_roundph_epi16(_mm512_castph128_ph512(val_ph), _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)));
29+
return ret_ph;
30+
}
31+
1332
inline void load3x3_ph(_Float16 *img, unsigned int width, unsigned int height, unsigned int stride, __m128h *out_8neighbors_ph, __m128h *out_center_ph)
1433
{
1534
int index = (height - 1) * stride + (width - 1);
@@ -291,10 +310,7 @@ void CTCountOfBitsChangedSegment_AVX512FP16_16f(_Float16 *LRImage, _Float16 *HRI
291310
__m512h val_ph = _mm512_add_ph( _mm512_mul_ph( weight_ph, center_LR_ph),
292311
_mm512_mul_ph(weight2_ph, center_HR_ph));
293312
val_ph = _mm512_add_ph( val_ph, _mm512_set1_ph(0.5));
294-
295-
#ifndef USE_ATAN2_APPROX
296-
val_ph = _mm512_floor_ph(val_ph); // svml instruction. perhaps rounding would be better?
297-
#endif
313+
val_ph = floor_ph_512(val_ph);
298314

299315
// convert (float)val to (epu8/16)val
300316
__m512i val_epu16 = _mm512_cvtph_epu16(val_ph), val_epu8, perm_epu;
@@ -420,14 +436,10 @@ void GetHashValue_AVX512FP16_16h_8Elements(_Float16 GTWG[3][32], int passIdx, in
420436
_mm_add_ph( _mm_add_ph(sqrtL1_ph, sqrtL2_ph), _mm_set1_ph(near_zero) ) );
421437
__m128h strength_ph = L1_ph;
422438

423-
__m128i angleIdx_epi16 = _mm_cvtph_epi16( _mm_mul_ph (angle_ph, _mm_set1_ph(gQAngle)));
424-
439+
__m128i angleIdx_epi16 = floor_ph_128( _mm_mul_ph (angle_ph, _mm_set1_ph(gQAngle)));
425440
__m128i quantAngle_lessone_epi16 = _mm_sub_epi16(_mm_set1_epi16(gQuantizationAngle), one_epi16);
426-
angleIdx_epi16 = _mm_mask_blend_epi16( _mm_cmp_epi16_mask( angleIdx_epi16, quantAngle_lessone_epi16, _MM_CMPINT_GT),
427-
_mm_mask_blend_epi16(_mm_cmp_epi16_mask( angleIdx_epi16, zero_epi16, _MM_CMPINT_LT),
428-
angleIdx_epi16,
429-
zero_epi16),
430-
quantAngle_lessone_epi16);
441+
angleIdx_epi16 = _mm_min_epi16( _mm_sub_epi16(_mm_set1_epi16(gQuantizationAngle), _mm_set1_epi16(1)),
442+
_mm_max_epi16(angleIdx_epi16, zero_epi16));
431443

432444
// AFAIK, today QStr & QCoh are vectors of size 2. I think searchsorted can return an index of 0,1, or 2
433445
_Float16 *gQStr_data, *gQCoh_data;
@@ -438,17 +450,14 @@ void GetHashValue_AVX512FP16_16h_8Elements(_Float16 GTWG[3][32], int passIdx, in
438450
__m128h gQCoh1_ph = _mm_set1_ph(gQCoh_data[0]);
439451
__m128h gQCoh2_ph = _mm_set1_ph(gQCoh_data[1]);
440452

441-
442-
__m128i strengthIdx_epi16 = _mm_mask_blend_epi16(_mm_cmp_ph_mask(gQStr1_ph, strength_ph, _MM_CMPINT_LE),
443-
zero_epi16,
444-
_mm_mask_blend_epi16(_mm_cmp_ph_mask(gQStr2_ph, strength_ph, _MM_CMPINT_LE),
445-
two_epi16,
446-
one_epi16));
447-
__m128i coherenceIdx_epi16 = _mm_mask_blend_epi16(_mm_cmp_ph_mask(gQCoh1_ph, coherence_ph, _MM_CMPINT_LE),
448-
zero_epi16,
449-
_mm_mask_blend_epi16(_mm_cmp_ph_mask(gQCoh2_ph, coherence_ph, _MM_CMPINT_LE),
450-
two_epi16,
451-
one_epi16));
453+
__m128i strengthIdx_epi16 =
454+
_mm_add_epi16(
455+
_mm_mask_blend_epi16(_mm_cmp_ph_mask(gQStr1_ph, strength_ph, _MM_CMPINT_LE),zero_epi16, one_epi16),
456+
_mm_mask_blend_epi16(_mm_cmp_ph_mask(gQStr2_ph, strength_ph, _MM_CMPINT_LE),zero_epi16, one_epi16));
457+
__m128i coherenceIdx_epi16 =
458+
_mm_add_epi16(
459+
_mm_mask_blend_epi16(_mm_cmp_ph_mask(gQCoh1_ph, coherence_ph, _MM_CMPINT_LE),zero_epi16, one_epi16),
460+
_mm_mask_blend_epi16(_mm_cmp_ph_mask(gQCoh2_ph, coherence_ph, _MM_CMPINT_LE),zero_epi16, one_epi16));
452461

453462
const __m128i gQuantizationCoherence_epi16 = _mm_set1_epi16(gQuantizationCoherence);
454463
__m128i idx_epi16 = _mm_mullo_epi16(gQuantizationCoherence_epi16,
@@ -498,9 +507,9 @@ void GetHashValue_AVX512FP16_16h_32Elements(_Float16 GTWG[3][32], int passIdx, i
498507
const int cmp_le = _CMP_LE_OQ;
499508
const int cmp_gt = _CMP_GT_OQ;
500509

501-
__m512h m_a_ph = _mm512_load_ph( &GTWG[0]);
502-
__m512h m_b_ph = _mm512_load_ph( &GTWG[1]);
503-
__m512h m_d_ph = _mm512_load_ph( &GTWG[2]);
510+
__m512h m_a_ph = _mm512_load_ph( GTWG[0]);
511+
__m512h m_b_ph = _mm512_load_ph( GTWG[1]);
512+
__m512h m_d_ph = _mm512_load_ph( GTWG[2]);
504513

505514
__m512h T_ph = _mm512_add_ph(m_a_ph, m_d_ph);
506515
__m512h D_ph = _mm512_sub_ph( _mm512_mul_ph( m_a_ph, m_d_ph),
@@ -535,14 +544,11 @@ void GetHashValue_AVX512FP16_16h_32Elements(_Float16 GTWG[3][32], int passIdx, i
535544
_mm512_add_ph( _mm512_add_ph(sqrtL1_ph, sqrtL2_ph), _mm512_set1_ph(near_zero) ) );
536545
__m512h strength_ph = L1_ph;
537546

538-
__m512i angleIdx_epi16 = _mm512_cvtph_epi16( _mm512_floor_ph(_mm512_mul_ph (angle_ph, _mm512_set1_ph(gQAngle))));
547+
__m512i angleIdx_epi16 = floor_ph_512(_mm512_mul_ph (angle_ph, _mm512_set1_ph(gQAngle)));
539548

540549
__m512i quantAngle_lessone_epi16 = _mm512_sub_epi16(_mm512_set1_epi16(gQuantizationAngle), one_epi16);
541-
angleIdx_epi16 = _mm512_mask_blend_epi16( _mm512_cmp_epi16_mask( angleIdx_epi16, quantAngle_lessone_epi16, _MM_CMPINT_GT),
542-
_mm512_mask_blend_epi16(_mm512_cmp_epi16_mask( angleIdx_epi16, zero_epi16, _MM_CMPINT_LT),
543-
angleIdx_epi16,
544-
zero_epi16),
545-
quantAngle_lessone_epi16);
550+
angleIdx_epi16 = _mm512_min_epi16(_mm512_sub_epi16(_mm512_set1_epi16(gQuantizationAngle),_mm512_set1_epi16(1)),
551+
_mm512_max_epi16(angleIdx_epi16, zero_epi16));
546552

547553
// AFAIK, today QStr & QCoh are vectors of size 2. I think searchsorted can return an index of 0,1, or 2
548554
_Float16 *gQStr_data, *gQCoh_data;
@@ -553,17 +559,14 @@ void GetHashValue_AVX512FP16_16h_32Elements(_Float16 GTWG[3][32], int passIdx, i
553559
__m512h gQCoh1_ph = _mm512_set1_ph(gQCoh_data[0]);
554560
__m512h gQCoh2_ph = _mm512_set1_ph(gQCoh_data[1]);
555561

556-
557-
__m512i strengthIdx_epi16 = _mm512_mask_blend_epi16(_mm512_cmp_ph_mask(gQStr1_ph, strength_ph, _MM_CMPINT_LE),
558-
zero_epi16,
559-
_mm512_mask_blend_epi16(_mm512_cmp_ph_mask(gQStr2_ph, strength_ph, _MM_CMPINT_LE),
560-
two_epi16,
561-
one_epi16));
562-
__m512i coherenceIdx_epi16 = _mm512_mask_blend_epi16(_mm512_cmp_ph_mask(gQCoh1_ph, coherence_ph, _MM_CMPINT_LE),
563-
zero_epi16,
564-
_mm512_mask_blend_epi16(_mm512_cmp_ph_mask(gQCoh2_ph, coherence_ph, _MM_CMPINT_LE),
565-
two_epi16,
566-
one_epi16));
562+
__m512i strengthIdx_epi16 =
563+
_mm512_add_epi16(
564+
_mm512_mask_blend_epi16(_mm512_cmp_ph_mask(gQStr1_ph, strength_ph, _MM_CMPINT_LE),zero_epi16, one_epi16),
565+
_mm512_mask_blend_epi16(_mm512_cmp_ph_mask(gQStr2_ph, strength_ph, _MM_CMPINT_LE),zero_epi16, one_epi16));
566+
__m512i coherenceIdx_epi16 =
567+
_mm512_add_epi16(
568+
_mm512_mask_blend_epi16(_mm512_cmp_ph_mask(gQCoh1_ph, coherence_ph, _MM_CMPINT_LE),zero_epi16, one_epi16),
569+
_mm512_mask_blend_epi16(_mm512_cmp_ph_mask(gQCoh2_ph, coherence_ph, _MM_CMPINT_LE),zero_epi16, one_epi16));
567570

568571
const __m512i gQuantizationCoherence_epi16 = _mm512_set1_epi16(gQuantizationCoherence);
569572
__m512i idx_epi16 = _mm512_mullo_epi16(gQuantizationCoherence_epi16,

0 commit comments

Comments
 (0)