Преглед изворни кода

AddOrUpdate增加忽略导航属性选项

懒得勤快 пре 3 година
родитељ
комит
9a0353a84e

+ 137 - 126
Masuit.Tools.Core/AspNetCore/DbSetExtensions.cs

@@ -3,130 +3,60 @@ using Microsoft.EntityFrameworkCore;
 using System.ComponentModel.DataAnnotations;
 using System.Linq.Expressions;
 using System.Reflection;
-namespace Masuit.Tools.Core.AspNetCore
+namespace Masuit.Tools.Core.AspNetCore;
+
+public static class DbSetExtensions
 {
-    public static class DbSetExtensions
+    /// <summary>
+    /// 添加或更新
+    /// </summary>
+    /// <typeparam name="T"></typeparam>
+    /// <typeparam name="TKey">按哪个字段更新</typeparam>
+    /// <param name="dbSet"></param>
+    /// <param name="keySelector">按哪个字段更新</param>
+    /// <param name="entities"></param>
+    public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, params T[] entities) where T : class
     {
-        /// <summary>
-        /// 添加或更新
-        /// </summary>
-        /// <typeparam name="T"></typeparam>
-        /// <typeparam name="TKey">按哪个字段更新</typeparam>
-        /// <param name="dbSet"></param>
-        /// <param name="keySelector">按哪个字段更新</param>
-        /// <param name="entities"></param>
-        public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, params T[] entities) where T : class
+        AddOrUpdate(dbSet, keySelector, entities.AsEnumerable());
+    }
+
+    /// <summary>
+    /// 添加或更新
+    /// </summary>
+    /// <typeparam name="T"></typeparam>
+    /// <typeparam name="TKey">按哪个字段更新</typeparam>
+    /// <param name="dbSet"></param>
+    /// <param name="keySelector">按哪个字段更新</param>
+    /// <param name="entities"></param>
+    /// <param name="ignoreNavigationProperty">是否忽略导航属性</param>
+    public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, IEnumerable<T> entities, bool ignoreNavigationProperty = false) where T : class
+    {
+        if (keySelector == null)
         {
-            AddOrUpdate(dbSet, keySelector, entities.AsEnumerable());
+            throw new ArgumentNullException(nameof(keySelector));
         }
 
-        /// <summary>
-        /// 添加或更新
-        /// </summary>
-        /// <typeparam name="T"></typeparam>
-        /// <typeparam name="TKey">按哪个字段更新</typeparam>
-        /// <param name="dbSet"></param>
-        /// <param name="keySelector">按哪个字段更新</param>
-        /// <param name="entities"></param>
-        public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, IEnumerable<T> entities) where T : class
+        if (entities == null)
         {
-            if (keySelector == null)
-            {
-                throw new ArgumentNullException(nameof(keySelector));
-            }
-
-            if (entities == null)
-            {
-                throw new ArgumentNullException(nameof(entities));
-            }
-
-            if (entities is not ICollection<T> collection)
-            {
-                collection = entities.ToList();
-            }
-
-            var func = keySelector.CompileFast();
-            var keyObjects = collection.Select(s => func(s)).ToList();
-            var parameter = keySelector.Parameters[0];
-            var array = Expression.Constant(keyObjects);
-            var call = Expression.Call(array, typeof(List<TKey>).GetMethod(nameof(List<TKey>.Contains)), keySelector.Body);
-            var lambda = Expression.Lambda<Func<T, bool>>(call, parameter);
-            var items = dbSet.Where(lambda).ToDictionary(t => func(t));
-            foreach (var entity in collection)
-            {
-                var key = func(entity);
-                if (items.ContainsKey(key))
-                {
-                    // 获取主键字段
-                    var dataType = typeof(T);
-                    var keyIgnoreFields = dataType.GetProperties().Where(p => p.GetCustomAttribute<KeyAttribute>() != null || p.GetCustomAttribute<UpdateIgnoreAttribute>() != null).ToList();
-                    if (!keyIgnoreFields.Any())
-                    {
-                        string idName = dataType.Name + "Id";
-                        keyIgnoreFields = dataType.GetProperties().Where(p => p.Name.Equals("Id", StringComparison.OrdinalIgnoreCase) || p.Name.Equals(idName, StringComparison.OrdinalIgnoreCase)).ToList();
-                    }
-
-                    // 更新所有非主键属性
-                    foreach (var p in typeof(T).GetProperties().Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
-                    {
-                        // 忽略主键和被忽略的字段
-                        if (keyIgnoreFields.Any(x => x.Name == p.Name))
-                        {
-                            continue;
-                        }
-
-                        var existingValue = p.GetValue(entity);
-                        if (p.GetValue(items[key]) != existingValue)
-                        {
-                            p.SetValue(items[key], existingValue);
-                        }
-                    }
-
-                    foreach (var idField in keyIgnoreFields.Where(p => p.SetMethod != null && p.GetMethod != null))
-                    {
-                        var existingValue = idField.GetValue(items[key]);
-                        if (idField.GetValue(entity) != existingValue)
-                        {
-                            idField.SetValue(entity, existingValue);
-                        }
-                    }
-                }
-                else
-                {
-                    dbSet.Add(entity);
-                }
-            }
+            throw new ArgumentNullException(nameof(entities));
         }
 
-        /// <summary>
-        /// 添加或更新
-        /// </summary>
-        /// <typeparam name="T"></typeparam>
-        /// <typeparam name="TKey">按哪个字段更新</typeparam>
-        /// <param name="dbSet"></param>
-        /// <param name="keySelector">按哪个字段更新</param>
-        /// <param name="entity"></param>
-        public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, T entity) where T : class
+        if (entities is not ICollection<T> collection)
         {
-            if (keySelector == null)
-            {
-                throw new ArgumentNullException(nameof(keySelector));
-            }
-
-            if (entity == null)
-            {
-                throw new ArgumentNullException(nameof(entity));
-            }
+            collection = entities.ToList();
+        }
 
