Skip to content

Commit f861516

Browse files
authored
fix: reconnect expired MCP backend sessions transparently, extend server session timeout (#2597)
Long-running agent workflows (~30+ min) fail when the gateway's HTTP session to a backend (e.g. safeoutputs) expires mid-run. The agent receives `session not found` from the backend and enters a futile retry loop rather than recovering. ## Changes ### Client-side session reconnect (`internal/mcp/connection.go`, `http_transport.go`) - **Plain JSON-RPC transport**: `sendHTTPRequest` detects HTTP 404 + `"session not found"` body, calls `reconnectPlainJSON()` to re-initialize the session, then retries the original request once. - **SDK transports (streamable / SSE)**: `callSDKMethodWithReconnect()` wraps all SDK method calls — on a `"session not found"` error it calls `reconnectSDKTransport()` (closes the dead session, dials a new one) and retries once. `SendRequestWithServerID` now routes through this wrapper for SDK transports. - **Thread safety**: `sessionMu sync.RWMutex` protects `httpSessionID`, `session`, and `client`. All reads go through `getSDKSession()` / `getHTTPSessionID()` (under `RLock`); reconnect functions hold the full `Lock`. ```go // Plain JSON-RPC: sendHTTPRequest now does this automatically result, err := c.executeHTTPRequest(...) if isSessionNotFoundHTTPResponse(result.StatusCode, result.ResponseBody) { if reconnErr := c.reconnectPlainJSON(); reconnErr == nil { result, err = c.executeHTTPRequest(...) // retry once } } // SDK transports (streamable / SSE): via callSDKMethodWithReconnect result, err := c.callSDKMethod(method, params) if err != nil && isSessionNotFoundError(err) { if reconnErr := c.reconnectSDKTransport(); reconnErr == nil { result, err = c.callSDKMethod(method, params) } } ``` ### Server-side idle timeout (`internal/server/transport.go`) `SessionTimeout` increased from `30m → 2h` so the agent→gateway inbound session doesn't expire during extended periods of no MCP activity (e.g. a long `lake build`). ### Tests (`internal/mcp/http_transport_test.go`) Added `TestSendHTTPRequest_ReconnectsOnSessionNotFound`, `TestSendHTTPRequest_ReconnectFailure`, `TestSendHTTPRequest_NoReconnectOnOtherErrors`, plus unit tests for `isSessionNotFoundError` and `isSessionNotFoundHTTPResponse`. > [!WARNING] > > <details> > <summary>Firewall rules blocked me from connecting to one or more addresses (expand for details)</summary> > > #### I tried to connect to the following addresses, but was blocked by firewall rules: > > - `example.com` > - Triggering command: `/tmp/go-build2012957444/b330/launcher.test /tmp/go-build2012957444/b330/launcher.test -test.testlogfile=/tmp/go-build2012957444/b330/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true g_.a --local es committer.email go` (dns block) > - Triggering command: `/tmp/go-build916828748/b330/launcher.test /tmp/go-build916828748/b330/launcher.test -test.testlogfile=/tmp/go-build916828748/b330/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true aw-mcpg/internal/config/rules/rules.go aw-mcpg/internal/config/rules/rules_test.go ache/go/1.25.8/x64/pkg/tool/linux_amd64/vet . contextprotocol/checkout --64 ache/go/1.25.8/x64/pkg/tool/linux_amd64/vet -I Scom/odrQfnn_R7r_as4UScom -I /opt/hostedtoolcache/go/1.25.8/x64/pkg/tool/linu-lang=go1.25 --gdwarf-5 --64 -o 2957444/b291/importcfg` (dns block) > - Triggering command: `/tmp/go-build1482963661/b334/launcher.test /tmp/go-build1482963661/b334/launcher.test -test.testlogfile=/tmp/go-build1482963661/b334/testlog.txt -test.paniconexit0 -test.timeout=10m0s --ve�� /home/REDACTED/work/gh-aw-mcpg/gh-aw-mcpg/internal/guard/context.go /home/REDACTED/work/gh-aw-mcpg/gh-aw-mcpg/internal/guard/guard.go ash b128.go s2rJ-u63W ache/go/1.25.8/x/tmp/go-build3142434740/b172/vet.cfg nIQTNFClwl8R --ve��` (dns block) > - `invalid-host-that-does-not-exist-12345.com` > - Triggering command: `/tmp/go-build2012957444/b315/config.test /tmp/go-build2012957444/b315/config.test -test.testlogfile=/tmp/go-build2012957444/b315/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true -c=4 -nolocalimports -importcfg /tmp/go-build2012957444/b289/importcfg -pack /home/REDACTED/go/pkg/mod/golang.org/x/oauth2@v0.34.0/deviceauth.go /home/REDACTED/go/pkg/mod/golang.org/x/oauth2@v0.34.0/oauth2.go ortc�� rLdqHevS3 64/src/internal/byteorder/byteor--gdwarf2 x_amd64/vet pull.rebase abis` (dns block) > - Triggering command: `/tmp/go-build588503108/b315/config.test /tmp/go-build588503108/b315/config.test -test.testlogfile=/tmp/go-build588503108/b315/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true ache/go/1.25.8/x64/src/runtime/c-I` (dns block) > - Triggering command: `/tmp/go-build1482963661/b319/config.test /tmp/go-build1482963661/b319/config.test -test.testlogfile=/tmp/go-build1482963661/b319/testlog.txt -test.paniconexit0 -test.timeout=10m0s -o d -n 10 -importcfg 2957444/b370=&gt; -s -w -buildmode=exe /usr/bin/runc.original --ve�� it tests...&#34; -extld=gcc x_amd64/vet 64/src/runtime/c/opt/hostedtoolcache/go/1.25.8/x64/pkg/tool/linux_amd64/vet credential.helpe/tmp/go-build3142434740/b239/vet.cfg ache/go/1.25.8/x-lang=go1.17 x_amd64/vet` (dns block) > - `nonexistent.local` > - Triggering command: `/tmp/go-build2012957444/b330/launcher.test /tmp/go-build2012957444/b330/launcher.test -test.testlogfile=/tmp/go-build2012957444/b330/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true g_.a --local es committer.email go` (dns block) > - Triggering command: `/tmp/go-build916828748/b330/launcher.test /tmp/go-build916828748/b330/launcher.test -test.testlogfile=/tmp/go-build916828748/b330/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true aw-mcpg/internal/config/rules/rules.go aw-mcpg/internal/config/rules/rules_test.go ache/go/1.25.8/x64/pkg/tool/linux_amd64/vet . contextprotocol/checkout --64 ache/go/1.25.8/x64/pkg/tool/linux_amd64/vet -I Scom/odrQfnn_R7r_as4UScom -I /opt/hostedtoolcache/go/1.25.8/x64/pkg/tool/linu-lang=go1.25 --gdwarf-5 --64 -o 2957444/b291/importcfg` (dns block) > - Triggering command: `/tmp/go-build1482963661/b334/launcher.test /tmp/go-build1482963661/b334/launcher.test -test.testlogfile=/tmp/go-build1482963661/b334/testlog.txt -test.paniconexit0 -test.timeout=10m0s --ve�� /home/REDACTED/work/gh-aw-mcpg/gh-aw-mcpg/internal/guard/context.go /home/REDACTED/work/gh-aw-mcpg/gh-aw-mcpg/internal/guard/guard.go ash b128.go s2rJ-u63W ache/go/1.25.8/x/tmp/go-build3142434740/b172/vet.cfg nIQTNFClwl8R --ve��` (dns block) > - `slow.example.com` > - Triggering command: `/tmp/go-build2012957444/b330/launcher.test /tmp/go-build2012957444/b330/launcher.test -test.testlogfile=/tmp/go-build2012957444/b330/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true g_.a --local es committer.email go` (dns block) > - Triggering command: `/tmp/go-build916828748/b330/launcher.test /tmp/go-build916828748/b330/launcher.test -test.testlogfile=/tmp/go-build916828748/b330/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true aw-mcpg/internal/config/rules/rules.go aw-mcpg/internal/config/rules/rules_test.go ache/go/1.25.8/x64/pkg/tool/linux_amd64/vet . contextprotocol/checkout --64 ache/go/1.25.8/x64/pkg/tool/linux_amd64/vet -I Scom/odrQfnn_R7r_as4UScom -I /opt/hostedtoolcache/go/1.25.8/x64/pkg/tool/linu-lang=go1.25 --gdwarf-5 --64 -o 2957444/b291/importcfg` (dns block) > - Triggering command: `/tmp/go-build1482963661/b334/launcher.test /tmp/go-build1482963661/b334/launcher.test -test.testlogfile=/tmp/go-build1482963661/b334/testlog.txt -test.paniconexit0 -test.timeout=10m0s --ve�� /home/REDACTED/work/gh-aw-mcpg/gh-aw-mcpg/internal/guard/context.go /home/REDACTED/work/gh-aw-mcpg/gh-aw-mcpg/internal/guard/guard.go ash b128.go s2rJ-u63W ache/go/1.25.8/x/tmp/go-build3142434740/b172/vet.cfg nIQTNFClwl8R --ve��` (dns block) > - `this-host-does-not-exist-12345.com` > - Triggering command: `/tmp/go-build2012957444/b339/mcp.test /tmp/go-build2012957444/b339/mcp.test -test.testlogfile=/tmp/go-build2012957444/b339/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true 64/src/runtime/cgo o x_amd64/vet` (dns block) > - Triggering command: `/tmp/go-build916828748/b339/mcp.test /tmp/go-build916828748/b339/mcp.test -test.testlogfile=/tmp/go-build916828748/b339/testlog.txt -test.paniconexit0 -test.timeout=10m0s -test.v=true 2957444/b308/_pkg_.a /tmp/go-build2012957444/b165/ ache/go/1.25.8/x64/pkg/tool/linux_amd64/compile . contextprotocol/-c --64 ache/go/1.25.8/x64/pkg/tool/linux_amd64/compile -I 2957444/b320/_pkg_.a -I docker-buildx --gdwarf-5 b/gh-aw-mcpg/int--norc -o docker-buildx` (dns block) > - Triggering command: `/tmp/go-build1482963661/b343/mcp.test /tmp/go-build1482963661/b343/mcp.test -test.testlogfile=/tmp/go-build1482963661/b343/testlog.txt -test.paniconexit0 -test.timeout=10m0s -o ssue -importcfg x_amd64/vet -s -w -buildmode=exe x_amd64/vet --no�� --noprofile -extld=gcc x_amd64/vet ternal/engine/wabash ternal/engine/wa/usr/bin/runc x_amd64/compile x_amd64/vet` (dns block) > > If you need me to access, download, or install something from one of these locations, you can either: > > - Configure [Actions setup steps](https://gh.io/copilot/actions-setup-steps) to set up my environment, which run before the firewall is enabled > - Add the appropriate URLs or hosts to the custom allowlist in this repository's [Copilot coding agent settings](https://github.com/github/gh-aw-mcpg/settings/copilot/coding_agent) (admins only) > > </details> <!-- START COPILOT CODING AGENT TIPS --> --- ⚡ Quickly spin up Copilot coding agent tasks from anywhere on your macOS or Windows machine with [Raycast](https://gh.io/cca-raycast-docs).
2 parents 6470017 + 1d95d26 commit f861516

4 files changed

Lines changed: 427 additions & 43 deletions

File tree

internal/mcp/connection.go

Lines changed: 121 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"net/http"
1212
"os/exec"
1313
"strings"
14+
"sync"
1415
"time"
1516

1617
"github.com/github/gh-aw-mcpg/internal/difc"
@@ -71,6 +72,27 @@ type Connection struct {
7172
httpClient *http.Client
7273
httpSessionID string // Session ID returned by the HTTP backend
7374
httpTransportType HTTPTransportType // Type of HTTP transport in use
75+
// sessionMu protects the mutable session fields: httpSessionID, session, and client.
76+
// Always use getHTTPSessionID() or getSDKSession() to read these fields; the
77+
// reconnect functions (reconnectPlainJSON, reconnectSDKTransport) hold the full Lock.
78+
sessionMu sync.RWMutex
79+
}
80+
81+
// getSDKSession returns a snapshot of the current SDK session under a read lock.
82+
// Returns nil if no session is available (e.g. plain JSON-RPC transport).
83+
func (c *Connection) getSDKSession() *sdk.ClientSession {
84+
c.sessionMu.RLock()
85+
s := c.session
86+
c.sessionMu.RUnlock()
87+
return s
88+
}
89+
90+
// getHTTPSessionID returns a snapshot of the current HTTP session ID under a read lock.
91+
func (c *Connection) getHTTPSessionID() string {
92+
c.sessionMu.RLock()
93+
id := c.httpSessionID
94+
c.sessionMu.RUnlock()
95+
return id
7496
}
7597

7698
// NewConnection creates a new MCP connection using the official SDK
@@ -255,6 +277,95 @@ func (c *Connection) GetHTTPHeaders() map[string]string {
255277
return c.headers
256278
}
257279

280+
// reconnectPlainJSON re-initialises the plain JSON-RPC session with the HTTP backend.
281+
// It is safe for concurrent callers: only one reconnect runs at a time, and the updated
282+
// session ID is available to all callers once the lock is released.
283+
func (c *Connection) reconnectPlainJSON() error {
284+
c.sessionMu.Lock()
285+
defer c.sessionMu.Unlock()
286+
287+
logConn.Printf("Session expired, reconnecting plain JSON-RPC for serverID=%s", c.serverID)
288+
logger.LogWarn("backend", "MCP session expired for %s, attempting to reconnect...", c.serverID)
289+
290+
sessionID, err := c.initializeHTTPSession()
291+
if err != nil {
292+
logger.LogError("backend", "Session reconnect failed for %s: %v", c.serverID, err)
293+
return fmt.Errorf("session reconnect failed: %w", err)
294+
}
295+
296+
c.httpSessionID = sessionID
297+
logConn.Printf("Reconnected plain JSON-RPC session for serverID=%s, new sessionID=%s", c.serverID, sessionID)
298+
logger.LogInfo("backend", "Session successfully reconnected for %s", c.serverID)
299+
return nil
300+
}
301+
302+
// reconnectSDKTransport re-establishes the SDK session for streamable or SSE transports.
303+
// It is safe for concurrent callers: only one reconnect runs at a time.
304+
func (c *Connection) reconnectSDKTransport() error {
305+
c.sessionMu.Lock()
306+
defer c.sessionMu.Unlock()
307+
308+
logConn.Printf("Session expired, reconnecting SDK transport for serverID=%s, type=%s", c.serverID, c.httpTransportType)
309+
logger.LogWarn("backend", "MCP session expired for %s, attempting to reconnect...", c.serverID)
310+
311+
// Close the existing session gracefully (ignore error – it's already dead).
312+
if c.session != nil {
313+
_ = c.session.Close()
314+
}
315+
316+
// Build the appropriate transport.
317+
client := newMCPClient(logConn)
318+
var transport sdk.Transport
319+
switch c.httpTransportType {
320+
case HTTPTransportStreamable:
321+
transport = &sdk.StreamableClientTransport{
322+
Endpoint: c.httpURL,
323+
HTTPClient: c.httpClient,
324+
MaxRetries: 0,
325+
}
326+
case HTTPTransportSSE:
327+
transport = &sdk.SSEClientTransport{
328+
Endpoint: c.httpURL,
329+
HTTPClient: c.httpClient,
330+
}
331+
default:
332+
return fmt.Errorf("cannot reconnect: unsupported transport type %s", c.httpTransportType)
333+
}
334+
335+
connectCtx, cancel := context.WithTimeout(c.ctx, 10*time.Second)
336+
defer cancel()
337+
338+
session, err := client.Connect(connectCtx, transport, nil)
339+
if err != nil {
340+
logger.LogError("backend", "Session reconnect failed for %s: %v", c.serverID, err)
341+
return fmt.Errorf("session reconnect failed: %w", err)
342+
}
343+
344+
c.client = client
345+
c.session = session
346+
347+
logConn.Printf("Reconnected SDK session for serverID=%s", c.serverID)
348+
logger.LogInfo("backend", "Session successfully reconnected for %s", c.serverID)
349+
return nil
350+
}
351+
352+
// callSDKMethodWithReconnect calls the SDK method and, if the session has expired,
353+
// reconnects and retries exactly once before propagating the error.
354+
func (c *Connection) callSDKMethodWithReconnect(method string, params interface{}) (*Response, error) {
355+
result, err := c.callSDKMethod(method, params)
356+
if err != nil && isSessionNotFoundError(err) {
357+
logConn.Printf("Session not found error from SDK (serverID=%s), attempting reconnect", c.serverID)
358+
if reconnErr := c.reconnectSDKTransport(); reconnErr != nil {
359+
logConn.Printf("SDK session reconnect failed for serverID=%s: %v; returning original error", c.serverID, reconnErr)
360+
logger.LogError("backend", "SDK session reconnect failed for %s: %v", c.serverID, reconnErr)
361+
// Return the original session-not-found error so the caller sees a meaningful message.
362+
return result, err
363+
}
364+
result, err = c.callSDKMethod(method, params)
365+
}
366+
return result, err
367+
}
368+
258369
// SendRequest sends a JSON-RPC request and waits for the response
259370
// The serverID parameter is used for logging to associate the request with a backend server
260371
func (c *Connection) SendRequest(method string, params interface{}) (*Response, error) {
@@ -301,7 +412,7 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string,
301412
}
302413

303414
// For streamable and SSE transports, use SDK session methods
304-
result, err = c.callSDKMethod(method, params)
415+
result, err = c.callSDKMethodWithReconnect(method, params)
305416
// Log the response from backend server
306417
var responsePayload []byte
307418
if result != nil {
@@ -374,7 +485,7 @@ func marshalToResponse(result interface{}) (*Response, error) {
374485
// This helper centralizes session validation logic across all MCP method wrappers.
375486
// Returns an error if the session is nil (e.g., for plain JSON-RPC transport).
376487
func (c *Connection) requireSession() error {
377-
if c.session == nil {
488+
if c.getSDKSession() == nil {
378489
return fmt.Errorf("SDK session not available for plain JSON-RPC transport")
379490
}
380491
return nil
@@ -429,7 +540,7 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in
429540
func (c *Connection) listTools() (*Response, error) {
430541
logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID)
431542
return c.callListMethod(func() (interface{}, error) {
432-
result, err := c.session.ListTools(c.ctx, &sdk.ListToolsParams{})
543+
result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{})
433544
if err == nil {
434545
logConn.Printf("listTools: received %d tools from serverID=%s", len(result.Tools), c.serverID)
435546
}
@@ -445,7 +556,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) {
445556
p.Arguments = make(map[string]interface{})
446557
}
447558
logConn.Printf("callTool: parsed name=%s, arguments=%+v", p.Name, p.Arguments)
448-
return c.session.CallTool(c.ctx, &sdk.CallToolParams{
559+
return c.getSDKSession().CallTool(c.ctx, &sdk.CallToolParams{
449560
Name: p.Name,
450561
Arguments: p.Arguments,
451562
})
@@ -455,7 +566,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) {
455566
func (c *Connection) listResources() (*Response, error) {
456567
logConn.Printf("listResources: requesting resource list from backend serverID=%s", c.serverID)
457568
return c.callListMethod(func() (interface{}, error) {
458-
result, err := c.session.ListResources(c.ctx, &sdk.ListResourcesParams{})
569+
result, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{})
459570
if err == nil {
460571
logConn.Printf("listResources: received %d resources from serverID=%s", len(result.Resources), c.serverID)
461572
}
@@ -469,7 +580,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) {
469580
}
470581
return callParamMethod(c, params, func(p readResourceParams) (interface{}, error) {
471582
logConn.Printf("readResource: reading resource uri=%s from serverID=%s", p.URI, c.serverID)
472-
return c.session.ReadResource(c.ctx, &sdk.ReadResourceParams{
583+
return c.getSDKSession().ReadResource(c.ctx, &sdk.ReadResourceParams{
473584
URI: p.URI,
474585
})
475586
})
@@ -478,7 +589,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) {
478589
func (c *Connection) listPrompts() (*Response, error) {
479590
logConn.Printf("listPrompts: requesting prompt list from backend serverID=%s", c.serverID)
480591
return c.callListMethod(func() (interface{}, error) {
481-
result, err := c.session.ListPrompts(c.ctx, &sdk.ListPromptsParams{})
592+
result, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{})
482593
if err == nil {
483594
logConn.Printf("listPrompts: received %d prompts from serverID=%s", len(result.Prompts), c.serverID)
484595
}
@@ -493,7 +604,7 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) {
493604
}
494605
return callParamMethod(c, params, func(p getPromptParams) (interface{}, error) {
495606
logConn.Printf("getPrompt: getting prompt name=%s from serverID=%s", p.Name, c.serverID)
496-
return c.session.GetPrompt(c.ctx, &sdk.GetPromptParams{
607+
return c.getSDKSession().GetPrompt(c.ctx, &sdk.GetPromptParams{
497608
Name: p.Name,
498609
Arguments: p.Arguments,
499610
})
@@ -504,8 +615,8 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) {
504615
func (c *Connection) Close() error {
505616
logConn.Printf("Closing connection: serverID=%s, isHTTP=%v", c.serverID, c.isHTTP)
506617
c.cancel()
507-
if c.session != nil {
508-
return c.session.Close()
618+
if session := c.getSDKSession(); session != nil {
619+
return session.Close()
509620
}
510621
return nil
511622
}

internal/mcp/http_transport.go

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ func isHTTPConnectionError(err error) bool {
6262
return false
6363
}
6464

65+
// isSessionNotFoundError checks if an error message indicates a backend MCP session has expired
66+
// or is not found. This is used to detect when automatic reconnection to the backend is needed.
67+
func isSessionNotFoundError(err error) bool {
68+
if err == nil {
69+
return false
70+
}
71+
return strings.Contains(strings.ToLower(err.Error()), "session not found")
72+
}
73+
74+
// isSessionNotFoundHTTPResponse checks if an HTTP response indicates the backend session was not found.
75+
// MCP backends return HTTP 404 with a "session not found" body when a session has expired.
76+
func isSessionNotFoundHTTPResponse(statusCode int, body []byte) bool {
77+
if statusCode != http.StatusNotFound {
78+
return false
79+
}
80+
return strings.Contains(strings.ToLower(string(body)), "session not found")
81+
}
82+
6583
// parseSSEResponse extracts JSON data from SSE-formatted response
6684
// SSE format: "event: message\ndata: {json}\n\n"
6785
func parseSSEResponse(body []byte) ([]byte, error) {
@@ -436,58 +454,47 @@ func (c *Connection) initializeHTTPSession() (string, error) {
436454
return sessionID, nil
437455
}
438456

439-
// sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server
440-
// The ctx parameter is used to extract session ID for the Mcp-Session-Id header
441-
func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) {
442-
// Generate unique request ID using atomic counter
443-
requestID := atomic.AddUint64(&requestIDCounter, 1)
444-
445-
// For tools/call, ensure arguments field always exists (MCP protocol requirement)
446-
if method == "tools/call" {
447-
params = ensureToolCallArguments(params)
448-
}
449-
450-
logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID)
451-
452-
// Execute HTTP request with custom header modification for session ID
453-
result, err := c.executeHTTPRequest(ctx, method, params, requestID, func(httpReq *http.Request) {
454-
// Add Mcp-Session-Id header with priority:
455-
// 1) Context session ID (if explicitly provided for this request)
456-
// 2) Stored httpSessionID from initialization
457+
// buildSessionHeaderModifier returns a header modifier function that adds the Mcp-Session-Id header.
458+
// Priority: context session ID > stored connection session ID.
459+
// Context session IDs are static for the lifetime of a single request and are captured once at
460+
// construction time. Connection session IDs can change during a reconnect, so getHTTPSessionID()
461+
// is called at request time to always pick up the current value.
462+
func (c *Connection) buildSessionHeaderModifier(ctx context.Context) func(*http.Request) {
463+
// Capture any context-provided session ID once (it never changes for this request).
464+
ctxSessionID, _ := ctx.Value(SessionIDContextKey).(string)
465+
return func(httpReq *http.Request) {
457466
var sessionID string
458-
if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" {
467+
if ctxSessionID != "" {
459468
sessionID = ctxSessionID
460469
logConn.Printf("Using session ID from context: %s", sessionID)
461-
} else if c.httpSessionID != "" {
462-
sessionID = c.httpSessionID
470+
} else if id := c.getHTTPSessionID(); id != "" {
471+
sessionID = id
463472
logConn.Printf("Using stored session ID from initialization: %s", sessionID)
464473
}
465-
466474
if sessionID != "" {
467475
httpReq.Header.Set("Mcp-Session-Id", sessionID)
468476
} else {
469477
logConn.Printf("No session ID available (backend may not require session management)")
470478
}
471-
})
472-
if err != nil {
473-
return nil, err
474479
}
480+
}
475481

476-
logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody))
477-
478-
// Parse JSON-RPC response
479-
// The response might be in SSE format (event: message\ndata: {...})
482+
// parseHTTPResult converts a raw httpRequestResult into a JSON-RPC Response, handling non-OK
483+
// HTTP status codes by synthesising a JSON-RPC error when the server did not provide one.
484+
func parseHTTPResult(result *httpRequestResult) (*Response, error) {
485+
// Parse JSON-RPC response.
486+
// The response might be in SSE format (event: message\ndata: {...}).
480487
rpcResponse, err := parseJSONRPCResponseWithSSE(result.ResponseBody, result.StatusCode, "JSON-RPC response")
481488
if err != nil {
482489
return nil, err
483490
}
484491

485-
// Check for HTTP errors after parsing
492+
// Check for HTTP errors after parsing.
486493
// If we have a non-OK status but successfully parsed a JSON-RPC response,
487-
// pass it through (it may already contain an error field)
494+
// pass it through (it may already contain an error field).
488495
if result.StatusCode != http.StatusOK {
489496
logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", result.StatusCode)
490-
// If the response doesn't already have an error, construct one
497+
// If the response doesn't already have an error, construct one.
491498
if rpcResponse.Error == nil {
492499
rpcResponse.Error = &ResponseError{
493500
Code: -32603, // Internal error
@@ -499,3 +506,44 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
499506

500507
return rpcResponse, nil
501508
}
509+
510+
// sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server.
511+
// The ctx parameter is used to extract session ID for the Mcp-Session-Id header.
512+
// If the backend returns a "session not found" (HTTP 404) response, it attempts a one-time
513+
// session reconnect and retries the request transparently.
514+
func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) {
515+
// For tools/call, ensure arguments field always exists (MCP protocol requirement)
516+
if method == "tools/call" {
517+
params = ensureToolCallArguments(params)
518+
}
519+
520+
headerModifier := c.buildSessionHeaderModifier(ctx)
521+
522+
requestID := atomic.AddUint64(&requestIDCounter, 1)
523+
logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID)
524+
525+
result, err := c.executeHTTPRequest(ctx, method, params, requestID, headerModifier)
526+
if err != nil {
527+
return nil, err
528+
}
529+
530+
logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody))
531+
532+
// If the backend reported that the session has expired, reconnect and retry once.
533+
if isSessionNotFoundHTTPResponse(result.StatusCode, result.ResponseBody) {
534+
logConn.Printf("Session not found from %s (serverID=%s), attempting reconnect", c.httpURL, c.serverID)
535+
if reconnErr := c.reconnectPlainJSON(); reconnErr == nil {
536+
requestID = atomic.AddUint64(&requestIDCounter, 1)
537+
logConn.Printf("Retrying HTTP request after reconnect: method=%s, id=%d", method, requestID)
538+
result, err = c.executeHTTPRequest(ctx, method, params, requestID, headerModifier)
539+
if err != nil {
540+
return nil, err
541+
}
542+
logConn.Printf("Retry HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody))
543+
} else {
544+
logConn.Printf("Session reconnect failed (%v), returning original session-not-found error", reconnErr)
545+
}
546+
}
547+
548+
return parseHTTPResult(result)
549+
}

0 commit comments

Comments
 (0)