relay.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "log"
  6. "net/http"
  7. "one-api/common"
  8. "strconv"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type Message struct {
  13. Role string `json:"role"`
  14. Content json.RawMessage `json:"content"`
  15. Name *string `json:"name,omitempty"`
  16. }
  17. type MediaMessage struct {
  18. Type string `json:"type"`
  19. Text string `json:"text"`
  20. ImageUrl any `json:"image_url,omitempty"`
  21. }
  22. type MessageImageUrl struct {
  23. Url string `json:"url"`
  24. Detail string `json:"detail"`
  25. }
  26. const (
  27. ContentTypeText = "text"
  28. ContentTypeImageURL = "image_url"
  29. )
  30. func (m Message) ParseContent() []MediaMessage {
  31. var contentList []MediaMessage
  32. var stringContent string
  33. if err := json.Unmarshal(m.Content, &stringContent); err == nil {
  34. contentList = append(contentList, MediaMessage{
  35. Type: ContentTypeText,
  36. Text: stringContent,
  37. })
  38. return contentList
  39. }
  40. var arrayContent []json.RawMessage
  41. if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
  42. for _, contentItem := range arrayContent {
  43. var contentMap map[string]any
  44. if err := json.Unmarshal(contentItem, &contentMap); err != nil {
  45. continue
  46. }
  47. switch contentMap["type"] {
  48. case ContentTypeText:
  49. if subStr, ok := contentMap["text"].(string); ok {
  50. contentList = append(contentList, MediaMessage{
  51. Type: ContentTypeText,
  52. Text: subStr,
  53. })
  54. }
  55. case ContentTypeImageURL:
  56. if subObj, ok := contentMap["image_url"].(map[string]any); ok {
  57. detail, ok := subObj["detail"]
  58. if ok {
  59. subObj["detail"] = detail.(string)
  60. } else {
  61. subObj["detail"] = "auto"
  62. }
  63. contentList = append(contentList, MediaMessage{
  64. Type: ContentTypeImageURL,
  65. ImageUrl: MessageImageUrl{
  66. Url: subObj["url"].(string),
  67. Detail: subObj["detail"].(string),
  68. },
  69. })
  70. }
  71. }
  72. }
  73. return contentList
  74. }
  75. return nil
  76. }
  77. const (
  78. RelayModeUnknown = iota
  79. RelayModeChatCompletions
  80. RelayModeCompletions
  81. RelayModeEmbeddings
  82. RelayModeModerations
  83. RelayModeImagesGenerations
  84. RelayModeEdits
  85. RelayModeMidjourneyImagine
  86. RelayModeMidjourneyDescribe
  87. RelayModeMidjourneyBlend
  88. RelayModeMidjourneyChange
  89. RelayModeMidjourneySimpleChange
  90. RelayModeMidjourneyNotify
  91. RelayModeMidjourneyTaskFetch
  92. RelayModeMidjourneyTaskFetchByCondition
  93. RelayModeAudioSpeech
  94. RelayModeAudioTranscription
  95. RelayModeAudioTranslation
  96. )
  97. // https://platform.openai.com/docs/api-reference/chat
  98. type ResponseFormat struct {
  99. Type string `json:"type,omitempty"`
  100. }
  101. type GeneralOpenAIRequest struct {
  102. Model string `json:"model,omitempty"`
  103. Messages []Message `json:"messages,omitempty"`
  104. Prompt any `json:"prompt,omitempty"`
  105. Stream bool `json:"stream,omitempty"`
  106. MaxTokens uint `json:"max_tokens,omitempty"`
  107. Temperature float64 `json:"temperature,omitempty"`
  108. TopP float64 `json:"top_p,omitempty"`
  109. N int `json:"n,omitempty"`
  110. Input any `json:"input,omitempty"`
  111. Instruction string `json:"instruction,omitempty"`
  112. Size string `json:"size,omitempty"`
  113. Functions any `json:"functions,omitempty"`
  114. FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
  115. PresencePenalty float64 `json:"presence_penalty,omitempty"`
  116. ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
  117. Seed float64 `json:"seed,omitempty"`
  118. Tools any `json:"tools,omitempty"`
  119. ToolChoice any `json:"tool_choice,omitempty"`
  120. User string `json:"user,omitempty"`
  121. }
  122. func (r GeneralOpenAIRequest) ParseInput() []string {
  123. if r.Input == nil {
  124. return nil
  125. }
  126. var input []string
  127. switch r.Input.(type) {
  128. case string:
  129. input = []string{r.Input.(string)}
  130. case []any:
  131. input = make([]string, 0, len(r.Input.([]any)))
  132. for _, item := range r.Input.([]any) {
  133. if str, ok := item.(string); ok {
  134. input = append(input, str)
  135. }
  136. }
  137. }
  138. return input
  139. }
  140. type AudioRequest struct {
  141. Model string `json:"model"`
  142. Voice string `json:"voice"`
  143. Input string `json:"input"`
  144. }
  145. type ChatRequest struct {
  146. Model string `json:"model"`
  147. Messages []Message `json:"messages"`
  148. MaxTokens uint `json:"max_tokens"`
  149. }
  150. type TextRequest struct {
  151. Model string `json:"model"`
  152. Messages []Message `json:"messages"`
  153. Prompt string `json:"prompt"`
  154. MaxTokens uint `json:"max_tokens"`
  155. //Stream bool `json:"stream"`
  156. }
  157. type ImageRequest struct {
  158. Model string `json:"model"`
  159. Prompt string `json:"prompt"`
  160. N int `json:"n"`
  161. Size string `json:"size"`
  162. Quality string `json:"quality,omitempty"`
  163. ResponseFormat string `json:"response_format,omitempty"`
  164. Style string `json:"style,omitempty"`
  165. }
  166. type AudioResponse struct {
  167. Text string `json:"text,omitempty"`
  168. }
  169. type Usage struct {
  170. PromptTokens int `json:"prompt_tokens"`
  171. CompletionTokens int `json:"completion_tokens"`
  172. TotalTokens int `json:"total_tokens"`
  173. }
  174. type OpenAIError struct {
  175. Message string `json:"message"`
  176. Type string `json:"type"`
  177. Param string `json:"param"`
  178. Code any `json:"code"`
  179. }
  180. type OpenAIErrorWithStatusCode struct {
  181. OpenAIError
  182. StatusCode int `json:"status_code"`
  183. }
  184. type TextResponse struct {
  185. Choices []OpenAITextResponseChoice `json:"choices"`
  186. Usage `json:"usage"`
  187. Error OpenAIError `json:"error"`
  188. }
  189. type OpenAITextResponseChoice struct {
  190. Index int `json:"index"`
  191. Message `json:"message"`
  192. FinishReason string `json:"finish_reason"`
  193. }
  194. type OpenAITextResponse struct {
  195. Id string `json:"id"`
  196. Object string `json:"object"`
  197. Created int64 `json:"created"`
  198. Choices []OpenAITextResponseChoice `json:"choices"`
  199. Usage `json:"usage"`
  200. }
  201. type OpenAIEmbeddingResponseItem struct {
  202. Object string `json:"object"`
  203. Index int `json:"index"`
  204. Embedding []float64 `json:"embedding"`
  205. }
  206. type OpenAIEmbeddingResponse struct {
  207. Object string `json:"object"`
  208. Data []OpenAIEmbeddingResponseItem `json:"data"`
  209. Model string `json:"model"`
  210. Usage `json:"usage"`
  211. }
  212. type ImageResponse struct {
  213. Created int `json:"created"`
  214. Data []struct {
  215. Url string `json:"url"`
  216. B64Json string `json:"b64_json"`
  217. }
  218. }
  219. type ChatCompletionsStreamResponseChoice struct {
  220. Delta struct {
  221. Content string `json:"content"`
  222. } `json:"delta"`
  223. FinishReason *string `json:"finish_reason,omitempty"`
  224. }
  225. type ChatCompletionsStreamResponse struct {
  226. Id string `json:"id"`
  227. Object string `json:"object"`
  228. Created int64 `json:"created"`
  229. Model string `json:"model"`
  230. Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
  231. }
  232. type ChatCompletionsStreamResponseSimple struct {
  233. Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
  234. }
  235. type CompletionsStreamResponse struct {
  236. Choices []struct {
  237. Text string `json:"text"`
  238. FinishReason string `json:"finish_reason"`
  239. } `json:"choices"`
  240. }
  241. type MidjourneyRequest struct {
  242. Prompt string `json:"prompt"`
  243. NotifyHook string `json:"notifyHook"`
  244. Action string `json:"action"`
  245. Index int `json:"index"`
  246. State string `json:"state"`
  247. TaskId string `json:"taskId"`
  248. Base64Array []string `json:"base64Array"`
  249. Content string `json:"content"`
  250. }
  251. type MidjourneyResponse struct {
  252. Code int `json:"code"`
  253. Description string `json:"description"`
  254. Properties interface{} `json:"properties"`
  255. Result string `json:"result"`
  256. }
  257. func Relay(c *gin.Context) {
  258. relayMode := RelayModeUnknown
  259. if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
  260. relayMode = RelayModeChatCompletions
  261. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
  262. relayMode = RelayModeCompletions
  263. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
  264. relayMode = RelayModeEmbeddings
  265. } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
  266. relayMode = RelayModeEmbeddings
  267. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
  268. relayMode = RelayModeModerations
  269. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
  270. relayMode = RelayModeImagesGenerations
  271. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
  272. relayMode = RelayModeEdits
  273. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
  274. relayMode = RelayModeAudioSpeech
  275. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
  276. relayMode = RelayModeAudioTranscription
  277. } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
  278. relayMode = RelayModeAudioTranslation
  279. }
  280. var err *OpenAIErrorWithStatusCode
  281. switch relayMode {
  282. case RelayModeImagesGenerations:
  283. err = relayImageHelper(c, relayMode)
  284. case RelayModeAudioSpeech:
  285. fallthrough
  286. case RelayModeAudioTranslation:
  287. fallthrough
  288. case RelayModeAudioTranscription:
  289. err = relayAudioHelper(c, relayMode)
  290. default:
  291. err = relayTextHelper(c, relayMode)
  292. }
  293. if err != nil {
  294. requestId := c.GetString(common.RequestIdKey)
  295. retryTimesStr := c.Query("retry")
  296. retryTimes, _ := strconv.Atoi(retryTimesStr)
  297. if retryTimesStr == "" {
  298. retryTimes = common.RetryTimes
  299. }
  300. if retryTimes > 0 {
  301. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
  302. } else {
  303. if err.StatusCode == http.StatusTooManyRequests {
  304. //err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
  305. }
  306. err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
  307. c.JSON(err.StatusCode, gin.H{
  308. "error": err.OpenAIError,
  309. })
  310. }
  311. channelId := c.GetInt("channel_id")
  312. autoBan := c.GetBool("auto_ban")
  313. common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
  314. // https://platform.openai.com/docs/guides/error-codes/api-errors
  315. if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
  316. channelId := c.GetInt("channel_id")
  317. channelName := c.GetString("channel_name")
  318. disableChannel(channelId, channelName, err.Message)
  319. }
  320. }
  321. }
  322. func RelayMidjourney(c *gin.Context) {
  323. relayMode := RelayModeUnknown
  324. if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
  325. relayMode = RelayModeMidjourneyImagine
  326. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
  327. relayMode = RelayModeMidjourneyBlend
  328. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
  329. relayMode = RelayModeMidjourneyDescribe
  330. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
  331. relayMode = RelayModeMidjourneyNotify
  332. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
  333. relayMode = RelayModeMidjourneyChange
  334. } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
  335. relayMode = RelayModeMidjourneyChange
  336. } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
  337. relayMode = RelayModeMidjourneyTaskFetch
  338. } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
  339. relayMode = RelayModeMidjourneyTaskFetchByCondition
  340. }
  341. var err *MidjourneyResponse
  342. switch relayMode {
  343. case RelayModeMidjourneyNotify:
  344. err = relayMidjourneyNotify(c)
  345. case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
  346. err = relayMidjourneyTask(c, relayMode)
  347. default:
  348. err = relayMidjourneySubmit(c, relayMode)
  349. }
  350. //err = relayMidjourneySubmit(c, relayMode)
  351. log.Println(err)
  352. if err != nil {
  353. retryTimesStr := c.Query("retry")
  354. retryTimes, _ := strconv.Atoi(retryTimesStr)
  355. if retryTimesStr == "" {
  356. retryTimes = common.RetryTimes
  357. }
  358. if retryTimes > 0 {
  359. c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
  360. } else {
  361. if err.Code == 30 {
  362. err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
  363. }
  364. c.JSON(400, gin.H{
  365. "error": err.Description + " " + err.Result,
  366. })
  367. }
  368. channelId := c.GetInt("channel_id")
  369. common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
  370. //if shouldDisableChannel(&err.OpenAIError) {
  371. // channelId := c.GetInt("channel_id")
  372. // channelName := c.GetString("channel_name")
  373. // disableChannel(channelId, channelName, err.Result)
  374. //};''''''''''''''''''''''''''''''''
  375. }
  376. }
  377. func RelayNotImplemented(c *gin.Context) {
  378. err := OpenAIError{
  379. Message: "API not implemented",
  380. Type: "new_api_error",
  381. Param: "",
  382. Code: "api_not_implemented",
  383. }
  384. c.JSON(http.StatusNotImplemented, gin.H{
  385. "error": err,
  386. })
  387. }
  388. func RelayNotFound(c *gin.Context) {
  389. err := OpenAIError{
  390. Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
  391. Type: "invalid_request_error",
  392. Param: "",
  393. Code: "",
  394. }
  395. c.JSON(http.StatusNotFound, gin.H{
  396. "error": err,
  397. })
  398. }