openai.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package render
  2. import (
  3. "encoding/base64"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "github.com/bytedance/sonic"
  8. "github.com/gin-gonic/gin"
  9. "github.com/labring/aiproxy/core/common/conv"
  10. "github.com/labring/aiproxy/core/relay/model"
  11. )
  12. type OpenaiSSE struct {
  13. Data []byte
  14. }
  15. func (r *OpenaiSSE) Render(w http.ResponseWriter) error {
  16. r.WriteContentType(w)
  17. for _, bytes := range [][]byte{
  18. dataBytes,
  19. r.Data,
  20. nnBytes,
  21. } {
  22. // nosemgrep:
  23. // go.lang.security.audit.xss.no-direct-write-to-responsewriter.no-direct-write-to-responsewriter
  24. if _, err := w.Write(bytes); err != nil {
  25. return err
  26. }
  27. }
  28. return nil
  29. }
  30. func (r *OpenaiSSE) WriteContentType(w http.ResponseWriter) {
  31. WriteSSEContentType(w)
  32. }
  33. func OpenaiStringData(c *gin.Context, str string) {
  34. OpenaiBytesData(c, conv.StringToBytes(str))
  35. }
  36. func OpenaiBytesData(c *gin.Context, data []byte) {
  37. if len(c.Errors) > 0 {
  38. return
  39. }
  40. if c.IsAborted() {
  41. return
  42. }
  43. c.Render(-1, &OpenaiSSE{Data: data})
  44. c.Writer.Flush()
  45. }
  46. func OpenaiObjectData(c *gin.Context, object any) error {
  47. if len(c.Errors) > 0 {
  48. return c.Errors.Last()
  49. }
  50. if c.IsAborted() {
  51. return errors.New("context aborted")
  52. }
  53. jsonData, err := sonic.Marshal(object)
  54. if err != nil {
  55. return fmt.Errorf("error marshalling object: %w", err)
  56. }
  57. c.Render(-1, &OpenaiSSE{Data: jsonData})
  58. c.Writer.Flush()
  59. return nil
  60. }
  61. func OpenaiDone(c *gin.Context) {
  62. OpenaiStringData(c, DONE)
  63. }
  64. type OpenaiTtsSSE struct {
  65. Audio string // base64 encode audio data
  66. Usage *model.TextToSpeechUsage
  67. }
  68. func (r *OpenaiTtsSSE) Render(w http.ResponseWriter) error {
  69. r.WriteContentType(w)
  70. payload := model.TextToSpeechSSEResponse{
  71. Audio: r.Audio,
  72. Usage: r.Usage,
  73. }
  74. if r.Usage != nil {
  75. payload.Type = model.TextToSpeechSSEResponseTypeDone
  76. } else {
  77. payload.Type = model.TextToSpeechSSEResponseTypeDelta
  78. }
  79. jsonData, err := sonic.Marshal(payload)
  80. if err != nil {
  81. return fmt.Errorf("error marshalling object: %w", err)
  82. }
  83. for _, bytes := range [][]byte{
  84. dataBytes,
  85. jsonData,
  86. nnBytes,
  87. } {
  88. // nosemgrep:
  89. // go.lang.security.audit.xss.no-direct-write-to-responsewriter.no-direct-write-to-responsewriter
  90. if _, err := w.Write(bytes); err != nil {
  91. return err
  92. }
  93. }
  94. return nil
  95. }
  96. func (r *OpenaiTtsSSE) WriteContentType(w http.ResponseWriter) {
  97. WriteSSEContentType(w)
  98. }
  99. func OpenaiAudioData(c *gin.Context, audio string) {
  100. if len(c.Errors) > 0 {
  101. return
  102. }
  103. if c.IsAborted() {
  104. return
  105. }
  106. c.Render(-1, &OpenaiTtsSSE{Audio: audio})
  107. c.Writer.Flush()
  108. }
  109. type OpenaiAudioDataWriter struct {
  110. c *gin.Context
  111. }
  112. func NewOpenaiAudioDataWriter(c *gin.Context) *OpenaiAudioDataWriter {
  113. return &OpenaiAudioDataWriter{c: c}
  114. }
  115. func (w *OpenaiAudioDataWriter) Write(p []byte) (n int, err error) {
  116. OpenaiAudioData(w.c, base64.StdEncoding.EncodeToString(p))
  117. return len(p), nil
  118. }
  119. func OpenaiAudioDone(c *gin.Context, usage model.TextToSpeechUsage) {
  120. if len(c.Errors) > 0 {
  121. return
  122. }
  123. if c.IsAborted() {
  124. return
  125. }
  126. c.Render(-1, &OpenaiTtsSSE{Usage: &usage})
  127. c.Writer.Flush()
  128. }