Skip to content

Commit b74992f

Browse files
cabirdmexiaoxial
authored andcommitted
added computeGTWG_Segment AVX2
1 parent ea6e83b commit b74992f

5 files changed

Lines changed: 190 additions & 31 deletions

File tree

Library/Raisr.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,15 @@ RNLERRORTYPE processSegment(VideoDataType *srcY, VideoDataType *final_outY, Blen
11471147
#pragma unroll(unrollSizePatchBased / 2)
11481148
for (pix = 0; pix < unrollSizePatchBased / 2; pix++)
11491149
{
1150-
computeGTWG_Segment_AVX512_32f(pSeg32f, rows, cols, rOffset, c + 2 * pix, &GTWG[2 * pix], &pixbuf[2 * pix][0], &pixbuf[2 * pix + 1][0]);
1150+
if (gAsmType == AVX2)
1151+
computeGTWG_Segment_AVX256_32f(pSeg32f, rows, cols, rOffset, c + 2 * pix, &GTWG[2 * pix], &pixbuf[2 * pix][0], &pixbuf[2 * pix + 1][0]);
1152+
else if (gAsmType == AVX512)
1153+
computeGTWG_Segment_AVX512_32f(pSeg32f, rows, cols, rOffset, c + 2 * pix, &GTWG[2 * pix], &pixbuf[2 * pix][0], &pixbuf[2 * pix + 1][0]);
1154+
else
1155+
{
1156+
std::cout << "expected avx512 or avx2, but got " << gAsmType << std::endl;
1157+
return RNLErrorBadParameter;
1158+
}
11511159
}
11521160

11531161
GetHashValue_AVX256_32f(GTWG, passIdx, hashValue);

Library/Raisr_AVX256.cpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,171 @@ inline float sumitup_ps_256(__m256 acc)
5858
return _mm_cvtss_f32(r1);
5959
}
6060

