Skip to content

Commit f0629e5

Browse files
authored
Merge pull request #201 from SenseUnit/persistent_tls_ticket_cache
Persistent TLS ticket cache
2 parents 6e83d58 + c54c306 commit f0629e5

6 files changed

Lines changed: 302 additions & 3 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ Usage of /home/user/go/bin/dumbproxy:
648648
mark TLS sessions with cookie-like unique session IDs (default true)
649649
-tls-session-key value
650650
override TLS server session keys. Key must be provided as hex-encoded 32-byte string. This option can be repeated multiple times, first key will be used to create session tickets. Empty value resets the list.
651+
-tls-session-cache-db string
652+
location of TLS client session cache DB
651653
-trusttunnel
652654
enable TrustTunnel protocol extensions (default true)
653655
-unix-sock-mode value

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ require (
1515
github.com/refraction-networking/utls v1.8.2
1616
github.com/tg123/go-htpasswd v1.2.4
1717
github.com/things-go/go-socks5 v0.1.0
18+
go.etcd.io/bbolt v1.4.3
1819
golang.org/x/crypto v0.48.0
1920
golang.org/x/crypto/x509roots/fallback v0.0.0-20260213171211-a408498e5541
2021
golang.org/x/net v0.51.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ
5959
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
6060
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
6161
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
62+
go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo=
63+
go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E=
6264
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
6365
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
6466
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=

main.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ type CLIArgs struct {
315315
tlsALPNEnabled bool
316316
tlsSessionKeys [][32]byte
317317
tlsCookies bool
318+
tlsSessionCacheDB string
318319
bwLimit forward.LimitSpec
319320
bwBurst int64
320321
bwSeparate bool
@@ -507,6 +508,7 @@ func parse_args() *CLIArgs {
507508
return nil
508509
})
509510
flag.BoolVar(&args.tlsCookies, "tls-cookies", true, "mark TLS sessions with cookie-like unique session IDs")
511+
flag.StringVar(&args.tlsSessionCacheDB, "tls-session-cache-db", "", "location of TLS client session cache DB")
510512
flag.Func("config", "read configuration from file with space-separated keys and values", readConfig)
511513
flag.Parse()
512514
// pull up remaining parameters from other BW-related arguments
@@ -617,6 +619,19 @@ func run() int {
617619
tlsSessionLogger := clog.NewCondLogger(log.New(logWriter, "TLSSESS :",
618620
log.LstdFlags|log.Lshortfile),
619621
args.verbosity)
622+
tlsCacheLogger := clog.NewCondLogger(log.New(logWriter, "TLSCACHE:",
623+
log.LstdFlags|log.Lshortfile),
624+
args.verbosity)
625+
626+
// setup TLS session cache
627+
if args.tlsSessionCacheDB != "" {
628+
cache, err := tlsutil.NewPersistentClientSessionCache(args.tlsSessionCacheDB, tlsCacheLogger)
629+
if err != nil {
630+
mainLogger.Critical("Failed to instantiate TLS session cache: %v", err)
631+
return 3
632+
}
633+
tlsutil.SessionCache = cache
634+
}
620635

621636
// setup auth provider
622637
authProvider, err := auth.NewAuth(args.auth, authLogger)

tlsutil/ccache.go

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
package tlsutil
2+
3+
import (
4+
"bytes"
5+
"crypto/tls"
6+
"encoding/binary"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"os"
11+
"path/filepath"
12+
"runtime/debug"
13+
"time"
14+
15+
"go.etcd.io/bbolt"
16+
17+
clog "github.com/SenseUnit/dumbproxy/log"
18+
)
19+
20+
var (
21+
bucketName = []byte("tickets")
22+
currentFormatVersion = getSessionFormatVer()
23+
ErrFormatVersionMismatch = errors.New("format version mismatch")
24+
)
25+
26+
func getSessionFormatVer() string {
27+
bi, ok := debug.ReadBuildInfo()
28+
if !ok {
29+
return "unknown"
30+
}
31+
return bi.GoVersion
32+
}
33+
34+
type sessionCacheEntry struct {
35+
formatVersion string
36+
ticket []byte
37+
state []byte
38+
}
39+
40+
func uintToVarintBytes(x uint) []byte {
41+
buf := make([]byte, binary.MaxVarintLen64)
42+
n := binary.PutUvarint(buf, uint64(x))
43+
return buf[:n]
44+
}
45+
46+
func (e *sessionCacheEntry) MarshalBinary() (data []byte, err error) {
47+
buf := new(bytes.Buffer)
48+
_, _ = buf.Write(uintToVarintBytes(uint(len(e.formatVersion))))
49+
_, _ = buf.WriteString(e.formatVersion)
50+
_, _ = buf.Write(uintToVarintBytes(uint(len(e.ticket))))
51+
_, _ = buf.Write(e.ticket)
52+
_, _ = buf.Write(uintToVarintBytes(uint(len(e.state))))
53+
_, _ = buf.Write(e.state)
54+
return append([]byte(nil), buf.Bytes()...), nil
55+
}
56+
57+
func (e *sessionCacheEntry) UnmarshalBinary(data []byte) error {
58+
r := bytes.NewReader(data)
59+
60+
formatVerLen, err := binary.ReadUvarint(r)
61+
if err != nil {
62+
return fmt.Errorf("unable to read length of format version field: %w", err)
63+
}
64+
formatVerBytes := make([]byte, formatVerLen)
65+
_, err = io.ReadFull(r, formatVerBytes)
66+
if err != nil {
67+
return fmt.Errorf("unable to read format version field: %w", err)
68+
}
69+
70+
ticketLen, err := binary.ReadUvarint(r)
71+
if err != nil {
72+
return fmt.Errorf("unable to read length of ticket field: %w", err)
73+
}
74+
ticketBytes := make([]byte, ticketLen)
75+
_, err = io.ReadFull(r, ticketBytes)
76+
if err != nil {
77+
return fmt.Errorf("unable to read ticket field: %w", err)
78+
}
79+
80+
stateLen, err := binary.ReadUvarint(r)
81+
if err != nil {
82+
return fmt.Errorf("unable to read length of state field: %w", err)
83+
}
84+
stateBytes := make([]byte, stateLen)
85+
_, err = io.ReadFull(r, stateBytes)
86+
if err != nil {
87+
return fmt.Errorf("unable to read state field: %w", err)
88+
}
89+
90+
e.formatVersion = string(formatVerBytes)
91+
e.ticket = ticketBytes
92+
e.state = stateBytes
93+
94+
return nil
95+
}
96+
97+
func clientSessionStateToBytes(cs *tls.ClientSessionState) ([]byte, error) {
98+
ticket, state, err := cs.ResumptionState()
99+
if err != nil {
100+
return nil, err
101+
}
102+
stateBytes, err := state.Bytes()
103+
if err != nil {
104+
return nil, err
105+
}
106+
return (&sessionCacheEntry{
107+
formatVersion: currentFormatVersion,
108+
ticket: ticket,
109+
state: stateBytes,
110+
}).MarshalBinary()
111+
}
112+
113+
func clientSessionStateFromBytes(data []byte) (*tls.ClientSessionState, error) {
114+
sce := new(sessionCacheEntry)
115+
err := sce.UnmarshalBinary(data)
116+
if err != nil {
117+
return nil, fmt.Errorf("TLS session state unmarshaling failed: %w", err)
118+
}
119+
if sce.formatVersion != currentFormatVersion {
120+
return nil, ErrFormatVersionMismatch
121+
}
122+
ss, err := tls.ParseSessionState(sce.state)
123+
if err != nil {
124+
return nil, fmt.Errorf("unable to parse TLS client session state: %w", err)
125+
}
126+
cs, err := tls.NewResumptionState(sce.ticket, ss)
127+
if err != nil {
128+
return nil, fmt.Errorf("unable to construct new resumption state: %w", err)
129+
}
130+
return cs, nil
131+
}
132+
133+
var SessionCache tls.ClientSessionCache = tls.NewLRUClientSessionCache(0)
134+
135+
type PersistentClientSessionCache struct {
136+
db *bbolt.DB
137+
logger *clog.CondLogger
138+
}
139+
140+
func NewPersistentClientSessionCache(path string, logger *clog.CondLogger) (*PersistentClientSessionCache, error) {
141+
dir := filepath.Dir(path)
142+
if err := os.MkdirAll(dir, 0700); err != nil {
143+
return nil, err
144+
}
145+
db, err := bbolt.Open(path, 0600, &bbolt.Options{
146+
Timeout: 5 * time.Second,
147+
Logger: bboltLogger{logger},
148+
})
149+
if err != nil {
150+
return nil, err
151+
}
152+
return &PersistentClientSessionCache{
153+
db: db,
154+
logger: logger,
155+
}, nil
156+
}
157+
158+
func (cache *PersistentClientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
159+
var data []byte
160+
err := cache.db.View(func(tx *bbolt.Tx) error {
161+
bucket := tx.Bucket(bucketName)
162+
if bucket == nil {
163+
return nil
164+
}
165+
data = bucket.Get([]byte(sessionKey))
166+
return nil
167+
})
168+
if err != nil {
169+
cache.logger.Error("cache db: key %q read failed: %v", err)
170+
return nil, false
171+
}
172+
if data == nil {
173+
return nil, false
174+
}
175+
cs, err := clientSessionStateFromBytes(data)
176+
if err != nil {
177+
if err == ErrFormatVersionMismatch {
178+
cache.logger.Debug("rejected cached ticket for key %q due to version mismatch", sessionKey)
179+
} else {
180+
cache.logger.Error("cached session recovery failed: %v", err)
181+
}
182+
}
183+
return cs, true
184+
}
185+
186+
func (cache *PersistentClientSessionCache) delete(sessionKey string) {
187+
err := cache.db.Update(func(tx *bbolt.Tx) error {
188+
if bucket := tx.Bucket(bucketName); bucket != nil {
189+
return bucket.Delete([]byte(sessionKey))
190+
}
191+
return nil
192+
})
193+
if err != nil {
194+
cache.logger.Error("cache db: key %q delete failed: %v", sessionKey, err)
195+
}
196+
}
197+
198+
func (cache *PersistentClientSessionCache) put(sessionKey string, cs *tls.ClientSessionState) {
199+
csBytes, err := clientSessionStateToBytes(cs)
200+
if err != nil {
201+
cache.logger.Error("dropping client session state with key %q: unable to marshal client session state: %v", sessionKey, err)
202+
return
203+
}
204+
err = cache.db.Update(func(tx *bbolt.Tx) error {
205+
if bucket, err := tx.CreateBucketIfNotExists(bucketName); err == nil {
206+
return bucket.Put([]byte(sessionKey), csBytes)
207+
} else {
208+
return err
209+
}
210+
})
211+
if err != nil {
212+
cache.logger.Error("cache db: key %q write failed: %v", sessionKey, err)
213+
}
214+
}
215+
216+
func (cache *PersistentClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
217+
if cs == nil {
218+
cache.delete(sessionKey)
219+
return
220+
} else {
221+
cache.put(sessionKey, cs)
222+
}
223+
}
224+
225+
type bboltLogger struct {
226+
l *clog.CondLogger
227+
}
228+
229+
func (l bboltLogger) Debug(v ...interface{}) {
230+
l.l.Debug("%s", fmt.Sprint(v...))
231+
}
232+
233+
func (l bboltLogger) Debugf(format string, v ...interface{}) {
234+
l.l.Debug(format, v...)
235+
}
236+
237+
func (l bboltLogger) Error(v ...interface{}) {
238+
l.l.Error("%s", fmt.Sprint(v...))
239+
}
240+
241+
func (l bboltLogger) Errorf(format string, v ...interface{}) {
242+
l.l.Error(format, v...)
243+
}
244+
245+
func (l bboltLogger) Info(v ...interface{}) {
246+
l.l.Info("%s", fmt.Sprint(v...))
247+
}
248+
249+
func (l bboltLogger) Infof(format string, v ...interface{}) {
250+
l.l.Info(format, v...)
251+
}
252+
253+
func (l bboltLogger) Warning(v ...interface{}) {
254+
l.l.Warning("%s", fmt.Sprint(v...))
255+
}
256+
257+
func (l bboltLogger) Warningf(format string, v ...interface{}) {
258+
l.l.Warning(format, v...)
259+
}
260+
261+
func (l bboltLogger) Fatal(v ...interface{}) {
262+
l.l.Critical("%s", fmt.Sprint(v...))
263+
os.Exit(1)
264+
}
265+
266+
func (l bboltLogger) Fatalf(format string, v ...interface{}) {
267+
l.l.Critical(format, v...)
268+
os.Exit(1)
269+
}
270+
271+
func (l bboltLogger) Panic(v ...interface{}) {
272+
s := fmt.Sprint(v...)
273+
l.l.Critical("%s", s)
274+
panic(s)
275+
}
276+
277+
func (l bboltLogger) Panicf(format string, v ...interface{}) {
278+
s := fmt.Sprintf(format, v...)
279+
l.l.Critical("%s", s)
280+
panic(s)
281+
}

tlsutil/util.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import (
1212
utls "github.com/refraction-networking/utls"
1313
)
1414

15-
var sessionCache = tls.NewLRUClientSessionCache(0)
16-
1715
func ExpectPeerName(name string, roots *x509.CertPool) func(cs tls.ConnectionState) error {
1816
return func(cs tls.ConnectionState) error {
1917
opts := x509.VerifyOptions{
@@ -173,7 +171,7 @@ func TLSConfigFromURL(u *url.URL) (*tls.Config, error) {
173171
}
174172
tlsConfig := &tls.Config{
175173
ServerName: host,
176-
ClientSessionCache: sessionCache,
174+
ClientSessionCache: SessionCache,
177175
}
178176
if params.Has("cafile") {
179177
roots, err := LoadCAfile(params.Get("cafile"))

0 commit comments

Comments
 (0)