Skip to content

Commit f9cc755

Browse files
author
Haiping Chen
committed
RealtimeConversationHook
1 parent c8a565f commit f9cc755

6 files changed

Lines changed: 148 additions & 85 deletions

File tree

src/Infrastructure/BotSharp.Abstraction/Realtime/IRealtimeHub.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using BotSharp.Abstraction.MLTasks;
12
using BotSharp.Abstraction.Realtime.Models;
23
using System.Net.WebSockets;
34

@@ -8,5 +9,11 @@ namespace BotSharp.Abstraction.Realtime;
89
/// </summary>
910
public interface IRealtimeHub
1011
{
11-
Task Listen(WebSocket userWebSocket, Func<string, RealtimeHubConnection> onUserMessageReceived);
12+
RealtimeHubConnection HubConn { get; }
13+
RealtimeHubConnection SetHubConnection(string conversationId);
14+
15+
IRealTimeCompletion Completer { get; }
16+
IRealTimeCompletion SetCompleter(string provider);
17+
18+
Task Listen(WebSocket userWebSocket, Action<string> onUserMessageReceived);
1219
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using BotSharp.Abstraction.Utilities;
2+
3+
namespace BotSharp.Core.Realtime.Hooks;
4+
5+
public class RealtimeConversationHook : ConversationHookBase, IConversationHook
6+
{
7+
private readonly IServiceProvider _services;
8+
public RealtimeConversationHook(IServiceProvider services)
9+
{
10+
_services = services;
11+
}
12+
13+
public async Task OnFunctionExecuting(RoleDialogModel message)
14+
{
15+
var hub = _services.GetRequiredService<IRealtimeHub>();
16+
if (hub.HubConn == null)
17+
{
18+
return;
19+
}
20+
// Save states
21+
var states = _services.GetRequiredService<IConversationStateService>();
22+
states.SaveStateByArgs(message.FunctionArgs?.JsonContent<JsonDocument>() ?? JsonDocument.Parse("{}"));
23+
}
24+
25+
public async Task OnFunctionExecuted(RoleDialogModel message)
26+
{
27+
var hub = _services.GetRequiredService<IRealtimeHub>();
28+
if (hub.HubConn == null)
29+
{
30+
return;
31+
}
32+
var routing = _services.GetRequiredService<IRoutingService>();
33+
34+
message.Role = AgentRole.Function;
35+
36+
if (message.FunctionName == "route_to_agent")
37+
{
38+
var inst = JsonSerializer.Deserialize<RoutingArgs>(message.FunctionArgs ?? "{}") ?? new();
39+
message.Content = $"Connected to agent of {inst.AgentName}";
40+
hub.HubConn.CurrentAgentId = routing.Context.GetCurrentAgentId();
41+
42+
await hub.Completer.UpdateSession(hub.HubConn);
43+
await hub.Completer.InsertConversationItem(message);
44+
await hub.Completer.TriggerModelInference($"Guide the user through the next steps of the process as this Agent ({inst.AgentName}), following its instructions and operational procedures.");
45+
}
46+
else if (message.FunctionName == "util-routing-fallback_to_router")
47+
{
48+
var inst = JsonSerializer.Deserialize<FallbackArgs>(message.FunctionArgs ?? "{}") ?? new();
49+
message.Content = $"Returned to Router due to {inst.Reason}";
50+
hub.HubConn.CurrentAgentId = routing.Context.GetCurrentAgentId();
51+
52+
await hub.Completer.UpdateSession(hub.HubConn);
53+
await hub.Completer.InsertConversationItem(message);
54+
await hub.Completer.TriggerModelInference($"Check with user whether to proceed the new request: {inst.Reason}");
55+
}
56+
else
57+
{
58+
// Update session for changed states
59+
await hub.Completer.UpdateSession(hub.HubConn);
60+
await hub.Completer.InsertConversationItem(message);
61+
await hub.Completer.TriggerModelInference("Reply based on the function's output.");
62+
}
63+
}
64+
}

src/Infrastructure/BotSharp.Core.Realtime/RealtimePlugin.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using BotSharp.Abstraction.Plugins;
2+
using BotSharp.Core.Realtime.Hooks;
23
using BotSharp.Core.Realtime.Services;
34
using Microsoft.Extensions.Configuration;
45

@@ -14,5 +15,6 @@ public class RealtimePlugin : IBotSharpPlugin
1415
public void RegisterDI(IServiceCollection services, IConfiguration config)
1516
{
1617
services.AddScoped<IRealtimeHub, RealtimeHub>();
18+
services.AddScoped<IConversationHook, RealtimeConversationHook>();
1719
}
1820
}
Lines changed: 66 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,31 @@
1+
using BotSharp.Abstraction.Utilities;
2+
13
namespace BotSharp.Core.Realtime.Services;
24

