< Summary

Information
Class: Elsa.Persistence.EFCore.Extensions.BulkUpsertExtensions
Assembly: Elsa.Persistence.EFCore.Common
File(s): /home/runner/work/elsa-core/elsa-core/src/modules/Elsa.Persistence.EFCore.Common/Extensions/BulkUpsertExtensions.cs
Line coverage
25%
Covered lines: 54
Uncovered lines: 158
Coverable lines: 212
Total lines: 411
Line coverage: 25.4%
Branch coverage
23%
Covered branches: 28
Total branches: 118
Branch coverage: 23.7%
Method coverage

Feature is only available for sponsors

Upgrade to PRO version

Metrics

MethodBranch coverage Crap Score Cyclomatic complexity Line coverage
BulkUpsertAsync()100%11100%
BulkUpsertAsync()35.71%402875%
GenerateSqlServerUpsert(...)0%420200%
GenerateSqliteUpsert(...)0%272160%
GeneratePostgresUpsert(...)90%202097.56%
GenerateMySqlUpsert(...)0%272160%
GenerateOracleUpsert(...)0%342180%

File(s)

/home/runner/work/elsa-core/elsa-core/src/modules/Elsa.Persistence.EFCore.Common/Extensions/BulkUpsertExtensions.cs

#LineLine coverage
 1using System.Text;
 2using Microsoft.EntityFrameworkCore;
 3using Microsoft.EntityFrameworkCore.Infrastructure;
 4using Microsoft.EntityFrameworkCore.Metadata;
 5using System.Linq.Expressions;
 6
 7// ReSharper disable once CheckNamespace
 8namespace Elsa.Persistence.EFCore.Extensions;
 9
 10/// <summary>
 11/// Provides extension methods to perform bulk upsert operations for entities
 12/// in an Entity Framework Core context, supporting multiple database providers.
 13/// </summary>
 14public static class BulkUpsertExtensions
 15{
 16    /// <summary>
 17    /// Performs a bulk upsert operation on a list of entities in the specified database context using a key selector.
 18    /// </summary>
 19    /// <typeparam name="TDbContext">The type of the database context.</typeparam>
 20    /// <typeparam name="TEntity">The type of the entity being upserted.</typeparam>
 21    /// <param name="dbContext">The database context where the bulk upsert operation will be executed.</param>
 22    /// <param name="entities">The list of entities to be upserted.</param>
 23    /// <param name="keySelector">An expression used to determine the key for upsert operations.</param>
 24    /// <param name="cancellationToken">A token to observe while waiting for the operation to complete.</param>
 25    public static async Task BulkUpsertAsync<TDbContext, TEntity>(
 26        this TDbContext dbContext,
 27        IList<TEntity> entities,
 28        Expression<Func<TEntity, string>> keySelector,
 29        CancellationToken cancellationToken = default)
 30        where TDbContext : DbContext
 31        where TEntity : class, new()
 32    {
 91033        await BulkUpsertAsync(dbContext, entities, keySelector, 50, cancellationToken);
 91034    }
 35
 36    /// <summary>
 37    /// Performs a bulk upsert operation on a list of entities in the specified database context using a key selector an
 38    /// </summary>
 39    /// <typeparam name="TDbContext">The type of the database context.</typeparam>
 40    /// <typeparam name="TEntity">The type of the entity being upserted.</typeparam>
 41    /// <param name="dbContext">The database context where the bulk upsert operation will be executed.</param>
 42    /// <param name="entities">The list of entities to be upserted.</param>
 43    /// <param name="keySelector">An expression used to determine the key for upsert operations.</param>
 44    /// <param name="batchSize">The size of each batch for processing the upsert operation. Defaults to 50.</param>
 45    /// <param name="cancellationToken">A token to observe while waiting for the operation to complete.</param>
 46    /// <exception cref="NotSupportedException">Thrown if the database provider for the context is not supported.</excep
 47    public static async Task BulkUpsertAsync<TDbContext, TEntity>(
 48        this TDbContext dbContext,
 49        IList<TEntity> entities,
 50        Expression<Func<TEntity, string>> keySelector,
 51        int batchSize = 50,
 52        CancellationToken cancellationToken = default)
 53        where TDbContext : DbContext
 54        where TEntity : class, new()
 55    {
 91056        if (entities.Count == 0)
 057            return;
 58
 59        // Identify the current provider (e.g., "Microsoft.EntityFrameworkCore.SqlServer")
 91060        var providerName = dbContext.Database.ProviderName?.ToLowerInvariant() ?? string.Empty;
 61
 62        // Determine the method for generating SQL based on the provider
 91063        Func<DbContext, IList<TEntity>, Expression<Func<TEntity, string>>, (string, object[])> generateSql = providerNam
 91064        {
 91065            var pn when pn.Contains("sqlserver") => GenerateSqlServerUpsert,
 91066            var pn when pn.Contains("sqlite") => GenerateSqliteUpsert,
 182067            var pn when pn.Contains("postgres") => GeneratePostgresUpsert,
 068            var pn when pn.Contains("mysql") => GenerateMySqlUpsert,
 069            var pn when pn.Contains("oracle") => GenerateOracleUpsert,
 070            _ => throw new NotSupportedException($"Provider '{providerName}' is not supported.")
 91071        };
 72
 73        // Loop through batched entities
 364074        foreach (var batch in entities.Chunk(batchSize))
 75        {
 76            // Generate SQL and parameters
 91077            var (sql, parameters) = generateSql(dbContext, batch, keySelector);
 78
 91079            await dbContext.Database.ExecuteSqlRawAsync(sql, parameters, cancellationToken);
 80        }
 91081    }
 82
 83    private static (string, object[]) GenerateSqlServerUpsert<TEntity>(
 84        DbContext dbContext,
 85        IList<TEntity> entities,
 86        Expression<Func<TEntity, string>> keySelector)
 87        where TEntity : class
 88    {
 089        var entityType = dbContext.Model.FindEntityType(typeof(TEntity))!;
 090        var tableName = $"[{entityType.GetSchema()}].[{entityType.GetTableName()}]";
 091        var storeObject = StoreObjectIdentifier.Table(entityType.GetTableName()!, entityType.GetSchema());
 092        var props = entityType.GetProperties().ToList();
 093        var keyProp = entityType.FindProperty(keySelector.GetMemberAccess().Name)!;
 094        var keyColumnName = $"[{keyProp.GetColumnName(storeObject)}]";
 095        var columnNames = props
 096            .Select(p => $"[{p.GetColumnName(storeObject)}]")
 097            .ToList();
 98
 099        var mergeSql = new StringBuilder();
 0100        mergeSql.AppendLine($"MERGE {tableName} AS Target");
 0101        mergeSql.AppendLine("USING (VALUES");
 102
 0103        var parameters = new List<object>();
 0104        var parameterCount = 0;
 105
 0106        for (var i = 0; i < entities.Count; i++)
 107        {
 0108            var entity = entities[i];
 0109            var values = new List<string>();
 110
 0111            foreach (var property in props)
 112            {
 0113                var paramName = $"{{{parameterCount++}}}";
 114
 115                // If it's a shadow property, retrieve value via Entry(..).Property(..)
 0116                var value = property.IsShadowProperty()
 0117                    ? dbContext.Entry(entity).Property(property.Name).CurrentValue
 0118                    : property.PropertyInfo?.GetValue(entity);
 119
 0120                var converter = property.GetTypeMapping().Converter;
 0121                if (converter != null)
 0122                    value = converter.ConvertToProvider(value)!;
 123
 124                // Explicitly cast null values for varbinary columns
 0125                if (property.GetColumnType().StartsWith("varbinary", StringComparison.OrdinalIgnoreCase) && value is nul
 0126                    values.Add("CAST(NULL AS varbinary(max))"); // Explicitly cast null
 127                else
 0128                    values.Add(paramName);
 129
 0130                parameters.Add(value!);
 131            }
 132
 0133            var line = $"({string.Join(", ", values)}){(i < entities.Count - 1 ? "," : string.Empty)}";
 0134            mergeSql.AppendLine(line);
 135        }
 136
 0137        mergeSql.AppendLine($") AS Source ({string.Join(", ", columnNames)})");
 0138        mergeSql.AppendLine($"ON Target.{keyColumnName} = Source.{keyColumnName}");
 0139        mergeSql.AppendLine("WHEN MATCHED THEN");
 0140        mergeSql.AppendLine($"    UPDATE SET {string.Join(", ", columnNames.Where(c => c != keyColumnName).Select(c => $
 0141        mergeSql.AppendLine("WHEN NOT MATCHED THEN");
 0142        mergeSql.AppendLine($"    INSERT ({string.Join(", ", columnNames)})");
 0143        mergeSql.AppendLine($"    VALUES ({string.Join(", ", columnNames.Select(c => $"Source.{c}"))});");
 144
 0145        return (mergeSql.ToString(), parameters.ToArray());
 146    }
 147
 148    private static (string, object[]) GenerateSqliteUpsert<TEntity>(
 149        DbContext dbContext,
 150        IList<TEntity> entities,
 151        Expression<Func<TEntity, string>> keySelector)
 152        where TEntity : class
 153    {
 0154        var entityType = dbContext.Model.FindEntityType(typeof(TEntity))!;
 0155        var tableName = entityType.GetTableName();
 0156        var storeObject = StoreObjectIdentifier.Table(tableName!, entityType.GetSchema());
 0157        var props = entityType.GetProperties().ToList();
 0158        var keyProp = entityType.FindProperty(keySelector.GetMemberAccess().Name)!;
 0159        var keyColumnName = keyProp.GetColumnName(storeObject);
 0160        var columnNames = props
 0161            .Select(p => p.GetColumnName(storeObject)!)
 0162            .ToList();
 163
 0164        var sb = new StringBuilder();
 0165        var parameters = new List<object>();
 0166        var parameterCount = 0;
 167
 0168        sb.Append($"INSERT INTO \"{tableName}\" ({string.Join(", ", columnNames.Select(c => $"\"{c}\""))}) VALUES ");
 169
 0170        for (var i = 0; i < entities.Count; i++)
 171        {
 0172            var entity = entities[i];
 0173            var placeholders = new List<string>();
 174
 0175            foreach (var property in props)
 176            {
 0177                var paramName = $"{{{parameterCount++}}}";
 178
 0179                var value = property.IsShadowProperty()
 0180                    ? dbContext.Entry(entity).Property(property.Name).CurrentValue
 0181                    : property.PropertyInfo?.GetValue(entity);
 182
 0183                var converter = property.GetTypeMapping().Converter;
 0184                if (converter != null)
 0185                    value = converter.ConvertToProvider(value);
 186
 0187                placeholders.Add(paramName);
 0188                parameters.Add(value!);
 189            }
 190
 0191            sb.Append($"({string.Join(", ", placeholders)})");
 0192            if (i < entities.Count - 1)
 0193                sb.Append(", ");
 194        }
 195
 0196        sb.AppendLine();
 0197        sb.AppendLine($"ON CONFLICT(\"{keyColumnName}\") DO UPDATE SET");
 198
 0199        var updateAssignments = columnNames
 0200            .Where(c => c != keyColumnName)
 0201            .Select(c => $"\"{c}\"=excluded.\"{c}\"");
 202
 0203        sb.AppendLine(string.Join(", ", updateAssignments) + ";");
 204
 0205        return (sb.ToString(), parameters.ToArray());
 206    }
 207
 208    private static (string, object[]) GeneratePostgresUpsert<TEntity>(
 209        DbContext dbContext,
 210        IList<TEntity> entities,
 211        Expression<Func<TEntity, string>> keySelector)
 212        where TEntity : class
 213    {
 910214        var entityType = dbContext.Model.FindEntityType(typeof(TEntity))!;
 910215        var tableName = entityType.GetTableName();
 910216        var storeObject = StoreObjectIdentifier.Table(tableName!, entityType.GetSchema());
 217
 910218        var props = entityType.GetProperties().ToList();
 219
 910220        var keyProp = entityType.FindProperty(keySelector.GetMemberAccess().Name)!;
 910221        var keyColumnName = keyProp.GetColumnName(storeObject);
 910222        var columnNames = props
 18044223            .Select(p => p.GetColumnName(storeObject)!)
 910224            .ToList();
 225
 910226        var sb = new StringBuilder();
 910227        var parameters = new List<object>();
 910228        var parameterCount = 0;
 229
 18954230        sb.Append($"INSERT INTO \"{storeObject.Schema}\".\"{storeObject.Name}\" ({string.Join(", ", columnNames.Select(c
 231
 7962232        for (var i = 0; i < entities.Count; i++)
 233        {
 3071234            var entity = entities[i];
 3071235            var placeholders = new List<string>();
 236
 128670237            foreach (var property in props)
 238            {
 61264239                var paramName = $"{{{parameterCount++}}}";
 240
 61264241                var value = property.IsShadowProperty()
 61264242                    ? dbContext.Entry(entity).Property(property.Name).CurrentValue
 61264243                    : property.PropertyInfo?.GetValue(entity);
 244
 61264245                var converter = property.GetTypeMapping().Converter;
 61264246                if (converter != null)
 3058247                    value = converter.ConvertToProvider(value);
 248
 249                // Detect json/jsonb column types and cast the parameter so PostgreSQL accepts it.
 61264250                var columnType = property.GetColumnType();
 61264251                if (columnType.StartsWith("jsonb", StringComparison.OrdinalIgnoreCase))
 577252                    placeholders.Add($"CAST({paramName} AS jsonb)");
 60687253                else if (columnType.StartsWith("json", StringComparison.OrdinalIgnoreCase))
 0254                    placeholders.Add($"CAST({paramName} AS json)");
 255                else
 60687256                    placeholders.Add(paramName);
 257
 61264258                parameters.Add(value!);
 259            }
 260
 3071261            sb.Append($"({string.Join(", ", placeholders)})");
 3071262            if (i < entities.Count - 1)
 2161263                sb.Append(", ");
 264        }
 265
 910266        sb.AppendLine();
 910267        sb.AppendLine($"ON CONFLICT (\"{keyColumnName}\") DO UPDATE SET");
 268
 910269        var updateAssignments = columnNames
 18044270            .Where(c => c != keyColumnName)
 18044271            .Select(c => $"\"{c}\" = EXCLUDED.\"{c}\"");
 272
 910273        sb.AppendLine(string.Join(", ", updateAssignments) + ";");
 274
 910275        return (sb.ToString(), parameters.ToArray());
 276    }
 277
 278    private static (string, object[]) GenerateMySqlUpsert<TEntity>(
 279        DbContext dbContext,
 280        IList<TEntity> entities,
 281        Expression<Func<TEntity, string>> keySelector)
 282        where TEntity : class
 283    {
 0284        var entityType = dbContext.Model.FindEntityType(typeof(TEntity))!;
 0285        var tableName = entityType.GetTableName();
 0286        var storeObject = StoreObjectIdentifier.Table(tableName!, entityType.GetSchema());
 287
 0288        var props = entityType.GetProperties().ToList();
 289
 0290        var keyProp = entityType.FindProperty(keySelector.GetMemberAccess().Name)!;
 0291        var keyColumnName = keyProp.GetColumnName(storeObject);
 0292        var columnNames = props
 0293            .Select(p => p.GetColumnName(storeObject)!)
 0294            .ToList();
 295
 0296        var sb = new StringBuilder();
 0297        var parameters = new List<object>();
 0298        var parameterCount = 0;
 299
 0300        sb.Append($"INSERT INTO `{tableName}` ({string.Join(", ", columnNames.Select(c => $"`{c}`"))}) VALUES ");
 301
 0302        for (var i = 0; i < entities.Count; i++)
 303        {
 0304            var entity = entities[i];
 0305            var placeholders = new List<string>();
 306
 0307            foreach (var property in props)
 308            {
 0309                var paramName = $"{{{parameterCount++}}}";
 310
 0311                var value = property.IsShadowProperty()
 0312                    ? dbContext.Entry(entity).Property(property.Name).CurrentValue
 0313                    : property.PropertyInfo?.GetValue(entity);
 314
 0315                var converter = property.GetTypeMapping().Converter;
 0316                if (converter != null)
 0317                    value = converter.ConvertToProvider(value);
 318
 0319                placeholders.Add(paramName);
 0320                parameters.Add(value!);
 321            }
 322
 0323            sb.Append($"({string.Join(", ", placeholders)})");
 0324            if (i < entities.Count - 1)
 0325                sb.Append(", ");
 326        }
 327
 0328        sb.AppendLine();
 0329        sb.AppendLine("ON DUPLICATE KEY UPDATE");
 330
 0331        var updateAssignments = columnNames
 0332            .Where(c => c != keyColumnName)
 0333            .Select(c => $"`{c}` = VALUES(`{c}`)");
 334
 0335        sb.AppendLine(string.Join(", ", updateAssignments) + ";");
 336
 0337        return (sb.ToString(), parameters.ToArray());
 338    }
 339
 340    private static (string, object[]) GenerateOracleUpsert<TEntity>(
 341        DbContext dbContext,
 342        IList<TEntity> entities,
 343        Expression<Func<TEntity, string>> keySelector)
 344        where TEntity : class
 345    {
 0346        var entityType = dbContext.Model.FindEntityType(typeof(TEntity))!;
 0347        var schema = entityType.GetSchema();
 0348        var tableName = entityType.GetTableName();
 0349        var storeObject = StoreObjectIdentifier.Table(tableName!, schema);
 0350        var fullName = !string.IsNullOrEmpty(schema) ? $"{schema}.{tableName}" : tableName;
 351
 0352        var props = entityType.GetProperties().ToList();
 353
 0354        var keyProp = entityType.FindProperty(keySelector.GetMemberAccess().Name)!;
 0355        var keyColumnName = keyProp.GetColumnName(storeObject);
 356
 0357        var columnNames = props
 0358            .Select(p => p.GetColumnName(storeObject)!)
 0359            .ToList();
 360
 0361        var sb = new StringBuilder();
 0362        var parameters = new List<object>();
 0363        var parameterCount = 0;
 364
 0365        sb.AppendLine($"MERGE INTO {fullName} Target");
 0366        sb.AppendLine("USING (SELECT");
 367
 0368        for (var i = 0; i < entities.Count; i++)
 369        {
 0370            var entity = entities[i];
 0371            var lineParts = new List<string>();
 372
 0373            foreach (var property in props)
 374            {
 0375                var paramName = $"{{{parameterCount++}}}";
 376
 0377                var value = property.IsShadowProperty()
 0378                    ? dbContext.Entry(entity).Property(property.Name).CurrentValue
 0379                    : property.PropertyInfo?.GetValue(entity);
 380
 0381                var converter = property.GetTypeMapping().Converter;
 0382                if (converter != null)
 0383                    value = converter.ConvertToProvider(value);
 384
 0385                parameters.Add(value!);
 386
 387                // Oracle aliases must match the column name
 0388                var alias = property.GetColumnName(storeObject);
 0389                lineParts.Add($"{paramName} AS {alias}");
 390            }
 391
 392            // Comma if not last
 0393            var suffix = (i < entities.Count - 1) ? " FROM DUAL UNION ALL SELECT" : " FROM DUAL";
 0394            sb.AppendLine(string.Join(", ", lineParts) + suffix);
 395        }
 396
 0397        sb.AppendLine($") Source ON (Target.{keyColumnName} = Source.{keyColumnName})");
 0398        sb.AppendLine("WHEN MATCHED THEN UPDATE SET");
 399
 0400        var updateSetClauses = columnNames
 0401            .Where(c => c != keyColumnName)
 0402            .Select(c => $"Target.{c} = Source.{c}");
 403
 0404        sb.AppendLine(string.Join(", ", updateSetClauses));
 0405        sb.AppendLine("WHEN NOT MATCHED THEN");
 0406        sb.AppendLine($"INSERT ({string.Join(", ", columnNames)})");
 0407        sb.AppendLine($"VALUES ({string.Join(", ", columnNames.Select(c => $"Source.{c}"))});");
 408
 0409        return (sb.ToString(), parameters.ToArray());
 410    }
 411}