Przeglądaj źródła

support base64 image

CaIon 2 lat temu
rodzic
commit
57d0fc3021

+ 64 - 0
common/image.go

@@ -0,0 +1,64 @@
+package common
+
+import (
+	"bytes"
+	"encoding/base64"
+	"errors"
+	"fmt"
+	"github.com/chai2010/webp"
+	"image"
+	"io"
+	"net/http"
+	"strings"
+)
+
+func DecodeBase64ImageData(base64String string) (image.Config, error) {
+	// 去除base64数据的URL前缀(如果有)
+	if idx := strings.Index(base64String, ","); idx != -1 {
+		base64String = base64String[idx+1:]
+	}
+
+	// 将base64字符串解码为字节切片
+	decodedData, err := base64.StdEncoding.DecodeString(base64String)
+	if err != nil {
+		fmt.Println("Error: Failed to decode base64 string")
+		return image.Config{}, err
+	}
+
+	// 创建一个bytes.Buffer用于存储解码后的数据
+	reader := bytes.NewReader(decodedData)
+	config, err := getImageConfig(reader)
+	return config, err
+}
+
+func DecodeUrlImageData(imageUrl string) (image.Config, error) {
+	response, err := http.Get(imageUrl)
+	if err != nil {
+		SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
+		return image.Config{}, err
+	}
+
+	// 限制读取的字节数,防止下载整个图片
+	limitReader := io.LimitReader(response.Body, 8192)
+	config, err := getImageConfig(limitReader)
+	response.Body.Close()
+	return config, err
+}
+
+func getImageConfig(reader io.Reader) (image.Config, error) {
+	// 读取图片的头部信息来获取图片尺寸
+	config, _, err := image.DecodeConfig(reader)
+	if err != nil {
+		err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
+		SysLog(err.Error())
+		config, err = webp.DecodeConfig(reader)
+		if err != nil {
+			err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
+			SysLog(err.Error())
+		}
+	}
+	if err != nil {
+		return image.Config{}, err
+	}
+	return config, nil
+}

+ 9 - 9
controller/midjourney.go

