embed.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package doubao
  2. import (
  3. "github.com/bytedance/sonic/ast"
  4. "github.com/labring/aiproxy/core/relay/meta"
  5. )
  6. func patchEmbeddingsVisionInput(node *ast.Node) error {
  7. inputNode := node.Get("input")
  8. if !inputNode.Exists() {
  9. return nil
  10. }
  11. switch inputNode.TypeSafe() {
  12. case ast.V_ARRAY:
  13. return inputNode.ForEach(func(_ ast.Sequence, item *ast.Node) bool {
  14. switch item.TypeSafe() {
  15. case ast.V_STRING:
  16. text, err := item.String()
  17. if err != nil {
  18. return false
  19. }
  20. *item = ast.NewObject([]ast.Pair{
  21. ast.NewPair("type", ast.NewString("text")),
  22. ast.NewPair("text", ast.NewString(text)),
  23. })
  24. return true
  25. case ast.V_OBJECT:
  26. textNode := item.Get("text")
  27. if textNode.Exists() && textNode.TypeSafe() == ast.V_STRING {
  28. _, err := item.Set("type", ast.NewString("text"))
  29. return err == nil
  30. }
  31. imageNode := item.Get("image")
  32. if imageNode.Exists() && imageNode.TypeSafe() == ast.V_STRING {
  33. imageURL, err := imageNode.String()
  34. if err != nil {
  35. return false
  36. }
  37. _, err = item.Unset("image")
  38. if err != nil {
  39. return false
  40. }
  41. _, err = item.Set("type", ast.NewString("image_url"))
  42. if err != nil {
  43. return false
  44. }
  45. _, err = item.SetAny("image_url", map[string]string{
  46. "url": imageURL,
  47. })
  48. if err != nil {
  49. return false
  50. }
  51. }
  52. return true
  53. default:
  54. return false
  55. }
  56. })
  57. case ast.V_STRING:
  58. inputText, err := inputNode.String()
  59. if err != nil {
  60. return err
  61. }
  62. _, err = node.SetAny("input", []map[string]string{
  63. {
  64. "type": "text",
  65. "text": inputText,
  66. },
  67. })
  68. return err
  69. default:
  70. return nil
  71. }
  72. }
  73. func embeddingPreHandler(_ *meta.Meta, node *ast.Node) error {
  74. return patchEmbeddingsVisionResponse(node)
  75. }
  76. func patchEmbeddingsVisionResponse(node *ast.Node) error {
  77. dataNode := node.Get("data")
  78. if !dataNode.Exists() {
  79. return nil
  80. }
  81. switch dataNode.TypeSafe() {
  82. case ast.V_ARRAY:
  83. return nil
  84. case ast.V_OBJECT:
  85. embeddingNode := dataNode.Get("embedding")
  86. if !embeddingNode.Exists() {
  87. return nil
  88. }
  89. _, err := node.Unset("data")
  90. if err != nil {
  91. return err
  92. }
  93. _, err = node.SetAny("data", []map[string]any{
  94. {
  95. "embedding": embeddingNode,
  96. "object": "embedding",
  97. "index": 0,
  98. },
  99. })
  100. return err
  101. default:
  102. return nil
  103. }
  104. }