123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- package service
- import (
- "fmt"
- "gorm.io/gorm"
- "x-ui/database"
- "x-ui/database/model"
- "x-ui/util/common"
- "x-ui/xray"
- )
- type InboundService struct {
- }
- func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) {
- db := database.GetDB()
- var inbounds []*model.Inbound
- err := db.Model(model.Inbound{}).Where("user_id = ?", userId).Find(&inbounds).Error
- if err != nil && err != gorm.ErrRecordNotFound {
- return nil, err
- }
- return inbounds, nil
- }
- func (s *InboundService) GetAllInbounds() ([]*model.Inbound, error) {
- db := database.GetDB()
- var inbounds []*model.Inbound
- err := db.Model(model.Inbound{}).Find(&inbounds).Error
- if err != nil && err != gorm.ErrRecordNotFound {
- return nil, err
- }
- return inbounds, nil
- }
- func (s *InboundService) checkPortExist(port int, ignoreId int) (bool, error) {
- db := database.GetDB()
- db = db.Model(model.Inbound{}).Where("port = ?", port)
- if ignoreId > 0 {
- db = db.Where("id != ?", ignoreId)
- }
- var count int64
- err := db.Count(&count).Error
- if err != nil {
- return false, err
- }
- return count > 0, nil
- }
- func (s *InboundService) AddInbound(inbound *model.Inbound) error {
- exist, err := s.checkPortExist(inbound.Port, 0)
- if err != nil {
- return err
- }
- if exist {
- return common.NewError("端口已存在:", inbound.Port)
- }
- db := database.GetDB()
- return db.Save(inbound).Error
- }
- func (s *InboundService) AddInbounds(inbounds []*model.Inbound) error {
- for _, inbound := range inbounds {
- exist, err := s.checkPortExist(inbound.Port, 0)
- if err != nil {
- return err
- }
- if exist {
- return common.NewError("端口已存在:", inbound.Port)
- }
- }
- db := database.GetDB()
- tx := db.Begin()
- var err error
- defer func() {
- if err == nil {
- tx.Commit()
- } else {
- tx.Rollback()
- }
- }()
- for _, inbound := range inbounds {
- err = tx.Save(inbound).Error
- if err != nil {
- return err
- }
- }
- return nil
- }
- func (s *InboundService) DelInbound(id int) error {
- db := database.GetDB()
- return db.Delete(model.Inbound{}, id).Error
- }
- func (s *InboundService) GetInbound(id int) (*model.Inbound, error) {
- db := database.GetDB()
- inbound := &model.Inbound{}
- err := db.Model(model.Inbound{}).First(inbound, id).Error
- if err != nil {
- return nil, err
- }
- return inbound, nil
- }
- func (s *InboundService) UpdateInbound(inbound *model.Inbound) error {
- exist, err := s.checkPortExist(inbound.Port, inbound.Id)
- if err != nil {
- return err
- }
- if exist {
- return common.NewError("端口已存在:", inbound.Port)
- }
- oldInbound, err := s.GetInbound(inbound.Id)
- if err != nil {
- return err
- }
- oldInbound.Up = inbound.Up
- oldInbound.Down = inbound.Down
- oldInbound.Total = inbound.Total
- oldInbound.Remark = inbound.Remark
- oldInbound.Enable = inbound.Enable
- oldInbound.ExpiryTime = inbound.ExpiryTime
- oldInbound.Listen = inbound.Listen
- oldInbound.Port = inbound.Port
- oldInbound.Protocol = inbound.Protocol
- oldInbound.Settings = inbound.Settings
- oldInbound.StreamSettings = inbound.StreamSettings
- oldInbound.Sniffing = inbound.Sniffing
- oldInbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port)
- db := database.GetDB()
- return db.Save(oldInbound).Error
- }
- func (s *InboundService) AddTraffic(traffics []*xray.Traffic) (err error) {
- if len(traffics) == 0 {
- return nil
- }
- db := database.GetDB()
- db = db.Model(model.Inbound{})
- tx := db.Begin()
- defer func() {
- if err != nil {
- tx.Rollback()
- } else {
- tx.Commit()
- }
- }()
- for _, traffic := range traffics {
- if traffic.IsInbound {
- err = tx.Where("tag = ?", traffic.Tag).
- UpdateColumn("up", gorm.Expr("up + ?", traffic.Up)).
- UpdateColumn("down", gorm.Expr("down + ?", traffic.Down)).
- Error
- if err != nil {
- return
- }
- }
- }
- return
- }
- func (s *InboundService) DisableInvalidInbounds() (int64, error) {
- db := database.GetDB()
- result := db.Model(model.Inbound{}).
- Where("up + down >= total and total > 0 and enable = ?", true).
- Update("enable", false)
- err := result.Error
- count := result.RowsAffected
- return count, err
- }
|