@@ -19,12 +19,12 @@ import (
 func UpdateMidjourneyTask() {
 	//revocer
 	imageModel := "midjourney"
+	defer func() {
+		if err := recover(); err != nil {
+			log.Printf("UpdateMidjourneyTask panic: %v", err)
+		}
+	}()
 	for {
-		defer func() {
-			if err := recover(); err != nil {
-				log.Printf("UpdateMidjourneyTask panic: %v", err)
-			}
-		}()
 		time.Sleep(time.Duration(15) * time.Second)
 		tasks := model.GetAllUnFinishTasks()
 		if len(tasks) != 0 {
@@ -55,7 +55,6 @@ func UpdateMidjourneyTask() {
 				// 设置超时时间
 				timeout := time.Second * 5
 				ctx, cancel := context.WithTimeout(context.Background(), timeout)
-				defer cancel()
 
 				// 使用带有超时的 context 创建新的请求
 				req = req.WithContext(ctx)
@@ -68,8 +67,8 @@ func UpdateMidjourneyTask() {
 					log.Printf("UpdateMidjourneyTask error: %v", err)
 					continue
 				}
-				defer resp.Body.Close()
 				responseBody, err := io.ReadAll(resp.Body)
+				resp.Body.Close()
 				log.Printf("responseBody: %s", string(responseBody))
 				var responseItem Midjourney
 				// err = json.NewDecoder(resp.Body).Decode(&responseItem)
@@ -83,12 +82,12 @@ func UpdateMidjourneyTask() {
 						if err1 == nil && err2 == nil {
 							jsonData, err3 := json.Marshal(responseWithoutStatus)
 							if err3 != nil {
-								log.Fatalf("UpdateMidjourneyTask error1: %v", err3)
+								log.Printf("UpdateMidjourneyTask error1: %v", err3)
 								continue
 							}
 							err4 := json.Unmarshal(jsonData, &responseStatus)
 							if err4 != nil {
-								log.Fatalf("UpdateMidjourneyTask error2: %v", err4)
+								log.Printf("UpdateMidjourneyTask error2: %v", err4)
 								continue
 							}
 							responseItem.Status = strconv.Itoa(responseStatus.Status)
@@ -138,6 +137,7 @@ func UpdateMidjourneyTask() {
 					log.Printf("UpdateMidjourneyTask error5: %v", err)
 				}
 				log.Printf("UpdateMidjourneyTask success: %v", task)
+				cancel()
 			}
 		}
 	}

+ 10 - 19
controller/relay-utils.go

@@ -4,7 +4,6 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/chai2010/webp"
 	"github.com/gin-gonic/gin"
 	"github.com/pkoukk/tiktoken-go"
 	"image"
@@ -75,29 +74,21 @@ func getImageToken(imageUrl MessageImageUrl) (int, error) {
 	if imageUrl.Detail == "low" {
 		return 85, nil
 	}
-
-	response, err := http.Get(imageUrl.Url)
+	var config image.Config
+	var err error
+	if strings.HasPrefix(imageUrl.Url, "http") {
+		common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
+		config, err = common.DecodeUrlImageData(imageUrl.Url)
+	} else {
+		common.SysLog(fmt.Sprintf("decoding image"))
+		config, err = common.DecodeBase64ImageData(imageUrl.Url)
+	}
 	if err != nil {
-		fmt.Println("Error: Failed to get the URL")
 		return 0, err
 	}
 
-	// 限制读取的字节数,防止下载整个图片
-	limitReader := io.LimitReader(response.Body, 8192)
-
-	response.Body.Close()
-
-	// 读取图片的头部信息来获取图片尺寸
-	config, _, err := image.DecodeConfig(limitReader)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
-		config, err = webp.DecodeConfig(limitReader)
-		if err != nil {
-			common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
-		}
-	}
 	if config.Width == 0 || config.Height == 0 {
-		return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error()))
+		return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url))
 	}
 	if config.Width < 512 && config.Height < 512 {
 		if imageUrl.Detail == "auto" || imageUrl.Detail == "" {

+ 7 - 4
web/src/components/LogsTable.js

@@ -106,7 +106,7 @@ const LogsTable = () => {
                 return (
                     record.type === 0 || record.type === 2 ?
                         <div>
-                            <Tag color='grey' size='large' onClick={()=>{
+                            <Tag color='grey' size='large' onClick={() => {
                                 copyText(text)
                             }}> {text} </Tag>
                         </div>
@@ -133,7 +133,7 @@ const LogsTable = () => {
                 return (
                     record.type === 0 || record.type === 2 ?
                         <div>
-                            <Tag color={stringToColor(text)} size='large' onClick={()=>{
+                            <Tag color={stringToColor(text)} size='large' onClick={() => {
                                 copyText(text)
                             }}> {text} </Tag>
                         </div>
@@ -202,11 +202,12 @@ const LogsTable = () => {
     const [logType, setLogType] = useState(0);
     const isAdminUser = isAdmin();
     let now = new Date();
+    // 初始化start_timestamp为前一天
     const [inputs, setInputs] = useState({
         username: '',
         token_name: '',
         model_name: '',
-        start_timestamp: timestamp2string(0),
+        start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
         end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
         channel: ''
     });
@@ -338,7 +339,7 @@ const LogsTable = () => {
             showSuccess('已复制:' + text);
         } else {
             // setSearchKeyword(text);
-            Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
+            Modal.error({title: '无法复制到剪贴板,请手动复制', content: text});
         }
     }
 
@@ -412,10 +413,12 @@ const LogsTable = () => {
                                     name='model_name'
                                     onChange={value => handleInputChange(value, 'model_name')}/>
                         <Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
+                                         initValue={start_timestamp}
                                          value={start_timestamp} type='dateTime'
                                          name='start_timestamp'
                                          onChange={value => handleInputChange(value, 'start_timestamp')}/>
                         <Form.DatePicker field="end_timestamp" fluid label='结束时间' style={{width: 272}}
+                                         initValue={end_timestamp}
                                          value={end_timestamp} type='dateTime'
                                          name='end_timestamp'
                                          onChange={value => handleInputChange(value, 'end_timestamp')}/>