|
|
@@ -1,6 +1,8 @@
|
|
|
package controller
|
|
|
|
|
|
import (
|
|
|
+ "encoding/json"
|
|
|
+ "fmt"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
@@ -9,6 +11,34 @@ import (
|
|
|
"strings"
|
|
|
)
|
|
|
|
|
|
+type OpenAIModel struct {
|
|
|
+ ID string `json:"id"`
|
|
|
+ Object string `json:"object"`
|
|
|
+ Created int64 `json:"created"`
|
|
|
+ OwnedBy string `json:"owned_by"`
|
|
|
+ Permission []struct {
|
|
|
+ ID string `json:"id"`
|
|
|
+ Object string `json:"object"`
|
|
|
+ Created int64 `json:"created"`
|
|
|
+ AllowCreateEngine bool `json:"allow_create_engine"`
|
|
|
+ AllowSampling bool `json:"allow_sampling"`
|
|
|
+ AllowLogprobs bool `json:"allow_logprobs"`
|
|
|
+ AllowSearchIndices bool `json:"allow_search_indices"`
|
|
|
+ AllowView bool `json:"allow_view"`
|
|
|
+ AllowFineTuning bool `json:"allow_fine_tuning"`
|
|
|
+ Organization string `json:"organization"`
|
|
|
+ Group string `json:"group"`
|
|
|
+ IsBlocking bool `json:"is_blocking"`
|
|
|
+ } `json:"permission"`
|
|
|
+ Root string `json:"root"`
|
|
|
+ Parent string `json:"parent"`
|
|
|
+}
|
|
|
+
|
|
|
+type OpenAIModelsResponse struct {
|
|
|
+ Data []OpenAIModel `json:"data"`
|
|
|
+ Success bool `json:"success"`
|
|
|
+}
|
|
|
+
|
|
|
func GetAllChannels(c *gin.Context) {
|
|
|
p, _ := strconv.Atoi(c.Query("p"))
|
|
|
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
|
@@ -35,6 +65,65 @@ func GetAllChannels(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+func FetchUpstreamModels(c *gin.Context) {
|
|
|
+ id, err := strconv.Atoi(c.Param("id"))
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": err.Error(),
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+ channel, err := model.GetChannelById(id, true)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": err.Error(),
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if channel.Type != common.ChannelTypeOpenAI {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": "仅支持 OpenAI 类型渠道",
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+ url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
|
|
|
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": err.Error(),
|
|
|
+ })
|
|
|
+ }
|
|
|
+ result := OpenAIModelsResponse{}
|
|
|
+ err = json.Unmarshal(body, &result)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": err.Error(),
|
|
|
+ })
|
|
|
+ }
|
|
|
+ if !result.Success {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": false,
|
|
|
+ "message": "上游返回错误",
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ var ids []string
|
|
|
+ for _, model := range result.Data {
|
|
|
+ ids = append(ids, model.ID)
|
|
|
+ }
|
|
|
+
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "success": true,
|
|
|
+ "message": "",
|
|
|
+ "data": ids,
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
func FixChannelsAbilities(c *gin.Context) {
|
|
|
count, err := model.FixAbility()
|
|
|
if err != nil {
|