Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions backend/internal/service/openai_ws_forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,15 @@ func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage
return true
}

func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool {
for _, item := range items {
if gjson.GetBytes(item, "type").String() == "function_call_output" {
return true
}
}
return false
}

func buildOpenAIWSReplayInputSequence(
previousFullInput []json.RawMessage,
previousFullInputExists bool,
Expand Down Expand Up @@ -3117,6 +3126,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentTurnReplayInput := []json.RawMessage(nil)
currentTurnReplayInputExists := false
skipBeforeTurn := false
hasCurrentOrReplayFunctionCallOutput := func(payload []byte) bool {
if gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() {
return true
}
return currentTurnReplayInputExists && openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
}
resetSessionLease := func(markBroken bool) {
if sessionLease == nil {
return
Expand All @@ -3139,7 +3154,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
if gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() {
if hasCurrentOrReplayFunctionCallOutput(currentPayload) {
return false
}
if isStrictAffinityTurn(currentPayload) {
Expand Down Expand Up @@ -3298,6 +3313,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentTurnReplayInput = nextReplayInput
currentTurnReplayInputExists = nextReplayInputExists
}
replayHasFunctionCallOutput := currentTurnReplayInputExists &&
openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
hasFunctionCallOutput = hasFunctionCallOutput || replayHasFunctionCallOutput
if storeDisabled && turn > 1 && currentPreviousResponseID != "" {
shouldKeepPreviousResponseID := false
strictReason := ""
Expand Down Expand Up @@ -3416,7 +3434,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
hasFCOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
hasFCOutput := hasFunctionCallOutput
if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput {
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
if dropErr != nil || !removed {
Expand Down Expand Up @@ -3464,6 +3482,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
}
}
if hasFCOutput && currentPreviousResponseID != "" {
logOpenAIWSModeInfo(
"ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=function_call_output action=fail_close previous_response_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
)
}
resetSessionLease(true)
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
Expand Down
292 changes: 292 additions & 0 deletions backend/internal/service/openai_ws_forwarder_ingress_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,298 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStr
require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
}

func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenFunctionCallOutputNeedsPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
openAIWSIngressPreflightPingIdle = 0
defer func() {
openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
}()

cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3

firstConn := &openAIWSPreflightFailConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
secondConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"Previous response not found."}}`),
},
}
dialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{firstConn, secondConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(dialer)

svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}

account := &Account{
ID: 129,
Name: "openai-ingress-preflight-replay-function-output",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}

serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()

rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req

readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}

serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()

dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()

writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}

writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_turn_ping_replay_fc_1", gjson.GetBytes(firstTurn, "response.id").String())

writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_fc_1","input":[{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)

select {
case serverErr := <-serverErrCh:
require.Error(t, serverErr)
var closeErr *OpenAIWSClientCloseError
require.ErrorAs(t, serverErr, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable")
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}

require.Equal(t, 1, dialer.DialCount(), "需要 previous_response_id 的 function_call_output 在原连接不可用时不应换新连接重试")
secondConn.mu.Lock()
secondWrites := append([]map[string]any(nil), secondConn.writes...)
secondConn.mu.Unlock()
require.Empty(t, secondWrites, "不能把旧连接的 previous_response_id 发送到新上游,否则会触发 previous_response_not_found")
}

func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenReplayHasFunctionCallOutput(t *testing.T) {
gin.SetMode(gin.TestMode)
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
openAIWSIngressPreflightPingIdle = 0
defer func() {
openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
}()

cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3

firstConn := &openAIWSPreflightFailConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_only_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
secondConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found for function call output with call_id call_replay_1.","param":"input"}}`),
},
}
dialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{firstConn, secondConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(dialer)

svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}

account := &Account{
ID: 130,
Name: "openai-ingress-preflight-replay-only-function-output",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}

serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()

rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req

readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}

serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()

dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()

writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}

writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_turn_ping_replay_only_fc_1", gjson.GetBytes(firstTurn, "response.id").String())

writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_only_fc_1","input":[{"type":"input_text","text":"after tool output"}]}`)

select {
case serverErr := <-serverErrCh:
require.Error(t, serverErr)
var closeErr *OpenAIWSClientCloseError
require.ErrorAs(t, serverErr, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable")
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}

require.Equal(t, 1, dialer.DialCount(), "replay input 带 function_call_output 时不应换新连接重试")
secondConn.mu.Lock()
secondWrites := append([]map[string]any(nil), secondConn.writes...)
secondConn.mu.Unlock()
require.Empty(t, secondWrites, "不能把会触发 No tool call found 的重放请求发到新上游")
}

func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) {
gin.SetMode(gin.TestMode)

Expand Down
Loading