FromBodyOrDefaultModelBinder.cs 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. using System.Collections;
  2. using System.Net.Mime;
  3. using System.Reflection;
  4. using System.Xml.Linq;
  5. using Masuit.Tools.Systems;
  6. using Microsoft.AspNetCore.Mvc.ModelBinding;
  7. using Microsoft.Extensions.Primitives;
  8. using Newtonsoft.Json.Linq;
  9. namespace Masuit.Tools.AspNetCore.ModelBinder;
  10. public class FromBodyOrDefaultModelBinder : IModelBinder
  11. {
  12. private static readonly List<BindType> BindTypes = new()
  13. {
  14. BindType.Query,
  15. BindType.Body,
  16. BindType.Header,
  17. BindType.Form,
  18. BindType.Cookie,
  19. BindType.Route
  20. };
  21. private readonly ILogger<FromBodyOrDefaultModelBinder> _logger;
  22. public FromBodyOrDefaultModelBinder(ILogger<FromBodyOrDefaultModelBinder> logger)
  23. {
  24. _logger = logger;
  25. }
  26. public Task BindModelAsync(ModelBindingContext bindingContext)
  27. {
  28. var context = bindingContext.HttpContext;
  29. var attr = bindingContext.GetAttribute<FromBodyOrDefaultAttribute>();
  30. var field = attr?.FieldName ?? bindingContext.FieldName;
  31. var modelType = bindingContext.ModelType;
  32. object value = null;
  33. if (attr != null)
  34. {
  35. if (modelType.IsSimpleType() || modelType.IsSimpleArrayType() || modelType.IsSimpleListType())
  36. {
  37. if (attr.Type == BindType.Default)
  38. {
  39. foreach (var type in BindTypes)
  40. {
  41. value = GetBindingValue(bindingContext, type, field, modelType);
  42. if (value != null)
  43. {
  44. break;
  45. }
  46. }
  47. }
  48. else
  49. {
  50. foreach (var type in attr.Type.Split())
  51. {
  52. value = GetBindingValue(bindingContext, type, field, modelType);
  53. if (value != null)
  54. {
  55. break;
  56. }
  57. }
  58. }
  59. }
  60. else
  61. {
  62. if (bindingContext.HttpContext.Items.TryGetValue("BodyOrDefaultModelBinder@JsonBody", out var obj) && obj is JObject json)
  63. {
  64. if (modelType.IsArray || modelType.IsGenericType && modelType.GenericTypeArguments.Length == 1)
  65. {
  66. if (json.TryGetValue(field, StringComparison.OrdinalIgnoreCase, out var jtoken))
  67. {
  68. value = jtoken.ToObject(modelType);
  69. }
  70. else
  71. {
  72. _logger.LogWarning($"TraceIdentifier:{context.TraceIdentifier},BodyOrDefaultModelBinder从{json}中获取{field}失败!");
  73. }
  74. }
  75. else
  76. {
  77. // 可能是 字典或者实体 类型,尝试将modeltype 当初整个请求参数对象
  78. try
  79. {
  80. value = json.ToObject(modelType);
  81. }
  82. catch (Exception e)
  83. {
  84. _logger.LogError(e, e.Message, json.ToString());
  85. }
  86. }
  87. }
  88. if (value == null)
  89. {
  90. var (requestData, keys) = GetRequestData(bindingContext, modelType);
  91. if (keys.Any())
  92. {
  93. var instance = Activator.CreateInstance(modelType);
  94. switch (requestData)
  95. {
  96. case IEnumerable<KeyValuePair<string, StringValues>> stringValues:
  97. {
  98. foreach (var item in stringValues)
  99. {
  100. var property = modelType.GetProperty(item.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
  101. if (property != null)
  102. {
  103. property.SetValue(instance, item.Value.ConvertObject(property.PropertyType));
  104. }
  105. }
  106. break;
  107. }
  108. case IEnumerable<KeyValuePair<string, string>> strs:
  109. {
  110. //处理Cookie
  111. foreach (var item in strs)
  112. {
  113. var property = modelType.GetProperty(item.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
  114. if (property != null)
  115. {
  116. property.SetValue(instance, item.Value.ConvertObject(property.PropertyType));
  117. }
  118. }
  119. break;
  120. }
  121. case IEnumerable<KeyValuePair<string, object>> objects:
  122. {
  123. //处理路由
  124. foreach (var item in objects)
  125. {
  126. var property = modelType.GetProperty(item.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
  127. if (property != null)
  128. {
  129. property.SetValue(instance, item.Value.ConvertObject(property.PropertyType));
  130. }
  131. }
  132. break;
  133. }
  134. }
  135. value = instance;
  136. }
  137. }
  138. }
  139. if (value == null && attr.DefaultValue != null)
  140. {
  141. value = attr.DefaultValue.ChangeType(modelType);
  142. }
  143. }
  144. if (value != null)
  145. {
  146. bindingContext.Result = ModelBindingResult.Success(value);
  147. }
  148. return Task.CompletedTask;
  149. }
  150. private static (IEnumerable data, List<string> keys) GetRequestData(ModelBindingContext bindingContext, Type type)
  151. {
  152. var request = bindingContext.HttpContext.Request;
  153. var props = type.GetProperties().Select(t => t.Name).ToList();
  154. var query = props.Except(request.Query.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  155. var headers = props.Except(request.Headers.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  156. var cookies = props.Except(request.Cookies.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  157. var routes = props.Except(bindingContext.ActionContext.RouteData.Values.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  158. var list = new List<KeyValuePair<List<string>, IEnumerable>>()
  159. {
  160. new(query, request.Query),
  161. new(headers, request.Headers),
  162. new(cookies, request.Cookies),
  163. new(routes, bindingContext.ActionContext.RouteData.Values),
  164. };
  165. if (request.HasFormContentType && request.Form.Count > 0)
  166. {
  167. var forms = props.Except(request.Form.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  168. list.Add(new KeyValuePair<List<string>, IEnumerable>(forms, request.Form));
  169. }
  170. var kv = list.OrderBy(t => t.Key.Count).FirstOrDefault();
  171. return (kv.Value, props.Except(kv.Key).ToList());
  172. }
  173. /// <summary>
  174. /// 获取要绑定的值
  175. /// </summary>
  176. /// <param name="bindingContext"></param>
  177. /// <param name="bindType"></param>
  178. /// <param name="fieldName"></param>
  179. /// <param name="modelType"></param>
  180. private object GetBindingValue(ModelBindingContext bindingContext, BindType bindType, string fieldName, Type modelType)
  181. {
  182. var context = bindingContext.HttpContext;
  183. var mediaType = string.Empty;
  184. if (!string.IsNullOrWhiteSpace(context.Request.ContentType))
  185. {
  186. try
  187. {
  188. var contentType = new ContentType(context.Request.ContentType);
  189. mediaType = contentType.MediaType.ToLower();
  190. }
  191. catch (Exception ex)
  192. {
  193. _logger.LogError(ex, ex.Message, context.Request.ContentType);
  194. }
  195. }
  196. object targetValue = null;
  197. switch (bindType)
  198. {
  199. case BindType.Body:
  200. switch (mediaType)
  201. {
  202. case "application/json":
  203. {
  204. if (bindingContext.HttpContext.Items.TryGetValue("BodyOrDefaultModelBinder@JsonBody", out var obj) && obj is JObject json && json.TryGetValue(fieldName, StringComparison.OrdinalIgnoreCase, out var values))
  205. {
  206. targetValue = values.ConvertObject(modelType);
  207. }
  208. }
  209. break;
  210. case "application/xml":
  211. {
  212. if (bindingContext.HttpContext.Items.TryGetValue("BodyOrDefaultModelBinder@XmlBody", out var obj) && obj is XDocument xml)
  213. {
  214. var xmlElt = xml.Element(fieldName);
  215. if (xmlElt != null)
  216. {
  217. targetValue = xmlElt.Value.ConvertObject(modelType);
  218. }
  219. }
  220. break;
  221. }
  222. }
  223. break;
  224. case BindType.Query:
  225. {
  226. if (context.Request.Query is { Count: > 0 } && context.Request.Query.TryGetValue(fieldName, out var values))
  227. {
  228. targetValue = values.ConvertObject(modelType);
  229. }
  230. }
  231. break;
  232. case BindType.Form:
  233. {
  234. if (context.Request is { HasFormContentType: true, Form.Count: > 0 } && context.Request.Form.TryGetValue(fieldName, out var values))
  235. {
  236. targetValue = values.ConvertObject(modelType);
  237. }
  238. }
  239. break;
  240. case BindType.Header:
  241. {
  242. if (context.Request.Headers is { Count: > 0 } && context.Request.Headers.TryGetValue(fieldName, out var values))
  243. {
  244. targetValue = values.ConvertObject(modelType);
  245. }
  246. }
  247. break;
  248. case BindType.Cookie:
  249. {
  250. if (context.Request.Cookies is { Count: > 0 } && context.Request.Cookies.TryGetValue(fieldName, out var values))
  251. {
  252. targetValue = values.ConvertObject(modelType);
  253. }
  254. }
  255. break;
  256. case BindType.Route:
  257. {
  258. if (bindingContext.ActionContext.RouteData.Values is { Count: > 0 } && bindingContext.ActionContext.RouteData.Values.TryGetValue(fieldName, out var values))
  259. {
  260. targetValue = values.ConvertObject(modelType);
  261. }
  262. }
  263. break;
  264. case BindType.Services:
  265. targetValue = bindingContext.ActionContext.HttpContext.RequestServices.GetRequiredService(modelType);
  266. break;
  267. }
  268. return targetValue;
  269. }
  270. }