Bladeren bron

Nullkiller: rework prioritization, add hero roles, skills and other variables

Andrii Danylchenko 4 jaren geleden
bovenliggende
commit
5fe2630c64

+ 16 - 1
AI/Nullkiller/AIhelper.cpp

@@ -197,7 +197,22 @@ int AIhelper::selectBestSkill(const HeroPtr & hero, const std::vector<SecondaryS
 	return heroManager->selectBestSkill(hero, skills);
 }
 
-std::map<HeroPtr, HeroRole> AIhelper::getHeroRoles() const
+const std::map<HeroPtr, HeroRole> & AIhelper::getHeroRoles() const
 {
 	return heroManager->getHeroRoles();
+}
+
+HeroRole AIhelper::getHeroRole(const HeroPtr & hero) const
+{
+	return heroManager->getHeroRole(hero);
+}
+
+void AIhelper::updateHeroRoles()
+{
+	heroManager->updateHeroRoles();
+}
+
+float AIhelper::evaluateSecSkill(SecondarySkill skill, const CGHeroInstance * hero) const
+{
+	return heroManager->evaluateSecSkill(skill, hero);
 }

+ 4 - 1
AI/Nullkiller/AIhelper.h

@@ -80,8 +80,11 @@ public:
 	std::vector<SlotInfo>::iterator getWeakestCreature(std::vector<SlotInfo> & army) const override;
 	std::vector<SlotInfo> getSortedSlots(const CCreatureSet * target, const CCreatureSet * source) const override;
 
-	std::map<HeroPtr, HeroRole> getHeroRoles() const override;
+	const std::map<HeroPtr, HeroRole> & getHeroRoles() const override;
+	HeroRole getHeroRole(const HeroPtr & hero) const override;
 	int selectBestSkill(const HeroPtr & hero, const std::vector<SecondarySkill> & skills) const override;
+	void updateHeroRoles() override;
+	float evaluateSecSkill(SecondarySkill skill, const CGHeroInstance * hero) const override;
 
 private:
 	bool notifyGoalCompleted(Goals::TSubgoal goal) override;

+ 2 - 0
AI/Nullkiller/Behaviors/CaptureObjectsBehavior.cpp

@@ -38,6 +38,8 @@ Goals::TGoalVec CaptureObjectsBehavior::getTasks()
 			return;
 		}
 
+		logAi->trace("Scanning objects, count %d", objs.size());
+
 		for(auto objToVisit : objs)
 		{			
 #ifdef VCMI_TRACE_PATHFINDER

+ 4 - 0
AI/Nullkiller/Engine/Nullkiller.cpp

@@ -38,6 +38,8 @@ Goals::TSubgoal Nullkiller::choseBestTask(std::shared_ptr<Behavior> behavior) co
 		return Goals::sptr(Goals::Invalid());
 	}
 
