testreq.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. package utils
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "github.com/bytedance/sonic"
  7. "github.com/labring/aiproxy/core/model"
  8. "github.com/labring/aiproxy/core/relay/mode"
  9. relaymodel "github.com/labring/aiproxy/core/relay/model"
  10. )
  11. type UnsupportedModelTypeError struct {
  12. ModelType string
  13. }
  14. func (e *UnsupportedModelTypeError) Error() string {
  15. return fmt.Sprintf("model type '%s' not supported", e.ModelType)
  16. }
  17. func NewErrUnsupportedModelType(modelType string) *UnsupportedModelTypeError {
  18. return &UnsupportedModelTypeError{ModelType: modelType}
  19. }
  20. func BuildRequest(modelConfig model.ModelConfig) (io.Reader, mode.Mode, error) {
  21. switch modelConfig.Type {
  22. case mode.ChatCompletions:
  23. body, err := BuildChatCompletionRequest(modelConfig.Model)
  24. if err != nil {
  25. return nil, mode.Unknown, err
  26. }
  27. return body, mode.ChatCompletions, nil
  28. case mode.Completions:
  29. body, err := BuildCompletionsRequest(modelConfig.Model)
  30. if err != nil {
  31. return nil, mode.Unknown, err
  32. }
  33. return body, mode.Completions, nil
  34. case mode.Embeddings:
  35. body, err := BuildEmbeddingsRequest(modelConfig.Model)
  36. if err != nil {
  37. return nil, mode.Unknown, err
  38. }
  39. return body, mode.Embeddings, nil
  40. case mode.Moderations:
  41. body, err := BuildModerationsRequest(modelConfig.Model)
  42. if err != nil {
  43. return nil, mode.Unknown, err
  44. }
  45. return body, mode.Moderations, nil
  46. case mode.ImagesGenerations:
  47. body, err := BuildImagesGenerationsRequest(modelConfig)
  48. if err != nil {
  49. return nil, mode.Unknown, err
  50. }
  51. return body, mode.ImagesGenerations, nil
  52. case mode.ImagesEdits:
  53. return nil, mode.Unknown, NewErrUnsupportedModelType("edits")
  54. case mode.AudioSpeech:
  55. body, err := BuildAudioSpeechRequest(modelConfig.Model)
  56. if err != nil {
  57. return nil, mode.Unknown, err
  58. }
  59. return body, mode.AudioSpeech, nil
  60. case mode.AudioTranscription:
  61. return nil, mode.Unknown, NewErrUnsupportedModelType("audio transcription")
  62. case mode.AudioTranslation:
  63. return nil, mode.Unknown, NewErrUnsupportedModelType("audio translation")
  64. case mode.Rerank:
  65. body, err := BuildRerankRequest(modelConfig.Model)
  66. if err != nil {
  67. return nil, mode.Unknown, err
  68. }
  69. return body, mode.Rerank, nil
  70. case mode.ParsePdf:
  71. return nil, mode.Unknown, NewErrUnsupportedModelType("parse pdf")
  72. default:
  73. return nil, mode.Unknown, NewErrUnsupportedModelType(modelConfig.Type.String())
  74. }
  75. }
  76. func BuildChatCompletionRequest(model string) (io.Reader, error) {
  77. testRequest := &relaymodel.GeneralOpenAIRequest{
  78. Model: model,
  79. Messages: []relaymodel.Message{
  80. {
  81. Role: "user",
  82. Content: "hi",
  83. },
  84. },
  85. }
  86. jsonBytes, err := sonic.Marshal(testRequest)
  87. if err != nil {
  88. return nil, err
  89. }
  90. return bytes.NewReader(jsonBytes), nil
  91. }
  92. func BuildCompletionsRequest(model string) (io.Reader, error) {
  93. completionsRequest := &relaymodel.GeneralOpenAIRequest{
  94. Model: model,
  95. Prompt: "hi",
  96. }
  97. jsonBytes, err := sonic.Marshal(completionsRequest)
  98. if err != nil {
  99. return nil, err
  100. }
  101. return bytes.NewReader(jsonBytes), nil
  102. }
  103. func BuildEmbeddingsRequest(model string) (io.Reader, error) {
  104. embeddingsRequest := &relaymodel.GeneralOpenAIRequest{
  105. Model: model,
  106. Input: "hi",
  107. }
  108. jsonBytes, err := sonic.Marshal(embeddingsRequest)
  109. if err != nil {
  110. return nil, err
  111. }
  112. return bytes.NewReader(jsonBytes), nil
  113. }
  114. func BuildModerationsRequest(model string) (io.Reader, error) {
  115. moderationsRequest := &relaymodel.GeneralOpenAIRequest{
  116. Model: model,
  117. Input: "hi",
  118. }
  119. jsonBytes, err := sonic.Marshal(moderationsRequest)
  120. if err != nil {
  121. return nil, err
  122. }
  123. return bytes.NewReader(jsonBytes), nil
  124. }
  125. func BuildImagesGenerationsRequest(modelConfig model.ModelConfig) (io.Reader, error) {
  126. imagesGenerationsRequest := &relaymodel.GeneralOpenAIRequest{
  127. Model: modelConfig.Model,
  128. Prompt: "hi",
  129. Size: "1024x1024",
  130. }
  131. for size := range modelConfig.ImagePrices {
  132. imagesGenerationsRequest.Size = size
  133. break
  134. }
  135. jsonBytes, err := sonic.Marshal(imagesGenerationsRequest)
  136. if err != nil {
  137. return nil, err
  138. }
  139. return bytes.NewReader(jsonBytes), nil
  140. }
  141. func BuildAudioSpeechRequest(model string) (io.Reader, error) {
  142. audioSpeechRequest := &relaymodel.GeneralOpenAIRequest{
  143. Model: model,
  144. Input: "hi",
  145. }
  146. jsonBytes, err := sonic.Marshal(audioSpeechRequest)
  147. if err != nil {
  148. return nil, err
  149. }
  150. return bytes.NewReader(jsonBytes), nil
  151. }
  152. func BuildRerankRequest(model string) (io.Reader, error) {
  153. rerankRequest := &relaymodel.RerankRequest{
  154. Model: model,
  155. Query: "hi",
  156. Documents: []string{"hi"},
  157. }
  158. jsonBytes, err := sonic.Marshal(rerankRequest)
  159. if err != nil {
  160. return nil, err
  161. }
  162. return bytes.NewReader(jsonBytes), nil
  163. }