Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/arbiterAI/chatClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,16 @@ ErrorCode ChatClient::completion(const CompletionRequest& request, CompletionRes
}

// Add assistant response to history
if (!response.text.empty())
// Must include tool_calls when present so tool results can reference them
if (!response.text.empty() || !response.toolCalls.empty())
{
Message assistantMsg;
assistantMsg.role = "assistant";
assistantMsg.content = response.text;
if (!response.toolCalls.empty())
{
assistantMsg.toolCalls = response.toolCalls;
}
m_history.push_back(assistantMsg);
}

Expand Down
195 changes: 189 additions & 6 deletions src/arbiterAI/providers/openai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,46 @@

namespace arbiterAI
{

// Sanitize a JSON Schema for llama.cpp server compatibility.
// The llama.cpp Jinja chat template uses `items` as a built-in filter,
// so parameter schemas with "type": "array" cause template errors even
// when "items" is stripped (template branches on type=="array" and expects items).
// Convert array-typed properties to string type with a description note.
static void sanitizeSchemaForLlamaCpp(nlohmann::json &schema)
{
if(!schema.is_object())
return;

// If this schema has "properties", recurse into each property
if(schema.contains("properties") && schema["properties"].is_object())
{
for(auto &[key, prop] : schema["properties"].items())
{
sanitizeSchemaForLlamaCpp(prop);
}
}

// Convert array-typed properties to string — the llama.cpp Jinja template
// can't handle "array" type with "items" sub-schema
if(schema.contains("type") && schema["type"] == "array")
{
std::string itemType = "string";
if(schema.contains("items") && schema["items"].is_object())
{
itemType = schema["items"].value("type", "any");
schema.erase("items");
}
schema["type"] = "string";
std::string desc = schema.value("description", "");
if(!desc.empty())
desc += " (JSON array of " + itemType + ", e.g. [\"a\",\"b\"])";
else
desc = "JSON array of " + itemType + ", e.g. [\"a\",\"b\"]";
schema["description"] = desc;
}
}

OpenAI::OpenAI()
: BaseProvider("openai")
{
Expand Down Expand Up @@ -54,10 +94,38 @@ nlohmann::json OpenAI::createRequestBody(const CompletionRequest &request, bool
nlohmann::json messages=nlohmann::json::array();
for(const auto &msg:request.messages)
{
messages.push_back({
nlohmann::json msgJson = {
{"role", msg.role},
{"content", msg.content}
});
};

// Include tool_call_id for tool-result messages
if(msg.toolCallId.has_value() && !msg.toolCallId->empty())
{
msgJson["tool_call_id"] = msg.toolCallId.value();
}

// Include tool_calls array for assistant messages that invoked tools
if(msg.role == "assistant" && msg.toolCalls.has_value() && !msg.toolCalls->empty())
{
nlohmann::json toolCallsJson = nlohmann::json::array();
for(const auto &tc : msg.toolCalls.value())
{
toolCallsJson.push_back({
{"id", tc.id},
{"type", "function"},
{"function", {
{"name", tc.name},
{"arguments", tc.arguments.is_string()
? tc.arguments.get<std::string>()
: tc.arguments.dump()}
}}
});
}
msgJson["tool_calls"] = toolCallsJson;
}

messages.push_back(msgJson);
}
body["messages"]=messages;

Expand Down Expand Up @@ -87,6 +155,66 @@ nlohmann::json OpenAI::createRequestBody(const CompletionRequest &request, bool
body["stop"]=request.stop.value();
}

// Serialize tools in OpenAI function-calling format
if(request.tools.has_value() && !request.tools->empty())
{
nlohmann::json toolsJson = nlohmann::json::array();
for(const auto &tool : request.tools.value())
{
nlohmann::json funcJson = {
{"name", tool.name},
{"description", tool.description}
};

// Use parametersSchema if available, otherwise build from parameters vector
if(!tool.parametersSchema.is_null())
{
nlohmann::json params = tool.parametersSchema;
sanitizeSchemaForLlamaCpp(params);
funcJson["parameters"] = params;
}
else if(!tool.parameters.empty())
{
nlohmann::json propsJson = nlohmann::json::object();
std::vector<std::string> requiredParams;
for(const auto &param : tool.parameters)
{
nlohmann::json paramJson = {{"type", param.type}};
if(!param.description.empty())
paramJson["description"] = param.description;
if(!param.schema.is_null())
paramJson.merge_patch(param.schema);
propsJson[param.name] = paramJson;
if(param.required)
requiredParams.push_back(param.name);
}
funcJson["parameters"] = {
{"type", "object"},
{"properties", propsJson}
};
if(!requiredParams.empty())
funcJson["parameters"]["required"] = requiredParams;
sanitizeSchemaForLlamaCpp(funcJson["parameters"]);
}
else
{
funcJson["parameters"] = {{"type", "object"}, {"properties", nlohmann::json::object()}};
}

toolsJson.push_back({
{"type", "function"},
{"function", funcJson}
});
}
body["tools"] = toolsJson;

// Add tool_choice if specified
if(request.tool_choice.has_value())
{
body["tool_choice"] = request.tool_choice.value();
}
}

return body;
}

Expand Down Expand Up @@ -118,23 +246,78 @@ ErrorCode OpenAI::parseResponse(const cpr::Response &rawResponse,
return ErrorCode::InvalidResponse;
}

// Extract the response text from the first choice
// Validate basic response structure
if(!jsonResponse.contains("choices")||
jsonResponse["choices"].empty()||
!jsonResponse["choices"][0].contains("message")||
!jsonResponse["choices"][0]["message"].contains("content"))
!jsonResponse["choices"][0].contains("message"))
{
return ErrorCode::InvalidResponse;
}

response.text=jsonResponse["choices"][0]["message"]["content"];
const auto &choice = jsonResponse["choices"][0];
const auto &message = choice["message"];

// Extract finish_reason
if(choice.contains("finish_reason") && !choice["finish_reason"].is_null())
{
response.finishReason = choice["finish_reason"].get<std::string>();
}

// Extract content (may be empty/null for tool_calls responses)
if(message.contains("content") && !message["content"].is_null())
{
response.text = message["content"].get<std::string>();
}

response.provider="openai";

if(jsonResponse.contains("model"))
{
response.model=jsonResponse["model"];
}

// Extract tool_calls if present
if(message.contains("tool_calls") && message["tool_calls"].is_array())
{
for(const auto &tc : message["tool_calls"])
{
ToolCall toolCall;

if(tc.contains("id"))
toolCall.id = tc["id"].get<std::string>();

if(tc.contains("function"))
{
const auto &func = tc["function"];
if(func.contains("name"))
toolCall.name = func["name"].get<std::string>();
if(func.contains("arguments"))
{
const auto &args = func["arguments"];
if(args.is_string())
{
// Arguments come as a JSON string — parse it
try
{
toolCall.arguments = nlohmann::json::parse(args.get<std::string>());
}
catch(const nlohmann::json::parse_error &)
{
// If parsing fails, store as raw string
toolCall.arguments = args;
}
}
else
{
toolCall.arguments = args;
}
}
}

response.toolCalls.push_back(toolCall);
}
}

// Extract usage information if available
if(jsonResponse.contains("usage")&&
jsonResponse["usage"].contains("total_tokens"))
Expand Down
Loading