@@ -3,8 +3,10 @@ package mcp
33
44import (
55 "context"
6+ "crypto/sha256"
67 "fmt"
78 "log/slog"
9+ "sort"
810 "strings"
911 "sync"
1012 "time"
@@ -31,17 +33,23 @@ type ClientManager struct {
3133 mu sync.Mutex
3234 clients map [string ]* sdk_mcp.Client
3335 sessions map [string ]* sdk_mcp.ClientSession
36+ // headerHashes tracks the hash of headers used when creating each session,
37+ // so we can detect when DynamicHeaders change and invalidate stale sessions.
38+ headerHashes map [string ]string
3439}
3540
3641// NewClientManager creates a new ClientManager.
3742func NewClientManager () * ClientManager {
3843 return & ClientManager {
39- clients : make (map [string ]* sdk_mcp.Client ),
40- sessions : make (map [string ]* sdk_mcp.ClientSession ),
44+ clients : make (map [string ]* sdk_mcp.Client ),
45+ sessions : make (map [string ]* sdk_mcp.ClientSession ),
46+ headerHashes : make (map [string ]string ),
4147 }
4248}
4349
4450// GetSession returns or creates an MCP session for the given server.
51+ // Sessions are cached by server name. If DynamicHeaders change (e.g. a different
52+ // user or rotated credentials), the stale session is invalidated and recreated.
4553func (m * ClientManager ) GetSession (ctx context.Context , server * protocol.MCPServerConfig , logger * slog.Logger ) (* sdk_mcp.ClientSession , error ) {
4654 if logger == nil {
4755 logger = slog .With ("server" , server .Name )
@@ -51,34 +59,43 @@ func (m *ClientManager) GetSession(ctx context.Context, server *protocol.MCPServ
5159 defer m .mu .Unlock ()
5260
5361 serverName := server .Name
62+ currentHash := hashHeaders (server .Headers , server .DynamicHeaders )
5463
5564 logger .Debug ("mcp resolving session" ,
5665 "transport" , server .Transport ,
5766 "url" , server .URL ,
5867 "command" , server .Command ,
68+ "auth_key" , maskKey (server .DynamicHeaders ["Authorization" ]),
5969 )
6070
61- // Check if session exists and is still valid
71+ // Check if session exists and headers haven't changed
6272 if session , ok := m .sessions [serverName ]; ok {
63- logger .Debug ("mcp reusing session" )
64- return session , nil
73+ if m .headerHashes [serverName ] == currentHash {
74+ logger .Debug ("mcp reusing session" )
75+ return session , nil
76+ }
77+ logger .Info ("mcp headers changed, invalidating cached session" ,
78+ "server" , serverName ,
79+ )
80+ _ = session .Close ()
81+ delete (m .sessions , serverName )
82+ delete (m .clients , serverName )
83+ delete (m .headerHashes , serverName )
6584 }
6685
67- // Create client if not exists
68- client , ok := m .clients [serverName ]
69- if ! ok {
70- client = sdk_mcp .NewClient (& sdk_mcp.Implementation {
71- Name : "flashduty-runner" ,
72- Version : "1.0.0" ,
73- }, nil )
74- m .clients [serverName ] = client
75- }
86+ // Create client
87+ client := sdk_mcp .NewClient (& sdk_mcp.Implementation {
88+ Name : "flashduty-runner" ,
89+ Version : "1.0.0" ,
90+ }, nil )
91+ m .clients [serverName ] = client
7692
7793 // Create transport
7894 logger .Info ("mcp creating transport" ,
7995 "transport" , server .Transport ,
8096 "url" , server .URL ,
8197 "command" , server .Command ,
98+ "auth_key" , maskKey (server .DynamicHeaders ["Authorization" ]),
8299 )
83100
84101 transport , err := createTransport (server )
@@ -103,8 +120,11 @@ func (m *ClientManager) GetSession(ctx context.Context, server *protocol.MCPServ
103120 return nil , fmt .Errorf ("failed to connect to MCP server '%s': %w" , serverName , err )
104121 }
105122
106- logger .Info ("mcp connected" )
123+ logger .Info ("mcp connected" ,
124+ "auth_key" , maskKey (server .DynamicHeaders ["Authorization" ]),
125+ )
107126 m .sessions [serverName ] = session
127+ m .headerHashes [serverName ] = currentHash
108128 return session , nil
109129}
110130
@@ -174,9 +194,46 @@ func (m *ClientManager) ListTools(ctx context.Context, server *protocol.MCPServe
174194func (m * ClientManager ) invalidateSession (serverName string ) {
175195 m .mu .Lock ()
176196 delete (m .sessions , serverName )
197+ delete (m .headerHashes , serverName )
177198 m .mu .Unlock ()
178199}
179200
201+ // hashHeaders computes a stable hash of both static and dynamic headers.
202+ // Used to detect when credentials change so stale sessions are invalidated.
203+ func hashHeaders (headers , dynamicHeaders map [string ]string ) string {
204+ h := sha256 .New ()
205+ // Sort keys for deterministic hashing
206+ writeMap := func (m map [string ]string ) {
207+ keys := make ([]string , 0 , len (m ))
208+ for k := range m {
209+ keys = append (keys , k )
210+ }
211+ sort .Strings (keys )
212+ for _ , k := range keys {
213+ h .Write ([]byte (k ))
214+ h .Write ([]byte (m [k ]))
215+ }
216+ }
217+ h .Write ([]byte ("static:" ))
218+ writeMap (headers )
219+ h .Write ([]byte ("dynamic:" ))
220+ writeMap (dynamicHeaders )
221+ return fmt .Sprintf ("%x" , h .Sum (nil ))
222+ }
223+
224+ // maskKey returns the first 6 characters of a key for safe logging.
225+ // Returns empty string for empty/short keys.
226+ func maskKey (key string ) string {
227+ // Strip "Bearer " prefix if present
228+ if strings .HasPrefix (key , "Bearer " ) {
229+ key = key [7 :]
230+ }
231+ if len (key ) <= 6 {
232+ return key
233+ }
234+ return key [:6 ] + "***"
235+ }
236+
180237// Close closes all active sessions and clients.
181238func (m * ClientManager ) Close () {
182239 m .mu .Lock ()
0 commit comments