|
| 1 | +package tlsutil |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + crand "crypto/rand" |
| 7 | + "crypto/tls" |
| 8 | + "errors" |
| 9 | + "net" |
| 10 | + |
| 11 | + clog "github.com/SenseUnit/dumbproxy/log" |
| 12 | +) |
| 13 | + |
| 14 | +const tlsCookiePrefix = "dpSessionCookieV1=" |
| 15 | + |
| 16 | +type TLSSessionID = [16]byte |
| 17 | + |
| 18 | +func NewTLSSessionID() (res TLSSessionID) { |
| 19 | + crand.Read(res[:]) |
| 20 | + return |
| 21 | +} |
| 22 | + |
| 23 | +func TLSSessionIDFromState(ss *tls.SessionState) (TLSSessionID, bool) { |
| 24 | + for _, tag := range ss.Extra { |
| 25 | + if !bytes.HasPrefix(tag, []byte(tlsCookiePrefix)) { |
| 26 | + continue |
| 27 | + } |
| 28 | + tag = tag[len(tlsCookiePrefix):] |
| 29 | + if len(tag) != len(TLSSessionID{}) { |
| 30 | + continue |
| 31 | + } |
| 32 | + return TLSSessionID(tag), true |
| 33 | + } |
| 34 | + return TLSSessionID{}, false |
| 35 | +} |
| 36 | + |
| 37 | +type tlsSessionIDKey struct{} |
| 38 | +type connKey struct{} |
| 39 | + |
| 40 | +func getTLSSessionID(conn ConnTagger) (TLSSessionID, bool) { |
| 41 | + saved, ok := conn.GetTag(tlsSessionIDKey{}) |
| 42 | + if !ok { |
| 43 | + return TLSSessionID{}, false |
| 44 | + } |
| 45 | + val, ok := saved.(TLSSessionID) |
| 46 | + return val, ok |
| 47 | +} |
| 48 | + |
| 49 | +func setTLSSessionID(conn ConnTagger, sessionID TLSSessionID) { |
| 50 | + conn.SetTag(tlsSessionIDKey{}, sessionID) |
| 51 | +} |
| 52 | + |
| 53 | +func GetTLSSessionID(conn net.Conn) (TLSSessionID, bool) { |
| 54 | + tagger, ok := conn.(ConnTagger) |
| 55 | + if !ok { |
| 56 | + if netconner, ok := conn.(interface { |
| 57 | + NetConn() net.Conn |
| 58 | + }); ok { |
| 59 | + return GetTLSSessionID(netconner.NetConn()) |
| 60 | + } |
| 61 | + return TLSSessionID{}, false |
| 62 | + } |
| 63 | + return getTLSSessionID(tagger) |
| 64 | +} |
| 65 | + |
| 66 | +func TLSSessionIDToContext(ctx context.Context, conn net.Conn) context.Context { |
| 67 | + return context.WithValue(ctx, connKey{}, conn) |
| 68 | +} |
| 69 | + |
| 70 | +func TLSSessionIDFromContext(ctx context.Context) (TLSSessionID, bool) { |
| 71 | + val := ctx.Value(connKey{}) |
| 72 | + conn, ok := val.(net.Conn) |
| 73 | + if !ok { |
| 74 | + return TLSSessionID{}, false |
| 75 | + } |
| 76 | + return GetTLSSessionID(conn) |
| 77 | +} |
| 78 | + |
| 79 | +func EnableTLSCookies(cfg *tls.Config, logger *clog.CondLogger) *tls.Config { |
| 80 | + getConfig := func(chi *tls.ClientHelloInfo) (*tls.Config, error) { |
| 81 | + return cfg.Clone(), nil |
| 82 | + } |
| 83 | + if cfg.GetConfigForClient != nil { |
| 84 | + getConfig = cfg.GetConfigForClient |
| 85 | + } |
| 86 | + // this one will be returned as updated TLS config to outer function caller |
| 87 | + cfg = cfg.Clone() |
| 88 | + cfg.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { |
| 89 | + conn, ok := chi.Conn.(ConnTagger) |
| 90 | + remoteAddr := chi.Conn.RemoteAddr().String() |
| 91 | + if !ok { |
| 92 | + return nil, errors.New("tlsCfg.GetConfigForClient: connection does is not a ConnTagger") |
| 93 | + } |
| 94 | + // this one holds closures which capture conn |
| 95 | + cfg, err := getConfig(chi) |
| 96 | + if err != nil { |
| 97 | + return nil, err |
| 98 | + } |
| 99 | + cfg.UnwrapSession = func(identity []byte, cs tls.ConnectionState) (*tls.SessionState, error) { |
| 100 | + ss, err := cfg.DecryptTicket(identity, cs) |
| 101 | + if err != nil { |
| 102 | + logger.Error("got error from TLS session ticket decryption: %v", err) |
| 103 | + return nil, err |
| 104 | + } |
| 105 | + if ss == nil { |
| 106 | + // nothing was decrypted, issue a new session |
| 107 | + sessionID := NewTLSSessionID() |
| 108 | + logger.Debug("assigning NEW session ID %x to connection from %s", sessionID, remoteAddr) |
| 109 | + setTLSSessionID(conn, sessionID) |
| 110 | + return nil, nil |
| 111 | + } |
| 112 | + if sessionID, ok := TLSSessionIDFromState(ss); ok { |
| 113 | + // valid session ID in ticket |
| 114 | + logger.Debug("recovered session ID = %x from %s", sessionID, remoteAddr) |
| 115 | + setTLSSessionID(conn, sessionID) |
| 116 | + } else { |
| 117 | + // no valid session ID in ticket (migrating outdated ticket?) |
| 118 | + sessionID = NewTLSSessionID() |
| 119 | + logger.Debug("session ID was NOT recovered from ticket from %s. assigning NEW session ID %x", remoteAddr, sessionID) |
| 120 | + setTLSSessionID(conn, sessionID) |
| 121 | + } |
| 122 | + return ss, nil |
| 123 | + } |
| 124 | + cfg.WrapSession = func(cs tls.ConnectionState, ss *tls.SessionState) ([]byte, error) { |
| 125 | + // is there session in TLS session state already? |
| 126 | + if sessionID, found := TLSSessionIDFromState(ss); found { |
| 127 | + logger.Warning("sessionState from %s already has sessionID %x", remoteAddr, sessionID) |
| 128 | + setTLSSessionID(conn, sessionID) |
| 129 | + return cfg.EncryptTicket(cs, ss) |
| 130 | + } |
| 131 | + // did we had a chance to assign a session ID to this connection? |
| 132 | + sessionID, ok := getTLSSessionID(conn) |
| 133 | + if ok { |
| 134 | + logger.Debug("sending new TLS ticket with old session ID %x to remote %s", sessionID, remoteAddr) |
| 135 | + } else { |
| 136 | + sessionID = NewTLSSessionID() |
| 137 | + setTLSSessionID(conn, sessionID) |
| 138 | + logger.Debug("sending new TLS ticket with NEW session ID %x to remote %s", sessionID, remoteAddr) |
| 139 | + } |
| 140 | + cookie := append([]byte(tlsCookiePrefix), sessionID[:]...) |
| 141 | + ss.Extra = append(ss.Extra, cookie) |
| 142 | + return cfg.EncryptTicket(cs, ss) |
| 143 | + } |
| 144 | + return cfg, nil |
| 145 | + } |
| 146 | + return cfg |
| 147 | +} |
0 commit comments