35
public class RealtimeHub : IRealtimeHub
46
{
57
private readonly IServiceProvider _services;
68
private readonly ILogger _logger;
79

10+
private RealtimeHubConnection _conn;
11+
public RealtimeHubConnection HubConn => _conn;
12+
13+
private IRealTimeCompletion _completer;
14+
public IRealTimeCompletion Completer => _completer;
15+
816
public RealtimeHub(IServiceProvider services, ILogger<RealtimeHub> logger)
917
{
1018
_services = services;
1119
_logger = logger;
1220
}
1321

1422
public async Task Listen(WebSocket userWebSocket,
15-
Func<string, RealtimeHubConnection> onUserMessageReceived)
23+
Action<string> onUserMessageReceived)
1624
{
1725
var buffer = new byte[1024 * 16];
1826
WebSocketReceiveResult result;
1927

20-
var completer = _services.GetServices<IRealTimeCompletion>().First(x => x.Provider == "openai");
28+
2129

2230
do
2331
{
@@ -29,40 +37,40 @@ public async Task Listen(WebSocket userWebSocket,
2937
continue;
3038
}
3139

32-
var conn = onUserMessageReceived(receivedText);
40+
onUserMessageReceived(receivedText);
3341

34-
if (conn.Event == "user_connected")
42+
if (_conn.Event == "user_connected")
3543
{
36-
await ConnectToModel(completer, userWebSocket, conn);
44+
await ConnectToModel(userWebSocket);
3745
}
38-
else if (conn.Event == "user_data_received")
46+
else if (_conn.Event == "user_data_received")
3947
{
40-
await completer.AppenAudioBuffer(conn.Data);
48+
await _completer.AppenAudioBuffer(_conn.Data);
4149
}
42-
else if (conn.Event == "user_dtmf_received")
50+
else if (_conn.Event == "user_dtmf_received")
4351
{
44-
await HandleUserDtmfReceived(completer, conn);
52+
await HandleUserDtmfReceived();
4553
}
46-
else if (conn.Event == "user_disconnected")
54+
else if (_conn.Event == "user_disconnected")
4755
{
48-
await completer.Disconnect();
49-
await HandleUserDisconnected(conn);
56+
await _completer.Disconnect();
57+
await HandleUserDisconnected();
5058
}
5159
} while (!result.CloseStatus.HasValue);
5260

5361
await userWebSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None);
5462
}
5563

