using Microsoft.EntityFrameworkCore;
using System;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
namespace Masuit.Tools.Core.AspNetCore
{
public static class DbSetExtensions
{
///
/// 添加或更新
///
///
/// 按哪个字段更新
///
/// 按哪个字段更新
///
public static void AddOrUpdate(this DbSet dbSet, Expression> keySelector, params T[] entities) where T : class
{
foreach (var entity in entities)
{
AddOrUpdate(dbSet, keySelector, entity);
}
}
///
/// 添加或更新
///
///
/// 按哪个字段更新
///
/// 按哪个字段更新
///
public static void AddOrUpdate(this DbSet dbSet, Expression> keySelector, T entity) where T : class
{
if (keySelector == null)
{
throw new ArgumentNullException(nameof(keySelector));
}
if (entity == null)
{
throw new ArgumentNullException(nameof(entity));
}
var keyObject = keySelector.Compile()(entity);
var parameter = Expression.Parameter(typeof(T), "p");
var lambda = Expression.Lambda>(Expression.Equal(ReplaceParameter(keySelector.Body, parameter), Expression.Constant(keyObject)), parameter);
var item = dbSet.FirstOrDefault(lambda);
if (item == null)
{
dbSet.Add(entity);
}
else
{
// 获取主键字段
var dataType = typeof(T);
var keyFields = dataType.GetProperties().Where(p => p.GetCustomAttribute() != null).ToList();
if (!keyFields.Any())
{
string idName = dataType.Name + "Id";
keyFields = dataType.GetProperties().Where(p =>
string.Equals(p.Name, "Id", StringComparison.OrdinalIgnoreCase) ||
string.Equals(p.Name, idName, StringComparison.OrdinalIgnoreCase)).ToList();
}
// 更新所有非主键和非集合属性
foreach (var p in typeof(T).GetProperties().Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
{
// 忽略主键
if (keyFields.Any(x => x.Name == p.Name))
{
continue;
}
var existingValue = p.GetValue(entity);
if (!Equals(p.GetValue(item), existingValue))
{
p.SetValue(item, existingValue);
}
}
foreach (var idField in keyFields.Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
{
var existingValue = idField.GetValue(item);
if (!Equals(idField.GetValue(entity), existingValue))
{
idField.SetValue(entity, existingValue);
}
}
}
}
private static Expression ReplaceParameter(Expression oldExpression, ParameterExpression newParameter)
{
return oldExpression.NodeType switch
{
ExpressionType.MemberAccess => (Expression)Expression.MakeMemberAccess(newParameter, ((MemberExpression)oldExpression).Member),
ExpressionType.New => Expression.New(((NewExpression)oldExpression).Constructor, ((NewExpression)oldExpression).Arguments.Select(a => ReplaceParameter(a, newParameter)).ToArray()),
_ => throw new NotSupportedException("不支持的表达式类型:" + oldExpression.NodeType)
};
}
}
}