| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- package service
- import (
- "context"
- "encoding/json"
- "io"
- "log"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relayconstant "one-api/relay/constant"
- "one-api/setting"
- "strconv"
- "strings"
- "time"
- "github.com/gin-gonic/gin"
- )
- func CoverActionToModelName(mjAction string) string {
- modelName := "mj_" + strings.ToLower(mjAction)
- if mjAction == constant.MjActionSwapFace {
- modelName = "swap_face"
- }
- return modelName
- }
- func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
- action := ""
- if relayMode == relayconstant.RelayModeMidjourneyAction {
- // plus request
- err := CoverPlusActionToNormalAction(midjRequest)
- if err != nil {
- return "", err, false
- }
- action = midjRequest.Action
- } else {
- switch relayMode {
- case relayconstant.RelayModeMidjourneyImagine:
- action = constant.MjActionImagine
- case relayconstant.RelayModeMidjourneyVideo:
- action = constant.MjActionVideo
- case relayconstant.RelayModeMidjourneyEdits:
- action = constant.MjActionEdits
- case relayconstant.RelayModeMidjourneyDescribe:
- action = constant.MjActionDescribe
- case relayconstant.RelayModeMidjourneyBlend:
- action = constant.MjActionBlend
- case relayconstant.RelayModeMidjourneyShorten:
- action = constant.MjActionShorten
- case relayconstant.RelayModeMidjourneyChange:
- action = midjRequest.Action
- case relayconstant.RelayModeMidjourneyModal:
- action = constant.MjActionModal
- case relayconstant.RelayModeSwapFace:
- action = constant.MjActionSwapFace
- case relayconstant.RelayModeMidjourneyUpload:
- action = constant.MjActionUpload
- case relayconstant.RelayModeMidjourneySimpleChange:
- params := ConvertSimpleChangeParams(midjRequest.Content)
- if params == nil {
- return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
- }
- action = params.Action
- case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
- return "", nil, true
- default:
- return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
- }
- }
- modelName := CoverActionToModelName(action)
- return modelName, nil, true
- }
- func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
- // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
- customId := midjRequest.CustomId
- if customId == "" {
- return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
- }
- splits := strings.Split(customId, "::")
- var action string
- if splits[1] == "JOB" {
- action = splits[2]
- } else {
- action = splits[1]
- }
- if action == "" {
- return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
- }
- if strings.Contains(action, "upsample") {
- index, err := strconv.Atoi(splits[3])
- if err != nil {
- return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
- }
- midjRequest.Index = index
- midjRequest.Action = constant.MjActionUpscale
- } else if strings.Contains(action, "variation") {
- midjRequest.Index = 1
- if action == "variation" {
- index, err := strconv.Atoi(splits[3])
- if err != nil {
- return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
- }
- midjRequest.Index = index
- midjRequest.Action = constant.MjActionVariation
- } else if action == "low_variation" {
- midjRequest.Action = constant.MjActionLowVariation
- } else if action == "high_variation" {
- midjRequest.Action = constant.MjActionHighVariation
- }
- } else if strings.Contains(action, "pan") {
- midjRequest.Action = constant.MjActionPan
- midjRequest.Index = 1
- } else if strings.Contains(action, "reroll") {
- midjRequest.Action = constant.MjActionReRoll
- midjRequest.Index = 1
- } else if action == "Outpaint" {
- midjRequest.Action = constant.MjActionZoom
- midjRequest.Index = 1
- } else if action == "CustomZoom" {
- midjRequest.Action = constant.MjActionCustomZoom
- midjRequest.Index = 1
- } else if action == "Inpaint" {
- midjRequest.Action = constant.MjActionInPaint
- midjRequest.Index = 1
- } else {
- return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
- }
- return nil
- }
- func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
- split := strings.Split(content, " ")
- if len(split) != 2 {
- return nil
- }
- action := strings.ToLower(split[1])
- changeParams := &dto.MidjourneyRequest{}
- changeParams.TaskId = split[0]
- if action[0] == 'u' {
- changeParams.Action = "UPSCALE"
- } else if action[0] == 'v' {
- changeParams.Action = "VARIATION"
- } else if action == "r" {
- changeParams.Action = "REROLL"
- return changeParams
- } else {
- return nil
- }
- index, err := strconv.Atoi(action[1:2])
- if err != nil || index < 1 || index > 4 {
- return nil
- }
- changeParams.Index = index
- return changeParams
- }
- func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
- var nullBytes []byte
- //var requestBody io.Reader
- //requestBody = c.Request.Body
- // read request body to json, delete accountFilter and notifyHook
- var mapResult map[string]interface{}
- // if get request, no need to read request body
- if c.Request.Method != "GET" {
- err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
- }
- if !setting.MjAccountFilterEnabled {
- delete(mapResult, "accountFilter")
- }
- if !setting.MjNotifyEnabled {
- delete(mapResult, "notifyHook")
- }
- //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- // make new request with mapResult
- }
- if setting.MjModeClearEnabled {
- if prompt, ok := mapResult["prompt"].(string); ok {
- prompt = strings.Replace(prompt, "--fast", "", -1)
- prompt = strings.Replace(prompt, "--relax", "", -1)
- prompt = strings.Replace(prompt, "--turbo", "", -1)
- mapResult["prompt"] = prompt
- }
- }
- reqBody, err := json.Marshal(mapResult)
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
- }
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
- }
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
- if auth != "" {
- auth = strings.TrimPrefix(auth, "Bearer ")
- req.Header.Set("mj-api-secret", auth)
- }
- defer cancel()
- resp, err := GetHttpClient().Do(req)
- if err != nil {
- common.SysError("do request failed: " + err.Error())
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
- }
- statusCode := resp.StatusCode
- //if statusCode != 200 {
- // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
- //}
- err = req.Body.Close()
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
- }
- err = c.Request.Body.Close()
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
- }
- var midjResponse dto.MidjourneyResponse
- var midjourneyUploadsResponse dto.MidjourneyUploadResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
- }
- common.CloseResponseBodyGracefully(resp)
- respStr := string(responseBody)
- log.Printf("respStr: %s", respStr)
- if respStr == "" {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
- } else {
- err = json.Unmarshal(responseBody, &midjResponse)
- if err != nil {
- err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
- if err2 != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
- }
- }
- }
- //log.Printf("midjResponse: %v", midjResponse)
- //for k, v := range resp.Header {
- // c.Writer.Header().Set(k, v[0])
- //}
- return &dto.MidjourneyResponseWithStatusCode{
- StatusCode: statusCode,
- Response: midjResponse,
- }, responseBody, nil
- }
|