| | | 1 | | using System.Text; |
| | | 2 | | using System.Text.Json; |
| | | 3 | | using Elsa.AI.Abstractions.Contracts; |
| | | 4 | | using Elsa.AI.Abstractions.Models; |
| | | 5 | | using Elsa.AI.Host.Context; |
| | | 6 | | using Elsa.AI.Host.Options; |
| | | 7 | | using Elsa.AI.Host.Streaming; |
| | | 8 | | using Microsoft.Extensions.Logging; |
| | | 9 | | using Microsoft.Extensions.Options; |
| | | 10 | | |
| | | 11 | | namespace Elsa.AI.Host.Services; |
| | | 12 | | |
| | 42 | 13 | | public class AIOrchestrator( |
| | 42 | 14 | | IEnumerable<IAIProvider> providers, |
| | 42 | 15 | | IAIToolRegistry toolRegistry, |
| | 42 | 16 | | IAIConversationStore conversationStore, |
| | 42 | 17 | | AIContextResolver contextResolver, |
| | 42 | 18 | | AIStreamEventMapper streamEventMapper, |
| | 42 | 19 | | IAIAuditSink auditSink, |
| | 42 | 20 | | ILogger<AIOrchestrator> logger, |
| | 42 | 21 | | IOptions<AIHostOptions> options) : IAIOrchestrator |
| | | 22 | | { |
| | | 23 | | private const int MaxProviderTurns = 8; |
| | | 24 | | |
| | | 25 | | public async IAsyncEnumerable<AIStreamEvent> ExecuteChatAsync(AIChatRequest request, [System.Runtime.CompilerService |
| | | 26 | | { |
| | 43 | 27 | | var conversationId = request.ConversationId ?? Guid.NewGuid().ToString("N"); |
| | 43 | 28 | | var sequence = 0L; |
| | 43 | 29 | | var providerSelection = SelectProvider(request); |
| | 43 | 30 | | var provider = providerSelection.Provider; |
| | 43 | 31 | | var conversationPersistenceEnabled = options.Value.ConversationPersistenceEnabled; |
| | 43 | 32 | | AIConversation? conversation = null; |
| | 43 | 33 | | Exception? preparationError = null; |
| | 43 | 34 | | if (conversationPersistenceEnabled) |
| | | 35 | | { |
| | | 36 | | try |
| | | 37 | | { |
| | 42 | 38 | | conversation = await conversationStore.FindAsync(conversationId, cancellationToken); |
| | 41 | 39 | | } |
| | 1 | 40 | | catch (Exception e) when (e is not OperationCanceledException) |
| | | 41 | | { |
| | 1 | 42 | | preparationError = e; |
| | 1 | 43 | | } |
| | | 44 | | } |
| | | 45 | | |
| | 43 | 46 | | if (conversation != null && (!BelongsToTenant(conversation, request.TenantId) || !BelongsToUser(conversation, re |
| | | 47 | | { |
| | 2 | 48 | | conversation = null; |
| | 2 | 49 | | conversationId = Guid.NewGuid().ToString("N"); |
| | | 50 | | } |
| | 43 | 51 | | var messages = conversation?.Messages.ToList() ?? []; |
| | 43 | 52 | | if (request.IsReconnect && messages.Count == 0) |
| | 1 | 53 | | conversationId = Guid.NewGuid().ToString("N"); |
| | | 54 | | |
| | 43 | 55 | | if (request.IsReconnect && IsCompletedReconnect(conversation, request.Message)) |
| | | 56 | | { |
| | 2 | 57 | | var nextSequence = GetNextSequence(messages); |
| | 2 | 58 | | if (conversation!.Status == AIConversationStatus.Failed) |
| | | 59 | | { |
| | | 60 | | var lastAssistantContent = conversation.Messages.LastOrDefault(x => x.Role == AIMessageRole.Assistant)?. |
| | 1 | 61 | | if (!string.IsNullOrEmpty(lastAssistantContent)) |
| | 1 | 62 | | yield return CreateEvent("conversation.error", conversationId, nextSequence++, new JsonObject |
| | 1 | 63 | | { |
| | 1 | 64 | | ["content"] = lastAssistantContent |
| | 1 | 65 | | }); |
| | | 66 | | } |
| | | 67 | | |
| | 2 | 68 | | yield return CreateEvent("conversation.completed", conversationId, nextSequence); |
| | 2 | 69 | | yield break; |
| | | 70 | | } |
| | | 71 | | |
| | 41 | 72 | | var providerSessionId = conversation?.ProviderSessionId; |
| | | 73 | | |
| | 41 | 74 | | if (preparationError == null && provider != null && string.IsNullOrWhiteSpace(providerSessionId)) |
| | | 75 | | { |
| | | 76 | | try |
| | | 77 | | { |
| | 25 | 78 | | var session = await provider.CreateSessionAsync(new CreateAISessionRequest |
| | 25 | 79 | | { |
| | 25 | 80 | | ConversationId = conversationId, |
| | 25 | 81 | | Agent = request.Agent, |
| | 25 | 82 | | TenantId = request.TenantId, |
| | 25 | 83 | | ProviderConfiguration = providerSelection.Configuration |
| | 25 | 84 | | }, cancellationToken); |
| | 24 | 85 | | providerSessionId = session.ProviderSessionId ?? session.Id; |
| | 24 | 86 | | } |
| | 1 | 87 | | catch (Exception e) when (e is not OperationCanceledException) |
| | | 88 | | { |
| | 1 | 89 | | preparationError = e; |
| | 1 | 90 | | } |
| | | 91 | | } |
| | | 92 | | |
| | 41 | 93 | | var isDuplicateReconnectMessage = request.IsReconnect && HasReconnectUserMessage(conversation, request.Message); |
| | 41 | 94 | | var providerHistory = messages.ToList(); |
| | 41 | 95 | | if (request.IsReconnect && messages.Count > 0) |
| | 3 | 96 | | sequence = GetNextSequence(messages); |
| | | 97 | | |
| | 41 | 98 | | yield return CreateEvent("conversation.started", conversationId, sequence++); |
| | | 99 | | |
| | 41 | 100 | | var userMessage = isDuplicateReconnectMessage |
| | 5 | 101 | | ? messages.Last(x => x.Role == AIMessageRole.User && string.Equals(NormalizeMessage(x.Content), NormalizeMes |
| | 41 | 102 | | : CreateMessage(conversationId, AIMessageRole.User, request.Message, sequence++); |
| | | 103 | | |
| | 41 | 104 | | if (!isDuplicateReconnectMessage) |
| | 38 | 105 | | messages.Add(userMessage); |
| | | 106 | | |
| | | 107 | | var knownToolCallIds = RestoreToolResults(messages).Select(x => x.ToolCallId).ToHashSet(StringComparer.OrdinalIg |
| | 41 | 108 | | var pendingToolResults = isDuplicateReconnectMessage ? RestorePendingToolResults(messages) : new List<AIToolTurn |
| | | 109 | | |
| | 41 | 110 | | await TrySaveConversationAsync(conversationId, request, AIConversationStatus.Active, messages, conversation, pro |
| | 41 | 111 | | await RecordChatAuditAsync("chat.started", request, conversationId, provider?.Name, cancellationToken); |
| | | 112 | | |
| | 41 | 113 | | IReadOnlyCollection<AIResolvedContext> context = []; |
| | 41 | 114 | | IReadOnlyCollection<AIToolDefinition> tools = []; |
| | | 115 | | |
| | 41 | 116 | | if (preparationError == null) |
| | | 117 | | { |
| | | 118 | | try |
| | | 119 | | { |
| | 39 | 120 | | context = LimitResolvedContext(await contextResolver.ResolveAsync(request, cancellationToken)); |
| | 38 | 121 | | tools = await toolRegistry.ListAsync(new AIToolQuery |
| | 38 | 122 | | { |
| | 38 | 123 | | Agent = request.Agent, |
| | 38 | 124 | | ActorId = request.UserId, |
| | 38 | 125 | | TenantId = request.TenantId, |
| | 38 | 126 | | UserPermissions = request.UserPermissions |
| | 38 | 127 | | }, cancellationToken); |
| | 38 | 128 | | } |
| | 1 | 129 | | catch (Exception e) when (e is not OperationCanceledException) |
| | | 130 | | { |
| | 1 | 131 | | preparationError = e; |
| | 1 | 132 | | } |
| | | 133 | | } |
| | | 134 | | |
| | 41 | 135 | | if (preparationError != null) |
| | | 136 | | { |
| | | 137 | | const string content = "Weaver could not prepare AI context or tools for this request."; |
| | 3 | 138 | | logger.LogWarning(preparationError, "Failed to prepare AI chat context or tools for conversation {Conversati |
| | 3 | 139 | | yield return CreateEvent("conversation.error", conversationId, sequence++, new JsonObject |
| | 3 | 140 | | { |
| | 3 | 141 | | ["content"] = content |
| | 3 | 142 | | }); |
| | 3 | 143 | | messages.Add(CreateMessage(conversationId, AIMessageRole.Assistant, content, sequence - 1)); |
| | 3 | 144 | | await TrySaveConversationAsync(conversationId, request, AIConversationStatus.Failed, messages, conversation, |
| | 3 | 145 | | await RecordChatAuditAsync("chat.failed", request, conversationId, provider?.Name, cancellationToken); |
| | 3 | 146 | | yield return CreateEvent("conversation.completed", conversationId, sequence); |
| | 3 | 147 | | yield break; |
| | | 148 | | } |
| | | 149 | | |
| | 38 | 150 | | if (provider == null) |
| | | 151 | | { |
| | | 152 | | const string content = "Weaver is ready, but no AI provider is configured."; |
| | 11 | 153 | | yield return CreateEvent("assistant.delta", conversationId, sequence++, new JsonObject |
| | 11 | 154 | | { |
| | 11 | 155 | | ["content"] = content |
| | 11 | 156 | | }); |
| | 11 | 157 | | messages.Add(CreateMessage(conversationId, AIMessageRole.Assistant, content, sequence - 1)); |
| | | 158 | | } |
| | | 159 | | else |
| | | 160 | | { |
| | 27 | 161 | | var assistantContent = new StringBuilder(); |
| | | 162 | | |
| | 80 | 163 | | for (var turn = 0; turn < MaxProviderTurns; turn++) |
| | | 164 | | { |
| | 40 | 165 | | var currentTurnToolResults = new List<AIToolTurnResult>(); |
| | 40 | 166 | | var currentTurnMessages = new List<AIMessage>(); |
| | 40 | 167 | | var currentTurnToolMessages = new List<AIMessage>(); |
| | 40 | 168 | | assistantContent.Clear(); |
| | | 169 | | |
| | 40 | 170 | | var turnRequest = new AITurnRequest |
| | 40 | 171 | | { |
| | 40 | 172 | | ConversationId = conversationId, |
| | 40 | 173 | | ProviderSessionId = providerSessionId, |
| | 40 | 174 | | Message = turn == 0 && !isDuplicateReconnectMessage ? request.Message : "", |
| | 40 | 175 | | Messages = providerHistory.ToList(), |
| | 40 | 176 | | Context = context, |
| | | 177 | | Tools = tools.Where(x => x.IsEnabled).ToList(), |
| | 40 | 178 | | ToolResults = GetUnrepresentedToolResults(pendingToolResults, providerHistory), |
| | 40 | 179 | | Agent = request.Agent, |
| | 40 | 180 | | ProviderConfiguration = providerSelection.Configuration |
| | 40 | 181 | | }; |
| | 40 | 182 | | Exception? providerTurnError = null; |
| | 161 | 183 | | await foreach (var providerRead in ReadProviderEventsAsync(provider.ExecuteTurnAsync(turnRequest, cancel |
| | | 184 | | { |
| | 41 | 185 | | if (providerRead.Error != null) |
| | | 186 | | { |
| | 1 | 187 | | providerTurnError = providerRead.Error; |
| | 1 | 188 | | break; |
| | | 189 | | } |
| | | 190 | | |
| | 40 | 191 | | var providerEvent = providerRead.Event!; |
| | 40 | 192 | | var streamEvent = streamEventMapper.Map(conversationId, providerEvent) with { Sequence = sequence++ |
| | 40 | 193 | | yield return streamEvent; |
| | | 194 | | |
| | 40 | 195 | | if (TryReadAssistantContent(providerEvent, out var content)) |
| | 17 | 196 | | assistantContent.Append(content); |
| | | 197 | | |
| | 40 | 198 | | if (!TryReadToolCall(providerEvent, out var toolCall)) |
| | | 199 | | continue; |
| | | 200 | | |
| | 19 | 201 | | if (knownToolCallIds.Contains(toolCall.Id) || |
| | 19 | 202 | | currentTurnToolResults.Any(x => string.Equals(x.ToolCallId, toolCall.Id, StringComparison.Ordina |
| | | 203 | | continue; |
| | | 204 | | |
| | 14 | 205 | | var toolExecution = await ExecuteToolCallAsync(toolCall, request, conversationId, sequence++, cancel |
| | 14 | 206 | | yield return toolExecution.StreamEvent; |
| | | 207 | | |
| | 14 | 208 | | currentTurnToolResults.Add(toolExecution.TurnResult); |
| | 14 | 209 | | var toolMessage = CreateMessage(conversationId, AIMessageRole.Tool, toolExecution.TurnResult.Result. |
| | 14 | 210 | | { |
| | 14 | 211 | | ["toolCallId"] = toolExecution.TurnResult.ToolCallId, |
| | 14 | 212 | | ["toolName"] = toolExecution.TurnResult.ToolName, |
| | 14 | 213 | | ["status"] = toolExecution.TurnResult.Result.Status.ToString() |
| | 14 | 214 | | }); |
| | 14 | 215 | | currentTurnToolMessages.Add(toolMessage); |
| | 14 | 216 | | } |
| | | 217 | | |
| | 40 | 218 | | if (providerTurnError != null) |
| | | 219 | | { |
| | | 220 | | const string content = "Weaver could not complete the AI provider turn for this request."; |
| | 1 | 221 | | logger.LogWarning(providerTurnError, "Failed to execute AI provider turn for conversation {Conversat |
| | 1 | 222 | | yield return CreateEvent("conversation.error", conversationId, sequence++, new JsonObject |
| | 1 | 223 | | { |
| | 1 | 224 | | ["content"] = content |
| | 1 | 225 | | }); |
| | 1 | 226 | | messages.Add(CreateMessage(conversationId, AIMessageRole.Assistant, content, sequence - 1)); |
| | 1 | 227 | | await TrySaveConversationAsync(conversationId, request, AIConversationStatus.Failed, messages, conve |
| | 1 | 228 | | await RecordChatAuditAsync("chat.failed", request, conversationId, provider.Name, cancellationToken) |
| | 1 | 229 | | yield return CreateEvent("conversation.completed", conversationId, sequence); |
| | 1 | 230 | | yield break; |
| | | 231 | | } |
| | | 232 | | |
| | 39 | 233 | | if (assistantContent.Length > 0 || currentTurnToolMessages.Count > 0) |
| | | 234 | | { |
| | 31 | 235 | | var assistantSequence = currentTurnToolMessages.Count > 0 |
| | | 236 | | ? currentTurnToolMessages.Min(x => x.StreamSequence) - 1 |
| | 31 | 237 | | : sequence - 1; |
| | 31 | 238 | | var assistantMessage = CreateMessage(conversationId, AIMessageRole.Assistant, assistantContent.ToStr |
| | 31 | 239 | | messages.Add(assistantMessage); |
| | 31 | 240 | | currentTurnMessages.Add(assistantMessage); |
| | | 241 | | } |
| | | 242 | | |
| | 39 | 243 | | messages.AddRange(currentTurnToolMessages); |
| | 39 | 244 | | currentTurnMessages.AddRange(currentTurnToolMessages); |
| | | 245 | | |
| | 39 | 246 | | if (currentTurnToolResults.Count == 0) |
| | | 247 | | break; |
| | | 248 | | |
| | 56 | 249 | | foreach (var toolResult in currentTurnToolResults) |
| | 14 | 250 | | knownToolCallIds.Add(toolResult.ToolCallId); |
| | | 251 | | |
| | 14 | 252 | | pendingToolResults = currentTurnToolResults; |
| | 21 | 253 | | if (providerHistory.All(x => x.Id != userMessage.Id)) |
| | 7 | 254 | | providerHistory.Add(userMessage); |
| | | 255 | | |
| | 14 | 256 | | providerHistory.AddRange(currentTurnMessages); |
| | 14 | 257 | | await TrySaveConversationAsync(conversationId, request, AIConversationStatus.Active, messages, conversat |
| | | 258 | | |
| | 14 | 259 | | if (turn == MaxProviderTurns - 1) |
| | | 260 | | { |
| | | 261 | | const string content = "Tool execution stopped because the provider requested too many continuation |
| | 1 | 262 | | yield return CreateEvent("assistant.delta", conversationId, sequence++, new JsonObject |
| | 1 | 263 | | { |
| | 1 | 264 | | ["content"] = content |
| | 1 | 265 | | }); |
| | 1 | 266 | | messages.Add(CreateMessage(conversationId, AIMessageRole.Assistant, content, sequence - 1)); |
| | 1 | 267 | | break; |
| | | 268 | | } |
| | 13 | 269 | | } |
| | 26 | 270 | | } |
| | | 271 | | |
| | 37 | 272 | | await TrySaveConversationAsync(conversationId, request, AIConversationStatus.Completed, messages, conversation, |
| | 37 | 273 | | await RecordChatAuditAsync("chat.completed", request, conversationId, provider?.Name, cancellationToken); |
| | 37 | 274 | | yield return CreateEvent("conversation.completed", conversationId, sequence); |
| | 43 | 275 | | } |
| | | 276 | | |
| | | 277 | | private static AIStreamEvent CreateEvent(string type, string conversationId, long sequence, JsonObject? data = null) |
| | 115 | 278 | | new() |
| | 115 | 279 | | { |
| | 115 | 280 | | Type = type, |
| | 115 | 281 | | ConversationId = conversationId, |
| | 115 | 282 | | Sequence = sequence, |
| | 115 | 283 | | Timestamp = DateTimeOffset.UtcNow, |
| | 115 | 284 | | Data = data ?? [] |
| | 115 | 285 | | }; |
| | | 286 | | |
| | | 287 | | private static async IAsyncEnumerable<ProviderReadResult> ReadProviderEventsAsync( |
| | | 288 | | IAsyncEnumerable<AIProviderEvent> providerEvents, |
| | | 289 | | [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken) |
| | | 290 | | { |
| | 40 | 291 | | var enumerator = providerEvents.GetAsyncEnumerator(cancellationToken); |
| | | 292 | | try |
| | | 293 | | { |
| | 40 | 294 | | while (true) |
| | | 295 | | { |
| | 80 | 296 | | AIProviderEvent? providerEvent = null; |
| | 80 | 297 | | Exception? error = null; |
| | 80 | 298 | | var hasEvent = false; |
| | | 299 | | |
| | | 300 | | try |
| | | 301 | | { |
| | 80 | 302 | | hasEvent = await enumerator.MoveNextAsync(); |
| | 79 | 303 | | if (hasEvent) |
| | 40 | 304 | | providerEvent = enumerator.Current; |
| | 79 | 305 | | } |
| | 1 | 306 | | catch (Exception e) when (e is not OperationCanceledException) |
| | | 307 | | { |
| | 1 | 308 | | error = e; |
| | 1 | 309 | | } |
| | | 310 | | |
| | 80 | 311 | | if (error != null) |
| | | 312 | | { |
| | 1 | 313 | | yield return new ProviderReadResult(null, error); |
| | 0 | 314 | | yield break; |
| | | 315 | | } |
| | | 316 | | |
| | 79 | 317 | | if (!hasEvent) |
| | 39 | 318 | | yield break; |
| | | 319 | | |
| | 40 | 320 | | yield return new ProviderReadResult(providerEvent, null); |
| | 40 | 321 | | } |
| | | 322 | | } |
| | | 323 | | finally |
| | | 324 | | { |
| | 40 | 325 | | await enumerator.DisposeAsync(); |
| | | 326 | | } |
| | 40 | 327 | | } |
| | | 328 | | |
| | | 329 | | private ProviderSelection SelectProvider(AIChatRequest request) |
| | | 330 | | { |
| | 43 | 331 | | var providerOptions = options.Value.Providers.ToList(); |
| | 45 | 332 | | var configuredProviders = providerOptions.Where(x => x.Enabled).ToList(); |
| | 43 | 333 | | var availableProviders = providers |
| | 32 | 334 | | .Where(x => providerOptions.IsProviderEnabled(x.Name)) |
| | 43 | 335 | | .ToList(); |
| | 43 | 336 | | var providerName = request.ProviderName ?? FindAgentProviderName(request.Agent) ?? options.Value.DefaultProvider |
| | | 337 | | |
| | 43 | 338 | | if (!string.IsNullOrWhiteSpace(providerName)) |
| | | 339 | | { |
| | 5 | 340 | | var configuredProvider = configuredProviders.FirstOrDefault(x => string.Equals(x.Name, providerName, StringC |
| | 4 | 341 | | var provider = configuredProvider != null |
| | 1 | 342 | | ? availableProviders.FirstOrDefault(x => string.Equals(x.Name, configuredProvider.Name, StringComparison |
| | 1 | 343 | | availableProviders.FirstOrDefault(x => string.Equals(x.Name, configuredProvider.Provider, StringCompar |
| | 8 | 344 | | : availableProviders.FirstOrDefault(x => string.Equals(x.Name, providerName, StringComparison.OrdinalIgn |
| | | 345 | | |
| | 4 | 346 | | return new ProviderSelection(provider, configuredProvider?.ToProviderConfiguration()); |
| | | 347 | | } |
| | | 348 | | |
| | 39 | 349 | | if (availableProviders.Count != 1) |
| | | 350 | | { |
| | 13 | 351 | | if (availableProviders.Count > 1) |
| | 0 | 352 | | logger.LogWarning( |
| | 0 | 353 | | "Multiple AI providers are available ({ProviderNames}) but no default provider name is configured. S |
| | 0 | 354 | | string.Join(", ", availableProviders.Select(x => x.Name))); |
| | | 355 | | |
| | 13 | 356 | | return new ProviderSelection(null, null); |
| | | 357 | | } |
| | | 358 | | |
| | 26 | 359 | | var selectedProvider = availableProviders[0]; |
| | 26 | 360 | | var selectedConfiguration = configuredProviders.FirstOrDefault(x => string.Equals(x.Name, selectedProvider.Name, |
| | 26 | 361 | | string.Equals(x.Provider, selectedProvider.N |
| | | 362 | | |
| | 26 | 363 | | return new ProviderSelection(selectedProvider, selectedConfiguration?.ToProviderConfiguration()); |
| | | 364 | | } |
| | | 365 | | |
| | | 366 | | private string? FindAgentProviderName(string? agent) => |
| | 40 | 367 | | string.IsNullOrWhiteSpace(agent) |
| | 40 | 368 | | ? null |
| | 41 | 369 | | : options.Value.Agents.FirstOrDefault(x => string.Equals(x.Name, agent, StringComparison.OrdinalIgnoreCase)) |
| | | 370 | | |
| | | 371 | | private async ValueTask RecordChatAuditAsync(string type, AIChatRequest request, string conversationId, string? prov |
| | | 372 | | { |
| | | 373 | | try |
| | | 374 | | { |
| | 82 | 375 | | await auditSink.RecordAsync(new AIAuditEvent |
| | 82 | 376 | | { |
| | 82 | 377 | | Type = type, |
| | 82 | 378 | | TenantId = request.TenantId, |
| | 82 | 379 | | ActorId = request.UserId, |
| | 82 | 380 | | ConversationId = conversationId, |
| | 82 | 381 | | Timestamp = DateTimeOffset.UtcNow, |
| | 82 | 382 | | Summary = type switch |
| | 82 | 383 | | { |
| | 41 | 384 | | "chat.started" => "Chat started", |
| | 4 | 385 | | "chat.failed" => "Chat failed", |
| | 37 | 386 | | _ => "Chat completed" |
| | 82 | 387 | | }, |
| | 82 | 388 | | Data = new JsonObject |
| | 82 | 389 | | { |
| | 82 | 390 | | ["agent"] = request.Agent, |
| | 82 | 391 | | ["provider"] = providerName, |
| | 82 | 392 | | ["attachmentCount"] = request.Attachments.Count |
| | 82 | 393 | | } |
| | 82 | 394 | | }, cancellationToken); |
| | 80 | 395 | | } |
| | 2 | 396 | | catch (Exception e) when (e is not OperationCanceledException) |
| | | 397 | | { |
| | 2 | 398 | | logger.LogWarning(e, "Failed to record AI chat audit event {AuditEventType} for conversation {ConversationId |
| | 2 | 399 | | } |
| | 82 | 400 | | } |
| | | 401 | | |
| | | 402 | | private async ValueTask<ToolExecutionResult> ExecuteToolCallAsync(ToolCall toolCall, AIChatRequest request, string c |
| | | 403 | | { |
| | 14 | 404 | | var tool = await toolRegistry.FindAsync(toolCall.Name, new AIToolQuery |
| | 14 | 405 | | { |
| | 14 | 406 | | Agent = request.Agent, |
| | 14 | 407 | | ActorId = request.UserId, |
| | 14 | 408 | | TenantId = request.TenantId, |
| | 14 | 409 | | UserPermissions = request.UserPermissions |
| | 14 | 410 | | }, cancellationToken); |
| | 14 | 411 | | if (tool == null) |
| | | 412 | | { |
| | 1 | 413 | | var result = new AIToolResult { Status = AIToolInvocationStatus.Failed, Error = $"Tool '{toolCall.Name}' was |
| | 1 | 414 | | await RecordToolAuditEventsAsync(request, conversationId, toolCall, ["tool.failed"], cancellationToken); |
| | 1 | 415 | | return CreateToolExecutionResult(conversationId, sequence, toolCall, result); |
| | | 416 | | } |
| | | 417 | | |
| | 13 | 418 | | using var toolScope = tool; |
| | | 419 | | try |
| | | 420 | | { |
| | 13 | 421 | | await RecordToolAuditEventsAsync(request, conversationId, toolCall, ["tool.invoked"], cancellationToken); |
| | 13 | 422 | | var result = await tool.ExecuteAsync(new AIToolExecutionContext |
| | 13 | 423 | | { |
| | 13 | 424 | | ConversationId = conversationId, |
| | 13 | 425 | | TenantId = request.TenantId, |
| | 13 | 426 | | ActorId = request.UserId, |
| | 13 | 427 | | Agent = request.Agent, |
| | 13 | 428 | | Arguments = toolCall.Arguments |
| | 13 | 429 | | }, cancellationToken); |
| | 12 | 430 | | await RecordToolAuditEventsAsync(request, conversationId, toolCall, ["tool.completed"], cancellationToken); |
| | | 431 | | |
| | 12 | 432 | | return CreateToolExecutionResult(conversationId, sequence, toolCall, LimitToolResult(result)); |
| | | 433 | | } |
| | 1 | 434 | | catch (Exception e) when (e is not OperationCanceledException) |
| | | 435 | | { |
| | 1 | 436 | | logger.LogWarning(e, "AI tool {ToolName} failed for conversation {ConversationId}.", toolCall.Name, conversa |
| | 1 | 437 | | await RecordToolAuditEventsAsync(request, conversationId, toolCall, ["tool.failed"], cancellationToken); |
| | 1 | 438 | | return CreateToolExecutionResult(conversationId, sequence, toolCall, new AIToolResult { Status = AIToolInvoc |
| | | 439 | | } |
| | 14 | 440 | | } |
| | | 441 | | |
| | | 442 | | private async ValueTask RecordToolAuditEventsAsync(AIChatRequest request, string conversationId, ToolCall toolCall, |
| | | 443 | | { |
| | | 444 | | try |
| | | 445 | | { |
| | 54 | 446 | | await auditSink.RecordManyAsync(types.Select(type => CreateToolAuditEvent(type, request, conversationId, too |
| | 27 | 447 | | } |
| | 0 | 448 | | catch (Exception e) when (e is not OperationCanceledException) |
| | | 449 | | { |
| | 0 | 450 | | logger.LogWarning(e, "Failed to record AI tool audit events for tool {ToolName}.", toolCall.Name); |
| | 0 | 451 | | } |
| | 27 | 452 | | } |
| | | 453 | | |
| | | 454 | | private static AIAuditEvent CreateToolAuditEvent(string type, AIChatRequest request, string conversationId, ToolCall |
| | 27 | 455 | | new() |
| | 27 | 456 | | { |
| | 27 | 457 | | Type = type, |
| | 27 | 458 | | TenantId = request.TenantId, |
| | 27 | 459 | | ActorId = request.UserId, |
| | 27 | 460 | | ConversationId = conversationId, |
| | 27 | 461 | | ToolInvocationId = toolCall.Id, |
| | 27 | 462 | | Timestamp = DateTimeOffset.UtcNow, |
| | 27 | 463 | | Summary = $"{toolCall.Name} {type}", |
| | 27 | 464 | | Data = new JsonObject |
| | 27 | 465 | | { |
| | 27 | 466 | | ["toolName"] = toolCall.Name |
| | 27 | 467 | | } |
| | 27 | 468 | | }; |
| | | 469 | | |
| | | 470 | | private static AIStreamEvent CreateToolResultEvent(string conversationId, long sequence, ToolCall toolCall, AIToolRe |
| | 14 | 471 | | CreateEvent("tool.result", conversationId, sequence, new JsonObject |
| | 14 | 472 | | { |
| | 14 | 473 | | ["toolCallId"] = toolCall.Id, |
| | 14 | 474 | | ["toolName"] = toolCall.Name, |
| | 14 | 475 | | ["status"] = result.Status.ToString(), |
| | 14 | 476 | | ["summary"] = result.Summary, |
| | 14 | 477 | | ["error"] = result.Error, |
| | 14 | 478 | | ["data"] = result.Data.DeepClone() |
| | 14 | 479 | | }); |
| | | 480 | | |
| | | 481 | | private static ToolExecutionResult CreateToolExecutionResult(string conversationId, long sequence, ToolCall toolCall |
| | 14 | 482 | | new(CreateToolResultEvent(conversationId, sequence, toolCall, result), new AIToolTurnResult |
| | 14 | 483 | | { |
| | 14 | 484 | | ToolCallId = toolCall.Id, |
| | 14 | 485 | | ToolName = toolCall.Name, |
| | 14 | 486 | | Result = result |
| | 14 | 487 | | }); |
| | | 488 | | |
| | | 489 | | private IReadOnlyCollection<AIResolvedContext> LimitResolvedContext(IReadOnlyCollection<AIResolvedContext> contexts) |
| | | 490 | | { |
| | 38 | 491 | | var maxBytes = options.Value.MaxResolvedContextBytes; |
| | 38 | 492 | | if (maxBytes <= 0) |
| | 1 | 493 | | return contexts; |
| | | 494 | | |
| | 37 | 495 | | var limited = new List<AIResolvedContext>(); |
| | 37 | 496 | | var usedBytes = 0; |
| | | 497 | | |
| | 84 | 498 | | foreach (var context in contexts) |
| | | 499 | | { |
| | 7 | 500 | | var contextSize = GetUtf8Size(context); |
| | 7 | 501 | | if (usedBytes + contextSize <= maxBytes) |
| | | 502 | | { |
| | 2 | 503 | | usedBytes += contextSize; |
| | 2 | 504 | | limited.Add(context); |
| | 2 | 505 | | continue; |
| | | 506 | | } |
| | | 507 | | |
| | 5 | 508 | | if (limited.Count > 0) |
| | | 509 | | { |
| | 1 | 510 | | logger.LogDebug( |
| | 1 | 511 | | "Dropping AI resolved context {ContextKind}/{ReferenceId} because resolved context exceeds the confi |
| | 1 | 512 | | context.Kind, |
| | 1 | 513 | | context.ReferenceId, |
| | 1 | 514 | | maxBytes); |
| | 1 | 515 | | continue; |
| | | 516 | | } |
| | | 517 | | |
| | 4 | 518 | | limited.Add(TruncateContext(context, maxBytes)); |
| | 4 | 519 | | break; |
| | | 520 | | } |
| | | 521 | | |
| | 37 | 522 | | return limited; |
| | | 523 | | } |
| | | 524 | | |
| | | 525 | | private static AIResolvedContext TruncateContext(AIResolvedContext context, int maxBytes) => |
| | 4 | 526 | | context with |
| | 4 | 527 | | { |
| | 4 | 528 | | Summary = Truncate(context.Summary, maxBytes), |
| | 4 | 529 | | Data = CreateTruncatedPayload(maxBytes), |
| | 4 | 530 | | Metadata = CreateTruncatedPayload(maxBytes) |
| | 4 | 531 | | }; |
| | | 532 | | |
| | | 533 | | private AIToolResult LimitToolResult(AIToolResult result) |
| | | 534 | | { |
| | 12 | 535 | | if (GetUtf8Size(result) <= options.Value.MaxToolResultBytes) |
| | 11 | 536 | | return result; |
| | | 537 | | |
| | 1 | 538 | | return result with |
| | 1 | 539 | | { |
| | 1 | 540 | | Summary = Truncate(result.Summary, options.Value.MaxToolResultBytes), |
| | 1 | 541 | | Data = CreateTruncatedPayload(options.Value.MaxToolResultBytes) |
| | 1 | 542 | | }; |
| | | 543 | | } |
| | | 544 | | |
| | | 545 | | private static JsonObject CreateTruncatedPayload(int maxBytes) => |
| | 9 | 546 | | new() |
| | 9 | 547 | | { |
| | 9 | 548 | | ["truncated"] = true, |
| | 9 | 549 | | ["maxBytes"] = maxBytes |
| | 9 | 550 | | }; |
| | | 551 | | |
| | | 552 | | private static int GetUtf8Size<T>(T value) => |
| | 19 | 553 | | JsonSerializer.SerializeToUtf8Bytes(value).Length; |
| | | 554 | | |
| | | 555 | | private static string Truncate(string value, int maxBytes) |
| | | 556 | | { |
| | 5 | 557 | | if (string.IsNullOrEmpty(value)) |
| | 0 | 558 | | return value; |
| | | 559 | | |
| | 5 | 560 | | if (maxBytes <= 0) |
| | 0 | 561 | | return ""; |
| | | 562 | | |
| | 5 | 563 | | if (Encoding.UTF8.GetByteCount(value) <= maxBytes) |
| | 0 | 564 | | return value; |
| | | 565 | | |
| | 5 | 566 | | var low = 0; |
| | 5 | 567 | | var high = Math.Min(value.Length, maxBytes); |
| | 34 | 568 | | while (low < high) |
| | | 569 | | { |
| | 29 | 570 | | var candidate = (low + high + 1) / 2; |
| | 29 | 571 | | if (Encoding.UTF8.GetByteCount(value.AsSpan(0, candidate)) <= maxBytes) |
| | 25 | 572 | | low = candidate; |
| | | 573 | | else |
| | 4 | 574 | | high = candidate - 1; |
| | | 575 | | } |
| | | 576 | | |
| | 5 | 577 | | if (low > 0 && char.IsHighSurrogate(value[low - 1])) |
| | 1 | 578 | | low--; |
| | | 579 | | |
| | 5 | 580 | | return value[..low]; |
| | | 581 | | } |
| | | 582 | | |
| | | 583 | | private static bool TryReadToolCall(AIProviderEvent providerEvent, out ToolCall toolCall) |
| | | 584 | | { |
| | 40 | 585 | | toolCall = default; |
| | 40 | 586 | | if (!string.Equals(providerEvent.Type, "tool.call", StringComparison.OrdinalIgnoreCase)) |
| | 21 | 587 | | return false; |
| | | 588 | | |
| | 19 | 589 | | var name = providerEvent.Data["toolName"]?.GetValue<string>() ?? providerEvent.Data["name"]?.GetValue<string>(); |
| | 19 | 590 | | if (string.IsNullOrWhiteSpace(name)) |
| | 0 | 591 | | return false; |
| | | 592 | | |
| | 19 | 593 | | var id = providerEvent.Data["id"]?.GetValue<string>() ?? Guid.NewGuid().ToString("N"); |
| | 19 | 594 | | var arguments = providerEvent.Data["arguments"]?.DeepClone() as JsonObject ?? []; |
| | 19 | 595 | | toolCall = new ToolCall(id, name, arguments); |
| | 19 | 596 | | return true; |
| | | 597 | | } |
| | | 598 | | |
| | | 599 | | private static bool TryReadAssistantContent(AIProviderEvent providerEvent, out string content) |
| | | 600 | | { |
| | 40 | 601 | | content = ""; |
| | 40 | 602 | | if (!string.Equals(providerEvent.Type, "assistant.delta", StringComparison.OrdinalIgnoreCase)) |
| | 19 | 603 | | return false; |
| | | 604 | | |
| | 21 | 605 | | content = providerEvent.Data["content"]?.GetValue<string>() ?? ""; |
| | 21 | 606 | | return !string.IsNullOrEmpty(content); |
| | | 607 | | } |
| | | 608 | | |
| | | 609 | | private static AIMessage CreateMessage(string conversationId, AIMessageRole role, string content, long streamSequenc |
| | 99 | 610 | | new() |
| | 99 | 611 | | { |
| | 99 | 612 | | Id = Guid.NewGuid().ToString("N"), |
| | 99 | 613 | | ConversationId = conversationId, |
| | 99 | 614 | | Role = role, |
| | 99 | 615 | | Content = content, |
| | 99 | 616 | | CreatedAt = DateTimeOffset.UtcNow, |
| | 99 | 617 | | StreamSequence = streamSequence, |
| | 99 | 618 | | Metadata = metadata ?? [] |
| | 99 | 619 | | }; |
| | | 620 | | |
| | | 621 | | private static JsonObject? CreateAssistantToolCallMetadata(IReadOnlyCollection<AIToolTurnResult> toolResults) |
| | | 622 | | { |
| | 31 | 623 | | if (toolResults.Count == 0) |
| | 17 | 624 | | return null; |
| | | 625 | | |
| | 14 | 626 | | var toolCallIds = new JsonArray(); |
| | 56 | 627 | | foreach (var toolResult in toolResults) |
| | 14 | 628 | | toolCallIds.Add(toolResult.ToolCallId); |
| | | 629 | | |
| | 14 | 630 | | return new JsonObject |
| | 14 | 631 | | { |
| | 14 | 632 | | ["toolCallIds"] = toolCallIds |
| | 14 | 633 | | }; |
| | | 634 | | } |
| | | 635 | | |
| | | 636 | | private static bool HasReconnectUserMessage(AIConversation? conversation, string message) |
| | | 637 | | { |
| | 4 | 638 | | return conversation is { Status: AIConversationStatus.Active } && |
| | 4 | 639 | | HasUserMessage(conversation, message); |
| | | 640 | | } |
| | | 641 | | |
| | | 642 | | private static bool IsCompletedReconnect(AIConversation? conversation, string message) => |
| | 6 | 643 | | conversation is { Status: AIConversationStatus.Completed or AIConversationStatus.Failed } && HasUserMessage(conv |
| | | 644 | | |
| | | 645 | | private static bool HasUserMessage(AIConversation conversation, string message) => |
| | 10 | 646 | | conversation.Messages.Any(x => x.Role == AIMessageRole.User && string.Equals(NormalizeMessage(x.Content), Normal |
| | | 647 | | |
| | | 648 | | private static long GetNextSequence(IReadOnlyCollection<AIMessage> messages) => |
| | 14 | 649 | | messages.Count == 0 ? 0 : messages.Max(x => x.StreamSequence) + 1; |
| | | 650 | | |
| | | 651 | | private static List<AIToolTurnResult> RestoreToolResults(IEnumerable<AIMessage> messages) |
| | | 652 | | { |
| | 41 | 653 | | return messages |
| | 47 | 654 | | .Where(x => x.Role == AIMessageRole.Tool) |
| | 41 | 655 | | .Select(CreateToolTurnResult) |
| | 41 | 656 | | .OfType<AIToolTurnResult>() |
| | 41 | 657 | | .ToList(); |
| | | 658 | | } |
| | | 659 | | |
| | | 660 | | private static List<AIToolTurnResult> RestorePendingToolResults(IReadOnlyCollection<AIMessage> messages) |
| | | 661 | | { |
| | 3 | 662 | | return messages |
| | 3 | 663 | | .Reverse() |
| | 4 | 664 | | .TakeWhile(x => x.Role == AIMessageRole.Tool) |
| | 3 | 665 | | .Reverse() |
| | 3 | 666 | | .Select(CreateToolTurnResult) |
| | 3 | 667 | | .OfType<AIToolTurnResult>() |
| | 3 | 668 | | .ToList(); |
| | | 669 | | } |
| | | 670 | | |
| | | 671 | | private static IReadOnlyCollection<AIToolTurnResult> GetUnrepresentedToolResults(IReadOnlyCollection<AIToolTurnResul |
| | | 672 | | { |
| | 40 | 673 | | if (toolResults.Count == 0) |
| | 26 | 674 | | return []; |
| | | 675 | | |
| | 14 | 676 | | var representedToolCallIds = messages |
| | 84 | 677 | | .Where(x => x.Role == AIMessageRole.Tool) |
| | 35 | 678 | | .Select(x => x.Metadata["toolCallId"]?.GetValue<string>()) |
| | 35 | 679 | | .Where(x => !string.IsNullOrWhiteSpace(x)) |
| | 14 | 680 | | .ToHashSet(StringComparer.OrdinalIgnoreCase); |
| | | 681 | | |
| | 14 | 682 | | return toolResults |
| | 14 | 683 | | .Where(x => !representedToolCallIds.Contains(x.ToolCallId)) |
| | 14 | 684 | | .ToList(); |
| | | 685 | | } |
| | | 686 | | |
| | | 687 | | private static AIToolTurnResult? CreateToolTurnResult(AIMessage message) |
| | | 688 | | { |
| | 2 | 689 | | var toolCallId = message.Metadata["toolCallId"]?.GetValue<string>(); |
| | 2 | 690 | | var toolName = message.Metadata["toolName"]?.GetValue<string>(); |
| | | 691 | | |
| | 2 | 692 | | if (string.IsNullOrWhiteSpace(toolCallId) || string.IsNullOrWhiteSpace(toolName)) |
| | 0 | 693 | | return null; |
| | | 694 | | |
| | 2 | 695 | | var status = Enum.TryParse<AIToolInvocationStatus>(message.Metadata["status"]?.GetValue<string>(), out var parse |
| | 2 | 696 | | ? parsedStatus |
| | 2 | 697 | | : AIToolInvocationStatus.Completed; |
| | | 698 | | |
| | 2 | 699 | | return new AIToolTurnResult |
| | 2 | 700 | | { |
| | 2 | 701 | | ToolCallId = toolCallId, |
| | 2 | 702 | | ToolName = toolName, |
| | 2 | 703 | | Result = new AIToolResult |
| | 2 | 704 | | { |
| | 2 | 705 | | Status = status, |
| | 2 | 706 | | Summary = message.Content |
| | 2 | 707 | | } |
| | 2 | 708 | | }; |
| | | 709 | | } |
| | | 710 | | |
| | | 711 | | private static string NormalizeMessage(string message) => |
| | 16 | 712 | | message.ReplaceLineEndings("\n").Trim(); |
| | | 713 | | |
| | | 714 | | private static bool BelongsToTenant(AIConversation conversation, string? tenantId) => |
| | 10 | 715 | | string.Equals(NormalizeTenantId(conversation.TenantId), NormalizeTenantId(tenantId), StringComparison.Ordinal); |
| | | 716 | | |
| | | 717 | | private static bool BelongsToUser(AIConversation conversation, string userId) => |
| | 9 | 718 | | string.IsNullOrWhiteSpace(conversation.UserId) || string.Equals(conversation.UserId, userId, StringComparison.Or |
| | | 719 | | |
| | | 720 | | private static string NormalizeTenantId(string? tenantId) => |
| | 20 | 721 | | string.IsNullOrWhiteSpace(tenantId) ? "" : tenantId; |
| | | 722 | | |
| | | 723 | | private async ValueTask TrySaveConversationAsync(string conversationId, AIChatRequest request, AIConversationStatus |
| | | 724 | | { |
| | 96 | 725 | | if (!options.Value.ConversationPersistenceEnabled) |
| | 2 | 726 | | return; |
| | | 727 | | |
| | | 728 | | try |
| | | 729 | | { |
| | 94 | 730 | | var now = DateTimeOffset.UtcNow; |
| | 94 | 731 | | var retentionMode = conversation?.RetentionMode ?? AIRetentionMode.Configured; |
| | 94 | 732 | | DateTimeOffset? retentionExpiresAt = retentionMode == AIRetentionMode.Configured |
| | 94 | 733 | | ? now.Add(options.Value.ConversationRetention) |
| | 94 | 734 | | : null; |
| | | 735 | | |
| | 94 | 736 | | await conversationStore.SaveAsync(new AIConversation |
| | 94 | 737 | | { |
| | 94 | 738 | | Id = conversationId, |
| | 94 | 739 | | TenantId = request.TenantId, |
| | 94 | 740 | | UserId = request.UserId, |
| | 94 | 741 | | Title = conversation?.Title, |
| | 94 | 742 | | Status = status, |
| | 94 | 743 | | CreatedAt = conversation is null || conversation.CreatedAt == default ? now : conversation.CreatedAt, |
| | 94 | 744 | | UpdatedAt = now, |
| | 94 | 745 | | ProviderSessionId = providerSessionId ?? conversation?.ProviderSessionId, |
| | 94 | 746 | | RetentionMode = retentionMode, |
| | 94 | 747 | | RetentionExpiresAt = retentionExpiresAt, |
| | 94 | 748 | | Messages = messages.ToList() |
| | 94 | 749 | | }, cancellationToken); |
| | 92 | 750 | | } |
| | 2 | 751 | | catch (Exception e) when (e is not OperationCanceledException) |
| | | 752 | | { |
| | 2 | 753 | | logger.LogWarning(e, "Failed to persist AI conversation {ConversationId} with status {ConversationStatus}.", |
| | 2 | 754 | | } |
| | 96 | 755 | | } |
| | | 756 | | |
| | 185 | 757 | | private readonly record struct ToolCall(string Id, string Name, JsonObject Arguments); |
| | 98 | 758 | | private readonly record struct ToolExecutionResult(AIStreamEvent StreamEvent, AIToolTurnResult TurnResult); |
| | 82 | 759 | | private readonly record struct ProviderReadResult(AIProviderEvent? Event, Exception? Error); |
| | 108 | 760 | | private readonly record struct ProviderSelection(IAIProvider? Provider, AIProviderConfiguration? Configuration); |
| | | 761 | | } |