Skip to content

Commit 95064a2

Browse files
authored
Fix EqualFoldString (#92)
* add fixed default implementation and test * fix avx implementation
1 parent 3c246ad commit 95064a2

5 files changed

Lines changed: 231 additions & 123 deletions

File tree

ascii/ascii.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,22 @@ func hasMore32(x, n uint32) bool {
5454
func unsafeString(b []byte) string {
5555
return *(*string)(unsafe.Pointer(&b))
5656
}
57+
58+
var lower = [256]byte{
59+
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
60+
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
61+
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
62+
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
63+
0x40, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
64+
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
65+
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
66+
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
67+
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
68+
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
69+
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
70+
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
71+
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
72+
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
73+
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
74+
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
75+
}

ascii/ascii_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package ascii
22

33
import (
4+
"bytes"
5+
"fmt"
46
"strings"
57
"testing"
68
"unicode/utf8"
@@ -160,6 +162,37 @@ func TestHasSuffixFold(t *testing.T) {
160162
}
161163
}
162164

165+
func TestEqualFoldASCII(t *testing.T) {
166+
pairs := [...][2]byte{
167+
{0, ' '},
168+
{'@', '`'},
169+
{'[', '{'},
170+
{'_', 127},
171+
}
172+
173+
for _, pair := range pairs {
174+
t.Run(fmt.Sprintf("0x%02x=0x%02x", pair[0], pair[1]), func(t *testing.T) {
175+
for i := 1; i <= 256; i++ {
176+
a := bytes.Repeat([]byte{'x'}, i)
177+
b := bytes.Repeat([]byte{'X'}, i)
178+
179+
if !EqualFold(a, b) {
180+
t.Errorf("%q does not match %q", a, b)
181+
break
182+
}
183+
184+
a[0] = pair[0]
185+
b[0] = pair[1]
186+
187+
if EqualFold(a, b) {
188+
t.Errorf("%q matches %q", a, b)
189+
break
190+
}
191+
}
192+
})
193+
}
194+
}
195+
163196
func TestEqualFold(t *testing.T) {
164197
// Only test valid UTF-8 otherwise ToUpper/ToLower will convert invalid
165198
// characters to UTF-8 placeholders, which breaks the case-insensitive

ascii/equal_fold.go

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
//go:generate go run equal_fold_asm.go -out equal_fold_amd64.s -stubs equal_fold_amd64.go
22
package ascii
33

4-
import (
5-
"unsafe"
6-
)
4+
import "unsafe"
75

86
// EqualFold is a version of bytes.EqualFold designed to work on ASCII input
97
// instead of UTF-8.
@@ -35,13 +33,15 @@ func EqualFoldString(a, b string) bool {
3533
n := uintptr(len(a))
3634
p := *(*unsafe.Pointer)(unsafe.Pointer(&a))
3735
q := *(*unsafe.Pointer)(unsafe.Pointer(&b))
36+
c := byte(0)
37+
3838
// Pre-check to avoid the other tests that would all evaluate to false.
3939
// For very small strings, this helps reduce the processing overhead.
4040
if n >= 8 {
4141
// If there is more than 32 bytes to copy, use the AVX optimized version,
4242
// otherwise the overhead of the function call tends to be greater than
4343
// looping 2 or 3 times over 8 bytes.
44-
if n > 32 && asm.equalFoldAVX2 != nil {
44+
if n >= 32 && asm.equalFoldAVX2 != nil {
4545
if asm.equalFoldAVX2((*byte)(p), (*byte)(q), n) == 0 {
4646
return false
4747
}
@@ -50,51 +50,51 @@ func EqualFoldString(a, b string) bool {
5050
q = unsafe.Pointer(uintptr(q) + k)
5151
n -= k
5252
}
53-
54-
for n > 8 {
55-
const mask = 0xDFDFDFDFDFDFDFDF
56-
57-
if (*(*uint64)(p) & mask) != (*(*uint64)(q) & mask) {
53+
for n >= 8 {
54+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 0))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 0))]
55+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 1))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 1))]
56+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 2))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 2))]
57+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 3))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 3))]
58+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 4))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 4))]
59+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 5))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 5))]
60+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 6))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 6))]
61+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 7))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 7))]
62+
63+
if c != 0 {
5864
return false
5965
}
6066

