|
1 | 1 | package tlsutil |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "context" |
4 | 5 | "crypto/tls" |
5 | 6 | "errors" |
6 | 7 | "net" |
7 | 8 | ) |
8 | 9 |
|
9 | 10 | type preservedKeyKey struct{} |
10 | | - |
11 | 11 | type nonDefaultKeyUsedKey struct{} |
| 12 | +type connKey struct{} |
12 | 13 |
|
13 | 14 | func saveConnKey(conn ConnTagger, key [32]byte) { |
14 | 15 | conn.SetTag(preservedKeyKey{}, key) |
@@ -48,6 +49,19 @@ func WasNonDefaultKeyUsed(conn net.Conn) bool { |
48 | 49 | return val |
49 | 50 | } |
50 | 51 |
|
| 52 | +func NonDefaultKeyUsedToContext(ctx context.Context, conn net.Conn) context.Context { |
| 53 | + return context.WithValue(ctx, connKey{}, conn) |
| 54 | +} |
| 55 | + |
| 56 | +func NonDefaultKeyUsedFromContext(ctx context.Context) bool { |
| 57 | + val := ctx.Value(connKey{}) |
| 58 | + conn, ok := val.(net.Conn) |
| 59 | + if !ok { |
| 60 | + return false |
| 61 | + } |
| 62 | + return WasNonDefaultKeyUsed(conn) |
| 63 | +} |
| 64 | + |
51 | 65 | func PreserveSessionKeys(cfg *tls.Config, keys [][32]byte) *tls.Config { |
52 | 66 | if len(keys) < 2 { |
53 | 67 | // there's just one key defined, nothing to do |
@@ -88,13 +102,14 @@ func PreserveSessionKeys(cfg *tls.Config, keys [][32]byte) *tls.Config { |
88 | 102 | return nil, nil |
89 | 103 | } |
90 | 104 | cfg.WrapSession = func(cs tls.ConnectionState, ss *tls.SessionState) ([]byte, error) { |
| 105 | + skCfg := cfg.Clone() |
| 106 | + skCfg.SessionTicketKey = [32]byte{} |
| 107 | + key := keys[0] |
91 | 108 | // is there previous key? if so, use it |
92 | | - if key, ok := getConnKey(conn); ok { |
93 | | - skCfg := cfg.Clone() |
94 | | - skCfg.SessionTicketKey = [32]byte{} |
95 | | - skCfg.SetSessionTicketKeys([][32]byte{key}) |
96 | | - return skCfg.EncryptTicket(cs, ss) |
| 109 | + if k, ok := getConnKey(conn); ok { |
| 110 | + key = k |
97 | 111 | } |
| 112 | + skCfg.SetSessionTicketKeys([][32]byte{key}) |
98 | 113 | return cfg.EncryptTicket(cs, ss) |
99 | 114 | } |
100 | 115 | return cfg, nil |
|
0 commit comments