Skip to content

Commit 471e6ed

Browse files
committed
allow base conn of TLS conn to be tagged
1 parent 871d35d commit 471e6ed

2 files changed

Lines changed: 67 additions & 0 deletions

File tree

main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ func run() int {
785785
mainLogger.Critical("TLS config construction failed: %v", err1)
786786
return 3
787787
}
788+
listener = tlsutil.NewTaggedConnListener(listener) // attach DTO container
788789
listener = tls.NewListener(listener, cfg)
789790
} else if args.autocert {
790791
// cert caching chain
@@ -850,6 +851,7 @@ func run() int {
850851
if len(cfg.NextProtos) > 0 {
851852
cfg.NextProtos = append(cfg.NextProtos, acme.ALPNProto)
852853
}
854+
listener = tlsutil.NewTaggedConnListener(listener) // attach DTO container
853855
listener = tls.NewListener(listener, cfg)
854856
}
855857
defer listener.Close()

tlsutil/taggedconn.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package tlsutil
2+
3+
import (
4+
"net"
5+
"sync"
6+
)
7+
8+
type ConnTagger interface {
9+
GetTag(any) (any, bool)
10+
SetTag(any, any)
11+
}
12+
13+
type TaggedConn struct {
14+
net.Conn
15+
mux sync.RWMutex
16+
tags map[any]any
17+
}
18+
19+
func (c *TaggedConn) SetTag(key, value any) {
20+
c.mux.Lock()
21+
defer c.mux.Unlock()
22+
if c.tags == nil {
23+
c.tags = make(map[any]any)
24+
}
25+
c.tags[key] = value
26+
}
27+
28+
func (c *TaggedConn) GetTag(key any) (any, bool) {
29+
c.mux.RLock()
30+
defer c.mux.RUnlock()
31+
value, ok := c.tags[key]
32+
return value, ok
33+
}
34+
35+
func NewTaggedConn(conn net.Conn) *TaggedConn {
36+
return &TaggedConn{
37+
Conn: conn,
38+
}
39+
}
40+
41+
type TaggedConnListener struct {
42+
net.Listener
43+
}
44+
45+
func (l TaggedConnListener) Accept() (net.Conn, error) {
46+
conn, err := l.Listener.Accept()
47+
if err != nil {
48+
return nil, err
49+
}
50+
_, ok := conn.(ConnTagger)
51+
if ok {
52+
return conn, nil
53+
}
54+
return NewTaggedConn(conn), nil
55+
}
56+
57+
func NewTaggedConnListener(l net.Listener) TaggedConnListener {
58+
return TaggedConnListener{
59+
Listener: l,
60+
}
61+
}
62+
63+
var _ net.Conn = new(TaggedConn)
64+
var _ net.Listener = TaggedConnListener{}
65+
var _ ConnTagger = new(TaggedConn)

0 commit comments

Comments
 (0)