FromBodyOrDefaultModelBinder.cs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. using System.Collections;
  2. using System.Net.Mime;
  3. using System.Reflection;
  4. using System.Xml.Linq;
  5. using Masuit.Tools.Systems;
  6. using Microsoft.AspNetCore.Mvc.ModelBinding;
  7. using Microsoft.Extensions.DependencyInjection;
  8. using Microsoft.Extensions.Logging;
  9. using Microsoft.Extensions.Primitives;
  10. using Newtonsoft.Json.Linq;
  11. namespace Masuit.Tools.AspNetCore.ModelBinder;
  12. public class FromBodyOrDefaultModelBinder : IModelBinder
  13. {
  14. private static readonly List<BindType> BindTypes = new()
  15. {
  16. BindType.Query,
  17. BindType.Body,
  18. BindType.Header,
  19. BindType.Form,
  20. BindType.Cookie,
  21. BindType.Route
  22. };
  23. private readonly ILogger<FromBodyOrDefaultModelBinder> _logger;
  24. public FromBodyOrDefaultModelBinder(ILogger<FromBodyOrDefaultModelBinder> logger)
  25. {
  26. _logger = logger;
  27. }
  28. public Task BindModelAsync(ModelBindingContext bindingContext)
  29. {
  30. var context = bindingContext.HttpContext;
  31. var attr = bindingContext.GetAttribute<FromBodyOrDefaultAttribute>();
  32. var field = attr?.FieldName ?? bindingContext.FieldName;
  33. var modelType = bindingContext.ModelType;
  34. object value = null;
  35. if (attr != null)
  36. {
  37. if (modelType.IsSimpleType() || modelType.IsSimpleArrayType() || modelType.IsSimpleListType())
  38. {
  39. if (attr.Type == BindType.Default)
  40. {
  41. foreach (var type in BindTypes)
  42. {
  43. value = GetBindingValue(bindingContext, type, field, modelType);
  44. if (value != null)
  45. {
  46. break;
  47. }
  48. }
  49. }
  50. else
  51. {
  52. foreach (var type in attr.Type.Split())
  53. {
  54. value = GetBindingValue(bindingContext, type, field, modelType);
  55. if (value != null)
  56. {
  57. break;
  58. }
  59. }
  60. }
  61. }
  62. else
  63. {
  64. if (bindingContext.HttpContext.Items.TryGetValue("BodyOrDefaultModelBinder@JsonBody", out var obj) && obj is JObject json)
  65. {
  66. if (modelType.IsArray || modelType.IsGenericType && modelType.GenericTypeArguments.Length == 1)
  67. {
  68. if (json.TryGetValue(field, StringComparison.OrdinalIgnoreCase, out var jtoken))
  69. {
  70. value = jtoken.ToObject(modelType);
  71. }
  72. else
  73. {
  74. _logger.LogWarning($"TraceIdentifier:{context.TraceIdentifier},BodyOrDefaultModelBinder从{json}中获取{field}失败!");
  75. }
  76. }
  77. else
  78. {
  79. // 可能是 字典或者实体 类型,尝试将modeltype 当初整个请求参数对象
  80. try
  81. {
  82. value = json.ToObject(modelType);
  83. }
  84. catch (Exception e)
  85. {
  86. _logger.LogError(e, e.Message, json.ToString());
  87. }
  88. }
  89. }
  90. if (value == null)
  91. {
  92. var (requestData, keys) = GetRequestData(bindingContext, modelType);
  93. if (keys.Any())
  94. {
  95. var instance = Activator.CreateInstance(modelType);
  96. switch (requestData)
  97. {
  98. case IEnumerable<KeyValuePair<string, StringValues>> stringValues:
  99. {
  100. foreach (var item in stringValues)
  101. {
  102. var property = modelType.GetProperty(item.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
  103. if (property != null)
  104. {
  105. property.SetValue(instance, item.Value.ConvertObject(property.PropertyType));
  106. }
  107. }
  108. break;
  109. }
  110. case IEnumerable<KeyValuePair<string, string>> strs:
  111. {
  112. //处理Cookie
  113. foreach (var item in strs)
  114. {
  115. var property = modelType.GetProperty(item.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
  116. if (property != null)
  117. {
  118. property.SetValue(instance, item.Value.ConvertObject(property.PropertyType));
  119. }
  120. }
  121. break;
  122. }
  123. case IEnumerable<KeyValuePair<string, object>> objects:
  124. {
  125. //处理路由
  126. foreach (var item in objects)
  127. {
  128. var property = modelType.GetProperty(item.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
  129. if (property != null)
  130. {
  131. property.SetValue(instance, item.Value.ConvertObject(property.PropertyType));
  132. }
  133. }
  134. break;
  135. }
  136. }
  137. value = instance;
  138. }
  139. }
  140. }
  141. if (value == null && attr.DefaultValue != null)
  142. {
  143. value = attr.DefaultValue.ChangeType(modelType);
  144. }
  145. }
  146. if (value != null)
  147. {
  148. bindingContext.Result = ModelBindingResult.Success(value);
  149. }
  150. return Task.CompletedTask;
  151. }
  152. private static (IEnumerable data, List<string> keys) GetRequestData(ModelBindingContext bindingContext, Type type)
  153. {
  154. var request = bindingContext.HttpContext.Request;
  155. var props = type.GetProperties().Select(t => t.Name).ToList();
  156. var query = props.Except(request.Query.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  157. var headers = props.Except(request.Headers.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  158. var cookies = props.Except(request.Cookies.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  159. var routes = props.Except(bindingContext.ActionContext.RouteData.Values.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  160. var list = new List<KeyValuePair<List<string>, IEnumerable>>()
  161. {
  162. new(query, request.Query),
  163. new(headers, request.Headers),
  164. new(cookies, request.Cookies),
  165. new(routes, bindingContext.ActionContext.RouteData.Values),
  166. };
  167. if (request.HasFormContentType && request.Form.Count > 0)
  168. {
  169. var forms = props.Except(request.Form.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  170. list.Add(new KeyValuePair<List<string>, IEnumerable>(forms, request.Form));
  171. }
  172. var kv = list.OrderBy(t => t.Key.Count).FirstOrDefault();
  173. return (kv.Value, props.Except(kv.Key).ToList());
  174. }
  175. /// <summary>
  176. /// 获取要绑定的值
  177. /// </summary>
  178. /// <param name="bindingContext"></param>
  179. /// <param name="bindType"></param>
  180. /// <param name="fieldName"></param>
  181. /// <param name="modelType"></param>
  182. private object GetBindingValue(ModelBindingContext bindingContext, BindType bindType, string fieldName, Type modelType)
  183. {
  184. var context = bindingContext.HttpContext;
  185. var mediaType = string.Empty;
  186. if (!string.IsNullOrWhiteSpace(context.Request.ContentType))
  187. {
  188. try
  189. {
  190. var contentType = new ContentType(context.Request.ContentType);
  191. mediaType = contentType.MediaType.ToLower();
  192. }
  193. catch (Exception ex)
  194. {
  195. _logger.LogError(ex, ex.Message, context.Request.ContentType);
  196. }
  197. }
  198. object targetValue = null;
  199. switch (bindType)
  200. {
  201. case BindType.Body:
  202. switch (mediaType)
  203. {
  204. case "application/json":
  205. {
  206. if (bindingContext.HttpContext.Items.TryGetValue("BodyOrDefaultModelBinder@JsonBody", out var obj) && obj is JObject json && json.TryGetValue(fieldName, StringComparison.OrdinalIgnoreCase, out var values))
  207. {
  208. targetValue = values.ConvertObject(modelType);
  209. }
  210. }
  211. break;
  212. case "application/xml":
  213. {
  214. if (bindingContext.HttpContext.Items.TryGetValue("BodyOrDefaultModelBinder@XmlBody", out var obj) && obj is XDocument xml)
  215. {
  216. var xmlElt = xml.Element(fieldName);
  217. if (xmlElt != null)
  218. {
  219. targetValue = xmlElt.Value.ConvertObject(modelType);
  220. }
  221. }
  222. break;
  223. }
  224. }
  225. break;
  226. case BindType.Query:
  227. {
  228. if (context.Request.Query is { Count: > 0 } && context.Request.Query.TryGetValue(fieldName, out var values))
  229. {
  230. targetValue = values.ConvertObject(modelType);
  231. }
  232. }
  233. break;
  234. case BindType.Form:
  235. {
  236. if (context.Request is { HasFormContentType: true, Form.Count: > 0 } && context.Request.Form.TryGetValue(fieldName, out var values))
  237. {
  238. targetValue = values.ConvertObject(modelType);
  239. }
  240. }
  241. break;
  242. case BindType.Header:
  243. {
  244. if (context.Request.Headers is { Count: > 0 } && context.Request.Headers.TryGetValue(fieldName, out var values))
  245. {
  246. targetValue = values.ConvertObject(modelType);
  247. }
  248. }
  249. break;
  250. case BindType.Cookie:
  251. {
  252. if (context.Request.Cookies is { Count: > 0 } && context.Request.Cookies.TryGetValue(fieldName, out var values))
  253. {
  254. targetValue = values.ConvertObject(modelType);
  255. }
  256. }
  257. break;
  258. case BindType.Route:
  259. {
  260. if (bindingContext.ActionContext.RouteData.Values is { Count: > 0 } && bindingContext.ActionContext.RouteData.Values.TryGetValue(fieldName, out var values))
  261. {
  262. targetValue = values.ConvertObject(modelType);
  263. }
  264. }
  265. break;
  266. case BindType.Services:
  267. targetValue = bindingContext.ActionContext.HttpContext.RequestServices.GetRequiredService(modelType);
  268. break;
  269. }
  270. return targetValue;
  271. }
  272. }