From 8bfa1de3a55f932f4270fb26e1c1044442c9446b Mon Sep 17 00:00:00 2001 From: krazer Date: Fri, 3 Apr 2026 00:00:10 -0400 Subject: [PATCH] feat(openai): add tool calling support and llama.cpp compatibility - Serialize tools array in OpenAI-format request body - Parse tool_calls from assistant responses (function name, arguments, id) - Include tool_call_id and tool_calls on messages in conversation history - Store assistant messages with tool_calls in chatClient history - Add sanitizeSchemaForLlamaCpp() to convert array-typed parameters to string type, avoiding Jinja template conflict with 'items' filter - Handle null content in responses when finish_reason is tool_calls --- src/arbiterAI/chatClient.cpp | 7 +- src/arbiterAI/providers/openai.cpp | 195 ++++++++++++++++++++++++++++- 2 files changed, 195 insertions(+), 7 deletions(-) diff --git a/src/arbiterAI/chatClient.cpp b/src/arbiterAI/chatClient.cpp index aa4a0a5..3357f24 100644 --- a/src/arbiterAI/chatClient.cpp +++ b/src/arbiterAI/chatClient.cpp @@ -104,11 +104,16 @@ ErrorCode ChatClient::completion(const CompletionRequest& request, CompletionRes } // Add assistant response to history - if (!response.text.empty()) + // Must include tool_calls when present so tool results can reference them + if (!response.text.empty() || !response.toolCalls.empty()) { Message assistantMsg; assistantMsg.role = "assistant"; assistantMsg.content = response.text; + if (!response.toolCalls.empty()) + { + assistantMsg.toolCalls = response.toolCalls; + } m_history.push_back(assistantMsg); } diff --git a/src/arbiterAI/providers/openai.cpp b/src/arbiterAI/providers/openai.cpp index 705379f..43f34fb 100644 --- a/src/arbiterAI/providers/openai.cpp +++ b/src/arbiterAI/providers/openai.cpp @@ -2,6 +2,46 @@ namespace arbiterAI { + +// Sanitize a JSON Schema for llama.cpp server compatibility. +// The llama.cpp Jinja chat template uses `items` as a built-in filter, +// so parameter schemas with "type": "array" cause template errors even +// when "items" is stripped (template branches on type=="array" and expects items). +// Convert array-typed properties to string type with a description note. +static void sanitizeSchemaForLlamaCpp(nlohmann::json &schema) +{ + if(!schema.is_object()) + return; + + // If this schema has "properties", recurse into each property + if(schema.contains("properties") && schema["properties"].is_object()) + { + for(auto &[key, prop] : schema["properties"].items()) + { + sanitizeSchemaForLlamaCpp(prop); + } + } + + // Convert array-typed properties to string — the llama.cpp Jinja template + // can't handle "array" type with "items" sub-schema + if(schema.contains("type") && schema["type"] == "array") + { + std::string itemType = "string"; + if(schema.contains("items") && schema["items"].is_object()) + { + itemType = schema["items"].value("type", "any"); + schema.erase("items"); + } + schema["type"] = "string"; + std::string desc = schema.value("description", ""); + if(!desc.empty()) + desc += " (JSON array of " + itemType + ", e.g. [\"a\",\"b\"])"; + else + desc = "JSON array of " + itemType + ", e.g. [\"a\",\"b\"]"; + schema["description"] = desc; + } +} + OpenAI::OpenAI() : BaseProvider("openai") { @@ -54,10 +94,38 @@ nlohmann::json OpenAI::createRequestBody(const CompletionRequest &request, bool nlohmann::json messages=nlohmann::json::array(); for(const auto &msg:request.messages) { - messages.push_back({ + nlohmann::json msgJson = { {"role", msg.role}, {"content", msg.content} - }); + }; + + // Include tool_call_id for tool-result messages + if(msg.toolCallId.has_value() && !msg.toolCallId->empty()) + { + msgJson["tool_call_id"] = msg.toolCallId.value(); + } + + // Include tool_calls array for assistant messages that invoked tools + if(msg.role == "assistant" && msg.toolCalls.has_value() && !msg.toolCalls->empty()) + { + nlohmann::json toolCallsJson = nlohmann::json::array(); + for(const auto &tc : msg.toolCalls.value()) + { + toolCallsJson.push_back({ + {"id", tc.id}, + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments.is_string() + ? tc.arguments.get() + : tc.arguments.dump()} + }} + }); + } + msgJson["tool_calls"] = toolCallsJson; + } + + messages.push_back(msgJson); } body["messages"]=messages; @@ -87,6 +155,66 @@ nlohmann::json OpenAI::createRequestBody(const CompletionRequest &request, bool body["stop"]=request.stop.value(); } + // Serialize tools in OpenAI function-calling format + if(request.tools.has_value() && !request.tools->empty()) + { + nlohmann::json toolsJson = nlohmann::json::array(); + for(const auto &tool : request.tools.value()) + { + nlohmann::json funcJson = { + {"name", tool.name}, + {"description", tool.description} + }; + + // Use parametersSchema if available, otherwise build from parameters vector + if(!tool.parametersSchema.is_null()) + { + nlohmann::json params = tool.parametersSchema; + sanitizeSchemaForLlamaCpp(params); + funcJson["parameters"] = params; + } + else if(!tool.parameters.empty()) + { + nlohmann::json propsJson = nlohmann::json::object(); + std::vector requiredParams; + for(const auto ¶m : tool.parameters) + { + nlohmann::json paramJson = {{"type", param.type}}; + if(!param.description.empty()) + paramJson["description"] = param.description; + if(!param.schema.is_null()) + paramJson.merge_patch(param.schema); + propsJson[param.name] = paramJson; + if(param.required) + requiredParams.push_back(param.name); + } + funcJson["parameters"] = { + {"type", "object"}, + {"properties", propsJson} + }; + if(!requiredParams.empty()) + funcJson["parameters"]["required"] = requiredParams; + sanitizeSchemaForLlamaCpp(funcJson["parameters"]); + } + else + { + funcJson["parameters"] = {{"type", "object"}, {"properties", nlohmann::json::object()}}; + } + + toolsJson.push_back({ + {"type", "function"}, + {"function", funcJson} + }); + } + body["tools"] = toolsJson; + + // Add tool_choice if specified + if(request.tool_choice.has_value()) + { + body["tool_choice"] = request.tool_choice.value(); + } + } + return body; } @@ -118,16 +246,29 @@ ErrorCode OpenAI::parseResponse(const cpr::Response &rawResponse, return ErrorCode::InvalidResponse; } - // Extract the response text from the first choice + // Validate basic response structure if(!jsonResponse.contains("choices")|| jsonResponse["choices"].empty()|| - !jsonResponse["choices"][0].contains("message")|| - !jsonResponse["choices"][0]["message"].contains("content")) + !jsonResponse["choices"][0].contains("message")) { return ErrorCode::InvalidResponse; } - response.text=jsonResponse["choices"][0]["message"]["content"]; + const auto &choice = jsonResponse["choices"][0]; + const auto &message = choice["message"]; + + // Extract finish_reason + if(choice.contains("finish_reason") && !choice["finish_reason"].is_null()) + { + response.finishReason = choice["finish_reason"].get(); + } + + // Extract content (may be empty/null for tool_calls responses) + if(message.contains("content") && !message["content"].is_null()) + { + response.text = message["content"].get(); + } + response.provider="openai"; if(jsonResponse.contains("model")) @@ -135,6 +276,48 @@ ErrorCode OpenAI::parseResponse(const cpr::Response &rawResponse, response.model=jsonResponse["model"]; } + // Extract tool_calls if present + if(message.contains("tool_calls") && message["tool_calls"].is_array()) + { + for(const auto &tc : message["tool_calls"]) + { + ToolCall toolCall; + + if(tc.contains("id")) + toolCall.id = tc["id"].get(); + + if(tc.contains("function")) + { + const auto &func = tc["function"]; + if(func.contains("name")) + toolCall.name = func["name"].get(); + if(func.contains("arguments")) + { + const auto &args = func["arguments"]; + if(args.is_string()) + { + // Arguments come as a JSON string — parse it + try + { + toolCall.arguments = nlohmann::json::parse(args.get()); + } + catch(const nlohmann::json::parse_error &) + { + // If parsing fails, store as raw string + toolCall.arguments = args; + } + } + else + { + toolCall.arguments = args; + } + } + } + + response.toolCalls.push_back(toolCall); + } + } + // Extract usage information if available if(jsonResponse.contains("usage")&& jsonResponse["usage"].contains("total_tokens"))