6167
p = unsafe.Pointer(uintptr(p) + 8)
6268
q = unsafe.Pointer(uintptr(q) + 8)
6369
n -= 8
6470
}
65-
66-
if n == 8 {
67-
const mask = 0xDFDFDFDFDFDFDFDF
68-
return (*(*uint64)(p) & mask) == (*(*uint64)(q) & mask)
69-
}
70-
}
71-
72-
if n > 4 {
73-
const mask = 0xDFDFDFDF
74-
75-
if (*(*uint32)(p) & mask) != (*(*uint32)(q) & mask) {
76-
return false
77-
}
78-
79-
p = unsafe.Pointer(uintptr(p) + 4)
80-
q = unsafe.Pointer(uintptr(q) + 4)
81-
n -= 4
8271
}
8372

8473
switch n {
74+
case 7:
75+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 6))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 6))]
76+
fallthrough
77+
case 6:
78+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 5))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 5))]
79+
fallthrough
80+
case 5:
81+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 4))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 4))]
82+
fallthrough
8583
case 4:
86-
return (*(*uint32)(p) & 0xDFDFDFDF) == (*(*uint32)(q) & 0xDFDFDFDF)
84+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 3))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 3))]
85+
fallthrough
8786
case 3:
88-
x := uint32(*(*uint16)(p)) | uint32(*(*uint8)(unsafe.Pointer(uintptr(p) + 2)))<<16
89-
y := uint32(*(*uint16)(q)) | uint32(*(*uint8)(unsafe.Pointer(uintptr(q) + 2)))<<16
90-
return (x & 0xDFDFDF) == (y & 0xDFDFDF)
87+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 2))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 2))]
88+
fallthrough
9189
case 2:
92-
return (*(*uint16)(p) & 0xDFDF) == (*(*uint16)(q) & 0xDFDF)
90+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 1))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 1))]
91+
fallthrough
9392
case 1:
94-
return (*(*uint8)(p) & 0xDF) == (*(*uint8)(q) & 0xDF)
95-
default:
96-
return true
93+
c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 0))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 0))]
9794
}
95+
96+
return c == 0
97+
9898
}
9999

