| | | 1 | | using System.Text; |
| | | 2 | | using System.Text.Json; |
| | | 3 | | using System.Text.Json.Nodes; |
| | | 4 | | using Elsa.AI.Abstractions.Contracts; |
| | | 5 | | using Elsa.AI.Abstractions.Models; |
| | | 6 | | using Elsa.AI.Persistence.EFCore.Entities; |
| | | 7 | | using Microsoft.EntityFrameworkCore; |
| | | 8 | | |
| | | 9 | | namespace Elsa.AI.Persistence.EFCore.Stores; |
| | | 10 | | |
| | 19 | 11 | | public class EFCoreAIConversationStore(AIDbContext dbContext) : IAIConversationStore |
| | | 12 | | { |
| | | 13 | | private const int MaxStoredMessages = 256; |
| | | 14 | | private const int MaxMessagesJsonBytes = 1024 * 1024; |
| | | 15 | | |
| | | 16 | | public async ValueTask<AIConversation?> FindAsync(string id, CancellationToken cancellationToken = default) |
| | | 17 | | { |
| | 13 | 18 | | var record = await dbContext.Conversations.AsNoTracking().FirstOrDefaultAsync(x => x.Id == id, cancellationToken |
| | 13 | 19 | | if (record == null) |
| | 0 | 20 | | return null; |
| | | 21 | | |
| | 13 | 22 | | var conversation = Map(record); |
| | 13 | 23 | | if (!IsExpired(conversation)) |
| | 11 | 24 | | return conversation; |
| | | 25 | | |
| | 2 | 26 | | return null; |
| | 13 | 27 | | } |
| | | 28 | | |
| | | 29 | | public async ValueTask SaveAsync(AIConversation conversation, CancellationToken cancellationToken = default) |
| | | 30 | | { |
| | 22 | 31 | | Validate(conversation); |
| | 21 | 32 | | var isNew = false; |
| | 21 | 33 | | var record = await dbContext.Conversations.FindAsync([conversation.Id], cancellationToken); |
| | 21 | 34 | | if (record == null) |
| | | 35 | | { |
| | 16 | 36 | | record = new AIConversationRecord { Id = conversation.Id }; |
| | 16 | 37 | | dbContext.Conversations.Add(record); |
| | 16 | 38 | | isNew = true; |
| | | 39 | | } |
| | 5 | 40 | | else if (!BelongsToTenant(record.TenantId, conversation.TenantId)) |
| | | 41 | | { |
| | 1 | 42 | | throw new InvalidOperationException("Cannot overwrite an AI conversation that belongs to another tenant."); |
| | | 43 | | } |
| | | 44 | | else |
| | | 45 | | { |
| | 4 | 46 | | ValidateUserOwnership(record, conversation); |
| | | 47 | | } |
| | | 48 | | |
| | 19 | 49 | | Map(conversation, record); |
| | | 50 | | |
| | | 51 | | try |
| | | 52 | | { |
| | 19 | 53 | | await dbContext.SaveChangesAsync(cancellationToken); |
| | 19 | 54 | | } |
| | 0 | 55 | | catch (DbUpdateException e) when (isNew) |
| | | 56 | | { |
| | 0 | 57 | | await RetryAsUpdateAsync(conversation, e, cancellationToken); |
| | | 58 | | } |
| | 19 | 59 | | } |
| | | 60 | | |
| | | 61 | | private async ValueTask RetryAsUpdateAsync(AIConversation conversation, DbUpdateException originalException, Cancell |
| | | 62 | | { |
| | 0 | 63 | | dbContext.ChangeTracker.Clear(); |
| | 0 | 64 | | var record = await dbContext.Conversations.FindAsync([conversation.Id], cancellationToken); |
| | 0 | 65 | | if (record == null) |
| | 0 | 66 | | throw new DbUpdateException($"Failed to insert AI conversation {conversation.Id}, and no existing record was |
| | | 67 | | |
| | 0 | 68 | | if (!BelongsToTenant(record.TenantId, conversation.TenantId)) |
| | 0 | 69 | | throw new InvalidOperationException("Cannot overwrite an AI conversation that belongs to another tenant."); |
| | | 70 | | |
| | 0 | 71 | | ValidateUserOwnership(record, conversation); |
| | | 72 | | |
| | 0 | 73 | | Map(conversation, record); |
| | 0 | 74 | | await dbContext.SaveChangesAsync(cancellationToken); |
| | 0 | 75 | | } |
| | | 76 | | |
| | | 77 | | private static AIConversation Map(AIConversationRecord record) => |
| | 13 | 78 | | new() |
| | 13 | 79 | | { |
| | 13 | 80 | | Id = record.Id, |
| | 13 | 81 | | TenantId = record.TenantId, |
| | 13 | 82 | | UserId = record.UserId, |
| | 13 | 83 | | Title = record.Title, |
| | 13 | 84 | | Status = ParseEnum(record.Status, AIConversationStatus.Active), |
| | 13 | 85 | | CreatedAt = record.CreatedAt, |
| | 13 | 86 | | UpdatedAt = record.UpdatedAt, |
| | 13 | 87 | | ProviderSessionId = record.ProviderSessionId, |
| | 13 | 88 | | RetentionMode = ParseEnum(record.RetentionMode, AIRetentionMode.Configured), |
| | 13 | 89 | | RetentionExpiresAt = record.RetentionExpiresAt, |
| | 13 | 90 | | Messages = JsonSerializer.Deserialize<IReadOnlyCollection<AIMessage>>(record.Messages) ?? [] |
| | 13 | 91 | | }; |
| | | 92 | | |
| | | 93 | | private static void Map(AIConversation conversation, AIConversationRecord record) |
| | | 94 | | { |
| | 19 | 95 | | record.TenantId = NormalizeTenantId(conversation.TenantId); |
| | 19 | 96 | | record.UserId = conversation.UserId; |
| | 19 | 97 | | record.Title = conversation.Title; |
| | 19 | 98 | | record.Status = conversation.Status.ToString(); |
| | 19 | 99 | | if (record.CreatedAt == default) |
| | 16 | 100 | | record.CreatedAt = conversation.CreatedAt; |
| | | 101 | | |
| | 19 | 102 | | record.UpdatedAt = conversation.UpdatedAt; |
| | 19 | 103 | | record.ProviderSessionId = conversation.ProviderSessionId; |
| | 19 | 104 | | record.RetentionMode = conversation.RetentionMode.ToString(); |
| | 19 | 105 | | record.RetentionExpiresAt = conversation.RetentionExpiresAt; |
| | | 106 | | |
| | 19 | 107 | | record.Messages = SerializeMessages(conversation.Messages); |
| | 19 | 108 | | } |
| | | 109 | | |
| | | 110 | | private static string SerializeMessages(IReadOnlyCollection<AIMessage> messages) |
| | | 111 | | { |
| | 19 | 112 | | var orderedMessages = messages |
| | 600 | 113 | | .OrderBy(x => x.CreatedAt) |
| | 600 | 114 | | .ThenBy(x => x.StreamSequence) |
| | 19 | 115 | | .ToList(); |
| | 19 | 116 | | var boundedMessages = orderedMessages.Count > MaxStoredMessages |
| | 19 | 117 | | ? orderedMessages.Skip(orderedMessages.Count - MaxStoredMessages).ToList() |
| | 19 | 118 | | : orderedMessages; |
| | 19 | 119 | | var json = JsonSerializer.Serialize(boundedMessages); |
| | | 120 | | |
| | 19 | 121 | | if (boundedMessages.Count > 1 && Encoding.UTF8.GetByteCount(json) > MaxMessagesJsonBytes) |
| | 0 | 122 | | (boundedMessages, json) = ShrinkMessagesToByteLimit(boundedMessages); |
| | | 123 | | |
| | 19 | 124 | | if (boundedMessages.Count == 1 && Encoding.UTF8.GetByteCount(json) > MaxMessagesJsonBytes) |
| | 2 | 125 | | json = SerializeSingleTruncatedMessage(boundedMessages[0]); |
| | | 126 | | |
| | 19 | 127 | | return json; |
| | | 128 | | } |
| | | 129 | | |
| | | 130 | | private static string SerializeSingleTruncatedMessage(AIMessage message) |
| | | 131 | | { |
| | 2 | 132 | | var candidateLength = Math.Min(message.Content.Length, MaxMessagesJsonBytes / 4); |
| | | 133 | | |
| | 40 | 134 | | while (true) |
| | | 135 | | { |
| | 42 | 136 | | candidateLength = NormalizeSliceLength(message.Content, candidateLength); |
| | 42 | 137 | | var candidateContent = message.Content[..candidateLength]; |
| | 42 | 138 | | var candidateJson = JsonSerializer.Serialize(new[] { CreateTruncatedMessage(message, candidateContent) }); |
| | | 139 | | |
| | 42 | 140 | | if (Encoding.UTF8.GetByteCount(candidateJson) <= MaxMessagesJsonBytes) |
| | 1 | 141 | | return candidateJson; |
| | | 142 | | |
| | 41 | 143 | | if (candidateLength == 0) |
| | 1 | 144 | | return JsonSerializer.Serialize(new[] { CreateTruncatedMessage(message, "", preserveMetadata: false) }); |
| | | 145 | | |
| | 40 | 146 | | candidateLength -= Math.Max(1, candidateLength / 10); |
| | | 147 | | } |
| | | 148 | | } |
| | | 149 | | |
| | | 150 | | private static AIMessage CreateTruncatedMessage(AIMessage message, string content, bool preserveMetadata = true) |
| | | 151 | | { |
| | 43 | 152 | | var metadata = preserveMetadata ? message.Metadata.DeepClone().AsObject() : []; |
| | 43 | 153 | | metadata["truncated"] = true; |
| | 43 | 154 | | metadata["maxBytes"] = MaxMessagesJsonBytes; |
| | | 155 | | |
| | 43 | 156 | | return message with |
| | 43 | 157 | | { |
| | 43 | 158 | | Content = content, |
| | 43 | 159 | | Metadata = metadata |
| | 43 | 160 | | }; |
| | | 161 | | } |
| | | 162 | | |
| | | 163 | | private static int NormalizeSliceLength(string value, int length) |
| | | 164 | | { |
| | 42 | 165 | | if (length > 0 && char.IsHighSurrogate(value[length - 1])) |
| | 0 | 166 | | length--; |
| | | 167 | | |
| | 42 | 168 | | return length; |
| | | 169 | | } |
| | | 170 | | |
| | | 171 | | private static (List<AIMessage> Messages, string Json) ShrinkMessagesToByteLimit(List<AIMessage> messages) |
| | | 172 | | { |
| | 0 | 173 | | var low = 1; |
| | 0 | 174 | | var high = messages.Count; |
| | 0 | 175 | | var bestMessages = messages.Skip(messages.Count - 1).ToList(); |
| | 0 | 176 | | var bestJson = JsonSerializer.Serialize(bestMessages); |
| | | 177 | | |
| | 0 | 178 | | while (low < high) |
| | | 179 | | { |
| | 0 | 180 | | var candidateCount = (low + high + 1) / 2; |
| | 0 | 181 | | var candidateMessages = messages.Skip(messages.Count - candidateCount).ToList(); |
| | 0 | 182 | | var candidateJson = JsonSerializer.Serialize(candidateMessages); |
| | | 183 | | |
| | 0 | 184 | | if (Encoding.UTF8.GetByteCount(candidateJson) <= MaxMessagesJsonBytes) |
| | | 185 | | { |
| | 0 | 186 | | low = candidateCount; |
| | 0 | 187 | | bestMessages = candidateMessages; |
| | 0 | 188 | | bestJson = candidateJson; |
| | | 189 | | } |
| | | 190 | | else |
| | | 191 | | { |
| | 0 | 192 | | high = candidateCount - 1; |
| | | 193 | | } |
| | | 194 | | } |
| | | 195 | | |
| | 0 | 196 | | return (bestMessages, bestJson); |
| | | 197 | | } |
| | | 198 | | |
| | | 199 | | private static bool IsExpired(AIConversation conversation) |
| | | 200 | | { |
| | 13 | 201 | | if (conversation.RetentionMode == AIRetentionMode.Ephemeral) |
| | 1 | 202 | | return conversation.Status is AIConversationStatus.Completed or AIConversationStatus.Failed; |
| | | 203 | | |
| | 12 | 204 | | if (conversation.RetentionMode == AIRetentionMode.Durable) |
| | 0 | 205 | | return false; |
| | | 206 | | |
| | 12 | 207 | | var expiresAt = conversation.RetentionExpiresAt; |
| | 12 | 208 | | if (expiresAt == null) |
| | 8 | 209 | | return false; |
| | | 210 | | |
| | 4 | 211 | | return expiresAt <= DateTimeOffset.UtcNow; |
| | | 212 | | } |
| | | 213 | | |
| | | 214 | | private static TEnum ParseEnum<TEnum>(string value, TEnum defaultValue) where TEnum : struct => |
| | 26 | 215 | | Enum.TryParse<TEnum>(value, ignoreCase: true, out var result) ? result : defaultValue; |
| | | 216 | | |
| | | 217 | | private static bool BelongsToTenant(string? storedTenantId, string? requestedTenantId) => |
| | 5 | 218 | | string.Equals(NormalizeTenantId(storedTenantId), NormalizeTenantId(requestedTenantId), StringComparison.Ordinal) |
| | | 219 | | |
| | 29 | 220 | | private static string NormalizeTenantId(string? tenantId) => tenantId ?? ""; |
| | | 221 | | |
| | | 222 | | private static void ValidateUserOwnership(AIConversationRecord record, AIConversation conversation) |
| | | 223 | | { |
| | 4 | 224 | | if (!string.IsNullOrWhiteSpace(record.UserId) && !string.Equals(record.UserId, conversation.UserId, StringCompar |
| | 1 | 225 | | throw new InvalidOperationException("Cannot overwrite an AI conversation that belongs to another user."); |
| | 3 | 226 | | } |
| | | 227 | | |
| | | 228 | | private static void Validate(AIConversation conversation) |
| | | 229 | | { |
| | 22 | 230 | | if (string.IsNullOrWhiteSpace(conversation.Id)) |
| | 0 | 231 | | throw new ArgumentException("A conversation ID is required.", nameof(conversation)); |
| | | 232 | | |
| | 22 | 233 | | if (string.IsNullOrWhiteSpace(conversation.UserId)) |
| | 1 | 234 | | throw new ArgumentException("A conversation user ID is required.", nameof(conversation)); |
| | 21 | 235 | | } |
| | | 236 | | } |