BodyOrDefaultModelBinder.cs 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. using System.Collections;
  2. using System.Net.Mime;
  3. using System.Reflection;
  4. using Microsoft.AspNetCore.Mvc.ModelBinding;
  5. using Microsoft.Extensions.Primitives;
  6. namespace Masuit.Tools.AspNetCore.ModelBinder;
  7. public class BodyOrDefaultModelBinder : IModelBinder
  8. {
  9. private static readonly List<BindType> AutoTypes;
  10. static BodyOrDefaultModelBinder()
  11. {
  12. AutoTypes = Enum.GetNames(typeof(BindType)).Select(t => Enum.Parse<BindType>(t)).Where(t => t != BindType.None && t != BindType.Services).ToList();
  13. }
  14. private readonly ILogger<BodyOrDefaultModelBinder> _logger;
  15. public BodyOrDefaultModelBinder(ILogger<BodyOrDefaultModelBinder> logger)
  16. {
  17. _logger = logger;
  18. }
  19. public Task BindModelAsync(ModelBindingContext bindingContext)
  20. {
  21. var context = bindingContext.HttpContext;
  22. var attr = bindingContext.GetAttribute<FromBodyOrDefaultAttribute>();
  23. var fieldName = attr?.FieldName ?? bindingContext.FieldName;
  24. var modelType = bindingContext.ModelType;
  25. object targetValue = null;
  26. if (attr != null)
  27. {
  28. if (modelType.IsSimpleType() || modelType.IsSimpleArrayType() || modelType.IsSimpleListType())
  29. {
  30. if (attr.Type == BindType.None)
  31. {
  32. foreach (var type in AutoTypes)
  33. {
  34. targetValue = GetBindingValue(bindingContext, type, fieldName, modelType);
  35. if (targetValue != null)
  36. {
  37. break;
  38. }
  39. }
  40. }
  41. else
  42. {
  43. targetValue = GetBindingValue(bindingContext, attr.Type, fieldName, modelType);
  44. }
  45. }
  46. else
  47. {
  48. if (BodyOrDefaultModelBinderMiddleware.JsonObject != null)
  49. {
  50. if (modelType.IsArray || modelType.IsGenericType && modelType.GenericTypeArguments.Length == 1)
  51. {
  52. if (BodyOrDefaultModelBinderMiddleware.JsonObject.TryGetValue(fieldName, StringComparison.OrdinalIgnoreCase, out var jtoken))
  53. {
  54. targetValue = jtoken.ToObject(modelType);
  55. }
  56. else
  57. {
  58. _logger.LogWarning($"TraceIdentifier:{context.TraceIdentifier},AutoBinderMiddleware从{BodyOrDefaultModelBinderMiddleware.JsonObject}中获取{fieldName}失败!");
  59. }
  60. }
  61. else
  62. {
  63. // 可能是 字典或者实体 类型,尝试将modeltype 当初整个请求参数对象
  64. try
  65. {
  66. targetValue = BodyOrDefaultModelBinderMiddleware.JsonObject.ToObject(modelType);
  67. }
  68. catch (Exception e)
  69. {
  70. _logger.LogError(e, e.Message, BodyOrDefaultModelBinderMiddleware.JsonObject.ToString());
  71. }
  72. }
  73. }
  74. if (targetValue == null)
  75. {
  76. var (requestData, keys) = GetRequestData(bindingContext, modelType);
  77. if (keys.Any())
  78. {
  79. var setObj = Activator.CreateInstance(modelType);
  80. switch (requestData)
  81. {
  82. case IEnumerable<KeyValuePair<string, StringValues>> stringValues:
  83. {
  84. foreach (var item in stringValues)
  85. {
  86. var prop = modelType.GetProperty(item.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
  87. if (prop != null)
  88. {
  89. prop.SetValue(setObj, item.Value.ConvertObject(prop.PropertyType));
  90. }
  91. }
  92. break;
  93. }
  94. case IEnumerable<KeyValuePair<string, string>> strs:
  95. {
  96. //处理Cookie
  97. foreach (var item in strs)
  98. {
  99. var prop = modelType.GetProperty(item.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
  100. if (prop != null)
  101. {
  102. prop.SetValue(setObj, item.Value.ConvertObject(prop.PropertyType));
  103. }
  104. }
  105. break;
  106. }
  107. case IEnumerable<KeyValuePair<string, object>> objects:
  108. {
  109. //处理路由
  110. foreach (var item in objects)
  111. {
  112. var prop = modelType.GetProperty(item.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
  113. if (prop != null)
  114. {
  115. prop.SetValue(setObj, item.Value.ConvertObject(prop.PropertyType));
  116. }
  117. }
  118. break;
  119. }
  120. }
  121. targetValue = setObj;
  122. }
  123. }
  124. }
  125. if (targetValue == null && attr.DefaultValue != null)
  126. {
  127. targetValue = attr.DefaultValue.ChangeType(modelType);
  128. }
  129. }
  130. if (targetValue != null)
  131. {
  132. bindingContext.Result = ModelBindingResult.Success(targetValue);
  133. }
  134. return Task.CompletedTask;
  135. }
  136. private static (IEnumerable data, List<string> keys) GetRequestData(ModelBindingContext bindingContext, Type type)
  137. {
  138. var request = bindingContext.HttpContext.Request;
  139. var props = type.GetProperties().Select(t => t.Name).ToList();
  140. var query = props.Except(request.Query.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  141. var headers = props.Except(request.Headers.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  142. var cookies = props.Except(request.Cookies.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  143. var routes = props.Except(bindingContext.ActionContext.RouteData.Values.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  144. var list = new List<KeyValuePair<List<string>, IEnumerable>>()
  145. {
  146. new(query, request.Query),
  147. new(headers, request.Headers),
  148. new(cookies, request.Cookies),
  149. new(routes, bindingContext.ActionContext.RouteData.Values),
  150. };
  151. if (request.HasFormContentType && request.Form.Count > 0)
  152. {
  153. var forms = props.Except(request.Form.Keys, StringComparer.OrdinalIgnoreCase).ToList();
  154. list.Add(new KeyValuePair<List<string>, IEnumerable>(forms, request.Form));
  155. }
  156. var kv = list.OrderBy(t => t.Key.Count).FirstOrDefault();
  157. return (kv.Value, props.Except(kv.Key).ToList());
  158. }
  159. /// <summary>
  160. /// 获取要绑定的值
  161. /// </summary>
  162. /// <param name="bindingContext"></param>
  163. /// <param name="bindType"></param>
  164. /// <param name="fieldName"></param>
  165. /// <param name="modelType"></param>
  166. /// <returns></returns>
  167. private object GetBindingValue(ModelBindingContext bindingContext, BindType bindType, string fieldName, Type modelType)
  168. {
  169. var context = bindingContext.HttpContext;
  170. var mediaType = string.Empty;
  171. if (!string.IsNullOrWhiteSpace(context.Request.ContentType))
  172. {
  173. try
  174. {
  175. var cttType = new ContentType(context.Request.ContentType);
  176. mediaType = cttType.MediaType.ToLower();
  177. }
  178. catch (Exception ex)
  179. {
  180. _logger.LogError(ex, ex.Message, context.Request.ContentType);
  181. }
  182. }
  183. object targetValue = null;
  184. switch (bindType)
  185. {
  186. case BindType.Body:
  187. switch (mediaType)
  188. {
  189. case "application/json":
  190. {
  191. var jsonObj = BodyOrDefaultModelBinderMiddleware.JsonObject;
  192. if (jsonObj != null && jsonObj.TryGetValue(fieldName, StringComparison.OrdinalIgnoreCase, out var values))
  193. {
  194. targetValue = values.ConvertObject(modelType);
  195. }
  196. }
  197. break;
  198. case "application/xml":
  199. //var xmlObj = AutoBinderMiddleware.XmlObject;
  200. //if (xmlObj != null)
  201. //{
  202. // var xmlElt = xmlObj.Element(fieldName);
  203. // if (xmlElt != null)
  204. // {
  205. // }
  206. //}
  207. break;
  208. }
  209. break;
  210. case BindType.Query:
  211. {
  212. if (context.Request.Query != null && context.Request.Query.Count > 0 && context.Request.Query.TryGetValue(fieldName, out var values))
  213. {
  214. targetValue = values.ConvertObject(modelType);
  215. }
  216. }
  217. break;
  218. case BindType.Form:
  219. {
  220. if (context.Request.HasFormContentType && context.Request.Form.Count > 0 && context.Request.Form.TryGetValue(fieldName, out StringValues values))
  221. {
  222. targetValue = values.ConvertObject(modelType);
  223. }
  224. }
  225. break;
  226. case BindType.Header:
  227. {
  228. if (context.Request.Headers != null && context.Request.Headers.Count > 0 && context.Request.Headers.TryGetValue(fieldName, out var values))
  229. {
  230. targetValue = values.ConvertObject(modelType);
  231. }
  232. }
  233. break;
  234. case BindType.Cookie:
  235. {
  236. if (context.Request.Cookies != null && context.Request.Cookies.Count > 0 && context.Request.Cookies.TryGetValue(fieldName, out var values))
  237. {
  238. targetValue = values.ConvertObject(modelType);
  239. }
  240. }
  241. break;
  242. case BindType.Route:
  243. {
  244. if (bindingContext.ActionContext.RouteData.Values != null && bindingContext.ActionContext.RouteData.Values.Count > 0 && bindingContext.ActionContext.RouteData.Values.TryGetValue(fieldName, out var values))
  245. {
  246. targetValue = values.ConvertObject(modelType);
  247. }
  248. }
  249. break;
  250. }
  251. return targetValue;
  252. }
  253. }