policy.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package dbdata
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "strings"
  7. "time"
  8. )
  9. func GetPolicy(Username string) *Policy {
  10. policyData := &Policy{}
  11. err := One("Username", Username, policyData)
  12. if err != nil {
  13. return policyData
  14. }
  15. return policyData
  16. }
  17. func SetPolicy(p *Policy) error {
  18. var err error
  19. if p.Username == "" {
  20. return errors.New("用户名错误")
  21. }
  22. // 包含路由
  23. routeInclude := []ValData{}
  24. for _, v := range p.RouteInclude {
  25. if v.Val != "" {
  26. if v.Val == All {
  27. routeInclude = append(routeInclude, v)
  28. continue
  29. }
  30. ipMask, ipNet, err := parseIpNet(v.Val)
  31. if err != nil {
  32. return errors.New("RouteInclude 错误" + err.Error())
  33. }
  34. if strings.Split(ipMask, "/")[0] != ipNet.IP.String() {
  35. errMsg := fmt.Sprintf("RouteInclude 错误: 网络地址错误,建议: %s 改为 %s", v.Val, ipNet)
  36. return errors.New(errMsg)
  37. }
  38. v.IpMask = ipMask
  39. routeInclude = append(routeInclude, v)
  40. }
  41. }
  42. p.RouteInclude = routeInclude
  43. // 排除路由
  44. routeExclude := []ValData{}
  45. for _, v := range p.RouteExclude {
  46. if v.Val != "" {
  47. ipMask, ipNet, err := parseIpNet(v.Val)
  48. if err != nil {
  49. return errors.New("RouteExclude 错误" + err.Error())
  50. }
  51. if strings.Split(ipMask, "/")[0] != ipNet.IP.String() {
  52. errMsg := fmt.Sprintf("RouteInclude 错误: 网络地址错误,建议: %s 改为 %s", v.Val, ipNet)
  53. return errors.New(errMsg)
  54. }
  55. v.IpMask = ipMask
  56. routeExclude = append(routeExclude, v)
  57. }
  58. }
  59. p.RouteExclude = routeExclude
  60. // DNS 判断
  61. clientDns := []ValData{}
  62. for _, v := range p.ClientDns {
  63. if v.Val != "" {
  64. ip := net.ParseIP(v.Val)
  65. if ip.String() != v.Val {
  66. return errors.New("DNS IP 错误")
  67. }
  68. clientDns = append(clientDns, v)
  69. }
  70. }
  71. if len(routeInclude) == 0 || (len(routeInclude) == 1 && routeInclude[0].Val == "all") {
  72. if len(clientDns) == 0 {
  73. return errors.New("默认路由,必须设置一个DNS")
  74. }
  75. }
  76. p.ClientDns = clientDns
  77. // 域名拆分隧道,不能同时填写
  78. p.DsIncludeDomains = strings.TrimSpace(p.DsIncludeDomains)
  79. p.DsExcludeDomains = strings.TrimSpace(p.DsExcludeDomains)
  80. if p.DsIncludeDomains != "" && p.DsExcludeDomains != "" {
  81. return errors.New("包含/排除域名不能同时填写")
  82. }
  83. // 校验包含域名的格式
  84. err = CheckDomainNames(p.DsIncludeDomains)
  85. if err != nil {
  86. return errors.New("包含域名有误:" + err.Error())
  87. }
  88. // 校验排除域名的格式
  89. err = CheckDomainNames(p.DsExcludeDomains)
  90. if err != nil {
  91. return errors.New("排除域名有误:" + err.Error())
  92. }
  93. p.UpdatedAt = time.Now()
  94. if p.Id > 0 {
  95. err = Set(p)
  96. } else {
  97. err = Add(p)
  98. }
  99. return err
  100. }