package model import ( "database/sql/driver" "encoding/json" "errors" "fmt" "math/rand" "one-api/common" "one-api/constant" "one-api/dto" "one-api/types" "strings" "sync" "gorm.io/gorm" ) type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` Key string `json:"key" gorm:"not null"` OpenAIOrganization *string `json:"openai_organization"` TestModel *string `json:"test_model"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` Weight *uint `json:"weight" gorm:"default:0"` CreatedTime int64 `json:"created_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"` ResponseTime int `json:"response_time"` // in milliseconds BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` Other string `json:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` Group string `json:"group" gorm:"type:varchar(64);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:text"` //MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"` StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` AutoBan *int `json:"auto_ban" gorm:"default:1"` OtherInfo string `json:"other_info"` Tag *string `json:"tag" gorm:"index"` Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 ParamOverride *string `json:"param_override" gorm:"type:text"` // add after v0.8.5 ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` } type ChannelInfo struct { IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` } // Value implements driver.Valuer interface func (c ChannelInfo) Value() (driver.Value, error) { return common.Marshal(&c) } // Scan implements sql.Scanner interface func (c *ChannelInfo) Scan(value interface{}) error { bytesValue, _ := value.([]byte) return common.Unmarshal(bytesValue, c) } func (channel *Channel) getKeys() []string { if channel.Key == "" { return []string{} } trimmed := strings.TrimSpace(channel.Key) // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios) if strings.HasPrefix(trimmed, "[") { var arr []json.RawMessage if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { res := make([]string, len(arr)) for i, v := range arr { res[i] = string(v) } return res } } // Otherwise, fall back to splitting by newline keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n") return keys } func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { // If not in multi-key mode, return the original key string directly. if !channel.ChannelInfo.IsMultiKey { return channel.Key, 0, nil } // Obtain all keys (split by \n) keys := channel.getKeys() if len(keys) == 0 { // No keys available, return error, should disable the channel return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey) } statusList := channel.ChannelInfo.MultiKeyStatusList // helper to get key status, default to enabled when missing getStatus := func(idx int) int { if statusList == nil { return common.ChannelStatusEnabled } if status, ok := statusList[idx]; ok { return status } return common.ChannelStatusEnabled } // Collect indexes of enabled keys enabledIdx := make([]int, 0, len(keys)) for i := range keys { if getStatus(i) == common.ChannelStatusEnabled { enabledIdx = append(enabledIdx, i) } } // If no specific status list or none enabled, fall back to first key if len(enabledIdx) == 0 { return keys[0], 0, nil } switch channel.ChannelInfo.MultiKeyMode { case constant.MultiKeyModeRandom: // Randomly pick one enabled key selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))] return keys[selectedIdx], selectedIdx, nil case constant.MultiKeyModePolling: // Use channel-specific lock to ensure thread-safe polling lock := getChannelPollingLock(channel.Id) lock.Lock() defer lock.Unlock() channelInfo, err := CacheGetChannelInfo(channel.Id) if err != nil { return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed) } //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex) defer func() { if common.DebugEnabled { println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex)) } if !common.MemoryCacheEnabled { _ = channel.SaveChannelInfo() } else { // CacheUpdateChannel(channel) } }() // Start from the saved polling index and look for the next enabled key start := channelInfo.MultiKeyPollingIndex if start < 0 || start >= len(keys) { start = 0 } for i := 0; i < len(keys); i++ { idx := (start + i) % len(keys) if getStatus(idx) == common.ChannelStatusEnabled { // update polling index for next call (point to the next position) channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys) return keys[idx], idx, nil } } // Fallback – should not happen, but return first enabled key return keys[enabledIdx[0]], enabledIdx[0], nil default: // Unknown mode, default to first enabled key (or original key string) return keys[enabledIdx[0]], enabledIdx[0], nil } } func (channel *Channel) SaveChannelInfo() error { return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error } func (channel *Channel) GetModels() []string { if channel.Models == "" { return []string{} } return strings.Split(strings.Trim(channel.Models, ","), ",") } func (channel *Channel) GetGroups() []string { if channel.Group == "" { return []string{} } groups := strings.Split(strings.Trim(channel.Group, ","), ",") for i, group := range groups { groups[i] = strings.TrimSpace(group) } return groups } func (channel *Channel) GetOtherInfo() map[string]interface{} { otherInfo := make(map[string]interface{}) if channel.OtherInfo != "" { err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { common.SysError("failed to unmarshal other info: " + err.Error()) } } return otherInfo } func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { otherInfoBytes, err := json.Marshal(otherInfo) if err != nil { common.SysError("failed to marshal other info: " + err.Error()) return } channel.OtherInfo = string(otherInfoBytes) } func (channel *Channel) GetTag() string { if channel.Tag == nil { return "" } return *channel.Tag } func (channel *Channel) SetTag(tag string) { channel.Tag = &tag } func (channel *Channel) GetAutoBan() bool { if channel.AutoBan == nil { return false } return *channel.AutoBan == 1 } func (channel *Channel) Save() error { return DB.Save(channel).Error } func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) { var channels []*Channel var err error order := "priority desc" if idSort { order = "id desc" } if selectAll { err = DB.Order(order).Find(&channels).Error } else { err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error } return channels, err } func GetChannelsByTag(tag string, idSort bool) ([]*Channel, error) { var channels []*Channel order := "priority desc" if idSort { order = "id desc" } err := DB.Where("tag = ?", tag).Order(order).Find(&channels).Error return channels, err } func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) { var channels []*Channel modelsCol := "`models`" // 如果是 PostgreSQL,使用双引号 if common.UsingPostgreSQL { modelsCol = `"models"` } baseURLCol := "`base_url`" // 如果是 PostgreSQL,使用双引号 if common.UsingPostgreSQL { baseURLCol = `"base_url"` } order := "priority desc" if idSort { order = "id desc" } // 构造基础查询 baseQuery := DB.Model(&Channel{}).Omit("key") // 构造WHERE子句 var whereClause string var args []interface{} if group != "" && group != "null" { var groupCondition string if common.UsingMySQL { groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?` } else { // sqlite, PostgreSQL groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?` } whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") } else { whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") } // 执行查询 err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error if err != nil { return nil, err } return channels, nil } func GetChannelById(id int, selectAll bool) (*Channel, error) { channel := &Channel{Id: id} var err error = nil if selectAll { err = DB.First(channel, "id = ?", id).Error } else { err = DB.Omit("key").First(channel, "id = ?", id).Error } if err != nil { return nil, err } if channel == nil { return nil, errors.New("channel not found") } return channel, nil } func BatchInsertChannels(channels []Channel) error { var err error err = DB.Create(&channels).Error if err != nil { return err } for _, channel_ := range channels { err = channel_.AddAbilities() if err != nil { return err } } return nil } func BatchDeleteChannels(ids []int) error { //使用事务 删除channel表和channel_ability表 tx := DB.Begin() err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error if err != nil { // 回滚事务 tx.Rollback() return err } err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error if err != nil { // 回滚事务 tx.Rollback() return err } // 提交事务 tx.Commit() return err } func (channel *Channel) GetPriority() int64 { if channel.Priority == nil { return 0 } return *channel.Priority } func (channel *Channel) GetWeight() int { if channel.Weight == nil { return 0 } return int(*channel.Weight) } func (channel *Channel) GetBaseURL() string { if channel.BaseURL == nil { return "" } return *channel.BaseURL } func (channel *Channel) GetModelMapping() string { if channel.ModelMapping == nil { return "" } return *channel.ModelMapping } func (channel *Channel) GetStatusCodeMapping() string { if channel.StatusCodeMapping == nil { return "" } return *channel.StatusCodeMapping } func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error if err != nil { return err } err = channel.AddAbilities() return err } func (channel *Channel) Update() error { // If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys if channel.ChannelInfo.IsMultiKey { var keyStr string if channel.Key != "" { keyStr = channel.Key } else { // If key is not provided, read the existing key from the database if existing, err := GetChannelById(channel.Id, true); err == nil { keyStr = existing.Key } } // Parse the key list (supports newline separation or JSON array) keys := []string{} if keyStr != "" { trimmed := strings.TrimSpace(keyStr) if strings.HasPrefix(trimmed, "[") { var arr []json.RawMessage if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { keys = make([]string, len(arr)) for i, v := range arr { keys[i] = string(v) } } } if len(keys) == 0 { // fallback to newline split keys = strings.Split(strings.Trim(keyStr, "\n"), "\n") } } channel.ChannelInfo.MultiKeySize = len(keys) // Clean up status data that exceeds the new key count to prevent index out of range if channel.ChannelInfo.MultiKeyStatusList != nil { for idx := range channel.ChannelInfo.MultiKeyStatusList { if idx >= channel.ChannelInfo.MultiKeySize { delete(channel.ChannelInfo.MultiKeyStatusList, idx) } } } } var err error err = DB.Model(channel).Updates(channel).Error if err != nil { return err } DB.Model(channel).First(channel, "id = ?", channel.Id) err = channel.UpdateAbilities(nil) return err } func (channel *Channel) UpdateResponseTime(responseTime int64) { err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ TestTime: common.GetTimestamp(), ResponseTime: int(responseTime), }).Error if err != nil { common.SysError("failed to update response time: " + err.Error()) } } func (channel *Channel) UpdateBalance(balance float64) { err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ BalanceUpdatedTime: common.GetTimestamp(), Balance: balance, }).Error if err != nil { common.SysError("failed to update balance: " + err.Error()) } } func (channel *Channel) Delete() error { var err error err = DB.Delete(channel).Error if err != nil { return err } err = channel.DeleteAbilities() return err } var channelStatusLock sync.Mutex // channelPollingLocks stores locks for each channel.id to ensure thread-safe polling var channelPollingLocks sync.Map // getChannelPollingLock returns or creates a mutex for the given channel ID func getChannelPollingLock(channelId int) *sync.Mutex { if lock, exists := channelPollingLocks.Load(channelId); exists { return lock.(*sync.Mutex) } // Create new lock for this channel newLock := &sync.Mutex{} actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock) return actual.(*sync.Mutex) } // CleanupChannelPollingLocks removes locks for channels that no longer exist // This is optional and can be called periodically to prevent memory leaks func CleanupChannelPollingLocks() { var activeChannelIds []int DB.Model(&Channel{}).Pluck("id", &activeChannelIds) activeChannelSet := make(map[int]bool) for _, id := range activeChannelIds { activeChannelSet[id] = true } channelPollingLocks.Range(func(key, value interface{}) bool { channelId := key.(int) if !activeChannelSet[channelId] { channelPollingLocks.Delete(channelId) } return true }) } func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { keys := channel.getKeys() if len(keys) == 0 { channel.Status = status } else { var keyIndex int for i, key := range keys { if key == usingKey { keyIndex = i break } } if channel.ChannelInfo.MultiKeyStatusList == nil { channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) } if status == common.ChannelStatusEnabled { delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex) } else { channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status } if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize { channel.Status = common.ChannelStatusAutoDisabled info := channel.GetOtherInfo() info["status_reason"] = "All keys are disabled" info["status_time"] = common.GetTimestamp() channel.SetOtherInfo(info) } } } func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool { if common.MemoryCacheEnabled { channelStatusLock.Lock() defer channelStatusLock.Unlock() channelCache, _ := CacheGetChannel(channelId) if channelCache == nil { return false } if channelCache.ChannelInfo.IsMultiKey { // 如果是多Key模式,更新缓存中的状态 handlerMultiKeyUpdate(channelCache, usingKey, status) //CacheUpdateChannel(channelCache) //return true } else { // 如果缓存渠道存在,且状态已是目标状态,直接返回 if channelCache.Status == status { return false } // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 if status != common.ChannelStatusEnabled { return false } CacheUpdateChannelStatus(channelId, status) } } shouldUpdateAbilities := false defer func() { if shouldUpdateAbilities { err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) if err != nil { common.SysError("failed to update ability status: " + err.Error()) } } }() channel, err := GetChannelById(channelId, true) if err != nil { return false } else { if channel.Status == status { return false } if channel.ChannelInfo.IsMultiKey { beforeStatus := channel.Status handlerMultiKeyUpdate(channel, usingKey, status) if beforeStatus != channel.Status { shouldUpdateAbilities = true } } else { info := channel.GetOtherInfo() info["status_reason"] = reason info["status_time"] = common.GetTimestamp() channel.SetOtherInfo(info) channel.Status = status shouldUpdateAbilities = true } err = channel.Save() if err != nil { common.SysError("failed to update channel status: " + err.Error()) return false } } return true } func EnableChannelByTag(tag string) error { err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error if err != nil { return err } err = UpdateAbilityStatusByTag(tag, true) return err } func DisableChannelByTag(tag string) error { err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error if err != nil { return err } err = UpdateAbilityStatusByTag(tag, false) return err } func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error { updateData := Channel{} shouldReCreateAbilities := false updatedTag := tag // 如果 newTag 不为空且不等于 tag,则更新 tag if newTag != nil && *newTag != tag { updateData.Tag = newTag updatedTag = *newTag } if modelMapping != nil && *modelMapping != "" { updateData.ModelMapping = modelMapping } if models != nil && *models != "" { shouldReCreateAbilities = true updateData.Models = *models } if group != nil && *group != "" { shouldReCreateAbilities = true updateData.Group = *group } if priority != nil { updateData.Priority = priority } if weight != nil { updateData.Weight = weight } err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error if err != nil { return err } if shouldReCreateAbilities { channels, err := GetChannelsByTag(updatedTag, false) if err == nil { for _, channel := range channels { err = channel.UpdateAbilities(nil) if err != nil { common.SysError("failed to update abilities: " + err.Error()) } } } } else { err := UpdateAbilityByTag(tag, newTag, priority, weight) if err != nil { return err } } return nil } func UpdateChannelUsedQuota(id int, quota int) { if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return } updateChannelUsedQuota(id, quota) } func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { common.SysError("failed to update channel used quota: " + err.Error()) } } func DeleteChannelByStatus(status int64) (int64, error) { result := DB.Where("status = ?", status).Delete(&Channel{}) return result.RowsAffected, result.Error } func DeleteDisabledChannel() (int64, error) { result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } func GetPaginatedTags(offset int, limit int) ([]*string, error) { var tags []*string err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error return tags, err } func SearchTags(keyword string, group string, model string, idSort bool) ([]*string, error) { var tags []*string modelsCol := "`models`" // 如果是 PostgreSQL,使用双引号 if common.UsingPostgreSQL { modelsCol = `"models"` } baseURLCol := "`base_url`" // 如果是 PostgreSQL,使用双引号 if common.UsingPostgreSQL { baseURLCol = `"base_url"` } order := "priority desc" if idSort { order = "id desc" } // 构造基础查询 baseQuery := DB.Model(&Channel{}).Omit("key") // 构造WHERE子句 var whereClause string var args []interface{} if group != "" && group != "null" { var groupCondition string if common.UsingMySQL { groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?` } else { // sqlite, PostgreSQL groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?` } whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") } else { whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") } subQuery := baseQuery.Where(whereClause, args...). Select("tag"). Where("tag != ''"). Order(order) err := DB.Table("(?) as sub", subQuery). Select("DISTINCT tag"). Find(&tags).Error if err != nil { return nil, err } return tags, nil } func (channel *Channel) ValidateSettings() error { channelParams := &dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { err := json.Unmarshal([]byte(*channel.Setting), channelParams) if err != nil { return err } } return nil } func (channel *Channel) GetSetting() dto.ChannelSettings { setting := dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { err := json.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { common.SysError("failed to unmarshal setting: " + err.Error()) channel.Setting = nil // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } } return setting } func (channel *Channel) SetSetting(setting dto.ChannelSettings) { settingBytes, err := json.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) return } channel.Setting = common.GetPointer[string](string(settingBytes)) } func (channel *Channel) GetParamOverride() map[string]interface{} { paramOverride := make(map[string]interface{}) if channel.ParamOverride != nil && *channel.ParamOverride != "" { err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { common.SysError("failed to unmarshal param override: " + err.Error()) } } return paramOverride } func GetChannelsByIds(ids []int) ([]*Channel, error) { var channels []*Channel err := DB.Where("id in (?)", ids).Find(&channels).Error return channels, err } func BatchSetChannelTag(ids []int, tag *string) error { // 开启事务 tx := DB.Begin() if tx.Error != nil { return tx.Error } // 更新标签 err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error if err != nil { tx.Rollback() return err } // update ability status channels, err := GetChannelsByIds(ids) if err != nil { tx.Rollback() return err } for _, channel := range channels { err = channel.UpdateAbilities(tx) if err != nil { tx.Rollback() return err } } // 提交事务 return tx.Commit().Error } // CountAllChannels returns total channels in DB func CountAllChannels() (int64, error) { var total int64 err := DB.Model(&Channel{}).Count(&total).Error return total, err } // CountAllTags returns number of non-empty distinct tags func CountAllTags() (int64, error) { var total int64 err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error return total, err } // Get channels of specified type with pagination func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) { var channels []*Channel order := "priority desc" if idSort { order = "id desc" } err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error return channels, err } // Count channels of specific type func CountChannelsByType(channelType int) (int64, error) { var count int64 err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error return count, err } // Return map[type]count for all channels func CountChannelsGroupByType() (map[int64]int64, error) { type result struct { Type int64 `gorm:"column:type"` Count int64 `gorm:"column:count"` } var results []result err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error if err != nil { return nil, err } counts := make(map[int64]int64) for _, r := range results { counts[r.Type] = r.Count } return counts, nil }