using System;
using System.Collections;
using System.Collections.Generic;
namespace Masuit.Tools.RandomSelector
{
///
/// 权重筛选器
///
///
public class WeightedSelector : IEnumerable
{
internal readonly List> Items = new List>();
internal readonly SelectorOption Option;
///
/// 累计权重集
///
internal int[] CumulativeWeights;
///
/// 是否是已经添加过的权重值
///
private bool _isAddedCumulativeWeights;
public WeightedSelector(SelectorOption option = null)
{
Option = option ?? new SelectorOption();
}
public WeightedSelector(List> items, SelectorOption option = null) : this(option)
{
Add(items);
}
public WeightedSelector(IEnumerable> items, SelectorOption option = null) : this(option)
{
Add(items);
}
///
/// 添加元素
///
///
public void Add(WeightedItem item)
{
if (item.Weight <= 0)
{
if (Option.RemoveZeroWeightItems)
{
return;
}
throw new InvalidOperationException("权重值不能为0");
}
_isAddedCumulativeWeights = true;
Items.Add(item);
}
///
/// 批量添加元素
///
///
public void Add(IEnumerable> items)
{
foreach (var item in items)
{
Add(item);
}
}
///
/// 添加元素
///
///
///
public void Add(T item, int weight)
{
Add(new WeightedItem(item, weight));
}
///
/// 移除元素
///
///
public void Remove(WeightedItem item)
{
_isAddedCumulativeWeights = true;
Items.Remove(item);
}
///
/// 执行权重筛选,取一个元素
///
public T Select()
{
CalculateCumulativeWeights();
var selector = new SingleSelector(this);
return selector.Select();
}
///
/// 执行权重筛选,取多个元素
///
public List SelectMultiple(int count)
{
CalculateCumulativeWeights();
var selector = new MultipleSelector(this);
return selector.Select(count);
}
///
/// 计算累计权重
///
private void CalculateCumulativeWeights()
{
if (!_isAddedCumulativeWeights) //如果没有被添加,则跳过
{
return;
}
_isAddedCumulativeWeights = false;
CumulativeWeights = GetCumulativeWeights(Items);
}
///
/// 计算累计权重
///
///
///
///
public static int[] GetCumulativeWeights(List> items)
{
int totalWeight = 0;
int index = 0;
var results = new int[items.Count + 1];
foreach (var item in items)
{
totalWeight += item.Weight;
results[index] = totalWeight;
index++;
}
return results;
}
public IEnumerator GetEnumerator()
{
return Items.GetEnumerator() as IEnumerator;
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}