| | | 1 | | using System.Collections.Concurrent; |
| | | 2 | | using Elsa.AI.Abstractions.Contracts; |
| | | 3 | | using Elsa.AI.Abstractions.Models; |
| | | 4 | | |
| | | 5 | | namespace Elsa.AI.Host.Services; |
| | | 6 | | |
| | | 7 | | public class InMemoryAIConversationStore : IAITransientConversationStore |
| | | 8 | | { |
| | 48 | 9 | | private readonly ConcurrentDictionary<string, AIConversation> _conversations = new(StringComparer.OrdinalIgnoreCase) |
| | | 10 | | |
| | | 11 | | public ValueTask<AIConversation?> FindAsync(string id, CancellationToken cancellationToken = default) |
| | | 12 | | { |
| | 58 | 13 | | _conversations.TryGetValue(id, out var conversation); |
| | 58 | 14 | | if (conversation == null || !IsExpired(conversation)) |
| | 56 | 15 | | return ValueTask.FromResult(conversation); |
| | | 16 | | |
| | 2 | 17 | | _conversations.TryRemove(id, out _); |
| | 2 | 18 | | conversation = null; |
| | | 19 | | |
| | 2 | 20 | | return ValueTask.FromResult(conversation); |
| | | 21 | | } |
| | | 22 | | |
| | | 23 | | public ValueTask SaveAsync(AIConversation conversation, CancellationToken cancellationToken = default) |
| | | 24 | | { |
| | 109 | 25 | | Validate(conversation); |
| | 108 | 26 | | PruneExpired(); |
| | 108 | 27 | | _conversations.AddOrUpdate( |
| | 108 | 28 | | conversation.Id, |
| | 108 | 29 | | conversation, |
| | 108 | 30 | | (_, existing) => |
| | 108 | 31 | | { |
| | 61 | 32 | | ValidateOwnership(existing, conversation); |
| | 59 | 33 | | return conversation; |
| | 108 | 34 | | }); |
| | | 35 | | |
| | 106 | 36 | | return ValueTask.CompletedTask; |
| | | 37 | | } |
| | | 38 | | |
| | | 39 | | private void PruneExpired() |
| | | 40 | | { |
| | 216 | 41 | | foreach (var conversation in _conversations.Values.Where(IsExpired)) |
| | 0 | 42 | | _conversations.TryRemove(conversation.Id, out _); |
| | 108 | 43 | | } |
| | | 44 | | |
| | | 45 | | private bool IsExpired(AIConversation conversation) |
| | | 46 | | { |
| | 93 | 47 | | if (conversation.RetentionMode == AIRetentionMode.Ephemeral) |
| | 2 | 48 | | return conversation.Status is AIConversationStatus.Completed or AIConversationStatus.Failed; |
| | | 49 | | |
| | 91 | 50 | | if (conversation.RetentionMode == AIRetentionMode.Durable) |
| | 0 | 51 | | return false; |
| | | 52 | | |
| | 91 | 53 | | var expiresAt = conversation.RetentionExpiresAt; |
| | 91 | 54 | | return expiresAt.HasValue && expiresAt <= DateTimeOffset.UtcNow; |
| | | 55 | | } |
| | | 56 | | |
| | | 57 | | private static void ValidateOwnership(AIConversation existing, AIConversation conversation) |
| | | 58 | | { |
| | 61 | 59 | | if (!string.Equals(NormalizeTenantId(existing.TenantId), NormalizeTenantId(conversation.TenantId), StringCompari |
| | 1 | 60 | | throw new InvalidOperationException("Cannot overwrite an AI conversation that belongs to another tenant."); |
| | | 61 | | |
| | 60 | 62 | | if (!string.IsNullOrWhiteSpace(existing.UserId) && !string.Equals(existing.UserId, conversation.UserId, StringCo |
| | 1 | 63 | | throw new InvalidOperationException("Cannot overwrite an AI conversation that belongs to another user."); |
| | 59 | 64 | | } |
| | | 65 | | |
| | 122 | 66 | | private static string NormalizeTenantId(string? tenantId) => tenantId ?? ""; |
| | | 67 | | |
| | | 68 | | private static void Validate(AIConversation conversation) |
| | | 69 | | { |
| | 109 | 70 | | if (string.IsNullOrWhiteSpace(conversation.Id)) |
| | 0 | 71 | | throw new ArgumentException("A conversation ID is required.", nameof(conversation)); |
| | | 72 | | |
| | 109 | 73 | | if (string.IsNullOrWhiteSpace(conversation.UserId)) |
| | 1 | 74 | | throw new ArgumentException("A conversation user ID is required.", nameof(conversation)); |
| | 108 | 75 | | } |
| | | 76 | | } |