diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 372f420fad9..77cf7d95a3f 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1552,6 +1552,15 @@ func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage return true } +func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool { + for _, item := range items { + if gjson.GetBytes(item, "type").String() == "function_call_output" { + return true + } + } + return false +} + func buildOpenAIWSReplayInputSequence( previousFullInput []json.RawMessage, previousFullInputExists bool, @@ -3117,6 +3126,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentTurnReplayInput := []json.RawMessage(nil) currentTurnReplayInputExists := false skipBeforeTurn := false + hasCurrentOrReplayFunctionCallOutput := func(payload []byte) bool { + if gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() { + return true + } + return currentTurnReplayInputExists && openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput) + } resetSessionLease := func(markBroken bool) { if sessionLease == nil { return @@ -3139,7 +3154,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( // 携带 function_call_output 的请求不能丢弃 previous_response_id: // 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use, // 丢弃后会导致 "No tool call found for function call output" 400 错误。 - if gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() { + if hasCurrentOrReplayFunctionCallOutput(currentPayload) { return false } if isStrictAffinityTurn(currentPayload) { @@ -3298,6 +3313,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentTurnReplayInput = nextReplayInput currentTurnReplayInputExists = nextReplayInputExists } + replayHasFunctionCallOutput := currentTurnReplayInputExists && + openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput) + hasFunctionCallOutput = hasFunctionCallOutput || replayHasFunctionCallOutput if storeDisabled && turn > 1 && currentPreviousResponseID != "" { shouldKeepPreviousResponseID := false strictReason := "" @@ -3416,7 +3434,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( // 携带 function_call_output 的请求不能丢弃 previous_response_id: // 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use, // 丢弃后会导致 "No tool call found for function call output" 400 错误。 - hasFCOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + hasFCOutput := hasFunctionCallOutput if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput { updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) if dropErr != nil || !removed { @@ -3464,6 +3482,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } } } + if hasFCOutput && currentPreviousResponseID != "" { + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=function_call_output action=fail_close previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + } resetSessionLease(true) return NewOpenAIWSClientCloseError( coderws.StatusPolicyViolation, diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 5246d37d74b..a4b39ddf83a 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -1918,6 +1918,298 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStr require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenFunctionCallOutputNeedsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"Previous response not found."}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 129, + Name: "openai-ingress-preflight-replay-function-output", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_replay_fc_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_fc_1","input":[{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`) + + select { + case serverErr := <-serverErrCh: + require.Error(t, serverErr) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable") + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, dialer.DialCount(), "需要 previous_response_id 的 function_call_output 在原连接不可用时不应换新连接重试") + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Empty(t, secondWrites, "不能把旧连接的 previous_response_id 发送到新上游,否则会触发 previous_response_not_found") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenReplayHasFunctionCallOutput(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_only_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found for function call output with call_id call_replay_1.","param":"input"}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 130, + Name: "openai-ingress-preflight-replay-only-function-output", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_replay_only_fc_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_only_fc_1","input":[{"type":"input_text","text":"after tool output"}]}`) + + select { + case serverErr := <-serverErrCh: + require.Error(t, serverErr) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable") + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, dialer.DialCount(), "replay input 带 function_call_output 时不应换新连接重试") + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Empty(t, secondWrites, "不能把会触发 No tool call found 的重放请求发到新上游") +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) { gin.SetMode(gin.TestMode)