Skip to content

Commit 05becc0

Browse files
Copilotlpcox
andauthored
fix: address code review - proper RWMutex sync for session fields, log reconnect failures
- Use sync.RWMutex (sessionMu) instead of sync.Mutex so concurrent readers don't block - Add getSDKSession() and getHTTPSessionID() helpers that snapshot fields under RLock - Update all c.session reads in method bodies to use getSDKSession() - Update buildSessionHeaderModifier to use getHTTPSessionID() (eliminates data race) - Log reconnect failure in callSDKMethodWithReconnect when reconnErr != nil - Improve SessionTimeout comment in transport.go Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/a43c25fb-e9bc-4fc8-865e-19acf449fa21 Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com>
1 parent 4d26ddd commit 05becc0

3 files changed

Lines changed: 45 additions & 23 deletions

File tree

internal/mcp/connection.go

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,26 @@ type Connection struct {
7272
httpClient *http.Client
7373
httpSessionID string // Session ID returned by the HTTP backend
7474
httpTransportType HTTPTransportType // Type of HTTP transport in use
75-
// reconnectMu serialises session-reconnect operations so that only one
76-
// goroutine performs the reconnect while others wait for it to finish.
77-
reconnectMu sync.Mutex
75+
// sessionMu protects mutable session state: httpSessionID, session, and client.
76+
// Readers take RLock; the reconnect functions take the full Lock.
77+
sessionMu sync.RWMutex
78+
}
79+
80+
// getSDKSession returns a snapshot of the current SDK session under a read lock.
81+
// Returns nil if no session is available (e.g. plain JSON-RPC transport).
82+
func (c *Connection) getSDKSession() *sdk.ClientSession {
83+
c.sessionMu.RLock()
84+
s := c.session
85+
c.sessionMu.RUnlock()
86+
return s
87+
}
88+
89+
// getHTTPSessionID returns a snapshot of the current HTTP session ID under a read lock.
90+
func (c *Connection) getHTTPSessionID() string {
91+
c.sessionMu.RLock()
92+
id := c.httpSessionID
93+
c.sessionMu.RUnlock()
94+
return id
7895
}
7996

