ソースを参照

feat: 支持设置worker访问请求中的图片地址

[email protected] 1 年間 前
コミット
099068f543

+ 0 - 1
common/constants.go

@@ -12,7 +12,6 @@ import (
 var StartTime = time.Now().Unix() // unit: second
 var Version = "v0.0.0"            // this hard coding will be replaced automatically when building, no need to manually change
 var SystemName = "New API"
-var ServerAddress = "http://localhost:3000"
 var Footer = ""
 var Logo = ""
 var TopUpLink = ""

+ 9 - 0
constant/system.go

@@ -0,0 +1,9 @@
+package constant
+
+var ServerAddress = "http://localhost:3000"
+var WorkerUrl = ""
+var WorkerValidKey = ""
+
+func EnableWorker() bool {
+	return WorkerUrl != ""
+}

+ 2 - 2
controller/midjourney.go

@@ -235,7 +235,7 @@ func GetAllMidjourney(c *gin.Context) {
 	}
 	if constant.MjForwardUrlEnabled {
 		for i, midjourney := range logs {
-			midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
+			midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
 			logs[i] = midjourney
 		}
 	}
@@ -267,7 +267,7 @@ func GetUserMidjourney(c *gin.Context) {
 	}
 	if constant.MjForwardUrlEnabled {
 		for i, midjourney := range logs {
-			midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
+			midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
 			logs[i] = midjourney
 		}
 	}

+ 2 - 2
controller/misc.go

@@ -45,7 +45,7 @@ func GetStatus(c *gin.Context) {
 			"footer_html":              common.Footer,
 			"wechat_qrcode":            common.WeChatAccountQRCodeImageURL,
 			"wechat_login":             common.WeChatAuthEnabled,
-			"server_address":           common.ServerAddress,
+			"server_address":           constant.ServerAddress,
 			"price":                    constant.Price,
 			"min_topup":                constant.MinTopUp,
 			"turnstile_check":          common.TurnstileCheckEnabled,
@@ -203,7 +203,7 @@ func SendPasswordResetEmail(c *gin.Context) {
 	}
 	code := common.GenerateVerificationCode(0)
 	common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
-	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
+	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
 	subject := fmt.Sprintf("%s密码重置", common.SystemName)
 	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+

+ 1 - 1
controller/topup.go

@@ -92,7 +92,7 @@ func RequestEpay(c *gin.Context) {
 		payType = epay.WechatPay
 	}
 	callBackAddress := service.GetCallbackAddress()
-	returnUrl, _ := url.Parse(common.ServerAddress + "/log")
+	returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
 	notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
 	tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
 	client := GetEpayClient()

+ 7 - 1
model/option.go

@@ -59,6 +59,8 @@ func InitOptionMap() {
 	common.OptionMap["SystemName"] = common.SystemName
 	common.OptionMap["Logo"] = common.Logo
 	common.OptionMap["ServerAddress"] = ""
+	common.OptionMap["WorkerUrl"] = constant.WorkerUrl
+	common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
 	common.OptionMap["PayAddress"] = ""
 	common.OptionMap["CustomCallbackAddress"] = ""
 	common.OptionMap["EpayId"] = ""
@@ -232,7 +234,11 @@ func updateOptionMap(key string, value string) (err error) {
 	case "SMTPToken":
 		common.SMTPToken = value
 	case "ServerAddress":
-		common.ServerAddress = value
+		constant.ServerAddress = value
+	case "WorkerUrl":
+		constant.WorkerUrl = value
+	case "WorkerValidKey":
+		constant.WorkerValidKey = value
 	case "PayAddress":
 		constant.PayAddress = value
 	case "CustomCallbackAddress":

+ 2 - 1
model/token.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"gorm.io/gorm"
 	"one-api/common"
+	"one-api/constant"
 	"strconv"
 	"strings"
 )
@@ -297,7 +298,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
 						prompt = "您的额度已用尽"
 					}
 					if email != "" {
-						topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
+						topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
 						err = common.SendEmail(prompt, email,
 							fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
 						if err != nil {

+ 2 - 2
relay/channel/claude/relay-claude.go

@@ -138,11 +138,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 						// 判断是否是url
 						if strings.HasPrefix(imageUrl.Url, "http") {
 							// 是url,获取图片的类型和base64编码的数据
-							mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
+							mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
 							claudeMediaMessage.Source.MediaType = mimeType
 							claudeMediaMessage.Source.Data = data
 						} else {
-							_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
+							_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
 							if err != nil {
 								return nil, err
 							}

+ 1 - 1
relay/channel/gemini/relay-gemini.go

@@ -74,7 +74,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
 				if imageNum > GeminiVisionMaxImageNum {
 					continue
 				}
-				mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
+				mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
 				parts = append(parts, GeminiPart{
 					InlineData: &GeminiInlineData{
 						MimeType: mimeType,

+ 1 - 1
relay/relay-mj.go

@@ -111,7 +111,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
 	midjourneyTask.FinishTime = originTask.FinishTime
 	midjourneyTask.ImageUrl = ""
 	if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
-		midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
+		midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId
 		if originTask.Status != "SUCCESS" {
 			midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
 		}

+ 1 - 2
service/epay.go

@@ -1,13 +1,12 @@
 package service
 
 import (
-	"one-api/common"
 	"one-api/constant"
 )
 
 func GetCallbackAddress() string {
 	if constant.CustomCallbackAddress == "" {
-		return common.ServerAddress
+		return constant.ServerAddress
 	}
 	return constant.CustomCallbackAddress
 }

+ 15 - 22
common/image.go → service/image.go

@@ -1,4 +1,4 @@
-package common
+package service
 
 import (
 	"bytes"
@@ -8,7 +8,7 @@ import (
 	"golang.org/x/image/webp"
 	"image"
 	"io"
-	"net/http"
+	"one-api/common"
 	"strings"
 )
 
@@ -31,25 +31,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
 	return config, format, base64String, err
 }
 
-func IsImageUrl(url string) (bool, error) {
-	resp, err := http.Head(url)
-	if err != nil {
-		return false, err
-	}
-	if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
-		return false, nil
-	}
-	return true, nil
-}
-
 // GetImageFromUrl 获取图片的类型和base64编码的数据
 func GetImageFromUrl(url string) (mimeType string, data string, err error) {
-	isImage, err := IsImageUrl(url)
-	if !isImage {
+	resp, err := DoImageRequest(url)
+	if err != nil {
 		return
 	}
-	resp, err := http.Get(url)
-	if err != nil {
+	if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
 		return
 	}
 	defer resp.Body.Close()
@@ -64,16 +52,21 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
 }
 
 func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
-	response, err := http.Get(imageUrl)
+	response, err := DoImageRequest(imageUrl)
 	if err != nil {
-		SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
+		common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
 		return image.Config{}, "", err
 	}
 	defer response.Body.Close()
 
+	if response.StatusCode != 200 {
+		err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
+		return image.Config{}, "", err
+	}
+
 	var readData []byte
 	for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
-		SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
+		common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
 
 		// 从response.Body读取更多的数据直到达到当前的限制
 		additionalData := make([]byte, limit-int64(len(readData)))
@@ -99,11 +92,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
 	config, format, err := image.DecodeConfig(reader)
 	if err != nil {
 		err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
-		SysLog(err.Error())
+		common.SysLog(err.Error())
 		config, err = webp.DecodeConfig(reader)
 		if err != nil {
 			err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
-			SysLog(err.Error())
+			common.SysLog(err.Error())
 		}
 		format = "webp"
 	}

+ 2 - 3
service/token_counter.go

@@ -79,11 +79,10 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
 	var err error
 	var format string
 	if strings.HasPrefix(imageUrl.Url, "http") {
-		common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
-		config, format, err = common.DecodeUrlImageData(imageUrl.Url)
+		config, format, err = DecodeUrlImageData(imageUrl.Url)
 	} else {
 		common.SysLog(fmt.Sprintf("decoding image"))
-		config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
+		config, format, _, err = DecodeBase64ImageData(imageUrl.Url)
 	}
 	if err != nil {
 		return 0, err

+ 26 - 0
service/worker.go

@@ -0,0 +1,26 @@
+package service
+
+import (
+	"bytes"
+	"fmt"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"strings"
+)
+
+func DoImageRequest(originUrl string) (resp *http.Response, err error) {
+	if constant.EnableWorker() {
+		common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
+		workerUrl := constant.WorkerUrl
+		if !strings.HasSuffix(workerUrl, "/") {
+			workerUrl += "/"
+		}
+		// post request to worker
+		data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
+		return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
+	} else {
+		common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
+		return http.Get(originUrl)
+	}
+}

+ 34 - 0
web/src/components/SystemSetting.js

@@ -27,6 +27,8 @@ const SystemSetting = () => {
     SMTPFrom: '',
     SMTPToken: '',
     ServerAddress: '',
+    WorkerUrl: '',
+    WorkerValidKey: '',
     EpayId: '',
     EpayKey: '',
     Price: 7.3,
@@ -145,6 +147,8 @@ const SystemSetting = () => {
       name === 'Notice' ||
       (name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
       name === 'ServerAddress' ||
+      name === 'WorkerUrl' ||
+      name === 'WorkerValidKey' ||
       name === 'EpayId' ||
       name === 'EpayKey' ||
       name === 'Price' ||
@@ -172,6 +176,14 @@ const SystemSetting = () => {
     await updateOption('ServerAddress', ServerAddress);
   };
 
+  const submitWorker = async () => {
+    let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl);
+    await updateOption('WorkerUrl', WorkerUrl);
+    if (inputs.WorkerValidKey !== '') {
+      await updateOption('WorkerValidKey', inputs.WorkerValidKey);
+    }
+  }
+
   const submitPayAddress = async () => {
     if (inputs.ServerAddress === '') {
       showError('请先填写服务器地址');
@@ -327,6 +339,28 @@ const SystemSetting = () => {
           <Form.Button onClick={submitServerAddress}>
             更新服务器地址
           </Form.Button>
+          <Header as='h3' inverted={isDark}>
+            代理设置(支持 <a href='https://github.com/Calcium-Ion/new-api-worker' target='_blank' rel='noreferrer'>new-api-worker</a>)
+          </Header>
+          <Form.Group widths='equal'>
+            <Form.Input
+              label='Worker地址,不填写则不启用代理'
+              placeholder='例如:https://workername.yourdomain.workers.dev'
+              value={inputs.WorkerUrl}
+              name='WorkerUrl'
+              onChange={handleInputChange}
+            />
+            <Form.Input
+              label='Worker密钥,根据你部署的 Worker 填写'
+              placeholder='例如:your_secret_key'
+              value={inputs.WorkerValidKey}
+              name='WorkerValidKey'
+              onChange={handleInputChange}
+            />
+          </Form.Group>
+          <Form.Button onClick={submitWorker}>
+            更新Worker设置
+          </Form.Button>
           <Divider />
           <Header as='h3' inverted={isDark}>
             支付设置(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)