-            var keyObject = keySelector.CompileFast()(entity);
-            var parameter = keySelector.Parameters[0];
-            var lambda = Expression.Lambda<Func<T, bool>>(Expression.Equal(ReplaceParameter(keySelector.Body, parameter), Expression.Constant(keyObject)), parameter);
-            var item = dbSet.FirstOrDefault(lambda);
-            if (item == null)
-            {
-                dbSet.Add(entity);
-            }
-            else
+        var func = keySelector.CompileFast();
+        var keyObjects = collection.Select(s => func(s)).ToList();
+        var parameter = keySelector.Parameters[0];
+        var array = Expression.Constant(keyObjects);
+        var call = Expression.Call(array, typeof(List<TKey>).GetMethod(nameof(List<TKey>.Contains)), keySelector.Body);
+        var lambda = Expression.Lambda<Func<T, bool>>(call, parameter);
+        var items = dbSet.Where(lambda).ToDictionary(t => func(t));
+        foreach (var entity in collection)
+        {
+            var key = func(entity);
+            if (items.ContainsKey(key))
             {
                 // 获取主键字段
                 var dataType = typeof(T);
@@ -134,7 +64,12 @@ namespace Masuit.Tools.Core.AspNetCore
                 if (!keyIgnoreFields.Any())
                 {
                     string idName = dataType.Name + "Id";
-                    keyIgnoreFields = dataType.GetProperties().Where(p => p.Name.Equals("Id", StringComparison.OrdinalIgnoreCase) || p.Name.Equals(idName, StringComparison.OrdinalIgnoreCase)).ToList();
+                    keyIgnoreFields.AddRange(dataType.GetProperties().Where(p => p.Name.Equals("Id", StringComparison.OrdinalIgnoreCase) || p.Name.Equals(idName, StringComparison.OrdinalIgnoreCase)));
+                }
+
+                if (ignoreNavigationProperty)
+                {
+                    keyIgnoreFields.AddRange(dataType.GetProperties().Where(p => p.PropertyType.Namespace == "System.Collections.Generic"));
                 }
 
                 // 更新所有非主键属性
@@ -147,32 +82,109 @@ namespace Masuit.Tools.Core.AspNetCore
                     }
 
                     var existingValue = p.GetValue(entity);
-                    if (p.GetValue(item) != existingValue)
+                    if (p.GetValue(items[key]) != existingValue)
                     {
-                        p.SetValue(item, existingValue);
+                        p.SetValue(items[key], existingValue);
                     }
                 }
 
                 foreach (var idField in keyIgnoreFields.Where(p => p.SetMethod != null && p.GetMethod != null))
                 {
-                    var existingValue = idField.GetValue(item);
+                    var existingValue = idField.GetValue(items[key]);
                     if (idField.GetValue(entity) != existingValue)
                     {
                         idField.SetValue(entity, existingValue);
                     }
                 }
             }