61+
inline __m256 shiftL_AVX256(__m256 r)
62+
{
63+
return _mm256_permutevar8x32_ps(r, _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1));
64+
}
65+
66+
inline __m256 shiftR_AVX256(__m256 r)
67+
{
68+
return _mm256_permutevar8x32_ps(r, _mm256_set_epi32(6, 5, 4, 3, 2, 1, 0, 7));
69+
}
70+
71+
inline __m256 GetGx_AVX256(__m256 r1, __m256 r3)
72+
{
73+
return _mm256_sub_ps(r3, r1);
74+
}
75+
76+
inline __m256 GetGy_AVX256(__m256 r2)
77+
{
78+
return _mm256_sub_ps(shiftL_AVX256(r2), shiftR_AVX256(r2));
79+
}
80+
81+
inline __m128 GetFirstHalf(__m256 n)
82+
{
83+
return _mm256_extractf128_ps(n, 0);
84+
}
85+
86+
inline __m128 GetLastHalf(__m256 n)
87+
{
88+
return _mm256_extractf128_ps(n, 1);
89+
}
90+
91+
template <int halfIndex>
92+
inline __m256 SetFirstVal(__m256 n, __m128 halfWithValue) {
93+
__m128 newHalf = _mm_insert_ps(_mm256_extractf128_ps(n, 0), halfWithValue, halfIndex);
94+
return _mm256_insertf128_ps(n, newHalf, 0);
95+
}
96+
97+
template <int halfIndex>
98+
inline __m256 SetLastVal(__m256 n, __m128 halfWithValue) {
99+
__m128 newHalf = _mm_insert_ps(_mm256_extractf128_ps(n, 1), halfWithValue, halfIndex);
100+
return _mm256_insertf128_ps(n, newHalf, 1);
101+
}
102+
103+
inline __m256 GetGy_AVX256Hi(__m256 xlo, __m256 xhi)
104+
{
105+
// ideally we do some cross lane permute, but one doesnt seem to exist. Our approach instead is to save the original values,
106+
// do our in-lane permutes, then insert additional values on the ends to achieve correct behavior
107+
__m128 xlohi = GetLastHalf(xlo);
108+
__m128 xlolo = GetFirstHalf(xlo);
109+
110+
__m256 newloLeft = SetLastVal<0x30>(shiftL_AVX256(xhi), xlolo);
111+
__m256 newloRight = SetFirstVal<0xC0>(shiftR_AVX256(xhi), xlohi);
112+
__m256 ret = _mm256_sub_ps(newloLeft, newloRight);
113+
return ret;
114+
}
115+
116+
inline __m256 GetGy_AVX256Lo(__m256 xlo, __m256 xhi)
117+
{
118+
// ideally we do some cross lane permute, but one doesnt seem to exist. Our approach instead is to save the original values,
119+
// do our in-lane permutes, then insert additional values on the ends to achieve correct behavior
120+
__m128 xhilo = GetFirstHalf(xhi);
121+
__m128 xhihi = GetLastHalf(xhi);
122+
__m256 newloLeft = SetLastVal <0x30>(shiftL_AVX256(xlo), xhilo);
123+
__m256 newloRight = SetFirstVal<0xC0>(shiftR_AVX256(xlo), xhihi);
124+
125+
__m256 ret = _mm256_sub_ps(newloLeft, newloRight);
126+
return ret;
127+
}
128+
129+
inline __m256 GetGTWG_AVX256(__m256 acc, __m256 a, __m256 w, __m256 b)
130+
{
131+
return _mm256_fmadd_ps(_mm256_mul_ps(a, w), b, acc);
132+
}
133+
134+
void inline computeGTWG_Segment_AVX256_32f(const float *img, const int nrows, const int ncols, const int r, const int col, float GTWG[][4], float *buf1, float *buf2)
135+
{
136+
// offset is the starting position(top left) of the block which centered by (r, c)
137+
int offset = (r - gLoopMargin) * ncols + col - gLoopMargin;
138+
const float *p1 = img + offset;
139+
140+
__m256 gtwg0A1 = _mm256_setzero_ps(), gtwg0A2 = _mm256_setzero_ps();
141+
__m256 gtwg0B1 = _mm256_setzero_ps(), gtwg0B2 = _mm256_setzero_ps();
142+
__m256 gtwg1A1 = _mm256_setzero_ps(), gtwg1A2 = _mm256_setzero_ps();
143+
__m256 gtwg1B1 = _mm256_setzero_ps(), gtwg1B2 = _mm256_setzero_ps();
144+
__m256 gtwg3A1 = _mm256_setzero_ps(), gtwg3A2 = _mm256_setzero_ps();
145+
__m256 gtwg3B1 = _mm256_setzero_ps(), gtwg3B2 = _mm256_setzero_ps();
146+
147+
// load 2 rows
148+
__m256 a1 = _mm256_loadu_ps(p1);
149+
__m256 a2 = _mm256_loadu_ps(p1+8);
150+
p1 += ncols;
151+
__m256 b1 = _mm256_loadu_ps(p1);
152+
__m256 b2 = _mm256_loadu_ps(p1+8);
153+
#pragma unroll
154+
for (int i = 0; i < gPatchSize; i++)
155+
{
156+
// process patchSize rows
157+
// load next row
158+
p1 += ncols;
159+
__m256 c1 = _mm256_loadu_ps(p1);
160+
__m256 c2 = _mm256_loadu_ps(p1+8);
161+
__m256 w1, w2;
162+
if(gBitDepth == 8) {
163+
w1 = _mm256_loadu_ps(gGaussian2D8bit[i]);
164+
w2 = _mm256_loadu_ps(gGaussian2D8bit[i]+8);
165+
} else if (gBitDepth == 10) {
166+
w1 = _mm256_loadu_ps(gGaussian2D10bit[i]);
167+
w2 = _mm256_loadu_ps(gGaussian2D10bit[i]+8);
168+
} else {
169+
w1 = _mm256_loadu_ps(gGaussian2D16bit[i]);
170+
w2 = _mm256_loadu_ps(gGaussian2D16bit[i]+8);
171+
}
172+
173+
const __m256 gxi1 = GetGx_AVX256(a1, c1);
174+
const __m256 gxi2 = GetGx_AVX256(a2, c2);
175+
176+
const __m256 gyi1 = GetGy_AVX256Lo(b1,b2);
177+
const __m256 gyi2 = GetGy_AVX256Hi(b1,b2);
178+
179+
gtwg0A1 = GetGTWG_AVX256(gtwg0A1, gxi1, w1, gxi1);
180+
gtwg0A2 = GetGTWG_AVX256(gtwg0A2, gxi2, w2, gxi2);
181+
gtwg1A1 = GetGTWG_AVX256(gtwg1A1, gxi1, w1, gyi1);
182+
gtwg1A2 = GetGTWG_AVX256(gtwg1A2, gxi2, w2, gyi2);
183+
gtwg3A1 = GetGTWG_AVX256(gtwg3A1, gyi1, w1, gyi1);
184+
gtwg3A2 = GetGTWG_AVX256(gtwg3A2, gyi2, w2, gyi2);
185+
186+
// Store last bit for shiftR and mask
187+
__m128 xlohi = GetLastHalf(w1);
188+
__m128 xhihi = GetLastHalf(w2);
189+
w1 = SetFirstVal<0xC0>(shiftR_AVX256(w1), xhihi);
190+
w2 = SetFirstVal<0xC0>(shiftR_AVX256(w2), xlohi);
191+
192+
gtwg0B1 = GetGTWG_AVX256(gtwg0B1, gxi1, w1, gxi1);
193+
gtwg0B2 = GetGTWG_AVX256(gtwg0B2, gxi2, w2, gxi2);
194+
gtwg1B1 = GetGTWG_AVX256(gtwg1B1, gxi1, w1, gyi1);
195+
gtwg1B2 = GetGTWG_AVX256(gtwg1B2, gxi2, w2, gyi2);
196+
gtwg3B1 = GetGTWG_AVX256(gtwg3B1, gyi1, w1, gyi1);
197+
gtwg3B2 = GetGTWG_AVX256(gtwg3B2, gyi2, w2, gyi2);
198+
199+
// skip one, store next 11 bits. The two masks are 0xfe, 0x0f
200+
int lastbit = 0x80000000;
201+
_mm256_maskstore_ps(buf1 + gPatchSize * i - 1, _mm256_setr_epi32(0, lastbit, lastbit, lastbit, lastbit, lastbit, lastbit, lastbit), b1);
202+
_mm256_maskstore_ps(buf1 + gPatchSize * i - 1 + 8, _mm256_setr_epi32(lastbit, lastbit, lastbit, lastbit, 0,0,0,0), b2);
203+
// skip two, store next 11 bits. The two masks are 0xfc, 0x1f
204+
_mm256_maskstore_ps(buf2 + gPatchSize * i - 2, _mm256_setr_epi32(0,0,lastbit,lastbit,lastbit,lastbit,lastbit,lastbit), b1);
205+
_mm256_maskstore_ps(buf2 + gPatchSize * i - 2 + 8, _mm256_setr_epi32(lastbit,lastbit,lastbit,lastbit,lastbit,0,0,0), b2);
206+
a1 = b1;
207+
a2 = b2;
208+
b1 = c1;
209+
b2 = c2;
210+
}
211+
212+
GTWG[0][0] = sumitup_ps_256(_mm256_add_ps(gtwg0A1, gtwg0A2));
213+
GTWG[0][1] = sumitup_ps_256(_mm256_add_ps(gtwg1A1, gtwg1A2));
214+
GTWG[0][3] = sumitup_ps_256(_mm256_add_ps(gtwg3A1, gtwg3A2));
215+
GTWG[0][2] = GTWG[0][1];
216+
217+
GTWG[1][0] = sumitup_ps_256(_mm256_add_ps(gtwg0B1, gtwg0B2));
218+
GTWG[1][1] = sumitup_ps_256(_mm256_add_ps(gtwg1B1, gtwg1B2));
219+
GTWG[1][3] = sumitup_ps_256(_mm256_add_ps(gtwg3B1, gtwg3B2));
220+
GTWG[1][2] = GTWG[1][1];
221+
222+
return;
223+
}
224+
225+
61226
// AVX2 version: for now, gPatchSize must be <= 16 because we can work with up to 16 float32s in two AVX256 registers.
62227
float inline DotProdPatch_AVX256_32f(const float *buf, const float *filter)
63228
{

Library/Raisr_AVX256.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ inline void load3x3_ps(float *img, unsigned int width, unsigned int height, unsi
2525
*out_8neighbors_ps = _mm256_insertf128_ps(_mm256_castps128_ps256(rowlo_f), rowhi_f, 1);
2626
}
2727

28-
inline __m256i compare3x3_AVX256_32f(__m256 a, __m256 b, __m256i highbit_epi32);
29-
inline int sumitup_256_epi32(__m256i acc);
30-
inline float sumitup_ps_256(__m256 acc);
31-
3228
int inline CTRandomness_AVX256_32f(float *inYUpscaled32f, int cols, int r, int c, int pix);
3329
float inline DotProdPatch_AVX256_32f(const float *buf, const float *filter);
30+
void inline computeGTWG_Segment_AVX256_32f(const float *img, const int nrows, const int ncols, const int r, const int col, float GTWG[][4], float *buf1, float *buf2);

Library/Raisr_AVX512.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,26 @@ inline float sumitup_ps_512(__m512 acc)
4040
const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2));
4141
return _mm_cvtss_f32(r1);
4242
}
43-
inline __m512 shiftL(__m512 r)
43+
inline __m512 shiftL_AVX512(__m512 r)
4444
{
4545
return _mm512_permutexvar_ps(_mm512_set_epi32(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), r);
4646
}
47-
inline __m512 shiftR(__m512 r)
47+
inline __m512 shiftR_AVX512(__m512 r)
4848
{
4949
return _mm512_permutexvar_ps(_mm512_set_epi32(14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 15), r);
5050
}
5151

