Skip to content

Commit fa902ae

Browse files
cabirdmexiaoxial
authored andcommitted
added DotProdPatch AVX2
1 parent c62f762 commit fa902ae

3 files changed

Lines changed: 42 additions & 3 deletions

File tree

Library/Raisr.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,16 @@ RNLERRORTYPE processSegment(VideoDataType *srcY, VideoDataType *final_outY, Blen
11651165
{
11661166
if (likely(c + pix < cols - gLoopMargin))
11671167
{
1168-
float curPix = DotProdPatch_AVX512_32f(pixbuf[pix], fbase[pix]);
1168+
float curPix;
1169+
if (gAsmType == AVX2)
1170+
curPix = DotProdPatch_AVX256_32f(pixbuf[pix], fbase[pix]);
1171+
else if (gAsmType == AVX512)
1172+
curPix = DotProdPatch_AVX512_32f(pixbuf[pix], fbase[pix]);
1173+
else
1174+
{
1175+
std::cout << "expected avx512 or avx2, but got " << gAsmType << std::endl;
1176+
return RNLErrorBadParameter;
1177+
}
11691178
if ((gBitDepth == 8 && curPix > gMin8bit && curPix < gMax8bit) ||
11701179
(gBitDepth != 8 && curPix > gMin16bit && curPix < gMax16bit))
11711180
pRaisr32f[rOffset * cols + c + pix] = curPix;

Library/Raisr_AVX256.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,32 @@ inline float sumitup_ps_256(__m256 acc)
5757
const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
5858
return _mm_cvtss_f32(r1);
5959
}
60+
61+
// AVX2 version: for now, gPatchSize must be <= 16 because we can work with up to 16 float32s in two AVX256 registers.
62+
float inline DotProdPatch_AVX256_32f(const float *buf, const float *filter)
63+
{
64+
__m256 a1_ps = _mm256_load_ps(buf);
65+
__m256 b1_ps = _mm256_load_ps(filter);
66+
__m256 a2_ps = _mm256_load_ps(buf+8);
67+
__m256 b2_ps = _mm256_load_ps(filter+8);
68+
69+
__m256 sum1 = _mm256_mul_ps(a1_ps, b1_ps);
70+
__m256 sum2 = _mm256_mul_ps(a2_ps, b2_ps);
71+
72+
#pragma unroll
73+
for (int i = 1; i < 8; i++)
74+
{
75+
a1_ps = _mm256_load_ps(buf + i * 16);
76+
a2_ps = _mm256_load_ps(buf + i * 16 + 8);
77+
b1_ps = _mm256_load_ps(filter + i * 16);
78+
b2_ps = _mm256_load_ps(filter + i * 16 + 8);
79+
80+
// compute dot prod using fmadd
81+
sum1 = _mm256_fmadd_ps(a1_ps, b1_ps, sum1);
82+
sum2 = _mm256_fmadd_ps(a2_ps, b2_ps, sum2);
83+
}
84+
85+
// sumitup adds all 16 float values in sum(zmm) and returns a single float value
86+
return sumitup_ps_256(_mm256_add_ps(sum1, sum2));
87+
}
88+

Library/Raisr_AVX256.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ inline void load3x3_ps(float *img, unsigned int width, unsigned int height, unsi
2525
*out_8neighbors_ps = _mm256_insertf128_ps(_mm256_castps128_ps256(rowlo_f), rowhi_f, 1);
2626
}
2727

28-
2928
inline __m256i compare3x3_ps(__m256 a, __m256 b, __m256i highbit_epi32);
3029
inline int sumitup_256_epi32(__m256i acc);
31-
int inline CTRandomness_AVX256_32f(float *inYUpscaled32f, int cols, int r, int c, int pix);
3230
inline float sumitup_ps_256(__m256 acc);
31+
32+
int inline CTRandomness_AVX256_32f(float *inYUpscaled32f, int cols, int r, int c, int pix);
33+
float inline DotProdPatch_AVX256_32f(const float *buf, const float *filter);

0 commit comments

Comments
 (0)