+	logAi->trace("Evaluating priorities, tasks count %d", tasks.size());
+
 	for(auto task : tasks)
 	{
 		task->setpriority(priorityEvaluator->evaluate(task));
@@ -59,6 +61,7 @@ void Nullkiller::resetAiState()
 
 void Nullkiller::updateAiState()
 {
+	// TODO: move to hero manager
 	auto activeHeroes = ai->getMyHeroes();
 
 	vstd::erase_if(activeHeroes, [&](const HeroPtr & hero) -> bool{
@@ -66,6 +69,7 @@ void Nullkiller::updateAiState()
 	});
 
 	ai->ah->updatePaths(activeHeroes, true);
+	ai->ah->updateHeroRoles();
 }
 
 void Nullkiller::makeTurn()

+ 81 - 129
AI/Nullkiller/Engine/PriorityEvaluator.cpp

@@ -18,7 +18,9 @@
 #include "../../../lib/CGameStateFwd.h"
 #include "../../../lib/VCMI_Lib.h"
 #include "../../../CCallback.h"
+#include "../../../lib/filesystem/Filesystem.h"
 #include "../VCAI.h"
+#include "../AIhelper.h"
 
 #define MIN_AI_STRENGHT (0.5f) //lower when combat AI gets smarter
 #define UNGUARDED_OBJECT (100.0f) //we consider unguarded objects 100 times weaker than us
@@ -35,131 +37,28 @@ extern boost::thread_specific_ptr<VCAI> ai;
 PriorityEvaluator::PriorityEvaluator()
 {
 	initVisitTile();
-	configure();
 }
 
 PriorityEvaluator::~PriorityEvaluator()
 {
-	engine.removeRuleBlock(0);
-	engine.~Engine();
+	delete engine;
 }
 
 void PriorityEvaluator::initVisitTile()
 {
-	try
-	{
-		armyLossPersentageVariable = new fl::InputVariable("armyLoss");
-		armyStrengthVariable = new fl::InputVariable("armyStrength");
-		dangerVariable = new fl::InputVariable("danger");
-		turnDistanceVariable = new fl::InputVariable("turnDistance");
-		goldRewardVariable = new fl::InputVariable("goldReward");
-		armyRewardVariable = new fl::InputVariable("armyReward");
-
-		value = new fl::OutputVariable("Value");
-		value->setMinimum(0);
-		value->setMaximum(1);
-		value->setAggregation(new fl::AlgebraicSum());
-		value->setDefuzzifier(new fl::Centroid(100));
-		value->setDefaultValue(0.500);
-
-		rules.setConjunction(new fl::AlgebraicProduct());
-		rules.setDisjunction(new fl::AlgebraicSum());
-		rules.setImplication(new fl::AlgebraicProduct());
-		rules.setActivation(new fl::General());
-
-		std::vector<fl::InputVariable *> helper = {
-			armyLossPersentageVariable,
-			armyStrengthVariable,
-			turnDistanceVariable,
-			goldRewardVariable,
-			armyRewardVariable,
-			dangerVariable };
-
-		for(auto val : helper)
-		{
-			engine.addInputVariable(val);
-		}
-		engine.addOutputVariable(value);
-
-		armyLossPersentageVariable->addTerm(new fl::Ramp("LOW", 0.200, 0.000));
-		armyLossPersentageVariable->addTerm(new fl::Ramp("HIGH", 0.200, 0.500));
-		armyLossPersentageVariable->setRange(0, 1);
-
-		//strength compared to our main hero
-		armyStrengthVariable->addTerm(new fl::Ramp("LOW", 0.2, 0));
-		armyStrengthVariable->addTerm(new fl::Triangle("MEDIUM", 0.2, 0.8));
-		armyStrengthVariable->addTerm(new fl::Ramp("HIGH", 0.5, 1));
-		armyStrengthVariable->setRange(0.0, 1.0);
-
-		turnDistanceVariable->addTerm(new fl::Ramp("SMALL", 1.000, 0.000));
-		turnDistanceVariable->addTerm(new fl::Triangle("MEDIUM", 0.000, 1.000, 2.000));
-		turnDistanceVariable->addTerm(new fl::Ramp("LONG", 1.000, 3.000));
-		turnDistanceVariable->setRange(0, 3);
-
-		goldRewardVariable->addTerm(new fl::Ramp("LOW", 2000.000, 0.000));
-		goldRewardVariable->addTerm(new fl::Triangle("MEDIUM", 0.000, 2000.000, 3500.000));
-		goldRewardVariable->addTerm(new fl::Ramp("HIGH", 2000.000, 5000.000));
-		goldRewardVariable->setRange(0.0, 5000.0);
-
-		armyRewardVariable->addTerm(new fl::Ramp("LOW", 0.300, 0.000));
-		armyRewardVariable->addTerm(new fl::Triangle("MEDIUM", 0.100, 0.400, 0.800));
-		armyRewardVariable->addTerm(new fl::Ramp("HIGH", 0.400, 1.000));
-		armyRewardVariable->setRange(0.0, 1.0);
-
-		dangerVariable->addTerm(new fl::Ramp("NONE", 50, 0));
-		dangerVariable->addTerm(new fl::Ramp("HIGH", 50, 10000));
-
-		value->addTerm(new fl::Ramp("LOWEST", 0.150, 0.000));
-		value->addTerm(new fl::Triangle("LOW", 0.100, 0.100, 0.250, 0.500));
-		value->addTerm(new fl::Triangle("BITLOW", 0.200, 0.200, 0.350, 0.250));
-		value->addTerm(new fl::Triangle("MEDIUM", 0.300, 0.500, 0.700, 0.050));
-		value->addTerm(new fl::Triangle("BITHIGH", 0.650, 0.800, 0.800, 0.250));
-		value->addTerm(new fl::Triangle("HIGH", 0.750, 0.900, 0.900, 0.500));
-		value->addTerm(new fl::Ramp("HIGHEST", 0.850, 1.000));
-		value->setRange(0.0, 1.0);
-
-		//we may want to use secondary hero(es) rather than main hero
-
-		//do not cancel important goals
-		//addRule("if lockedMissionImportance is HIGH then Value is very LOW");
-		//addRule("if lockedMissionImportance is MEDIUM then Value is somewhat LOW");
-		//addRule("if lockedMissionImportance is LOW then Value is HIGH");
-
-		//pick nearby objects if it's easy, avoid long walks
-		/*addRule("if turnDistance is SMALL then Value is somewhat HIGH");
-		addRule("if turnDistance is MEDIUM then Value is MEDIUM");
-		addRule("if turnDistance is LONG then Value is LOW");*/
-
-		//some goals are more rewarding by definition f.e. capturing town is more important than collecting resource - experimental
-		addRule("if turnDistance is MEDIUM then Value is MEDIUM");
-		addRule("if turnDistance is SMALL then Value is BITHIGH");
-		addRule("if turnDistance is LONG then Value is BITLOW");
-		addRule("if turnDistance is very LONG then Value is LOW");
-		addRule("if goldReward is LOW then Value is MEDIUM");
-		addRule("if goldReward is MEDIUM and armyLoss is LOW then Value is BITHIGH");
-		addRule("if goldReward is HIGH and armyLoss is LOW then Value is HIGH");
-		addRule("if armyReward is LOW then Value is MEDIUM");
-		addRule("if armyReward is MEDIUM and armyLoss is LOW then Value is BITHIGH");
-		addRule("if armyReward is HIGH and armyLoss is LOW then Value is HIGHEST with 0.5");
-		addRule("if armyReward is HIGH and goldReward is HIGH and armyLoss is LOW then Value is HIGHEST");
-		addRule("if armyReward is HIGH and goldReward is MEDIUM and armyLoss is LOW then Value is HIGHEST with 0.8");
-		addRule("if armyReward is MEDIUM and goldReward is HIGH and armyLoss is LOW then Value is HIGHEST with 0.5");
-		addRule("if armyReward is MEDIUM and goldReward is MEDIUM and armyLoss is LOW then Value is HIGH");
-		addRule("if armyReward is HIGH and turnDistance is SMALL and armyLoss is LOW then Value is HIGHEST");
-		addRule("if goldReward is HIGH and turnDistance is SMALL and armyLoss is LOW then Value is HIGHEST");
-		addRule("if armyReward is MEDIUM and turnDistance is SMALL and armyLoss is LOW then Value is HIGH");
-		addRule("if goldReward is MEDIUM and turnDistance is SMALL and armyLoss is LOW then Value is BITHIGH");
-		addRule("if goldReward is LOW and armyReward is LOW and turnDistance is not SMALL then Value is LOWEST");
-		addRule("if armyLoss is HIGH then Value is LOWEST");
-		addRule("if armyLoss is LOW then Value is MEDIUM");
-		addRule("if armyReward is LOW and turnDistance is LONG then Value is LOWEST");
-		addRule("if danger is NONE and armyStrength is HIGH and armyReward is LOW then Value is LOW");
-		addRule("if danger is NONE and armyStrength is HIGH and armyReward is MEDIUM then Value is BITLOW");
-	}
-	catch(fl::Exception & fe)
-	{
-		logAi->error("visitTile: %s", fe.getWhat());
-	}
+	auto file = CResourceHandler::get("initial")->load(ResourceID("config/ai-priorities.txt"))->readAll();
+	std::string str = std::string((char *)file.first.get(), file.second);
+	engine = fl::FllImporter().fromString(str);
+	armyLossPersentageVariable = engine->getInputVariable("armyLoss");
+	heroRoleVariable = engine->getInputVariable("heroRole");
+	dangerVariable = engine->getInputVariable("danger");
+	turnDistanceVariable = engine->getInputVariable("turnDistance");
+	goldRewardVariable = engine->getInputVariable("goldReward");
+	armyRewardVariable = engine->getInputVariable("armyReward");
+	skillRewardVariable = engine->getInputVariable("skillReward");
+	rewardTypeVariable = engine->getInputVariable("rewardType");
+	closestHeroRatioVariable = engine->getInputVariable("closestHeroRatio");
+	value = engine->getOutputVariable("Value");
 }
 
 int32_t estimateTownIncome(const CGObjectInstance * target, const CGHeroInstance * hero)
@@ -225,8 +124,6 @@ uint64_t getArmyReward(const CGObjectInstance * target, const CGHeroInstance * h
 	if(!target)
 		return 0;
 
-	const int dailyIncomeMultiplier = 5;
-
 	switch(target->ID)
 	{
 	case Obj::TOWN:
@@ -234,7 +131,7 @@ uint64_t getArmyReward(const CGObjectInstance * target, const CGHeroInstance * h
 	case Obj::CREATURE_BANK:
 		return getCreatureBankArmyReward(target, hero);
 	case Obj::CREATURE_GENERATOR1:
-		return getDwellingScore(target) * dailyIncomeMultiplier;
+		return getDwellingScore(target);
 	case Obj::CRYPT:
 	case Obj::SHIPWRECK:
 	case Obj::SHIPWRECK_SURVIVOR:
@@ -248,6 +145,51 @@ uint64_t getArmyReward(const CGObjectInstance * target, const CGHeroInstance * h
 	}
 }
 
+float evaluateWitchHutSkillScore(const CGWitchHut * hut, const CGHeroInstance * hero, HeroRole role)
+{
+	if(!hut->wasVisited(hero->tempOwner))
+		return role == HeroRole::SCOUT ? 2 : 0;
+
+	auto skill = SecondarySkill(hut->ability);
+
+	if(hero->getSecSkillLevel(skill) != SecSkillLevel::NONE
+		|| hero->secSkills.size() >= GameConstants::SKILL_PER_HERO)
+		return 0;
+
+	auto score = ai->ah->evaluateSecSkill(skill, hero);
+
+	return score >= 2 ? (role == HeroRole::MAIN ? 10 : 4) : score;
+}
+
+float getSkillReward(const CGObjectInstance * target, const CGHeroInstance * hero, HeroRole role)
+{
+	if(!target)
+		return 0;
+
+	switch(target->ID)
+	{
+	case Obj::STAR_AXIS:
+	case Obj::SCHOLAR:
+	case Obj::SCHOOL_OF_MAGIC:
+	case Obj::SCHOOL_OF_WAR:
+	case Obj::GARDEN_OF_REVELATION:
+	case Obj::MARLETTO_TOWER:
+	case Obj::MERCENARY_CAMP:
+	case Obj::SHRINE_OF_MAGIC_GESTURE:
+		return 1;
+	case Obj::ARENA:
+	case Obj::SHRINE_OF_MAGIC_INCANTATION:
+	case Obj::SHRINE_OF_MAGIC_THOUGHT:
+		return 2;
+	case Obj::LIBRARY_OF_ENLIGHTENMENT:
+		return 8;
+	case Obj::WITCH_HUT:
+		return evaluateWitchHutSkillScore(dynamic_cast<const CGWitchHut *>(target), hero, role);
+	default:
+		return 0;
+	}
+}
+
 /// Gets aproximated reward in gold. Daily income is multiplied by 5
 int32_t getGoldReward(const CGObjectInstance * target, const CGHeroInstance * hero)
 {
@@ -260,7 +202,7 @@ int32_t getGoldReward(const CGObjectInstance * target, const CGHeroInstance * he
 	switch(target->ID)
 	{
 	case Obj::RESOURCE:
-		return isGold ? 800 : 100;
+		return isGold ? 600 : 100;
 	case Obj::TREASURE_CHEST:
 		return 1500;
 	case Obj::WATER_WHEEL:
@@ -274,7 +216,7 @@ int32_t getGoldReward(const CGObjectInstance * target, const CGHeroInstance * he
 	case Obj::WINDMILL:
 		return 100;
 	case Obj::CAMPFIRE:
-		return 900;
+		return 800;
 	case Obj::CREATURE_BANK:
 		return getCreatureBankResources(target, hero)[Res::GOLD];
 	case Obj::CRYPT:
@@ -310,27 +252,36 @@ float PriorityEvaluator::evaluate(Goals::TSubgoal task)
 
 	const CGObjectInstance * target = cb->getObj((ObjectInstanceID)objId, false);
 	
+	auto day = cb->getDate(Date::DAY);
 	auto hero = heroPtr.get();
 	auto armyTotal = task->evaluationContext.heroStrength;
 	double armyLossPersentage = task->evaluationContext.armyLoss / (double)armyTotal;
 	int32_t goldReward = getGoldReward(target, hero);
 	uint64_t armyReward = getArmyReward(target, hero);
+	HeroRole heroRole = ai->ah->getHeroRole(heroPtr);
+	float skillReward = getSkillReward(target, hero, heroRole);
 	uint64_t danger = task->evaluationContext.danger;
-	float armyStrength = (fl::scalar)hero->getArmyStrength() / ai->primaryHero()->getArmyStrength();
 	double result = 0;
+	int rewardType = (goldReward > 0 ? 1 : 0) + (armyReward > 0 ? 1 : 0) + (skillReward > 0 ? 1 : 0);
+
+	if(day == 1)
+		goldReward *= 2;
 
 	try
 	{
 		armyLossPersentageVariable->setValue(armyLossPersentage);
-		armyStrengthVariable->setValue(armyStrength);
+		heroRoleVariable->setValue(heroRole);
 		turnDistanceVariable->setValue(task->evaluationContext.movementCost);
 		goldRewardVariable->setValue(goldReward);
-		armyRewardVariable->setValue(armyReward / 10000.0);
+		armyRewardVariable->setValue(armyReward);
+		skillRewardVariable->setValue(skillReward);
 		dangerVariable->setValue(danger);
+		rewardTypeVariable->setValue(rewardType);
+		closestHeroRatioVariable->setValue(task->evaluationContext.closestWayRatio);
 
-		engine.process();
+		engine->process();
 		//engine.process(VISIT_TILE); //TODO: Process only Visit_Tile
-		result = value->getValue() / task->evaluationContext.closestWayRatio;
+		result = value->getValue();
 	}
 	catch(fl::Exception & fe)
 	{
@@ -339,14 +290,15 @@ float PriorityEvaluator::evaluate(Goals::TSubgoal task)
 	assert(result >= 0);
 
 #ifdef VCMI_TRACE_PATHFINDER
-	logAi->trace("Evaluated %s, loss: %f, turns: %f, gold: %d, army gain: %d, danger: %d, army strength: %f%%, result %f",
+	logAi->trace("Evaluated %s, hero %s, loss: %f, turns: %f, gold: %d, army gain: %d, danger: %d, role: %s, result %f",
 		task->name(),
+		hero->name,
 		armyLossPersentage,
 		task->evaluationContext.movementCost,
 		goldReward,
 		armyReward,
 		danger,
-		(int)(armyStrength * 100),
+		heroRole ? "scout" : "main",
 		result);
 #endif
 

+ 6 - 2
AI/Nullkiller/Engine/PriorityEvaluator.h

@@ -17,7 +17,7 @@ class CArmedInstance;
 class CBank;
 struct SectorMap;
 
-class PriorityEvaluator : public engineBase
+class PriorityEvaluator
 {
 public:
 	PriorityEvaluator();
@@ -27,11 +27,15 @@ public:
 	float evaluate(Goals::TSubgoal task);
 
 private:
+	fl::Engine * engine;
 	fl::InputVariable * armyLossPersentageVariable;
-	fl::InputVariable * armyStrengthVariable;
+	fl::InputVariable * heroRoleVariable;
 	fl::InputVariable * turnDistanceVariable;
 	fl::InputVariable * goldRewardVariable;
 	fl::InputVariable * armyRewardVariable;
 	fl::InputVariable * dangerVariable;
+	fl::InputVariable * skillRewardVariable;
+	fl::InputVariable * rewardTypeVariable;
+	fl::InputVariable * closestHeroRatioVariable;
 	fl::OutputVariable * value;
 };

+ 5 - 2
AI/Nullkiller/Goals/ExecuteHeroChain.cpp

@@ -78,9 +78,12 @@ void ExecuteHeroChain::accept(VCAI * ai)
 
 		try
 		{
-			ai->nullkiller->setActive(hero);
+			if(hero->movement)
+			{
+				ai->nullkiller->setActive(hero);
 
-			Goals::VisitTile(node.coord).sethero(hero).accept(ai);
+				Goals::VisitTile(node.coord).sethero(hero).accept(ai);
+			}
 
 			// no exception means we were not able to rich the tile
 			ai->nullkiller->lockHero(hero);

+ 25 - 8
AI/Nullkiller/HeroManager.cpp

@@ -67,6 +67,16 @@ void HeroManager::setAI(VCAI * AI)
 	ai = AI;
 }
 
+float HeroManager::evaluateSecSkill(SecondarySkill skill, const CGHeroInstance * hero) const
+{
+	auto role = getHeroRole(hero);
+
+	if(role == HeroRole::MAIN)
+		return wariorSkillsScores.evaluateSecSkill(hero, skill);
+
+	return scountSkillsScores.evaluateSecSkill(hero, skill);
+}
+
 float HeroManager::evaluateSpeciality(const CGHeroInstance * hero) const
 {
 	auto heroSpecial = Selector::source(Bonus::HERO_SPECIAL, hero->type->ID.getNum());
@@ -91,10 +101,9 @@ float HeroManager::evaluateFightingStrength(const CGHeroInstance * hero) const
 	return evaluateSpeciality(hero) + wariorSkillsScores.evaluateSecSkills(hero) + hero->level * 1.5f;
 }
 
-std::map<HeroPtr, HeroRole> HeroManager::getHeroRoles() const
+void HeroManager::updateHeroRoles()
 {
 	std::map<HeroPtr, float> scores;
-	std::map<HeroPtr, HeroRole> result;
 	auto myHeroes = ai->getMyHeroes();
 
 	for(auto & hero : myHeroes)
@@ -104,23 +113,31 @@ std::map<HeroPtr, HeroRole> HeroManager::getHeroRoles() const
 
 	std::sort(myHeroes.begin(), myHeroes.end(), [&](const HeroPtr & h1, const HeroPtr & h2) -> bool
 	{
-		return scores.at(h1) < scores.at(h2);
+		return scores.at(h1) > scores.at(h2);
 	});
 
-	int mainHeroCount = 4;
+	int mainHeroCount = (myHeroes.size() + 2) / 3;
 
 	for(auto & hero : myHeroes)
 	{
-		result[hero] = (mainHeroCount--) > 0 ? HeroRole::MAIN : HeroRole::SCOUT;
+		heroRoles[hero] = (mainHeroCount--) > 0 ? HeroRole::MAIN : HeroRole::SCOUT;
+		logAi->trace("Hero %s has role %s", hero.name, heroRoles[hero] == HeroRole::MAIN ? "main" : "scout");
 	}
+}
 
-	return result;
+HeroRole HeroManager::getHeroRole(const HeroPtr & hero) const
+{
+	return heroRoles.at(hero);
+}
+
+const std::map<HeroPtr, HeroRole> & HeroManager::getHeroRoles() const
+{
+	return heroRoles;
 }
 
 int HeroManager::selectBestSkill(const HeroPtr & hero, const std::vector<SecondarySkill> & skills) const
 {
-	auto roles = getHeroRoles();
-	auto role = roles[hero];
+	auto role = getHeroRole(hero);
 	auto & evaluator = role == HeroRole::MAIN ? wariorSkillsScores : scountSkillsScores;
 
 	int result = 0;

+ 9 - 2
AI/Nullkiller/HeroManager.h

@@ -30,8 +30,11 @@ class DLL_EXPORT IHeroManager //: public: IAbstractManager
 public:
 	virtual void init(CPlayerSpecificInfoCallback * CB) = 0;
 	virtual void setAI(VCAI * AI) = 0;
-	virtual std::map<HeroPtr, HeroRole> getHeroRoles() const = 0;
+	virtual const std::map<HeroPtr, HeroRole> & getHeroRoles() const = 0;
 	virtual int selectBestSkill(const HeroPtr & hero, const std::vector<SecondarySkill> & skills) const = 0;
+	virtual HeroRole getHeroRole(const HeroPtr & hero) const = 0;
+	virtual void updateHeroRoles() = 0;
+	virtual float evaluateSecSkill(SecondarySkill skill, const CGHeroInstance * hero) const = 0;
 };
 
 class DLL_EXPORT ISecondarySkillRule
@@ -59,12 +62,16 @@ private:
 
 	CPlayerSpecificInfoCallback * cb; //this is enough, but we downcast from CCallback
 	VCAI * ai;
+	std::map<HeroPtr, HeroRole> heroRoles;
 
 public:
 	void init(CPlayerSpecificInfoCallback * CB) override;
 	void setAI(VCAI * AI) override;
-	std::map<HeroPtr, HeroRole> getHeroRoles() const override;
+	const std::map<HeroPtr, HeroRole> & getHeroRoles() const override;
+	HeroRole getHeroRole(const HeroPtr & hero) const override;
 	int selectBestSkill(const HeroPtr & hero, const std::vector<SecondarySkill> & skills) const override;
+	void updateHeroRoles() override;
+	float evaluateSecSkill(SecondarySkill skill, const CGHeroInstance * hero) const override;
 
 private:
 	float evaluateFightingStrength(const CGHeroInstance * hero) const;

+ 4 - 1
AI/Nullkiller/Pathfinding/AINodeStorage.cpp

@@ -330,7 +330,10 @@ void AINodeStorage::calculateHeroChain(
 	AIPathNode * other, 
 	std::vector<ExchangeCandidate> & result) const
 {	
-	if(carrier->actor->canExchange(other->actor))
+	if(carrier->armyLoss < carrier->actor->armyValue
+		&& (carrier->action != CGPathNode::BATTLE || carrier->actor->allowBattle && carrier->specialAction)
+		&& other->armyLoss < other->actor->armyValue
+		&& carrier->actor->canExchange(other->actor))
 	{
 #if VCMI_TRACE_PATHFINDER >= 2
 		logAi->trace(