56-
private async Task ConnectToModel(IRealTimeCompletion completer, WebSocket userWebSocket, RealtimeHubConnection conn)
64+
private async Task ConnectToModel(WebSocket userWebSocket)
5765
{
5866
var hookProvider = _services.GetRequiredService<ConversationHookProvider>();
5967
var convService = _services.GetRequiredService<IConversationService>();
60-
convService.SetConversationId(conn.ConversationId, []);
61-
var conversation = await convService.GetConversation(conn.ConversationId);
68+
convService.SetConversationId(_conn.ConversationId, []);
69+
var conversation = await convService.GetConversation(_conn.ConversationId);
6270

6371
var agentService = _services.GetRequiredService<IAgentService>();
6472
var agent = await agentService.LoadAgent(conversation.AgentId);
65-
conn.CurrentAgentId = agent.Id;
73+
_conn.CurrentAgentId = agent.Id;
6674

6775
// Set model
6876
var model = agent.LlmConfig.Model;
@@ -72,8 +80,8 @@ private async Task ConnectToModel(IRealTimeCompletion completer, WebSocket userW
7280
model = llmProviderService.GetProviderModel("openai", "gpt-4", realTime: true).Name;
7381
}
7482

75-
completer.SetModelName(model);
76-
conn.Model = model;
83+
_completer.SetModelName(model);
84+
_conn.Model = model;
7785

7886
var routing = _services.GetRequiredService<IRoutingService>();
7987
routing.Context.Push(agent.Id);
@@ -85,54 +93,48 @@ private async Task ConnectToModel(IRealTimeCompletion completer, WebSocket userW
8593
}
8694
routing.Context.SetDialogs(dialogs);
8795

88-
await completer.Connect(conn,
96+
await _completer.Connect(_conn,
8997
onModelReady: async () =>
9098
{
9199
// Control initial session, prevent initial response interruption
92-
await completer.UpdateSession(conn, turnDetection: false);
93-
94-
// Add dialog history
95-
//foreach (var item in dialogs)
96-
//{
97-
// await completer.InsertConversationItem(item);
98-
//}
100+
await _completer.UpdateSession(_conn, turnDetection: false);
99101

100102
if (dialogs.LastOrDefault()?.Role == AgentRole.Assistant)
101103
{
102-
await completer.TriggerModelInference($"Rephase your last response:\r\n{dialogs.LastOrDefault()?.Content}");
104+
await _completer.TriggerModelInference($"Rephase your last response:\r\n{dialogs.LastOrDefault()?.Content}");
103105
}
104106
else
105107
{
106-
await completer.TriggerModelInference("Reply based on the conversation context.");
108+
await _completer.TriggerModelInference("Reply based on the conversation context.");
107109
}
108110

109111
// Start turn detection
110112
await Task.Delay(1000 * 8);
111-
await completer.UpdateSession(conn, turnDetection: true);
113+
await _completer.UpdateSession(_conn, turnDetection: true);
112114
},
113115
onModelAudioDeltaReceived: async (audioDeltaData, itemId) =>
114116
{
115-
var data = conn.OnModelMessageReceived(audioDeltaData);
117+
var data = _conn.OnModelMessageReceived(audioDeltaData);
116118
await SendEventToUser(userWebSocket, data);
117119

118120
// If this is the first delta of a new response, set the start timestamp
119-
if (!conn.ResponseStartTimestamp.HasValue)
121+
if (!_conn.ResponseStartTimestamp.HasValue)
120122
{
121-
conn.ResponseStartTimestamp = conn.LatestMediaTimestamp;
122-
_logger.LogDebug($"Setting start timestamp for new response: {conn.ResponseStartTimestamp}ms");
123+
_conn.ResponseStartTimestamp = _conn.LatestMediaTimestamp;
124+
_logger.LogDebug($"Setting start timestamp for new response: {_conn.ResponseStartTimestamp}ms");
123125
}
124126
// Record last assistant item ID for interruption handling
125127
if (!string.IsNullOrEmpty(itemId))
126128
{
127-
conn.LastAssistantItemId = itemId;
129+
_conn.LastAssistantItemId = itemId;
128130
}
129131

130132
// Send mark messages to Media Streams so we know if and when AI response playback is finished
131-
await SendMark(userWebSocket, conn);
133+
await SendMark(userWebSocket, _conn);
132134
},
133135
onModelAudioResponseDone: async () =>
134136
{
135-
var data = conn.OnModelAudioResponseDone();
137+
var data = _conn.OnModelAudioResponseDone();
136138
await SendEventToUser(userWebSocket, data);
137139
},
138140
onAudioTranscriptDone: async transcript =>
@@ -144,36 +146,10 @@ await completer.Connect(conn,
144146
foreach (var message in messages)
145147
{
146148
// Invoke function
147-
if (message.MessageType == MessageTypeName.FunctionCall)
149+
if (message.MessageType == MessageTypeName.FunctionCall &&
150+
!string.IsNullOrEmpty(message.FunctionName))
148151
{
149152
await routing.InvokeFunction(message.FunctionName, message);
150-
message.Role = AgentRole.Function;
151-
152-
if (message.FunctionName == "route_to_agent")
153-
{
154-
var inst = JsonSerializer.Deserialize<RoutingArgs>(message.FunctionArgs ?? "{}");
155-
message.Content = $"Connected to agent of {inst.AgentName}";
156-
conn.CurrentAgentId = routing.Context.GetCurrentAgentId();
157-
158-
await completer.UpdateSession(conn);
159-
await completer.InsertConversationItem(message);
160-
await completer.TriggerModelInference($"Guide the user through the next steps of the process as this Agent ({inst.AgentName}), following its instructions and operational procedures.");
161-
}
162-
else if (message.FunctionName == "util-routing-fallback_to_router")
163-
{
164-
var inst = JsonSerializer.Deserialize<FallbackArgs>(message.FunctionArgs ?? "{}");
165-
message.Content = $"Returned to Router due to {inst.Reason}";
166-
conn.CurrentAgentId = routing.Context.GetCurrentAgentId();
167-
168-
await completer.UpdateSession(conn);
169-
await completer.InsertConversationItem(message);
170-
await completer.TriggerModelInference($"Check with user whether to proceed the new request: {inst.Reason}");
171-
}
172-
else
173-
{
174-
await completer.InsertConversationItem(message);
175-
await completer.TriggerModelInference("Reply based on the function's output.");
176-
}
177153
}
178154
else
179155
{
@@ -210,9 +186,9 @@ await completer.Connect(conn,
210186
onUserInterrupted: async () =>
211187
{
212188
// Reset states
213-
conn.ResetResponseState();
189+
_conn.ResetResponseState();
214190

215-
var data = conn.OnModelUserInterrupted();
191+
var data = _conn.OnModelUserInterrupted();
216192
await SendEventToUser(userWebSocket, data);
217193
});
218194
}
@@ -232,17 +208,17 @@ private async Task SendMark(WebSocket userWebSocket, RealtimeHubConnection conn)
232208
}
233209
}
234210

235-
private async Task HandleUserDtmfReceived(IRealTimeCompletion completer, RealtimeHubConnection conn)
211+
private async Task HandleUserDtmfReceived()
236212
{
237213
var routing = _services.GetRequiredService<IRoutingService>();
238214
var hookProvider = _services.GetRequiredService<ConversationHookProvider>();
239215
var agentService = _services.GetRequiredService<IAgentService>();
240-
var agent = await agentService.LoadAgent(conn.CurrentAgentId);
216+
var agent = await agentService.LoadAgent(_conn.CurrentAgentId);
241217
var dialogs = routing.Context.GetDialogs();
242218
var convService = _services.GetRequiredService<IConversationService>();
243-
var conversation = await convService.GetConversation(conn.ConversationId);
219+
var conversation = await convService.GetConversation(_conn.ConversationId);
244220

245-
var message = new RoleDialogModel(AgentRole.User, conn.Data)
221+
var message = new RoleDialogModel(AgentRole.User, _conn.Data)
246222
{
247223
CurrentAgentId = routing.Context.GetCurrentAgentId()
248224
};
@@ -256,19 +232,19 @@ private async Task HandleUserDtmfReceived(IRealTimeCompletion completer, Realtim
256232
await hook.OnMessageReceived(message);
257233
}
258234

259-
await completer.InsertConversationItem(message);
260-
await completer.TriggerModelInference("Reply based on the user input");
235+
await _completer.InsertConversationItem(message);
236+
await _completer.TriggerModelInference("Reply based on the user input");
261237
}
262238

263-
private async Task HandleUserDisconnected(RealtimeHubConnection conn)
239+
private async Task HandleUserDisconnected()
264240
{
265241
// Save dialog history
266242
var routing = _services.GetRequiredService<IRoutingService>();
267243
var storage = _services.GetRequiredService<IConversationStorage>();
268244
var dialogs = routing.Context.GetDialogs();
269245
foreach (var item in dialogs)
270246
{
271-
storage.Append(conn.ConversationId, item);
247+
storage.Append(_conn.ConversationId, item);
272248
}
273249
}
274250

@@ -278,4 +254,20 @@ private async Task SendEventToUser(WebSocket webSocket, object message)
278254
var buffer = Encoding.UTF8.GetBytes(data);
279255
await webSocket.SendAsync(new ArraySegment<byte>(buffer), WebSocketMessageType.Text, true, CancellationToken.None);
280256
}
257+
258+
public RealtimeHubConnection SetHubConnection(string conversationId)
259+
{
260+
_conn = new RealtimeHubConnection
261+
{
262+
ConversationId = conversationId
263+
};
264+
265+
return _conn;
266+
}
267+
268+
public IRealTimeCompletion SetCompleter(string provider)
269+
{
270+
_completer = _services.GetServices<IRealTimeCompletion>().First(x => x.Provider == provider);
271+
return _completer;
272+
}
281273
}

0 commit comments

Comments
 (0)