inbound.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. package service
  2. import (
  3. "fmt"
  4. "gorm.io/gorm"
  5. "x-ui/database"
  6. "x-ui/database/model"
  7. "x-ui/util/common"
  8. "x-ui/xray"
  9. )
  10. type InboundService struct {
  11. }
  12. func (s *InboundService) GetInbounds(userId int) ([]*model.Inbound, error) {
  13. db := database.GetDB()
  14. var inbounds []*model.Inbound
  15. err := db.Model(model.Inbound{}).Where("user_id = ?", userId).Find(&inbounds).Error
  16. if err != nil && err != gorm.ErrRecordNotFound {
  17. return nil, err
  18. }
  19. return inbounds, nil
  20. }
  21. func (s *InboundService) GetAllInbounds() ([]*model.Inbound, error) {
  22. db := database.GetDB()
  23. var inbounds []*model.Inbound
  24. err := db.Model(model.Inbound{}).Find(&inbounds).Error
  25. if err != nil && err != gorm.ErrRecordNotFound {
  26. return nil, err
  27. }
  28. return inbounds, nil
  29. }
  30. func (s *InboundService) checkPortExist(port int, ignoreId int) (bool, error) {
  31. db := database.GetDB()
  32. db = db.Model(model.Inbound{}).Where("port = ?", port)
  33. if ignoreId > 0 {
  34. db = db.Where("id != ?", ignoreId)
  35. }
  36. var count int64
  37. err := db.Count(&count).Error
  38. if err != nil {
  39. return false, err
  40. }
  41. return count > 0, nil
  42. }
  43. func (s *InboundService) AddInbound(inbound *model.Inbound) error {
  44. exist, err := s.checkPortExist(inbound.Port, 0)
  45. if err != nil {
  46. return err
  47. }
  48. if exist {
  49. return common.NewError("端口已存在:", inbound.Port)
  50. }
  51. db := database.GetDB()
  52. return db.Save(inbound).Error
  53. }
  54. func (s *InboundService) AddInbounds(inbounds []*model.Inbound) error {
  55. for _, inbound := range inbounds {
  56. exist, err := s.checkPortExist(inbound.Port, 0)
  57. if err != nil {
  58. return err
  59. }
  60. if exist {
  61. return common.NewError("端口已存在:", inbound.Port)
  62. }
  63. }
  64. db := database.GetDB()
  65. tx := db.Begin()
  66. var err error
  67. defer func() {
  68. if err == nil {
  69. tx.Commit()
  70. } else {
  71. tx.Rollback()
  72. }
  73. }()
  74. for _, inbound := range inbounds {
  75. err = tx.Save(inbound).Error
  76. if err != nil {
  77. return err
  78. }
  79. }
  80. return nil
  81. }
  82. func (s *InboundService) DelInbound(id int) error {
  83. db := database.GetDB()
  84. return db.Delete(model.Inbound{}, id).Error
  85. }
  86. func (s *InboundService) GetInbound(id int) (*model.Inbound, error) {
  87. db := database.GetDB()
  88. inbound := &model.Inbound{}
  89. err := db.Model(model.Inbound{}).First(inbound, id).Error
  90. if err != nil {
  91. return nil, err
  92. }
  93. return inbound, nil
  94. }
  95. func (s *InboundService) UpdateInbound(inbound *model.Inbound) error {
  96. exist, err := s.checkPortExist(inbound.Port, inbound.Id)
  97. if err != nil {
  98. return err
  99. }
  100. if exist {
  101. return common.NewError("端口已存在:", inbound.Port)
  102. }
  103. oldInbound, err := s.GetInbound(inbound.Id)
  104. if err != nil {
  105. return err
  106. }
  107. oldInbound.Up = inbound.Up
  108. oldInbound.Down = inbound.Down
  109. oldInbound.Total = inbound.Total
  110. oldInbound.Remark = inbound.Remark
  111. oldInbound.Enable = inbound.Enable
  112. oldInbound.ExpiryTime = inbound.ExpiryTime
  113. oldInbound.Listen = inbound.Listen
  114. oldInbound.Port = inbound.Port
  115. oldInbound.Protocol = inbound.Protocol
  116. oldInbound.Settings = inbound.Settings
  117. oldInbound.StreamSettings = inbound.StreamSettings
  118. oldInbound.Sniffing = inbound.Sniffing
  119. oldInbound.Tag = fmt.Sprintf("inbound-%v", inbound.Port)
  120. db := database.GetDB()
  121. return db.Save(oldInbound).Error
  122. }
  123. func (s *InboundService) AddTraffic(traffics []*xray.Traffic) (err error) {
  124. if len(traffics) == 0 {
  125. return nil
  126. }
  127. db := database.GetDB()
  128. db = db.Model(model.Inbound{})
  129. tx := db.Begin()
  130. defer func() {
  131. if err != nil {
  132. tx.Rollback()
  133. } else {
  134. tx.Commit()
  135. }
  136. }()
  137. for _, traffic := range traffics {
  138. if traffic.IsInbound {
  139. err = tx.Where("tag = ?", traffic.Tag).
  140. UpdateColumn("up", gorm.Expr("up + ?", traffic.Up)).
  141. UpdateColumn("down", gorm.Expr("down + ?", traffic.Down)).
  142. Error
  143. if err != nil {
  144. return
  145. }
  146. }
  147. }
  148. return
  149. }
  150. func (s *InboundService) DisableInvalidInbounds() (int64, error) {
  151. db := database.GetDB()
  152. result := db.Model(model.Inbound{}).
  153. Where("up + down >= total and total > 0 and enable = ?", true).
  154. Update("enable", false)
  155. err := result.Error
  156. count := result.RowsAffected
  157. return count, err
  158. }