WeightedSelector.cs 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. namespace Masuit.Tools.RandomSelector
  5. {
  6. /// <summary>
  7. /// 权重筛选器
  8. /// </summary>
  9. /// <typeparam name="T"></typeparam>
  10. public class WeightedSelector<T> : IEnumerable<T>
  11. {
  12. internal readonly List<WeightedItem<T>> Items = new List<WeightedItem<T>>();
  13. internal readonly SelectorOption Option;
  14. /// <summary>
  15. /// 累计权重集
  16. /// </summary>
  17. internal int[] CumulativeWeights;
  18. /// <summary>
  19. /// 是否是已经添加过的权重值
  20. /// </summary>
  21. private bool _isAddedCumulativeWeights;
  22. public WeightedSelector(SelectorOption option = null)
  23. {
  24. Option = option ?? new SelectorOption();
  25. }
  26. public WeightedSelector(List<WeightedItem<T>> items, SelectorOption option = null) : this(option)
  27. {
  28. Add(items);
  29. }
  30. public WeightedSelector(IEnumerable<WeightedItem<T>> items, SelectorOption option = null) : this(option)
  31. {
  32. Add(items);
  33. }
  34. /// <summary>
  35. /// 添加元素
  36. /// </summary>
  37. /// <param name="item"></param>
  38. public void Add(WeightedItem<T> item)
  39. {
  40. if (item.Weight <= 0)
  41. {
  42. if (Option.RemoveZeroWeightItems)
  43. {
  44. return;
  45. }
  46. throw new InvalidOperationException("权重值不能为0");
  47. }
  48. _isAddedCumulativeWeights = true;
  49. Items.Add(item);
  50. }
  51. /// <summary>
  52. /// 批量添加元素
  53. /// </summary>
  54. /// <param name="items"></param>
  55. public void Add(IEnumerable<WeightedItem<T>> items)
  56. {
  57. foreach (var item in items)
  58. {
  59. Add(item);
  60. }
  61. }
  62. /// <summary>
  63. /// 添加元素
  64. /// </summary>
  65. /// <param name="item"></param>
  66. /// <param name="weight"></param>
  67. public void Add(T item, int weight)
  68. {
  69. Add(new WeightedItem<T>(item, weight));
  70. }
  71. /// <summary>
  72. /// 移除元素
  73. /// </summary>
  74. /// <param name="item"></param>
  75. public void Remove(WeightedItem<T> item)
  76. {
  77. _isAddedCumulativeWeights = true;
  78. Items.Remove(item);
  79. }
  80. /// <summary>
  81. /// 执行权重筛选,取一个元素
  82. /// </summary>
  83. public T Select()
  84. {
  85. CalculateCumulativeWeights();
  86. var selector = new SingleSelector<T>(this);
  87. return selector.Select();
  88. }
  89. /// <summary>
  90. /// 执行权重筛选,取多个元素
  91. /// </summary>
  92. public List<T> SelectMultiple(int count)
  93. {
  94. CalculateCumulativeWeights();
  95. var selector = new MultipleSelector<T>(this);
  96. return selector.Select(count);
  97. }
  98. /// <summary>
  99. /// 计算累计权重
  100. /// </summary>
  101. private void CalculateCumulativeWeights()
  102. {
  103. if (!_isAddedCumulativeWeights) //如果没有被添加,则跳过
  104. {
  105. return;
  106. }
  107. _isAddedCumulativeWeights = false;
  108. CumulativeWeights = GetCumulativeWeights(Items);
  109. }
  110. /// <summary>
  111. /// 计算累计权重
  112. /// </summary>
  113. /// <typeparam name="T"></typeparam>
  114. /// <param name="items"></param>
  115. /// <returns></returns>
  116. public static int[] GetCumulativeWeights(List<WeightedItem<T>> items)
  117. {
  118. int totalWeight = 0;
  119. int index = 0;
  120. var results = new int[items.Count + 1];
  121. foreach (var item in items)
  122. {
  123. totalWeight += item.Weight;
  124. results[index] = totalWeight;
  125. index++;
  126. }
  127. return results;
  128. }
  129. public IEnumerator<T> GetEnumerator()
  130. {
  131. return Items.GetEnumerator() as IEnumerator<T>;
  132. }
  133. IEnumerator IEnumerable.GetEnumerator()
  134. {
  135. return GetEnumerator();
  136. }
  137. }
  138. }