@@ -72,9 +72,26 @@ type Connection struct {
7272 httpClient * http.Client
7373 httpSessionID string // Session ID returned by the HTTP backend
7474 httpTransportType HTTPTransportType // Type of HTTP transport in use
75- // reconnectMu serialises session-reconnect operations so that only one
76- // goroutine performs the reconnect while others wait for it to finish.
77- reconnectMu sync.Mutex
75+ // sessionMu protects mutable session state: httpSessionID, session, and client.
76+ // Readers take RLock; the reconnect functions take the full Lock.
77+ sessionMu sync.RWMutex
78+ }
79+
80+ // getSDKSession returns a snapshot of the current SDK session under a read lock.
81+ // Returns nil if no session is available (e.g. plain JSON-RPC transport).
82+ func (c * Connection ) getSDKSession () * sdk.ClientSession {
83+ c .sessionMu .RLock ()
84+ s := c .session
85+ c .sessionMu .RUnlock ()
86+ return s
87+ }
88+
89+ // getHTTPSessionID returns a snapshot of the current HTTP session ID under a read lock.
90+ func (c * Connection ) getHTTPSessionID () string {
91+ c .sessionMu .RLock ()
92+ id := c .httpSessionID
93+ c .sessionMu .RUnlock ()
94+ return id
7895}
7996
8097// NewConnection creates a new MCP connection using the official SDK
@@ -261,10 +278,10 @@ func (c *Connection) GetHTTPHeaders() map[string]string {
261278
262279// reconnectPlainJSON re-initialises the plain JSON-RPC session with the HTTP backend.
263280// It is safe for concurrent callers: only one reconnect runs at a time, and the updated
264- // session ID is available to all callers once the mutex is released.
281+ // session ID is available to all callers once the lock is released.
265282func (c * Connection ) reconnectPlainJSON () error {
266- c .reconnectMu .Lock ()
267- defer c .reconnectMu .Unlock ()
283+ c .sessionMu .Lock ()
284+ defer c .sessionMu .Unlock ()
268285
269286 logConn .Printf ("Session expired, reconnecting plain JSON-RPC for serverID=%s" , c .serverID )
270287 logger .LogWarn ("backend" , "MCP session expired for %s, attempting to reconnect..." , c .serverID )
@@ -284,8 +301,8 @@ func (c *Connection) reconnectPlainJSON() error {
284301// reconnectSDKTransport re-establishes the SDK session for streamable or SSE transports.
285302// It is safe for concurrent callers: only one reconnect runs at a time.
286303func (c * Connection ) reconnectSDKTransport () error {
287- c .reconnectMu .Lock ()
288- defer c .reconnectMu .Unlock ()
304+ c .sessionMu .Lock ()
305+ defer c .sessionMu .Unlock ()
289306
290307 logConn .Printf ("Session expired, reconnecting SDK transport for serverID=%s, type=%s" , c .serverID , c .httpTransportType )
291308 logger .LogWarn ("backend" , "MCP session expired for %s, attempting to reconnect..." , c .serverID )
@@ -337,9 +354,13 @@ func (c *Connection) callSDKMethodWithReconnect(method string, params interface{
337354 result , err := c .callSDKMethod (method , params )
338355 if err != nil && isSessionNotFoundError (err ) {
339356 logConn .Printf ("Session not found error from SDK (serverID=%s), attempting reconnect" , c .serverID )
340- if reconnErr := c .reconnectSDKTransport (); reconnErr == nil {
341- result , err = c .callSDKMethod (method , params )
357+ if reconnErr := c .reconnectSDKTransport (); reconnErr != nil {
358+ logConn .Printf ("SDK session reconnect failed for serverID=%s: %v; returning original error" , c .serverID , reconnErr )
359+ logger .LogError ("backend" , "SDK session reconnect failed for %s: %v" , c .serverID , reconnErr )
360+ // Return the original session-not-found error so the caller sees a meaningful message.
361+ return result , err
342362 }
363+ result , err = c .callSDKMethod (method , params )
343364 }
344365 return result , err
345366}
@@ -463,7 +484,7 @@ func marshalToResponse(result interface{}) (*Response, error) {
463484// This helper centralizes session validation logic across all MCP method wrappers.
464485// Returns an error if the session is nil (e.g., for plain JSON-RPC transport).
465486func (c * Connection ) requireSession () error {
466- if c .session == nil {
487+ if c .getSDKSession () == nil {
467488 return fmt .Errorf ("SDK session not available for plain JSON-RPC transport" )
468489 }
469490 return nil
@@ -518,7 +539,7 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in
518539func (c * Connection ) listTools () (* Response , error ) {
519540 logConn .Printf ("listTools: requesting tool list from backend serverID=%s" , c .serverID )
520541 return c .callListMethod (func () (interface {}, error ) {
521- result , err := c .session .ListTools (c .ctx , & sdk.ListToolsParams {})
542+ result , err := c .getSDKSession () .ListTools (c .ctx , & sdk.ListToolsParams {})
522543 if err == nil {
523544 logConn .Printf ("listTools: received %d tools from serverID=%s" , len (result .Tools ), c .serverID )
524545 }
@@ -534,7 +555,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) {
534555 p .Arguments = make (map [string ]interface {})
535556 }
536557 logConn .Printf ("callTool: parsed name=%s, arguments=%+v" , p .Name , p .Arguments )
537- return c .session .CallTool (c .ctx , & sdk.CallToolParams {
558+ return c .getSDKSession () .CallTool (c .ctx , & sdk.CallToolParams {
538559 Name : p .Name ,
539560 Arguments : p .Arguments ,
540561 })
@@ -544,7 +565,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) {
544565func (c * Connection ) listResources () (* Response , error ) {
545566 logConn .Printf ("listResources: requesting resource list from backend serverID=%s" , c .serverID )
546567 return c .callListMethod (func () (interface {}, error ) {
547- result , err := c .session .ListResources (c .ctx , & sdk.ListResourcesParams {})
568+ result , err := c .getSDKSession () .ListResources (c .ctx , & sdk.ListResourcesParams {})
548569 if err == nil {
549570 logConn .Printf ("listResources: received %d resources from serverID=%s" , len (result .Resources ), c .serverID )
550571 }
@@ -558,7 +579,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) {
558579 }
559580 return callParamMethod (c , params , func (p readResourceParams ) (interface {}, error ) {
560581 logConn .Printf ("readResource: reading resource uri=%s from serverID=%s" , p .URI , c .serverID )
561- return c .session .ReadResource (c .ctx , & sdk.ReadResourceParams {
582+ return c .getSDKSession () .ReadResource (c .ctx , & sdk.ReadResourceParams {
562583 URI : p .URI ,
563584 })
564585 })
@@ -567,7 +588,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) {
567588func (c * Connection ) listPrompts () (* Response , error ) {
568589 logConn .Printf ("listPrompts: requesting prompt list from backend serverID=%s" , c .serverID )
569590 return c .callListMethod (func () (interface {}, error ) {
570- result , err := c .session .ListPrompts (c .ctx , & sdk.ListPromptsParams {})
591+ result , err := c .getSDKSession () .ListPrompts (c .ctx , & sdk.ListPromptsParams {})
571592 if err == nil {
572593 logConn .Printf ("listPrompts: received %d prompts from serverID=%s" , len (result .Prompts ), c .serverID )
573594 }
@@ -582,7 +603,7 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) {
582603 }
583604 return callParamMethod (c , params , func (p getPromptParams ) (interface {}, error ) {
584605 logConn .Printf ("getPrompt: getting prompt name=%s from serverID=%s" , p .Name , c .serverID )
585- return c .session .GetPrompt (c .ctx , & sdk.GetPromptParams {
606+ return c .getSDKSession () .GetPrompt (c .ctx , & sdk.GetPromptParams {
586607 Name : p .Name ,
587608 Arguments : p .Arguments ,
588609 })
@@ -593,8 +614,8 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) {
593614func (c * Connection ) Close () error {
594615 logConn .Printf ("Closing connection: serverID=%s, isHTTP=%v" , c .serverID , c .isHTTP )
595616 c .cancel ()
596- if c . session != nil {
597- return c . session .Close ()
617+ if session := c . getSDKSession (); session != nil {
618+ return session .Close ()
598619 }
599620 return nil
600621}
0 commit comments