group.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package dbdata
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "time"
  7. "github.com/bjdgyc/anylink/base"
  8. )
  9. const (
  10. Allow = "allow"
  11. Deny = "deny"
  12. All = "all"
  13. )
  14. type GroupLinkAcl struct {
  15. // 自上而下匹配 默认 allow * *
  16. Action string `json:"action"` // allow、deny
  17. Val string `json:"val"`
  18. Port uint16 `json:"port"`
  19. IpNet *net.IPNet `json:"ip_net"`
  20. Note string `json:"note"`
  21. }
  22. type ValData struct {
  23. Val string `json:"val"`
  24. IpMask string `json:"ip_mask"`
  25. Note string `json:"note"`
  26. }
  27. // type Group struct {
  28. // Id int `json:"id" xorm:"pk autoincr not null"`
  29. // Name string `json:"name" xorm:"not null unique"`
  30. // Note string `json:"note"`
  31. // AllowLan bool `json:"allow_lan"`
  32. // ClientDns []ValData `json:"client_dns"`
  33. // RouteInclude []ValData `json:"route_include"`
  34. // RouteExclude []ValData `json:"route_exclude"`
  35. // LinkAcl []GroupLinkAcl `json:"link_acl"`
  36. // Bandwidth int `json:"bandwidth"` // 带宽限制
  37. // Status int8 `json:"status"` // 1正常
  38. // CreatedAt time.Time `json:"created_at"`
  39. // UpdatedAt time.Time `json:"updated_at"`
  40. // }
  41. func GetGroupNames() []string {
  42. var datas []Group
  43. err := Find(&datas, 0, 0)
  44. if err != nil {
  45. base.Error(err)
  46. return nil
  47. }
  48. var names []string
  49. for _, v := range datas {
  50. names = append(names, v.Name)
  51. }
  52. return names
  53. }
  54. func SetGroup(g *Group) error {
  55. var err error
  56. if g.Name == "" {
  57. return errors.New("用户组名错误")
  58. }
  59. // 判断数据
  60. routeInclude := []ValData{}
  61. for _, v := range g.RouteInclude {
  62. if v.Val != "" {
  63. if v.Val == All {
  64. routeInclude = append(routeInclude, v)
  65. continue
  66. }
  67. ipMask, _, err := parseIpNet(v.Val)
  68. if err != nil {
  69. return errors.New("RouteInclude 错误" + err.Error())
  70. }
  71. v.IpMask = ipMask
  72. routeInclude = append(routeInclude, v)
  73. }
  74. }
  75. g.RouteInclude = routeInclude
  76. routeExclude := []ValData{}
  77. for _, v := range g.RouteExclude {
  78. if v.Val != "" {
  79. ipMask, _, err := parseIpNet(v.Val)
  80. if err != nil {
  81. return errors.New("RouteExclude 错误" + err.Error())
  82. }
  83. v.IpMask = ipMask
  84. routeExclude = append(routeExclude, v)
  85. }
  86. }
  87. g.RouteExclude = routeExclude
  88. // 转换数据
  89. linkAcl := []GroupLinkAcl{}
  90. for _, v := range g.LinkAcl {
  91. if v.Val != "" {
  92. _, ipNet, err := parseIpNet(v.Val)
  93. if err != nil {
  94. return errors.New("GroupLinkAcl 错误" + err.Error())
  95. }
  96. v.IpNet = ipNet
  97. linkAcl = append(linkAcl, v)
  98. }
  99. }
  100. g.LinkAcl = linkAcl
  101. // DNS 判断
  102. clientDns := []ValData{}
  103. for _, v := range g.ClientDns {
  104. if v.Val != "" {
  105. ip := net.ParseIP(v.Val)
  106. if ip.String() != v.Val {
  107. return errors.New("DNS IP 错误")
  108. }
  109. clientDns = append(clientDns, v)
  110. }
  111. }
  112. if len(routeInclude) == 0 || (len(routeInclude) == 1 && routeInclude[0].Val == "all") {
  113. if len(clientDns) == 0 {
  114. return errors.New("默认路由,必须设置一个DNS")
  115. }
  116. }
  117. g.ClientDns = clientDns
  118. g.UpdatedAt = time.Now()
  119. if g.Id > 0 {
  120. err = Set(g)
  121. } else {
  122. err = Add(g)
  123. }
  124. return err
  125. }
  126. func parseIpNet(s string) (string, *net.IPNet, error) {
  127. ip, ipNet, err := net.ParseCIDR(s)
  128. if err != nil {
  129. return "", nil, err
  130. }
  131. mask := net.IP(ipNet.Mask)
  132. ipMask := fmt.Sprintf("%s/%s", ip, mask)
  133. return ipMask, ipNet, nil
  134. }