8097
// NewConnection creates a new MCP connection using the official SDK
@@ -261,10 +278,10 @@ func (c *Connection) GetHTTPHeaders() map[string]string {
261278

262279
// reconnectPlainJSON re-initialises the plain JSON-RPC session with the HTTP backend.
263280
// It is safe for concurrent callers: only one reconnect runs at a time, and the updated
264-
// session ID is available to all callers once the mutex is released.
281+
// session ID is available to all callers once the lock is released.
265282
func (c *Connection) reconnectPlainJSON() error {
266-
c.reconnectMu.Lock()
267-
defer c.reconnectMu.Unlock()
283+
c.sessionMu.Lock()
284+
defer c.sessionMu.Unlock()
268285

269286
logConn.Printf("Session expired, reconnecting plain JSON-RPC for serverID=%s", c.serverID)
270287
logger.LogWarn("backend", "MCP session expired for %s, attempting to reconnect...", c.serverID)
@@ -284,8 +301,8 @@ func (c *Connection) reconnectPlainJSON() error {
284301
// reconnectSDKTransport re-establishes the SDK session for streamable or SSE transports.
285302
// It is safe for concurrent callers: only one reconnect runs at a time.
286303
func (c *Connection) reconnectSDKTransport() error {
287-
c.reconnectMu.Lock()
288-
defer c.reconnectMu.Unlock()
304+
c.sessionMu.Lock()
305+
defer c.sessionMu.Unlock()
289306

290307
logConn.Printf("Session expired, reconnecting SDK transport for serverID=%s, type=%s", c.serverID, c.httpTransportType)
291308
logger.LogWarn("backend", "MCP session expired for %s, attempting to reconnect...", c.serverID)
@@ -337,9 +354,13 @@ func (c *Connection) callSDKMethodWithReconnect(method string, params interface{
337354
result, err := c.callSDKMethod(method, params)
338355
if err != nil && isSessionNotFoundError(err) {
339356
logConn.Printf("Session not found error from SDK (serverID=%s), attempting reconnect", c.serverID)
340-
if reconnErr := c.reconnectSDKTransport(); reconnErr == nil {
341-
result, err = c.callSDKMethod(method, params)
357+
if reconnErr := c.reconnectSDKTransport(); reconnErr != nil {
358+
logConn.Printf("SDK session reconnect failed for serverID=%s: %v; returning original error", c.serverID, reconnErr)
359+
logger.LogError("backend", "SDK session reconnect failed for %s: %v", c.serverID, reconnErr)
360+
// Return the original session-not-found error so the caller sees a meaningful message.
361+
return result, err
342362
}
363+
result, err = c.callSDKMethod(method, params)
343364
}
344365
return result, err
345366
}
@@ -463,7 +484,7 @@ func marshalToResponse(result interface{}) (*Response, error) {
463484
// This helper centralizes session validation logic across all MCP method wrappers.
464485
// Returns an error if the session is nil (e.g., for plain JSON-RPC transport).
465486
func (c *Connection) requireSession() error {
466-
if c.session == nil {
487+
if c.getSDKSession() == nil {
467488
return fmt.Errorf("SDK session not available for plain JSON-RPC transport")
468489
}
469490
return nil
@@ -518,7 +539,7 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in
518539
func (c *Connection) listTools() (*Response, error) {
519540
logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID)
520541
return c.callListMethod(func() (interface{}, error) {
521-
result, err := c.session.ListTools(c.ctx, &sdk.ListToolsParams{})
542+
result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{})
522543
if err == nil {
523544
logConn.Printf("listTools: received %d tools from serverID=%s", len(result.Tools), c.serverID)
524545
}
@@ -534,7 +555,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) {
534555
p.Arguments = make(map[string]interface{})
535556
}
536557
logConn.Printf("callTool: parsed name=%s, arguments=%+v", p.Name, p.Arguments)
537-
return c.session.CallTool(c.ctx, &sdk.CallToolParams{
558+
return c.getSDKSession().CallTool(c.ctx, &sdk.CallToolParams{
538559
Name: p.Name,
539560
Arguments: p.Arguments,
540561
})
@@ -544,7 +565,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) {
544565
func (c *Connection) listResources() (*Response, error) {
545566
logConn.Printf("listResources: requesting resource list from backend serverID=%s", c.serverID)
546567
return c.callListMethod(func() (interface{}, error) {
547-
result, err := c.session.ListResources(c.ctx, &sdk.ListResourcesParams{})
568+
result, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{})
548569
if err == nil {
549570
logConn.Printf("listResources: received %d resources from serverID=%s", len(result.Resources), c.serverID)
550571
}
@@ -558,7 +579,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) {
558579
}
559580
return callParamMethod(c, params, func(p readResourceParams) (interface{}, error) {
560581
logConn.Printf("readResource: reading resource uri=%s from serverID=%s", p.URI, c.serverID)
561-
return c.session.ReadResource(c.ctx, &sdk.ReadResourceParams{
582+
return c.getSDKSession().ReadResource(c.ctx, &sdk.ReadResourceParams{
562583
URI: p.URI,
563584
})
564585
})
@@ -567,7 +588,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) {
567588
func (c *Connection) listPrompts() (*Response, error) {
568589
logConn.Printf("listPrompts: requesting prompt list from backend serverID=%s", c.serverID)
569590
return c.callListMethod(func() (interface{}, error) {
570-
result, err := c.session.ListPrompts(c.ctx, &sdk.ListPromptsParams{})
591+
result, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{})
571592
if err == nil {
572593
logConn.Printf("listPrompts: received %d prompts from serverID=%s", len(result.Prompts), c.serverID)
573594
}
@@ -582,7 +603,7 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) {
582603
}
583604
return callParamMethod(c, params, func(p getPromptParams) (interface{}, error) {
584605
logConn.Printf("getPrompt: getting prompt name=%s from serverID=%s", p.Name, c.serverID)
585-
return c.session.GetPrompt(c.ctx, &sdk.GetPromptParams{
606+
return c.getSDKSession().GetPrompt(c.ctx, &sdk.GetPromptParams{
586607
Name: p.Name,
587608
Arguments: p.Arguments,
588609
})
@@ -593,8 +614,8 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) {
593614
func (c *Connection) Close() error {
594615
logConn.Printf("Closing connection: serverID=%s, isHTTP=%v", c.serverID, c.isHTTP)
595616
c.cancel()
596-
if c.session != nil {
597-
return c.session.Close()
617+
if session := c.getSDKSession(); session != nil {
618+
return session.Close()
598619
}
599620
return nil
600621
}

internal/mcp/http_transport.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,8 @@ func (c *Connection) initializeHTTPSession() (string, error) {
456456

457457
// buildSessionHeaderModifier returns a header modifier function that adds the Mcp-Session-Id header.
458458
// Priority: context session ID > stored connection session ID.
459-
// The returned function reads c.httpSessionID at call time, so it picks up any reconnected session.
459+
// The returned function calls getHTTPSessionID() at request time, so it always picks up
460+
// any session ID updated by a reconnect.
460461
func (c *Connection) buildSessionHeaderModifier(ctx context.Context) func(*http.Request) {
461462
// Capture any context-provided session ID once (it never changes for this request).
462463
ctxSessionID, _ := ctx.Value(SessionIDContextKey).(string)
@@ -465,8 +466,8 @@ func (c *Connection) buildSessionHeaderModifier(ctx context.Context) func(*http.
465466
if ctxSessionID != "" {
466467
sessionID = ctxSessionID
467468
logConn.Printf("Using session ID from context: %s", sessionID)
468-
} else if c.httpSessionID != "" {
469-
sessionID = c.httpSessionID
469+
} else if id := c.getHTTPSessionID(); id != "" {
470+
sessionID = id
470471
logConn.Printf("Using stored session ID from initialization: %s", sessionID)
471472
}
472473
if sessionID != "" {

internal/server/transport.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey st
3737
}, &sdk.StreamableHTTPOptions{
3838
Stateless: false, // Support stateful sessions
3939
Logger: logger.NewSlogLoggerWithHandler(logTransport), // Integrate SDK logging with project logger
40-
SessionTimeout: 2 * time.Hour, // Allow long-running agent workflows (issue: sessions expiring mid-run)
40+
SessionTimeout: 2 * time.Hour, // Long-running agent workflows can exceed 30 min without MCP activity; 2 h reduces forced reconnects
4141
})
4242

4343
// Apply standard middleware stack (SDK logging → shutdown check → auth)

0 commit comments

Comments
 (0)