Explorar o código

Merge pull request #2086 from feitianbubu/pr/openai-tts-stream

feat: openai tts support streaming realtime audio
IcedTangerine hai 2 meses
pai
achega
e80b442dd6
Modificáronse 1 ficheiros con 43 adicións e 4 borrados
  1. 43 4
      relay/channel/openai/relay-openai.go

+ 43 - 4
relay/channel/openai/relay-openai.go

@@ -273,6 +273,39 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	return &simpleResponse.Usage, nil
 	return &simpleResponse.Usage, nil
 }
 }
 
 
+func streamTTSResponse(c *gin.Context, resp *http.Response) {
+	c.Writer.WriteHeaderNow()
+
+	flusher, ok := c.Writer.(http.Flusher)
+	if !ok {
+		logger.LogWarn(c, "streaming not supported")
+		_, err := io.Copy(c.Writer, resp.Body)
+		if err != nil {
+			logger.LogWarn(c, err.Error())
+		}
+		return
+	}
+
+	buffer := make([]byte, 4096)
+	for {
+		n, err := resp.Body.Read(buffer)
+		//logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
+		if n > 0 {
+			if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
+				logger.LogError(c, writeErr.Error())
+				break
+			}
+			flusher.Flush()
+		}
+		if err != nil {
+			if err != io.EOF {
+				logger.LogError(c, err.Error())
+			}
+			break
+		}
+	}
+}
+
 func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
 func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
 	// the status code has been judged before, if there is a body reading failure,
 	// the status code has been judged before, if there is a body reading failure,
 	// it should be regarded as a non-recoverable error, so it should not return err for external retry.
 	// it should be regarded as a non-recoverable error, so it should not return err for external retry.
@@ -288,10 +321,16 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		c.Writer.Header().Set(k, v[0])
 		c.Writer.Header().Set(k, v[0])
 	}
 	}
 	c.Writer.WriteHeader(resp.StatusCode)
 	c.Writer.WriteHeader(resp.StatusCode)
-	c.Writer.WriteHeaderNow()
-	_, err := io.Copy(c.Writer, resp.Body)
-	if err != nil {
-		logger.LogError(c, err.Error())
+
+	isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
+	if isStreaming {
+		streamTTSResponse(c, resp)
+	} else {
+		c.Writer.WriteHeaderNow()
+		_, err := io.Copy(c.Writer, resp.Body)
+		if err != nil {
+			logger.LogError(c, err.Error())
+		}
 	}
 	}
 	return usage
 	return usage
 }
 }