publicmcp.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/url"
  6. "regexp"
  7. "strings"
  8. "time"
  9. "github.com/bytedance/sonic"
  10. "github.com/labring/aiproxy/core/common"
  11. log "github.com/sirupsen/logrus"
  12. "gorm.io/gorm"
  13. )
  14. type PublicMCPStatus int
  15. const (
  16. PublicMCPStatusEnabled PublicMCPStatus = iota + 1
  17. PublicMCPStatusDisabled
  18. )
  19. const (
  20. ErrPublicMCPNotFound = "public mcp"
  21. ErrMCPReusingParamNotFound = "mcp reusing param"
  22. )
  23. type PublicMCPType string
  24. const (
  25. PublicMCPTypeProxySSE PublicMCPType = "mcp_proxy_sse"
  26. PublicMCPTypeProxyStreamable PublicMCPType = "mcp_proxy_streamable"
  27. PublicMCPTypeDocs PublicMCPType = "mcp_docs" // read only
  28. PublicMCPTypeOpenAPI PublicMCPType = "mcp_openapi"
  29. PublicMCPTypeEmbed PublicMCPType = "mcp_embed"
  30. )
  31. type ProxyParamType string
  32. const (
  33. ParamTypeURL ProxyParamType = "url"
  34. ParamTypeHeader ProxyParamType = "header"
  35. ParamTypeQuery ProxyParamType = "query"
  36. )
  37. type ReusingParam struct {
  38. Name string `json:"name"`
  39. Description string `json:"description"`
  40. Required bool `json:"required"`
  41. }
  42. type MCPPrice struct {
  43. DefaultToolsCallPrice float64 `json:"default_tools_call_price"`
  44. ToolsCallPrices map[string]float64 `json:"tools_call_prices" gorm:"serializer:fastjson;type:text"`
  45. }
  46. type PublicMCPProxyReusingParam struct {
  47. ReusingParam
  48. Type ProxyParamType `json:"type"`
  49. }
  50. type PublicMCPProxyConfig struct {
  51. URL string `json:"url"`
  52. Querys map[string]string `json:"querys"`
  53. Headers map[string]string `json:"headers"`
  54. Reusing map[string]PublicMCPProxyReusingParam `json:"reusing"`
  55. }
  56. type Params = map[string]string
  57. type PublicMCPReusingParam struct {
  58. MCPID string `gorm:"primaryKey" json:"mcp_id"`
  59. GroupID string `gorm:"primaryKey" json:"group_id"`
  60. CreatedAt time.Time `gorm:"index" json:"created_at"`
  61. UpdateAt time.Time `gorm:"index" json:"update_at"`
  62. Group *Group `gorm:"foreignKey:GroupID" json:"-"`
  63. Params Params `gorm:"serializer:fastjson;type:text" json:"params"`
  64. }
  65. func (p *PublicMCPReusingParam) BeforeCreate(_ *gorm.DB) (err error) {
  66. if p.MCPID == "" {
  67. return errors.New("mcp id is empty")
  68. }
  69. if p.GroupID == "" {
  70. return errors.New("group is empty")
  71. }
  72. return err
  73. }
  74. func (p *PublicMCPReusingParam) MarshalJSON() ([]byte, error) {
  75. type Alias PublicMCPReusingParam
  76. a := &struct {
  77. *Alias
  78. CreatedAt int64 `json:"created_at"`
  79. UpdateAt int64 `json:"update_at"`
  80. }{
  81. Alias: (*Alias)(p),
  82. CreatedAt: p.CreatedAt.UnixMilli(),
  83. UpdateAt: p.UpdateAt.UnixMilli(),
  84. }
  85. return sonic.Marshal(a)
  86. }
  87. type MCPOpenAPIConfig struct {
  88. OpenAPISpec string `json:"openapi_spec"`
  89. OpenAPIContent string `json:"openapi_content,omitempty"`
  90. V2 bool `json:"v2"`
  91. ServerAddr string `json:"server_addr,omitempty"`
  92. Authorization string `json:"authorization,omitempty"`
  93. }
  94. type MCPEmbeddingConfig struct {
  95. Init map[string]string `json:"init"`
  96. Reusing map[string]ReusingParam `json:"reusing"`
  97. }
  98. var validateMCPIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
  99. func validateMCPID(id string) error {
  100. if id == "" {
  101. return errors.New("mcp id is empty")
  102. }
  103. if !validateMCPIDRegex.MatchString(id) {
  104. return errors.New("mcp id is invalid")
  105. }
  106. return nil
  107. }
  108. type TestConfig struct {
  109. Enabled bool `json:"enabled"`
  110. Params Params `json:"params"`
  111. }
  112. type PublicMCP struct {
  113. ID string `gorm:"primaryKey" json:"id"`
  114. CreatedAt time.Time `gorm:"index,autoCreateTime" json:"created_at"`
  115. UpdateAt time.Time `gorm:"index,autoUpdateTime" json:"update_at"`
  116. PublicMCPReusingParams []PublicMCPReusingParam `gorm:"foreignKey:MCPID" json:"-"`
  117. Name string `json:"name"`
  118. NameCN string `json:"name_cn,omitempty"`
  119. Status PublicMCPStatus `json:"status" gorm:"index;default:1"`
  120. Type PublicMCPType `json:"type,omitempty" gorm:"index"`
  121. Description string `json:"description"`
  122. DescriptionCN string `json:"description_cn,omitempty"`
  123. GitHubURL string `json:"github_url"`
  124. Readme string `json:"readme,omitempty" gorm:"type:text"`
  125. ReadmeCN string `json:"readme_cn,omitempty" gorm:"type:text"`
  126. ReadmeURL string `json:"readme_url,omitempty"`
  127. ReadmeCNURL string `json:"readme_cn_url,omitempty"`
  128. Tags []string `json:"tags,omitempty" gorm:"serializer:fastjson;type:text"`
  129. LogoURL string `json:"logo_url,omitempty"`
  130. Price MCPPrice `json:"price" gorm:"embedded"`
  131. ProxyConfig *PublicMCPProxyConfig `gorm:"serializer:fastjson;type:text" json:"proxy_config,omitempty"`
  132. OpenAPIConfig *MCPOpenAPIConfig `gorm:"serializer:fastjson;type:text" json:"openapi_config,omitempty"`
  133. EmbedConfig *MCPEmbeddingConfig `gorm:"serializer:fastjson;type:text" json:"embed_config,omitempty"`
  134. // only used by list tools
  135. TestConfig *TestConfig `gorm:"serializer:fastjson;type:text" json:"test_config,omitempty"`
  136. }
  137. func (p *PublicMCP) BeforeCreate(_ *gorm.DB) error {
  138. if err := validateMCPID(p.ID); err != nil {
  139. return err
  140. }
  141. if p.Status == 0 {
  142. p.Status = PublicMCPStatusEnabled
  143. }
  144. return nil
  145. }
  146. func (p *PublicMCP) BeforeSave(_ *gorm.DB) error {
  147. if p.OpenAPIConfig != nil {
  148. config := p.OpenAPIConfig
  149. if config.OpenAPISpec != "" {
  150. return validateHTTPURL(config.OpenAPISpec)
  151. }
  152. if config.OpenAPIContent != "" {
  153. return nil
  154. }
  155. return errors.New("openapi spec and content is empty")
  156. }
  157. if p.ProxyConfig != nil {
  158. config := p.ProxyConfig
  159. return validateHTTPURL(config.URL)
  160. }
  161. return nil
  162. }
  163. func validateHTTPURL(str string) error {
  164. if str == "" {
  165. return errors.New("url is empty")
  166. }
  167. u, err := url.Parse(str)
  168. if err != nil {
  169. return err
  170. }
  171. if u.Scheme != "http" && u.Scheme != "https" {
  172. return errors.New("url scheme not support")
  173. }
  174. return nil
  175. }
  176. func (p *PublicMCP) BeforeDelete(tx *gorm.DB) (err error) {
  177. return tx.Model(&PublicMCPReusingParam{}).
  178. Where("mcp_id = ?", p.ID).
  179. Delete(&PublicMCPReusingParam{}).
  180. Error
  181. }
  182. // CreatePublicMCP creates a new MCP
  183. func CreatePublicMCP(mcp *PublicMCP) error {
  184. err := DB.Create(mcp).Error
  185. if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) {
  186. return errors.New("mcp server already exist")
  187. }
  188. return err
  189. }
  190. func SavePublicMCP(mcp *PublicMCP) (err error) {
  191. defer func() {
  192. if err == nil {
  193. if err := CacheDeletePublicMCP(mcp.ID); err != nil {
  194. log.Error("cache delete public mcp error: " + err.Error())
  195. }
  196. }
  197. }()
  198. return DB.
  199. Omit(
  200. "created_at",
  201. "update_at",
  202. ).
  203. Save(mcp).Error
  204. }
  205. func SavePublicMCPs(msps []PublicMCP) (err error) {
  206. defer func() {
  207. if err == nil {
  208. for _, mcp := range msps {
  209. if err := CacheDeletePublicMCP(mcp.ID); err != nil {
  210. log.Error("cache delete public mcp error: " + err.Error())
  211. }
  212. }
  213. }
  214. }()
  215. return DB.
  216. Omit(
  217. "created_at",
  218. "update_at",
  219. ).
  220. Save(msps).Error
  221. }
  222. // UpdatePublicMCP updates an existing MCP
  223. func UpdatePublicMCP(mcp *PublicMCP) (err error) {
  224. defer func() {
  225. if err == nil {
  226. if err := CacheDeletePublicMCP(mcp.ID); err != nil {
  227. log.Error("cache delete public mcp error: " + err.Error())
  228. }
  229. }
  230. }()
  231. selects := []string{
  232. "github_url",
  233. "description",
  234. "description_cn",
  235. "readme",
  236. "readme_cn",
  237. "readme_url",
  238. "readme_cn_url",
  239. "tags",
  240. "logo_url",
  241. "proxy_config",
  242. "openapi_config",
  243. "embed_config",
  244. "test_config",
  245. }
  246. if mcp.Status != 0 {
  247. selects = append(selects, "status")
  248. }
  249. if mcp.Name != "" {
  250. selects = append(selects, "name")
  251. }
  252. if mcp.NameCN != "" {
  253. selects = append(selects, "name_cn")
  254. }
  255. if mcp.Type != "" {
  256. selects = append(selects, "type")
  257. }
  258. if mcp.Price.DefaultToolsCallPrice != 0 ||
  259. len(mcp.Price.ToolsCallPrices) != 0 {
  260. selects = append(selects, "price")
  261. }
  262. result := DB.
  263. Select(selects).
  264. Where("id = ?", mcp.ID).
  265. Updates(mcp)
  266. return HandleUpdateResult(result, ErrPublicMCPNotFound)
  267. }
  268. func UpdatePublicMCPStatus(id string, status PublicMCPStatus) (err error) {
  269. defer func() {
  270. if err == nil {
  271. if err := CacheUpdatePublicMCPStatus(id, status); err != nil {
  272. log.Error("cache update public mcp status error: " + err.Error())
  273. }
  274. }
  275. }()
  276. result := DB.Model(&PublicMCP{}).Where("id = ?", id).Update("status", status)
  277. return HandleUpdateResult(result, ErrPublicMCPNotFound)
  278. }
  279. // DeletePublicMCP deletes an MCP by ID
  280. func DeletePublicMCP(id string) (err error) {
  281. defer func() {
  282. if err == nil {
  283. if err := CacheDeletePublicMCP(id); err != nil {
  284. log.Error("cache delete public mcp error: " + err.Error())
  285. }
  286. }
  287. }()
  288. if id == "" {
  289. return errors.New("MCP id is empty")
  290. }
  291. result := DB.Delete(&PublicMCP{ID: id})
  292. return HandleUpdateResult(result, ErrPublicMCPNotFound)
  293. }
  294. // GetPublicMCPByID retrieves an MCP by ID
  295. func GetPublicMCPByID(id string) (PublicMCP, error) {
  296. var mcp PublicMCP
  297. if id == "" {
  298. return mcp, errors.New("MCP id is empty")
  299. }
  300. err := DB.Where("id = ?", id).First(&mcp).Error
  301. return mcp, HandleNotFound(err, ErrPublicMCPNotFound)
  302. }
  303. // GetPublicMCPs retrieves MCPs with pagination and filtering
  304. func GetPublicMCPs(
  305. page, perPage int,
  306. id string,
  307. mcpType []PublicMCPType,
  308. keyword string,
  309. status PublicMCPStatus,
  310. ) (mcps []PublicMCP, total int64, err error) {
  311. tx := DB.Model(&PublicMCP{})
  312. if id != "" {
  313. tx = tx.Where("id = ?", id)
  314. }
  315. if status != 0 {
  316. tx = tx.Where("status = ?", status)
  317. }
  318. if len(mcpType) > 0 {
  319. tx = tx.Where("type IN (?)", mcpType)
  320. }
  321. if keyword != "" {
  322. var (
  323. conditions []string
  324. values []any
  325. )
  326. if id == "" {
  327. if !common.UsingSQLite {
  328. conditions = append(conditions, "id ILIKE ?")
  329. values = append(values, "%"+keyword+"%")
  330. } else {
  331. conditions = append(conditions, "id LIKE ?")
  332. values = append(values, "%"+keyword+"%")
  333. }
  334. }
  335. if !common.UsingSQLite {
  336. conditions = append(conditions, "name ILIKE ?")
  337. values = append(values, "%"+keyword+"%")
  338. } else {
  339. conditions = append(conditions, "name LIKE ?")
  340. values = append(values, "%"+keyword+"%")
  341. }
  342. if !common.UsingSQLite {
  343. conditions = append(conditions, "name_cn ILIKE ?")
  344. values = append(values, "%"+keyword+"%")
  345. } else {
  346. conditions = append(conditions, "name_cn LIKE ?")
  347. values = append(values, "%"+keyword+"%")
  348. }
  349. if !common.UsingSQLite {
  350. conditions = append(conditions, "description ILIKE ?")
  351. values = append(values, "%"+keyword+"%")
  352. } else {
  353. conditions = append(conditions, "description LIKE ?")
  354. values = append(values, "%"+keyword+"%")
  355. }
  356. if !common.UsingSQLite {
  357. conditions = append(conditions, "description_cn ILIKE ?")
  358. values = append(values, "%"+keyword+"%")
  359. } else {
  360. conditions = append(conditions, "description_cn LIKE ?")
  361. values = append(values, "%"+keyword+"%")
  362. }
  363. if !common.UsingSQLite {
  364. conditions = append(conditions, "readme ILIKE ?")
  365. values = append(values, "%"+keyword+"%")
  366. } else {
  367. conditions = append(conditions, "readme LIKE ?")
  368. values = append(values, "%"+keyword+"%")
  369. }
  370. if !common.UsingSQLite {
  371. conditions = append(conditions, "readme_cn ILIKE ?")
  372. values = append(values, "%"+keyword+"%")
  373. } else {
  374. conditions = append(conditions, "readme_cn LIKE ?")
  375. values = append(values, "%"+keyword+"%")
  376. }
  377. if len(conditions) > 0 {
  378. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  379. }
  380. }
  381. err = tx.Count(&total).Error
  382. if err != nil {
  383. return nil, 0, err
  384. }
  385. if total <= 0 {
  386. return nil, 0, nil
  387. }
  388. limit, offset := toLimitOffset(page, perPage)
  389. err = tx.
  390. Limit(limit).
  391. Offset(offset).
  392. Find(&mcps).
  393. Error
  394. return mcps, total, err
  395. }
  396. func GetAllPublicMCPs(status PublicMCPStatus) ([]PublicMCP, error) {
  397. var mcps []PublicMCP
  398. tx := DB.Model(&PublicMCP{})
  399. if status != 0 {
  400. tx = tx.Where("status = ?", status)
  401. }
  402. err := tx.Find(&mcps).Error
  403. return mcps, err
  404. }
  405. func GetPublicMCPsEnabled(ids []string) ([]string, error) {
  406. var mcpIDs []string
  407. err := DB.Model(&PublicMCP{}).
  408. Select("id").
  409. Where("id IN (?) AND status = ?", ids, PublicMCPStatusEnabled).
  410. Pluck("id", &mcpIDs).
  411. Error
  412. if err != nil {
  413. return nil, err
  414. }
  415. return mcpIDs, nil
  416. }
  417. func GetPublicMCPsEmbedConfig(ids []string) (map[string]MCPEmbeddingConfig, error) {
  418. var configs []struct {
  419. ID string
  420. EmbedConfig MCPEmbeddingConfig `gorm:"serializer:fastjson;type:text"`
  421. }
  422. err := DB.Model(&PublicMCP{}).
  423. Select("id, embed_config").
  424. Where("id IN (?)", ids).
  425. Find(&configs).Error
  426. if err != nil {
  427. return nil, err
  428. }
  429. configsMap := make(map[string]MCPEmbeddingConfig)
  430. for _, config := range configs {
  431. configsMap[config.ID] = config.EmbedConfig
  432. }
  433. return configsMap, nil
  434. }
  435. func SavePublicMCPReusingParam(param *PublicMCPReusingParam) (err error) {
  436. defer func() {
  437. if err == nil {
  438. if err := CacheDeletePublicMCPReusingParam(param.MCPID, param.GroupID); err != nil {
  439. log.Error("cache delete public mcp reusing param error: " + err.Error())
  440. }
  441. }
  442. }()
  443. return DB.Save(param).Error
  444. }
  445. // UpdatePublicMCPReusingParam updates an existing GroupMCPReusingParam
  446. func UpdatePublicMCPReusingParam(param *PublicMCPReusingParam) (err error) {
  447. defer func() {
  448. if err == nil {
  449. if err := CacheDeletePublicMCPReusingParam(param.MCPID, param.GroupID); err != nil {
  450. log.Error("cache delete public mcp reusing param error: " + err.Error())
  451. }
  452. }
  453. }()
  454. result := DB.
  455. Select([]string{
  456. "params",
  457. }).
  458. Where("mcp_id = ? AND group_id = ?", param.MCPID, param.GroupID).
  459. Updates(param)
  460. return HandleUpdateResult(result, ErrMCPReusingParamNotFound)
  461. }
  462. // DeletePublicMCPReusingParam deletes a GroupMCPReusingParam
  463. func DeletePublicMCPReusingParam(mcpID, groupID string) (err error) {
  464. defer func() {
  465. if err == nil {
  466. if err := CacheDeletePublicMCPReusingParam(mcpID, groupID); err != nil {
  467. log.Error("cache delete public mcp reusing param error: " + err.Error())
  468. }
  469. }
  470. }()
  471. if mcpID == "" || groupID == "" {
  472. return errors.New("MCP ID or Group ID is empty")
  473. }
  474. result := DB.
  475. Where("mcp_id = ? AND group_id = ?", mcpID, groupID).
  476. Delete(&PublicMCPReusingParam{})
  477. return HandleUpdateResult(result, ErrMCPReusingParamNotFound)
  478. }
  479. // GetPublicMCPReusingParam retrieves a GroupMCPReusingParam by MCP ID and Group ID
  480. func GetPublicMCPReusingParam(mcpID, groupID string) (PublicMCPReusingParam, error) {
  481. if mcpID == "" || groupID == "" {
  482. return PublicMCPReusingParam{}, errors.New("MCP ID or Group ID is empty")
  483. }
  484. var param PublicMCPReusingParam
  485. err := DB.Where("mcp_id = ? AND group_id = ?", mcpID, groupID).First(&param).Error
  486. return param, HandleNotFound(err, ErrMCPReusingParamNotFound)
  487. }