| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598 |
- package controller
- import (
- "context"
- "errors"
- "fmt"
- "io"
- "math/rand/v2"
- "net/http"
- "net/http/httptest"
- "net/url"
- "slices"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/labring/aiproxy/core/common/conv"
- "github.com/labring/aiproxy/core/common/notify"
- "github.com/labring/aiproxy/core/common/trylock"
- "github.com/labring/aiproxy/core/middleware"
- "github.com/labring/aiproxy/core/model"
- "github.com/labring/aiproxy/core/monitor"
- "github.com/labring/aiproxy/core/relay/adaptors"
- "github.com/labring/aiproxy/core/relay/meta"
- "github.com/labring/aiproxy/core/relay/mode"
- "github.com/labring/aiproxy/core/relay/render"
- "github.com/labring/aiproxy/core/relay/utils"
- log "github.com/sirupsen/logrus"
- )
- const channelTestRequestID = "channel-test"
- var (
- modelConfigCache map[string]model.ModelConfig = make(map[string]model.ModelConfig)
- modelConfigCacheOnce sync.Once
- )
- func guessModelConfig(modelName string) model.ModelConfig {
- modelConfigCacheOnce.Do(func() {
- for _, c := range adaptors.ChannelAdaptor {
- for _, m := range c.Metadata().Models {
- if _, ok := modelConfigCache[m.Model]; !ok {
- modelConfigCache[m.Model] = m
- }
- }
- }
- })
- if cachedConfig, ok := modelConfigCache[modelName]; ok {
- return cachedConfig
- }
- return model.ModelConfig{}
- }
- // testSingleModel tests a single model in the channel
- func testSingleModel(
- mc *model.ModelCaches,
- channel *model.Channel,
- modelName string,
- ) (*model.ChannelTest, error) {
- modelConfig, ok := mc.ModelConfig.GetModelConfig(modelName)
- if !ok {
- return nil, errors.New(modelName + " model config not found")
- }
- if modelConfig.Type == mode.Unknown {
- newModelConfig := guessModelConfig(modelName)
- if newModelConfig.Type != mode.Unknown {
- modelConfig = newModelConfig
- }
- }
- if modelConfig.Type != mode.Unknown {
- a, ok := adaptors.GetAdaptor(channel.Type)
- if !ok {
- return nil, errors.New("adaptor not found")
- }
- if !a.SupportMode(modelConfig.Type) {
- return nil, fmt.Errorf("%s not supported by adaptor", modelConfig.Type)
- }
- }
- if modelConfig.ExcludeFromTests {
- return &model.ChannelTest{
- TestAt: time.Now(),
- Model: modelName,
- ActualModel: modelName,
- Success: true,
- Code: http.StatusOK,
- Mode: modelConfig.Type,
- ChannelName: channel.Name,
- ChannelType: channel.Type,
- ChannelID: channel.ID,
- }, nil
- }
- body, m, err := utils.BuildRequest(modelConfig)
- if err != nil {
- return nil, err
- }
- w := httptest.NewRecorder()
- newc, _ := gin.CreateTestContext(w)
- newc.Request = &http.Request{
- URL: &url.URL{},
- Body: io.NopCloser(body),
- Header: make(http.Header),
- }
- middleware.SetRequestID(newc, channelTestRequestID)
- meta := meta.NewMeta(
- channel,
- m,
- modelName,
- modelConfig,
- meta.WithRequestID(channelTestRequestID),
- )
- result := relayHandler(newc, meta, mc)
- success := result.Error == nil
- var (
- respStr string
- code int
- )
- if success {
- switch meta.Mode {
- case mode.AudioSpeech,
- mode.ImagesGenerations:
- respStr = ""
- default:
- respStr = w.Body.String()
- }
- code = w.Code
- } else {
- respBody, _ := result.Error.MarshalJSON()
- respStr = conv.BytesToString(respBody)
- code = result.Error.StatusCode()
- }
- return channel.UpdateModelTest(
- meta.RequestAt,
- meta.OriginModel,
- meta.ActualModel,
- meta.Mode,
- time.Since(meta.RequestAt).Seconds(),
- success,
- respStr,
- code,
- )
- }
- // TestChannel godoc
- //
- // @Summary Test channel model
- // @Description Tests a single model in the channel
- // @Tags channel
- // @Produce json
- // @Security ApiKeyAuth
- // @Param id path int true "Channel ID"
- // @Param model path string true "Model name"
- // @Success 200 {object} middleware.APIResponse{data=model.ChannelTest}
- // @Router /api/channel/{id}/{model} [get]
- //
- //nolint:goconst
- func TestChannel(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: false,
- Message: err.Error(),
- })
- return
- }
- modelName := strings.TrimPrefix(c.Param("model"), "/")
- if modelName == "" {
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: false,
- Message: "model is required",
- })
- return
- }
- channel, err := model.LoadChannelByID(id)
- if err != nil {
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: false,
- Message: "channel not found",
- })
- return
- }
- if !slices.Contains(channel.Models, modelName) {
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: false,
- Message: "model not supported by channel",
- })
- return
- }
- ct, err := testSingleModel(model.LoadModelCaches(), channel, modelName)
- if err != nil {
- log.Errorf(
- "failed to test channel %s(%d) model %s: %s",
- channel.Name,
- channel.ID,
- modelName,
- err.Error(),
- )
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: false,
- Message: fmt.Sprintf(
- "failed to test channel %s(%d) model %s: %s",
- channel.Name,
- channel.ID,
- modelName,
- err.Error(),
- ),
- })
- return
- }
- if c.Query("success_body") != "true" && ct.Success {
- ct.Response = ""
- }
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: true,
- Data: ct,
- })
- }
- type TestResult struct {
- Data *model.ChannelTest `json:"data,omitempty"`
- Message string `json:"message,omitempty"`
- Success bool `json:"success"`
- }
- func processTestResult(
- mc *model.ModelCaches,
- channel *model.Channel,
- modelName string,
- returnSuccess, successResponseBody bool,
- ) *TestResult {
- ct, err := testSingleModel(mc, channel, modelName)
- e := &utils.UnsupportedModelTypeError{}
- if errors.As(err, &e) {
- log.Errorf("model %s not supported test: %s", modelName, err.Error())
- return nil
- }
- result := &TestResult{
- Success: err == nil,
- }
- if err != nil {
- result.Message = fmt.Sprintf(
- "failed to test channel %s(%d) model %s: %s",
- channel.Name,
- channel.ID,
- modelName,
- err.Error(),
- )
- return result
- }
- if !ct.Success {
- result.Data = ct
- return result
- }
- if !returnSuccess {
- return nil
- }
- if !successResponseBody {
- ct.Response = ""
- }
- result.Data = ct
- return result
- }
- // TestChannelModels godoc
- //
- // @Summary Test channel models
- // @Description Tests all models in the channel
- // @Tags channel
- // @Produce json
- // @Security ApiKeyAuth
- // @Param id path int true "Channel ID"
- // @Param return_success query bool false "Return success"
- // @Param success_body query bool false "Success body"
- // @Param stream query bool false "Stream"
- // @Success 200 {object} middleware.APIResponse{data=[]TestResult}
- // @Router /api/channel/{id}/test [get]
- func TestChannelModels(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: false,
- Message: err.Error(),
- })
- return
- }
- channel, err := model.LoadChannelByID(id)
- if err != nil {
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: false,
- Message: "channel not found",
- })
- return
- }
- returnSuccess := c.Query("return_success") == "true"
- successResponseBody := c.Query("success_body") == "true"
- isStream := c.Query("stream") == "true"
- results := make([]*TestResult, 0)
- resultsMutex := sync.Mutex{}
- hasError := atomic.Bool{}
- var wg sync.WaitGroup
- semaphore := make(chan struct{}, 5)
- models := slices.Clone(channel.Models)
- rand.Shuffle(len(models), func(i, j int) {
- models[i], models[j] = models[j], models[i]
- })
- mc := model.LoadModelCaches()
- for _, modelName := range models {
- wg.Add(1)
- semaphore <- struct{}{}
- go func(model string) {
- defer wg.Done()
- defer func() { <-semaphore }()
- result := processTestResult(mc, channel, model, returnSuccess, successResponseBody)
- if result == nil {
- return
- }
- if !result.Success || (result.Data != nil && !result.Data.Success) {
- hasError.Store(true)
- }
- resultsMutex.Lock()
- if isStream {
- err := render.OpenaiObjectData(c, result)
- if err != nil {
- log.Errorf("failed to render result: %s", err.Error())
- }
- } else {
- results = append(results, result)
- }
- resultsMutex.Unlock()
- }(modelName)
- }
- wg.Wait()
- if !hasError.Load() {
- err := model.ClearLastTestErrorAt(channel.ID)
- if err != nil {
- log.Errorf(
- "failed to clear last test error at for channel %s(%d): %s",
- channel.Name,
- channel.ID,
- err.Error(),
- )
- }
- }
- if !isStream {
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: true,
- Data: results,
- })
- }
- }
- // TestAllChannels godoc
- //
- // @Summary Test all channels
- // @Description Tests all channels
- // @Tags channel
- // @Produce json
- // @Security ApiKeyAuth
- // @Param test_disabled query bool false "Test disabled"
- // @Param return_success query bool false "Return success"
- // @Param success_body query bool false "Success body"
- // @Param stream query bool false "Stream"
- // @Success 200 {object} middleware.APIResponse{data=[]TestResult}
- //
- // @Router /api/channels/test [get]
- func TestAllChannels(c *gin.Context) {
- testDisabled := c.Query("test_disabled") == "true"
- var (
- channels []*model.Channel
- err error
- )
- if testDisabled {
- channels, err = model.LoadChannels()
- } else {
- channels, err = model.LoadEnabledChannels()
- }
- if err != nil {
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: false,
- Message: err.Error(),
- })
- return
- }
- returnSuccess := c.Query("return_success") == "true"
- successResponseBody := c.Query("success_body") == "true"
- isStream := c.Query("stream") == "true"
- results := make([]*TestResult, 0)
- resultsMutex := sync.Mutex{}
- hasErrorMap := make(map[int]*atomic.Bool)
- var wg sync.WaitGroup
- semaphore := make(chan struct{}, 5)
- newChannels := slices.Clone(channels)
- rand.Shuffle(len(newChannels), func(i, j int) {
- newChannels[i], newChannels[j] = newChannels[j], newChannels[i]
- })
- mc := model.LoadModelCaches()
- for _, channel := range newChannels {
- channelHasError := &atomic.Bool{}
- hasErrorMap[channel.ID] = channelHasError
- models := slices.Clone(channel.Models)
- rand.Shuffle(len(models), func(i, j int) {
- models[i], models[j] = models[j], models[i]
- })
- for _, modelName := range models {
- wg.Add(1)
- semaphore <- struct{}{}
- go func(model string, ch *model.Channel, hasError *atomic.Bool) {
- defer wg.Done()
- defer func() { <-semaphore }()
- result := processTestResult(mc, ch, model, returnSuccess, successResponseBody)
- if result == nil {
- return
- }
- if !result.Success || (result.Data != nil && !result.Data.Success) {
- hasError.Store(true)
- }
- resultsMutex.Lock()
- if isStream {
- err := render.OpenaiObjectData(c, result)
- if err != nil {
- log.Errorf("failed to render result: %s", err.Error())
- }
- } else {
- results = append(results, result)
- }
- resultsMutex.Unlock()
- }(modelName, channel, channelHasError)
- }
- }
- wg.Wait()
- for id, hasError := range hasErrorMap {
- if !hasError.Load() {
- err := model.ClearLastTestErrorAt(id)
- if err != nil {
- log.Errorf("failed to clear last test error at for channel %d: %s", id, err.Error())
- }
- }
- }
- if !isStream {
- c.JSON(http.StatusOK, middleware.APIResponse{
- Success: true,
- Data: results,
- })
- }
- }
- func tryTestChannel(channelID int, modelName string) bool {
- return trylock.Lock(
- fmt.Sprintf("channel_test_lock:%d:%s", channelID, modelName),
- 30*time.Second,
- )
- }
- func AutoTestBannedModels() {
- log := log.WithFields(log.Fields{
- "auto_test_banned_models": "true",
- })
- channels, err := monitor.GetAllBannedModelChannels(context.Background())
- if err != nil {
- log.Errorf("failed to get banned channels: %s", err.Error())
- return
- }
- if len(channels) == 0 {
- return
- }
- mc := model.LoadModelCaches()
- for modelName, ids := range channels {
- for _, id := range ids {
- if !tryTestChannel(int(id), modelName) {
- continue
- }
- channel, err := model.LoadChannelByID(int(id))
- if err != nil {
- log.Errorf("failed to get channel by model %s: %s", modelName, err.Error())
- continue
- }
- result, err := testSingleModel(mc, channel, modelName)
- if err != nil {
- notify.Error(
- fmt.Sprintf(
- "channel %s (type: %d, id: %d) model %s test failed",
- channel.Name,
- channel.Type,
- channel.ID,
- modelName,
- ),
- err.Error(),
- )
- continue
- }
- if result.Success {
- notify.Info(
- fmt.Sprintf(
- "channel %s (type: %d, id: %d) model %s test success",
- channel.Name,
- channel.Type,
- channel.ID,
- modelName,
- ),
- "unban it",
- )
- err = monitor.ClearChannelModelErrors(context.Background(), modelName, channel.ID)
- if err != nil {
- log.Errorf("clear channel errors failed: %+v", err)
- }
- } else {
- notify.Error(fmt.Sprintf("channel %s (type: %d, id: %d) model %s test failed", channel.Name, channel.Type, channel.ID, modelName),
- fmt.Sprintf("code: %d, response: %s", result.Code, result.Response))
- }
- }
- }
- }
|