| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596 |
- package model
- import (
- "errors"
- "fmt"
- "net/url"
- "regexp"
- "strings"
- "time"
- "github.com/bytedance/sonic"
- "github.com/labring/aiproxy/core/common"
- log "github.com/sirupsen/logrus"
- "gorm.io/gorm"
- )
- type PublicMCPStatus int
- const (
- PublicMCPStatusEnabled PublicMCPStatus = iota + 1
- PublicMCPStatusDisabled
- )
- const (
- ErrPublicMCPNotFound = "public mcp"
- ErrMCPReusingParamNotFound = "mcp reusing param"
- )
- type PublicMCPType string
- const (
- PublicMCPTypeProxySSE PublicMCPType = "mcp_proxy_sse"
- PublicMCPTypeProxyStreamable PublicMCPType = "mcp_proxy_streamable"
- PublicMCPTypeDocs PublicMCPType = "mcp_docs" // read only
- PublicMCPTypeOpenAPI PublicMCPType = "mcp_openapi"
- PublicMCPTypeEmbed PublicMCPType = "mcp_embed"
- )
- type ProxyParamType string
- const (
- ParamTypeURL ProxyParamType = "url"
- ParamTypeHeader ProxyParamType = "header"
- ParamTypeQuery ProxyParamType = "query"
- )
- type ReusingParam struct {
- Name string `json:"name"`
- Description string `json:"description"`
- Required bool `json:"required"`
- }
- type MCPPrice struct {
- DefaultToolsCallPrice float64 `json:"default_tools_call_price"`
- ToolsCallPrices map[string]float64 `json:"tools_call_prices" gorm:"serializer:fastjson;type:text"`
- }
- type PublicMCPProxyReusingParam struct {
- ReusingParam
- Type ProxyParamType `json:"type"`
- }
- type PublicMCPProxyConfig struct {
- URL string `json:"url"`
- Querys map[string]string `json:"querys"`
- Headers map[string]string `json:"headers"`
- Reusing map[string]PublicMCPProxyReusingParam `json:"reusing"`
- }
- type Params = map[string]string
- type PublicMCPReusingParam struct {
- MCPID string `gorm:"primaryKey" json:"mcp_id"`
- GroupID string `gorm:"primaryKey" json:"group_id"`
- CreatedAt time.Time `gorm:"index" json:"created_at"`
- UpdateAt time.Time `gorm:"index" json:"update_at"`
- Group *Group `gorm:"foreignKey:GroupID" json:"-"`
- Params Params `gorm:"serializer:fastjson;type:text" json:"params"`
- }
- func (p *PublicMCPReusingParam) BeforeCreate(_ *gorm.DB) (err error) {
- if p.MCPID == "" {
- return errors.New("mcp id is empty")
- }
- if p.GroupID == "" {
- return errors.New("group is empty")
- }
- return err
- }
- func (p *PublicMCPReusingParam) MarshalJSON() ([]byte, error) {
- type Alias PublicMCPReusingParam
- a := &struct {
- *Alias
- CreatedAt int64 `json:"created_at"`
- UpdateAt int64 `json:"update_at"`
- }{
- Alias: (*Alias)(p),
- CreatedAt: p.CreatedAt.UnixMilli(),
- UpdateAt: p.UpdateAt.UnixMilli(),
- }
- return sonic.Marshal(a)
- }
- type MCPOpenAPIConfig struct {
- OpenAPISpec string `json:"openapi_spec"`
- OpenAPIContent string `json:"openapi_content,omitempty"`
- V2 bool `json:"v2"`
- ServerAddr string `json:"server_addr,omitempty"`
- Authorization string `json:"authorization,omitempty"`
- }
- type MCPEmbeddingConfig struct {
- Init map[string]string `json:"init"`
- Reusing map[string]ReusingParam `json:"reusing"`
- }
- var validateMCPIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
- func validateMCPID(id string) error {
- if id == "" {
- return errors.New("mcp id is empty")
- }
- if !validateMCPIDRegex.MatchString(id) {
- return errors.New("mcp id is invalid")
- }
- return nil
- }
- type TestConfig struct {
- Enabled bool `json:"enabled"`
- Params Params `json:"params"`
- }
- type PublicMCP struct {
- ID string `gorm:"primaryKey" json:"id"`
- CreatedAt time.Time `gorm:"index,autoCreateTime" json:"created_at"`
- UpdateAt time.Time `gorm:"index,autoUpdateTime" json:"update_at"`
- PublicMCPReusingParams []PublicMCPReusingParam `gorm:"foreignKey:MCPID" json:"-"`
- Name string `json:"name"`
- NameCN string `json:"name_cn,omitempty"`
- Status PublicMCPStatus `json:"status" gorm:"index;default:1"`
- Type PublicMCPType `json:"type,omitempty" gorm:"index"`
- Description string `json:"description"`
- DescriptionCN string `json:"description_cn,omitempty"`
- GitHubURL string `json:"github_url"`
- Readme string `json:"readme,omitempty" gorm:"type:text"`
- ReadmeCN string `json:"readme_cn,omitempty" gorm:"type:text"`
- ReadmeURL string `json:"readme_url,omitempty"`
- ReadmeCNURL string `json:"readme_cn_url,omitempty"`
- Tags []string `json:"tags,omitempty" gorm:"serializer:fastjson;type:text"`
- LogoURL string `json:"logo_url,omitempty"`
- Price MCPPrice `json:"price" gorm:"embedded"`
- ProxyConfig *PublicMCPProxyConfig `gorm:"serializer:fastjson;type:text" json:"proxy_config,omitempty"`
- OpenAPIConfig *MCPOpenAPIConfig `gorm:"serializer:fastjson;type:text" json:"openapi_config,omitempty"`
- EmbedConfig *MCPEmbeddingConfig `gorm:"serializer:fastjson;type:text" json:"embed_config,omitempty"`
- // only used by list tools
- TestConfig *TestConfig `gorm:"serializer:fastjson;type:text" json:"test_config,omitempty"`
- }
- func (p *PublicMCP) BeforeCreate(_ *gorm.DB) error {
- if err := validateMCPID(p.ID); err != nil {
- return err
- }
- if p.Status == 0 {
- p.Status = PublicMCPStatusEnabled
- }
- return nil
- }
- func (p *PublicMCP) BeforeSave(_ *gorm.DB) error {
- if p.OpenAPIConfig != nil {
- config := p.OpenAPIConfig
- if config.OpenAPISpec != "" {
- return validateHTTPURL(config.OpenAPISpec)
- }
- if config.OpenAPIContent != "" {
- return nil
- }
- return errors.New("openapi spec and content is empty")
- }
- if p.ProxyConfig != nil {
- config := p.ProxyConfig
- return validateHTTPURL(config.URL)
- }
- return nil
- }
- func validateHTTPURL(str string) error {
- if str == "" {
- return errors.New("url is empty")
- }
- u, err := url.Parse(str)
- if err != nil {
- return err
- }
- if u.Scheme != "http" && u.Scheme != "https" {
- return errors.New("url scheme not support")
- }
- return nil
- }
- func (p *PublicMCP) BeforeDelete(tx *gorm.DB) (err error) {
- return tx.Model(&PublicMCPReusingParam{}).
- Where("mcp_id = ?", p.ID).
- Delete(&PublicMCPReusingParam{}).
- Error
- }
- // CreatePublicMCP creates a new MCP
- func CreatePublicMCP(mcp *PublicMCP) error {
- err := DB.Create(mcp).Error
- if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) {
- return errors.New("mcp server already exist")
- }
- return err
- }
- func SavePublicMCP(mcp *PublicMCP) (err error) {
- defer func() {
- if err == nil {
- if err := CacheDeletePublicMCP(mcp.ID); err != nil {
- log.Error("cache delete public mcp error: " + err.Error())
- }
- }
- }()
- return DB.
- Omit(
- "created_at",
- "update_at",
- ).
- Save(mcp).Error
- }
- func SavePublicMCPs(msps []PublicMCP) (err error) {
- defer func() {
- if err == nil {
- for _, mcp := range msps {
- if err := CacheDeletePublicMCP(mcp.ID); err != nil {
- log.Error("cache delete public mcp error: " + err.Error())
- }
- }
- }
- }()
- return DB.
- Omit(
- "created_at",
- "update_at",
- ).
- Save(msps).Error
- }
- // UpdatePublicMCP updates an existing MCP
- func UpdatePublicMCP(mcp *PublicMCP) (err error) {
- defer func() {
- if err == nil {
- if err := CacheDeletePublicMCP(mcp.ID); err != nil {
- log.Error("cache delete public mcp error: " + err.Error())
- }
- }
- }()
- selects := []string{
- "github_url",
- "description",
- "description_cn",
- "readme",
- "readme_cn",
- "readme_url",
- "readme_cn_url",
- "tags",
- "logo_url",
- "proxy_config",
- "openapi_config",
- "embed_config",
- "test_config",
- }
- if mcp.Status != 0 {
- selects = append(selects, "status")
- }
- if mcp.Name != "" {
- selects = append(selects, "name")
- }
- if mcp.NameCN != "" {
- selects = append(selects, "name_cn")
- }
- if mcp.Type != "" {
- selects = append(selects, "type")
- }
- if mcp.Price.DefaultToolsCallPrice != 0 ||
- len(mcp.Price.ToolsCallPrices) != 0 {
- selects = append(selects, "price")
- }
- result := DB.
- Select(selects).
- Where("id = ?", mcp.ID).
- Updates(mcp)
- return HandleUpdateResult(result, ErrPublicMCPNotFound)
- }
- func UpdatePublicMCPStatus(id string, status PublicMCPStatus) (err error) {
- defer func() {
- if err == nil {
- if err := CacheUpdatePublicMCPStatus(id, status); err != nil {
- log.Error("cache update public mcp status error: " + err.Error())
- }
- }
- }()
- result := DB.Model(&PublicMCP{}).Where("id = ?", id).Update("status", status)
- return HandleUpdateResult(result, ErrPublicMCPNotFound)
- }
- // DeletePublicMCP deletes an MCP by ID
- func DeletePublicMCP(id string) (err error) {
- defer func() {
- if err == nil {
- if err := CacheDeletePublicMCP(id); err != nil {
- log.Error("cache delete public mcp error: " + err.Error())
- }
- }
- }()
- if id == "" {
- return errors.New("MCP id is empty")
- }
- result := DB.Delete(&PublicMCP{ID: id})
- return HandleUpdateResult(result, ErrPublicMCPNotFound)
- }
- // GetPublicMCPByID retrieves an MCP by ID
- func GetPublicMCPByID(id string) (PublicMCP, error) {
- var mcp PublicMCP
- if id == "" {
- return mcp, errors.New("MCP id is empty")
- }
- err := DB.Where("id = ?", id).First(&mcp).Error
- return mcp, HandleNotFound(err, ErrPublicMCPNotFound)
- }
- // GetPublicMCPs retrieves MCPs with pagination and filtering
- func GetPublicMCPs(
- page, perPage int,
- id string,
- mcpType []PublicMCPType,
- keyword string,
- status PublicMCPStatus,
- ) (mcps []PublicMCP, total int64, err error) {
- tx := DB.Model(&PublicMCP{})
- if id != "" {
- tx = tx.Where("id = ?", id)
- }
- if status != 0 {
- tx = tx.Where("status = ?", status)
- }
- if len(mcpType) > 0 {
- tx = tx.Where("type IN (?)", mcpType)
- }
- if keyword != "" {
- var (
- conditions []string
- values []any
- )
- if id == "" {
- if !common.UsingSQLite {
- conditions = append(conditions, "id ILIKE ?")
- values = append(values, "%"+keyword+"%")
- } else {
- conditions = append(conditions, "id LIKE ?")
- values = append(values, "%"+keyword+"%")
- }
- }
- if !common.UsingSQLite {
- conditions = append(conditions, "name ILIKE ?")
- values = append(values, "%"+keyword+"%")
- } else {
- conditions = append(conditions, "name LIKE ?")
- values = append(values, "%"+keyword+"%")
- }
- if !common.UsingSQLite {
- conditions = append(conditions, "name_cn ILIKE ?")
- values = append(values, "%"+keyword+"%")
- } else {
- conditions = append(conditions, "name_cn LIKE ?")
- values = append(values, "%"+keyword+"%")
- }
- if !common.UsingSQLite {
- conditions = append(conditions, "description ILIKE ?")
- values = append(values, "%"+keyword+"%")
- } else {
- conditions = append(conditions, "description LIKE ?")
- values = append(values, "%"+keyword+"%")
- }
- if !common.UsingSQLite {
- conditions = append(conditions, "description_cn ILIKE ?")
- values = append(values, "%"+keyword+"%")
- } else {
- conditions = append(conditions, "description_cn LIKE ?")
- values = append(values, "%"+keyword+"%")
- }
- if !common.UsingSQLite {
- conditions = append(conditions, "readme ILIKE ?")
- values = append(values, "%"+keyword+"%")
- } else {
- conditions = append(conditions, "readme LIKE ?")
- values = append(values, "%"+keyword+"%")
- }
- if !common.UsingSQLite {
- conditions = append(conditions, "readme_cn ILIKE ?")
- values = append(values, "%"+keyword+"%")
- } else {
- conditions = append(conditions, "readme_cn LIKE ?")
- values = append(values, "%"+keyword+"%")
- }
- if len(conditions) > 0 {
- tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
- }
- }
- err = tx.Count(&total).Error
- if err != nil {
- return nil, 0, err
- }
- if total <= 0 {
- return nil, 0, nil
- }
- limit, offset := toLimitOffset(page, perPage)
- err = tx.
- Limit(limit).
- Offset(offset).
- Find(&mcps).
- Error
- return mcps, total, err
- }
- func GetAllPublicMCPs(status PublicMCPStatus) ([]PublicMCP, error) {
- var mcps []PublicMCP
- tx := DB.Model(&PublicMCP{})
- if status != 0 {
- tx = tx.Where("status = ?", status)
- }
- err := tx.Find(&mcps).Error
- return mcps, err
- }
- func GetPublicMCPsEnabled(ids []string) ([]string, error) {
- var mcpIDs []string
- err := DB.Model(&PublicMCP{}).
- Select("id").
- Where("id IN (?) AND status = ?", ids, PublicMCPStatusEnabled).
- Pluck("id", &mcpIDs).
- Error
- if err != nil {
- return nil, err
- }
- return mcpIDs, nil
- }
- func GetPublicMCPsEmbedConfig(ids []string) (map[string]MCPEmbeddingConfig, error) {
- var configs []struct {
- ID string
- EmbedConfig MCPEmbeddingConfig `gorm:"serializer:fastjson;type:text"`
- }
- err := DB.Model(&PublicMCP{}).
- Select("id, embed_config").
- Where("id IN (?)", ids).
- Find(&configs).Error
- if err != nil {
- return nil, err
- }
- configsMap := make(map[string]MCPEmbeddingConfig)
- for _, config := range configs {
- configsMap[config.ID] = config.EmbedConfig
- }
- return configsMap, nil
- }
- func SavePublicMCPReusingParam(param *PublicMCPReusingParam) (err error) {
- defer func() {
- if err == nil {
- if err := CacheDeletePublicMCPReusingParam(param.MCPID, param.GroupID); err != nil {
- log.Error("cache delete public mcp reusing param error: " + err.Error())
- }
- }
- }()
- return DB.Save(param).Error
- }
- // UpdatePublicMCPReusingParam updates an existing GroupMCPReusingParam
- func UpdatePublicMCPReusingParam(param *PublicMCPReusingParam) (err error) {
- defer func() {
- if err == nil {
- if err := CacheDeletePublicMCPReusingParam(param.MCPID, param.GroupID); err != nil {
- log.Error("cache delete public mcp reusing param error: " + err.Error())
- }
- }
- }()
- result := DB.
- Select([]string{
- "params",
- }).
- Where("mcp_id = ? AND group_id = ?", param.MCPID, param.GroupID).
- Updates(param)
- return HandleUpdateResult(result, ErrMCPReusingParamNotFound)
- }
- // DeletePublicMCPReusingParam deletes a GroupMCPReusingParam
- func DeletePublicMCPReusingParam(mcpID, groupID string) (err error) {
- defer func() {
- if err == nil {
- if err := CacheDeletePublicMCPReusingParam(mcpID, groupID); err != nil {
- log.Error("cache delete public mcp reusing param error: " + err.Error())
- }
- }
- }()
- if mcpID == "" || groupID == "" {
- return errors.New("MCP ID or Group ID is empty")
- }
- result := DB.
- Where("mcp_id = ? AND group_id = ?", mcpID, groupID).
- Delete(&PublicMCPReusingParam{})
- return HandleUpdateResult(result, ErrMCPReusingParamNotFound)
- }
- // GetPublicMCPReusingParam retrieves a GroupMCPReusingParam by MCP ID and Group ID
- func GetPublicMCPReusingParam(mcpID, groupID string) (PublicMCPReusingParam, error) {
- if mcpID == "" || groupID == "" {
- return PublicMCPReusingParam{}, errors.New("MCP ID or Group ID is empty")
- }
- var param PublicMCPReusingParam
- err := DB.Where("mcp_id = ? AND group_id = ?", mcpID, groupID).First(¶m).Error
- return param, HandleNotFound(err, ErrMCPReusingParamNotFound)
- }
|