+            else
+            {
+                dbSet.Add(entity);
+            }
+        }
+    }
+
+    /// <summary>
+    /// 添加或更新
+    /// </summary>
+    /// <typeparam name="T"></typeparam>
+    /// <typeparam name="TKey">按哪个字段更新</typeparam>
+    /// <param name="dbSet"></param>
+    /// <param name="keySelector">按哪个字段更新</param>
+    /// <param name="entity"></param>
+    /// <param name="ignoreNavigationProperty">是否忽略导航属性</param>
+    public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, T entity, bool ignoreNavigationProperty = false) where T : class
+    {
+        if (keySelector == null)
+        {
+            throw new ArgumentNullException(nameof(keySelector));
         }
 
-        private static Expression ReplaceParameter(Expression oldExpression, ParameterExpression newParameter)
+        if (entity == null)
         {
-            return oldExpression.NodeType switch
+            throw new ArgumentNullException(nameof(entity));
+        }
+
+        var keyObject = keySelector.CompileFast()(entity);
+        var parameter = keySelector.Parameters[0];
+        var lambda = Expression.Lambda<Func<T, bool>>(Expression.Equal(ReplaceParameter(keySelector.Body, parameter), Expression.Constant(keyObject)), parameter);
+        var item = dbSet.FirstOrDefault(lambda);
+        if (item == null)
+        {
+            dbSet.Add(entity);
+        }
+        else
+        {
+            // 获取主键字段
+            var dataType = typeof(T);
+            var keyIgnoreFields = dataType.GetProperties().Where(p => p.GetCustomAttribute<KeyAttribute>() != null || p.GetCustomAttribute<UpdateIgnoreAttribute>() != null).ToList();
+            if (!keyIgnoreFields.Any())
             {
-                ExpressionType.MemberAccess => Expression.MakeMemberAccess(newParameter, ((MemberExpression)oldExpression).Member),
-                ExpressionType.New => Expression.New(((NewExpression)oldExpression).Constructor, ((NewExpression)oldExpression).Arguments.Select(a => ReplaceParameter(a, newParameter)).ToArray()),
-                _ => throw new NotSupportedException("不支持的表达式类型:" + oldExpression.NodeType)
-            };
+                string idName = dataType.Name + "Id";
+                keyIgnoreFields.AddRange(dataType.GetProperties().Where(p => p.Name.Equals("Id", StringComparison.OrdinalIgnoreCase) || p.Name.Equals(idName, StringComparison.OrdinalIgnoreCase)));
+            }
+
+            if (ignoreNavigationProperty)
+            {
+                keyIgnoreFields.AddRange(dataType.GetProperties().Where(p => p.PropertyType.Namespace == "System.Collections.Generic"));
+            }
+
+            // 更新所有非主键属性
+            foreach (var p in typeof(T).GetProperties().Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
+            {
+                // 忽略主键和被忽略的字段
+                if (keyIgnoreFields.Any(x => x.Name == p.Name))
+                {
+                    continue;
+                }
+
+                var existingValue = p.GetValue(entity);
+                if (p.GetValue(item) != existingValue)
+                {
+                    p.SetValue(item, existingValue);
+                }
+            }
+
+            foreach (var idField in keyIgnoreFields.Where(p => p.SetMethod != null && p.GetMethod != null))
+            {
+                var existingValue = idField.GetValue(item);
+                if (idField.GetValue(entity) != existingValue)
+                {
+                    idField.SetValue(entity, existingValue);
+                }
+            }
         }
+    }
+
+    private static Expression ReplaceParameter(Expression oldExpression, ParameterExpression newParameter)
+    {
+        return oldExpression.NodeType switch
+        {
+            ExpressionType.MemberAccess => Expression.MakeMemberAccess(newParameter, ((MemberExpression)oldExpression).Member),
+            ExpressionType.New => Expression.New(((NewExpression)oldExpression).Constructor, ((NewExpression)oldExpression).Arguments.Select(a => ReplaceParameter(a, newParameter)).ToArray()),
+            _ => throw new NotSupportedException("不支持的表达式类型:" + oldExpression.NodeType)
+        };
+    }
 
 #if NET6_0_OR_GREATER
 
@@ -187,5 +199,4 @@ namespace Masuit.Tools.Core.AspNetCore
             return query.OrderBy(_ => EF.Functions.Random());
         }
 #endif
-    }
-}
+}

+ 7 - 10
Masuit.Tools.Core/AspNetCore/UpdateIgnoreAttribute.cs

@@ -1,12 +1,9 @@
-using System;
+namespace Masuit.Tools.Core.AspNetCore;
 
-namespace Masuit.Tools.Core.AspNetCore
+/// <summary>
+/// 更新时忽略的字段,检测到这个attribute时AddOrUpdate将忽略这个字段
+/// </summary>
+[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)]
+public sealed class UpdateIgnoreAttribute : Attribute
 {
-    /// <summary>
-    /// 更新时忽略的字段,检测到这个attribute时AddOrUpdate将忽略改字段
-    /// </summary>
-    [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)]
-    public sealed class UpdateIgnoreAttribute : Attribute
-    {
-    }
-}
+}

+ 1 - 1
Test/Masuit.Tools.Abstractions.Test/Masuit.Tools.Abstractions.Test.csproj

@@ -19,7 +19,7 @@
       <PrivateAssets>all</PrivateAssets>
       <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
     </PackageReference>
-    <PackageReference Include="coverlet.collector" Version="3.1.2">
+    <PackageReference Include="coverlet.collector" Version="3.2.0">
       <PrivateAssets>all</PrivateAssets>
       <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
     </PackageReference>