DbSetExtensions.cs 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. using Microsoft.EntityFrameworkCore;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.ComponentModel.DataAnnotations;
  5. using System.Linq;
  6. using System.Linq.Expressions;
  7. using System.Reflection;
  8. namespace Masuit.Tools.Core.AspNetCore
  9. {
  10. public static class DbSetExtensions
  11. {
  12. /// <summary>
  13. /// 添加或更新
  14. /// </summary>
  15. /// <typeparam name="T"></typeparam>
  16. /// <typeparam name="TKey">按哪个字段更新</typeparam>
  17. /// <param name="dbSet"></param>
  18. /// <param name="keySelector">按哪个字段更新</param>
  19. /// <param name="entities"></param>
  20. public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, params T[] entities) where T : class
  21. {
  22. foreach (var entity in entities)
  23. {
  24. AddOrUpdate(dbSet, keySelector, entity);
  25. }
  26. }
  27. /// <summary>
  28. /// 添加或更新
  29. /// </summary>
  30. /// <typeparam name="T"></typeparam>
  31. /// <typeparam name="TKey">按哪个字段更新</typeparam>
  32. /// <param name="dbSet"></param>
  33. /// <param name="keySelector">按哪个字段更新</param>
  34. /// <param name="entities"></param>
  35. public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, IEnumerable<T> entities) where T : class
  36. {
  37. foreach (var entity in entities)
  38. {
  39. AddOrUpdate(dbSet, keySelector, entity);
  40. }
  41. }
  42. /// <summary>
  43. /// 添加或更新
  44. /// </summary>
  45. /// <typeparam name="T"></typeparam>
  46. /// <typeparam name="TKey">按哪个字段更新</typeparam>
  47. /// <param name="dbSet"></param>
  48. /// <param name="keySelector">按哪个字段更新</param>
  49. /// <param name="entity"></param>
  50. public static void AddOrUpdate<T, TKey>(this DbSet<T> dbSet, Expression<Func<T, TKey>> keySelector, T entity) where T : class
  51. {
  52. if (keySelector == null)
  53. {
  54. throw new ArgumentNullException(nameof(keySelector));
  55. }
  56. if (entity == null)
  57. {
  58. throw new ArgumentNullException(nameof(entity));
  59. }
  60. var keyObject = keySelector.Compile()(entity);
  61. var parameter = Expression.Parameter(typeof(T), "p");
  62. var lambda = Expression.Lambda<Func<T, bool>>(Expression.Equal(ReplaceParameter(keySelector.Body, parameter), Expression.Constant(keyObject)), parameter);
  63. var item = dbSet.FirstOrDefault(lambda);
  64. if (item == null)
  65. {
  66. dbSet.Add(entity);
  67. }
  68. else
  69. {
  70. // 获取主键字段
  71. var dataType = typeof(T);
  72. var keyIgnoreFields = dataType.GetProperties().Where(p => p.GetCustomAttribute<KeyAttribute>() != null || p.GetCustomAttribute<UpdateIgnoreAttribute>() != null).ToList();
  73. if (!keyIgnoreFields.Any())
  74. {
  75. string idName = dataType.Name + "Id";
  76. keyIgnoreFields = dataType.GetProperties().Where(p => p.Name.Equals("Id", StringComparison.OrdinalIgnoreCase) || p.Name.Equals(idName, StringComparison.OrdinalIgnoreCase)).ToList();
  77. }
  78. // 更新所有非主键属性
  79. foreach (var p in typeof(T).GetProperties().Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
  80. {
  81. // 忽略主键和被忽略的字段
  82. if (keyIgnoreFields.Any(x => x.Name == p.Name))
  83. {
  84. continue;
  85. }
  86. var existingValue = p.GetValue(entity);
  87. if (p.GetValue(item) != existingValue)
  88. {
  89. p.SetValue(item, existingValue);
  90. }
  91. }
  92. foreach (var idField in keyIgnoreFields.Where(p => p.SetMethod != null && p.GetMethod != null))
  93. {
  94. var existingValue = idField.GetValue(item);
  95. if (idField.GetValue(entity) != existingValue)
  96. {
  97. idField.SetValue(entity, existingValue);
  98. }
  99. }
  100. }
  101. }
  102. private static Expression ReplaceParameter(Expression oldExpression, ParameterExpression newParameter)
  103. {
  104. return oldExpression.NodeType switch
  105. {
  106. ExpressionType.MemberAccess => Expression.MakeMemberAccess(newParameter, ((MemberExpression)oldExpression).Member),
  107. ExpressionType.New => Expression.New(((NewExpression)oldExpression).Constructor, ((NewExpression)oldExpression).Arguments.Select(a => ReplaceParameter(a, newParameter)).ToArray()),
  108. _ => throw new NotSupportedException("不支持的表达式类型:" + oldExpression.NodeType)
  109. };
  110. }
  111. }
  112. }