using FastExpressionCompiler; using Microsoft.EntityFrameworkCore; using System.ComponentModel.DataAnnotations; using System.Linq.Expressions; using System.Reflection; namespace Masuit.Tools.Core.AspNetCore; public static class DbSetExtensions { /// /// 添加或更新 /// /// /// 按哪个字段更新 /// /// 按哪个字段更新 /// public static void AddOrUpdate(this DbSet dbSet, Expression> keySelector, params T[] entities) where T : class { AddOrUpdate(dbSet, keySelector, entities.AsEnumerable()); } /// /// 添加或更新 /// /// /// 按哪个字段更新 /// /// 按哪个字段更新 /// /// 是否忽略导航属性 public static void AddOrUpdate(this DbSet dbSet, Expression> keySelector, IEnumerable entities, bool ignoreNavigationProperty = false) where T : class { if (keySelector == null) { throw new ArgumentNullException(nameof(keySelector)); } if (entities == null) { throw new ArgumentNullException(nameof(entities)); } if (entities is not ICollection collection) { collection = entities.ToList(); } if (collection.Count == 0) { return; } var func = keySelector.CompileFast(); var keyObjects = collection.Select(s => func(s)).ToList(); var parameter = keySelector.Parameters[0]; var array = Expression.Constant(keyObjects); var call = Expression.Call(array, typeof(List).GetMethod(nameof(List.Contains)), keySelector.Body); var lambda = Expression.Lambda>(call, parameter); var items = dbSet.Where(lambda).ToDictionary(t => func(t)); foreach (var entity in collection) { var key = func(entity); if (items.ContainsKey(key)) { // 获取主键字段 var dataType = typeof(T); var keyIgnoreFields = dataType.GetProperties().Where(p => p.GetCustomAttribute() != null || p.GetCustomAttribute() != null).ToList(); if (!keyIgnoreFields.Any()) { string idName = dataType.Name + "Id"; keyIgnoreFields.AddRange(dataType.GetProperties().Where(p => p.Name.Equals("Id", StringComparison.OrdinalIgnoreCase) || p.Name.Equals(idName, StringComparison.OrdinalIgnoreCase))); } if (ignoreNavigationProperty) { keyIgnoreFields.AddRange(dataType.GetProperties().Where(p => p.PropertyType.Namespace == "System.Collections.Generic")); } // 更新所有非主键属性 foreach (var p in typeof(T).GetProperties().Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null)) { // 忽略主键和被忽略的字段 if (keyIgnoreFields.Any(x => x.Name == p.Name)) { continue; } var existingValue = p.GetValue(entity); if (p.GetValue(items[key]) != existingValue) { p.SetValue(items[key], existingValue); } } foreach (var idField in keyIgnoreFields.Where(p => p.SetMethod != null && p.GetMethod != null)) { var existingValue = idField.GetValue(items[key]); if (idField.GetValue(entity) != existingValue) { idField.SetValue(entity, existingValue); } } } else { dbSet.Add(entity); } } } /// /// 添加或更新 /// /// /// 按哪个字段更新 /// /// 按哪个字段更新 /// /// 是否忽略导航属性 public static void AddOrUpdate(this DbSet dbSet, Expression> keySelector, T entity, bool ignoreNavigationProperty = false) where T : class { if (keySelector == null) { throw new ArgumentNullException(nameof(keySelector)); } if (entity == null) { throw new ArgumentNullException(nameof(entity)); } var keyObject = keySelector.CompileFast()(entity); var parameter = keySelector.Parameters[0]; var lambda = Expression.Lambda>(Expression.Equal(keySelector.Body, Expression.Constant(keyObject)), parameter); var item = dbSet.FirstOrDefault(lambda); if (item == null) { dbSet.Add(entity); } else { // 获取主键字段 var dataType = typeof(T); var keyIgnoreFields = dataType.GetProperties().Where(p => p.GetCustomAttribute() != null || p.GetCustomAttribute() != null).ToList(); if (!keyIgnoreFields.Any()) { string idName = dataType.Name + "Id"; keyIgnoreFields.AddRange(dataType.GetProperties().Where(p => p.Name.Equals("Id", StringComparison.OrdinalIgnoreCase) || p.Name.Equals(idName, StringComparison.OrdinalIgnoreCase))); } if (ignoreNavigationProperty) { keyIgnoreFields.AddRange(dataType.GetProperties().Where(p => p.PropertyType.Namespace == "System.Collections.Generic")); } // 更新所有非主键属性 foreach (var p in typeof(T).GetProperties().Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null)) { // 忽略主键和被忽略的字段 if (keyIgnoreFields.Any(x => x.Name == p.Name)) { continue; } var existingValue = p.GetValue(entity); if (p.GetValue(item) != existingValue) { p.SetValue(item, existingValue); } } foreach (var idField in keyIgnoreFields.Where(p => p.SetMethod != null && p.GetMethod != null)) { var existingValue = idField.GetValue(item); if (idField.GetValue(entity) != existingValue) { idField.SetValue(entity, existingValue); } } } } #if NET6_0_OR_GREATER /// /// 随机排序 /// /// /// /// public static IOrderedQueryable OrderByRandom(this IQueryable query) { return query.OrderBy(_ => EF.Functions.Random()); } #endif }