reusing.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package controller
  2. import (
  3. "fmt"
  4. "net/url"
  5. "github.com/labring/aiproxy/core/model"
  6. )
  7. type ParamsFunc interface {
  8. GetParams() (map[string]string, error)
  9. }
  10. type groupParams struct {
  11. mcpID string
  12. groupID string
  13. }
  14. func (g *groupParams) GetParams() (map[string]string, error) {
  15. param, err := model.CacheGetPublicMCPReusingParam(g.mcpID, g.groupID)
  16. if err != nil {
  17. return nil, fmt.Errorf("failed to get reusing params: %w", err)
  18. }
  19. return param.Params, nil
  20. }
  21. func newGroupParams(mcpID, groupID string) ParamsFunc {
  22. return &groupParams{
  23. mcpID: mcpID,
  24. groupID: groupID,
  25. }
  26. }
  27. type staticParams map[string]string
  28. func (s staticParams) GetParams() (map[string]string, error) {
  29. return s, nil
  30. }
  31. // ReusingParamProcessor 统一处理reusing参数
  32. type ReusingParamProcessor struct {
  33. mcpID string
  34. paramsFunc ParamsFunc
  35. }
  36. func NewReusingParamProcessor(mcpID string, paramsFunc ParamsFunc) *ReusingParamProcessor {
  37. return &ReusingParamProcessor{
  38. mcpID: mcpID,
  39. paramsFunc: paramsFunc,
  40. }
  41. }
  42. // ProcessProxyReusingParams 处理代理类型的reusing参数
  43. func (p *ReusingParamProcessor) ProcessProxyReusingParams(
  44. reusingParams map[string]model.PublicMCPProxyReusingParam,
  45. headers map[string]string,
  46. backendQuery *url.Values,
  47. ) error {
  48. if len(reusingParams) == 0 {
  49. return nil
  50. }
  51. param, err := p.paramsFunc.GetParams()
  52. if err != nil {
  53. return err
  54. }
  55. for key, config := range reusingParams {
  56. value, exists := param[key]
  57. if !exists {
  58. if config.Required {
  59. return fmt.Errorf("required reusing parameter %s is missing", key)
  60. }
  61. continue
  62. }
  63. if err := p.applyProxyParam(key, value, config.Type, headers, backendQuery); err != nil {
  64. return err
  65. }
  66. }
  67. return nil
  68. }
  69. // ProcessEmbedReusingParams 处理嵌入类型的reusing参数
  70. func (p *ReusingParamProcessor) ProcessEmbedReusingParams(
  71. reusingParams map[string]model.ReusingParam,
  72. ) (map[string]string, error) {
  73. if len(reusingParams) == 0 {
  74. return nil, nil
  75. }
  76. param, err := p.paramsFunc.GetParams()
  77. if err != nil {
  78. return nil, fmt.Errorf("failed to get reusing params: %w", err)
  79. }
  80. reusingConfig := make(map[string]string)
  81. for key, config := range reusingParams {
  82. value, exists := param[key]
  83. if !exists {
  84. if config.Required {
  85. return nil, fmt.Errorf("required reusing parameter %s is missing", key)
  86. }
  87. continue
  88. }
  89. reusingConfig[key] = value
  90. }
  91. return reusingConfig, nil
  92. }
  93. // applyProxyParam 应用代理参数到相应位置
  94. func (p *ReusingParamProcessor) applyProxyParam(
  95. key, value string,
  96. paramType model.ProxyParamType,
  97. headers map[string]string,
  98. backendQuery *url.Values,
  99. ) error {
  100. switch paramType {
  101. case model.ParamTypeHeader:
  102. headers[key] = value
  103. case model.ParamTypeQuery:
  104. backendQuery.Set(key, value)
  105. case model.ParamTypeURL:
  106. return fmt.Errorf("URL parameter %s cannot be set via reusing", key)
  107. default:
  108. return fmt.Errorf("unknown param type: %s", paramType)
  109. }
  110. return nil
  111. }