100100
func HasPrefixFoldString(s, prefix string) bool {

ascii/equal_fold_amd64.s

Lines changed: 82 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,60 +5,99 @@
55
// func equalFoldAVX2(a *byte, b *byte, n uintptr) int
66
// Requires: AVX, AVX2, SSE4.1
77
TEXT ·equalFoldAVX2(SB), NOSPLIT, $0-32
8-
MOVQ a+0(FP), AX
9-
MOVQ b+8(FP), CX
10-
MOVQ n+16(FP), DX
11-
SHRQ $0x04, DX
12-
MOVQ $0x0000000000000000, BX
13-
MOVQ $0xdfdfdfdfdfdfdfdf, BP
14-
PINSRQ $0x00, BP, X0
15-
PINSRQ $0x01, BP, X0
16-
VPBROADCASTQ X0, Y1
8+
MOVQ a+0(FP), CX
9+
MOVQ b+8(FP), DX
10+
MOVQ n+16(FP), BX
11+
XORQ AX, AX
12+
SHRQ $0x04, BX
13+
XORQ SI, SI
14+
MOVB $0x20, DI
15+
PINSRB $0x00, DI, X6
16+
VPBROADCASTB X6, Y6
17+
MOVB $0x1f, DI
18+
PINSRB $0x00, DI, X7
19+
VPBROADCASTB X7, Y7
20+
MOVB $0x9a, DI
21+
PINSRB $0x00, DI, X8
22+
VPBROADCASTB X8, Y8
23+
MOVB $0x01, DI
24+
PINSRB $0x00, DI, X9
25+
VPBROADCASTB X9, Y9
1726

1827
loop64:
19-
CMPQ DX, $0x04
20-
JL loop32
21-
VPAND (AX), Y1, Y2
22-
VPAND (CX), Y1, Y3
23-
VPCMPEQB Y3, Y2, Y2
24-
VPAND 32(AX), Y1, Y3
25-
VPAND 32(CX), Y1, Y4
28+
CMPQ BX, $0x04
29+
JB cmp32
30+
VMOVDQU (CX)(AX*1), Y0
31+
VMOVDQU 32(CX)(AX*1), Y3
32+
VMOVDQU (DX)(AX*1), Y1
33+
VMOVDQU 32(DX)(AX*1), Y4
34+
VXORPD Y0, Y1, Y1
35+
VPCMPEQB Y6, Y1, Y2
36+
VORPD Y6, Y0, Y0
37+
VPADDB Y7, Y0, Y0
38+
VPCMPGTB Y0, Y8, Y0
39+
VPAND Y2, Y0, Y0
40+
VPAND Y9, Y0, Y0
41+
VPSLLW $0x05, Y0, Y0
42+
VPCMPEQB Y1, Y0, Y0
43+
VXORPD Y3, Y4, Y4
44+
VPCMPEQB Y6, Y4, Y5
45+
VORPD Y6, Y3, Y3
46+
VPADDB Y7, Y3, Y3
47+
VPCMPGTB Y3, Y8, Y3
48+
VPAND Y5, Y3, Y3
49+
VPAND Y9, Y3, Y3
50+
VPSLLW $0x05, Y3, Y3
2651
VPCMPEQB Y4, Y3, Y3
27-
VPAND Y3, Y2, Y2
28-
VPMOVMSKB Y2, BP
29-
CMPL BP, $0xffffffff
30-
JNE done
52+
VPAND Y3, Y0, Y0
3153
ADDQ $0x40, AX
32-
ADDQ $0x40, CX
33-
SUBQ $0x04, DX
54+
SUBQ $0x04, BX
55+
VPMOVMSKB Y0, DI
56+
CMPL DI, $0xffffffff
57+
JNE done
3458
JMP loop64
3559

36-
loop32:
37-
CMPQ DX, $0x02
38-
JL loop16
39-
VPAND (AX), Y1, Y2
40-
VPAND (CX), Y1, Y3
41-
VPCMPEQB Y3, Y2, Y2
42-
VPMOVMSKB Y2, BP
43-
CMPL BP, $0xffffffff
44-
JNE done
60+
cmp32:
61+
CMPQ BX, $0x02
62+
JB cmp16
63+
VMOVDQU (CX)(AX*1), Y0
64+
VMOVDQU (DX)(AX*1), Y1
65+
VXORPD Y0, Y1, Y1
66+
VPCMPEQB Y6, Y1, Y2
67+
VORPD Y6, Y0, Y0
68+
VPADDB Y7, Y0, Y0
69+
VPCMPGTB Y0, Y8, Y0
70+
VPAND Y2, Y0, Y0
71+
VPAND Y9, Y0, Y0
72+
VPSLLW $0x05, Y0, Y0
73+
VPCMPEQB Y1, Y0, Y0
4574
ADDQ $0x20, AX
46-
ADDQ $0x20, CX
47-
SUBQ $0x02, DX
75+
SUBQ $0x02, BX
76+
VPMOVMSKB Y0, DI
77+
CMPL DI, $0xffffffff
78+
JNE done
4879

49-
loop16:
50-
CMPQ DX, $0x00
51-
JE equal
52-
VPAND (AX), X0, X1
53-
VPAND (CX), X0, X0
54-
VPCMPEQB X0, X1, X1
55-
VPMOVMSKB X1, BP
56-
CMPL BP, $0x0000ffff
80+
cmp16:
81+
CMPQ BX, $0x01
82+
JB equal
83+
VMOVDQU (CX)(AX*1), X0
84+
VMOVDQU (DX)(AX*1), X1
85+
VXORPD X0, X1, X1
86+
VPCMPEQB X6, X1, X2
87+
VORPD X6, X0, X0
88+
VPADDB X7, X0, X0
89+
VPCMPGTB X0, X8, X0
90+
VPAND X2, X0, X0
91+
VPAND X9, X0, X0
92+
VPSLLW $0x05, X0, X0
93+
VPCMPEQB X1, X0, X0
94+
VPMOVMSKB X0, DI
95+
CMPL DI, $0x0000ffff
5796
JNE done
5897

5998
equal:
60-
MOVQ $0x0000000000000001, BX
99+
MOVQ $0x0000000000000001, SI
61100

62101
done:
63-
MOVQ BX, ret+24(FP)
102+
MOVQ SI, ret+24(FP)
64103
RET

0 commit comments

Comments
 (0)