DbSetExtensions.cs 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. using Microsoft.EntityFrameworkCore;
  2. using System;
  3. using System.ComponentModel.DataAnnotations;
  4. using System.Linq;
  5. using System.Linq.Expressions;
  6. using System.Reflection;
  7. namespace Masuit.Tools.Core.AspNetCore
  8. {
  9. public static class DbSetExtensions
  10. {
  11. /// <summary>
  12. /// 添加或更新
  13. /// </summary>
  14. /// <typeparam name="T"></typeparam>
  15. /// <typeparam name="TKey">按哪个字段更新</typeparam>
  16. /// <param name="dbSet"></param>
  17. /// <param name="keySelector">按哪个字段更新</param>
  18. /// <param name="entities"></param>
  19. public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, params T[] entities) where T : class
  20. {
  21. foreach (var entity in entities)
  22. {
  23. AddOrUpdate(dbSet, keySelector, entity);
  24. }
  25. }
  26. /// <summary>
  27. /// 添加或更新
  28. /// </summary>
  29. /// <typeparam name="T"></typeparam>
  30. /// <typeparam name="TKey">按哪个字段更新</typeparam>
  31. /// <param name="dbSet"></param>
  32. /// <param name="keySelector">按哪个字段更新</param>
  33. /// <param name="entity"></param>
  34. public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, T entity) where T : class
  35. {
  36. if (keySelector == null)
  37. {
  38. throw new ArgumentNullException(nameof(keySelector));
  39. }
  40. if (entity == null)
  41. {
  42. throw new ArgumentNullException(nameof(entity));
  43. }
  44. var keyObject = keySelector.Compile()(entity);
  45. var parameter = Expression.Parameter(typeof(T), "p");
  46. var lambda = Expression.Lambda<Func<T, bool>>(Expression.Equal(ReplaceParameter(keySelector.Body, parameter), Expression.Constant(keyObject)), parameter);
  47. var item = dbSet.FirstOrDefault(lambda);
  48. if (item == null)
  49. {
  50. dbSet.Add(entity);
  51. }
  52. else
  53. {
  54. // 获取主键字段
  55. var dataType = typeof(T);
  56. var keyFields = dataType.GetProperties().Where(p => p.GetCustomAttribute<KeyAttribute>() != null).ToList();
  57. if (!keyFields.Any())
  58. {
  59. string idName = dataType.Name + "Id";
  60. keyFields = dataType.GetProperties().Where(p =>
  61. string.Equals(p.Name, "Id", StringComparison.OrdinalIgnoreCase) ||
  62. string.Equals(p.Name, idName, StringComparison.OrdinalIgnoreCase)).ToList();
  63. }
  64. // 更新所有非主键和非集合属性
  65. foreach (var p in typeof(T).GetProperties().Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
  66. {
  67. // 忽略主键
  68. if (keyFields.Any(x => x.Name == p.Name))
  69. {
  70. continue;
  71. }
  72. var existingValue = p.GetValue(entity);
  73. if (!Equals(p.GetValue(item), existingValue))
  74. {
  75. p.SetValue(item, existingValue);
  76. }
  77. }
  78. foreach (var idField in keyFields.Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
  79. {
  80. var existingValue = idField.GetValue(item);
  81. if (!Equals(idField.GetValue(entity), existingValue))
  82. {
  83. idField.SetValue(entity, existingValue);
  84. }
  85. }
  86. }
  87. }
  88. private static Expression ReplaceParameter(Expression oldExpression, ParameterExpression newParameter)
  89. {
  90. return oldExpression.NodeType switch
  91. {
  92. ExpressionType.MemberAccess => (Expression)Expression.MakeMemberAccess(newParameter, ((MemberExpression)oldExpression).Member),
  93. ExpressionType.New => Expression.New(((NewExpression)oldExpression).Constructor, ((NewExpression)oldExpression).Arguments.Select(a => ReplaceParameter(a, newParameter)).ToArray()),
  94. _ => throw new NotSupportedException("不支持的表达式类型:" + oldExpression.NodeType)
  95. };
  96. }
  97. }
  98. }