52-
inline __m512 GetGx(__m512 r1, __m512 r3)
52+
inline __m512 GetGx_AVX512(__m512 r1, __m512 r3)
5353
{
5454
return _mm512_sub_ps(r3, r1);
5555
}
5656

57-
inline __m512 GetGy(__m512 r2)
57+
inline __m512 GetGy_AVX512(__m512 r2)
5858
{
59-
return _mm512_sub_ps(shiftL(r2), shiftR(r2));
59+
return _mm512_sub_ps(shiftL_AVX512(r2), shiftR_AVX512(r2));
6060
}
6161

62-
inline __m512 GetGTWG(__m512 acc, __m512 a, __m512 w, __m512 b)
62+
inline __m512 GetGTWG_AVX512(__m512 acc, __m512 a, __m512 w, __m512 b)
6363
{
6464
return _mm512_fmadd_ps(_mm512_mul_ps(a, w), b, acc);
6565
}
@@ -80,9 +80,6 @@ void inline computeGTWG_Segment_AVX512_32f(const float *img, const int nrows, co
8080
#pragma unroll
8181
for (int i = 0; i < gPatchSize; i++)
8282
{
83-
// memcpy(buf1+gPatchSize*i, p1+1, sizeof(float)*gPatchSize);
84-
// memcpy(buf2+gPatchSize*i, p1+2, sizeof(float)*gPatchSize);
85-
8683
// process patchSize rows
8784
// load next row
8885
p1 += ncols;
@@ -101,17 +98,17 @@ void inline computeGTWG_Segment_AVX512_32f(const float *img, const int nrows, co
10198
w = _mm512_loadu_ps(gGaussian2D16bit[i]);
10299
}
103100

104-
const __m512 gxi = GetGx(a, c);
105-
const __m512 gyi = GetGy(b);
101+
const __m512 gxi = GetGx_AVX512(a, c);
102+
const __m512 gyi = GetGy_AVX512(b);
106103

107-
gtwg0A = GetGTWG(gtwg0A, gxi, w, gxi);
108-
gtwg1A = GetGTWG(gtwg1A, gxi, w, gyi);
109-
gtwg3A = GetGTWG(gtwg3A, gyi, w, gyi);
104+
gtwg0A = GetGTWG_AVX512(gtwg0A, gxi, w, gxi);
105+
gtwg1A = GetGTWG_AVX512(gtwg1A, gxi, w, gyi);
106+
gtwg3A = GetGTWG_AVX512(gtwg3A, gyi, w, gyi);
110107

111-
w = shiftR(w);
112-
gtwg0B = GetGTWG(gtwg0B, gxi, w, gxi);
113-
gtwg1B = GetGTWG(gtwg1B, gxi, w, gyi);
114-
gtwg3B = GetGTWG(gtwg3B, gyi, w, gyi);
108+
w = shiftR_AVX512(w);
109+
gtwg0B = GetGTWG_AVX512(gtwg0B, gxi, w, gxi);
110+
gtwg1B = GetGTWG_AVX512(gtwg1B, gxi, w, gyi);
111+
gtwg3B = GetGTWG_AVX512(gtwg3B, gyi, w, gyi);
115112

116113
_mm512_mask_storeu_ps(buf1 + gPatchSize * i - 1, 0x0ffe, b);
117114
_mm512_mask_storeu_ps(buf2 + gPatchSize * i - 2, 0x1ffc, b);

Library/Raisr_AVX512.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,6 @@
77
#pragma once
88
#include <immintrin.h>
99

10-
inline __mmask8 compare3x3_ps_AVX512(__m256 a, __m256 b);
11-
inline float sumitup_ps_512(__m512 acc);
12-
inline __m512 shiftL(__m512 r);
13-
inline __m512 shiftR(__m512 r);
14-
inline __m512 GetGx(__m512 r1, __m512 r3);
15-
inline __m512 GetGy(__m512 r2);
16-
inline __m512 GetGTWG(__m512 acc, __m512 a, __m512 w, __m512 b);
17-
1810
void inline computeGTWG_Segment_AVX512_32f(const float *img, const int nrows, const int ncols, const int r, const int col, float GTWG[][4], float *buf1, float *buf2);
1911
int inline CTRandomness_AVX512_32f(float *inYUpscaled32f, int cols, int r, int c, int pix);
2012
float inline DotProdPatch_AVX512_32f(const float *buf, const float *filter);

0 commit comments

Comments
 (0)