group.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. package dbdata
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "regexp"
  7. "strings"
  8. "time"
  9. "github.com/bjdgyc/anylink/base"
  10. "golang.org/x/text/language"
  11. "golang.org/x/text/message"
  12. )
  13. const (
  14. Allow = "allow"
  15. Deny = "deny"
  16. All = "all"
  17. )
  18. // 域名分流最大字符2万
  19. const DsMaxLen = 20000
  20. type GroupLinkAcl struct {
  21. // 自上而下匹配 默认 allow * *
  22. Action string `json:"action"` // allow、deny
  23. Val string `json:"val"`
  24. Port uint16 `json:"port"`
  25. IpNet *net.IPNet `json:"ip_net"`
  26. Note string `json:"note"`
  27. }
  28. type ValData struct {
  29. Val string `json:"val"`
  30. IpMask string `json:"ip_mask"`
  31. Note string `json:"note"`
  32. }
  33. type GroupNameId struct {
  34. Id int `json:"id"`
  35. Name string `json:"name"`
  36. }
  37. // type Group struct {
  38. // Id int `json:"id" xorm:"pk autoincr not null"`
  39. // Name string `json:"name" xorm:"varchar(60) not null unique"`
  40. // Note string `json:"note" xorm:"varchar(255)"`
  41. // AllowLan bool `json:"allow_lan" xorm:"Bool"`
  42. // ClientDns []ValData `json:"client_dns" xorm:"Text"`
  43. // RouteInclude []ValData `json:"route_include" xorm:"Text"`
  44. // RouteExclude []ValData `json:"route_exclude" xorm:"Text"`
  45. // DsExcludeDomains string `json:"ds_exclude_domains" xorm:"Text"`
  46. // DsIncludeDomains string `json:"ds_include_domains" xorm:"Text"`
  47. // LinkAcl []GroupLinkAcl `json:"link_acl" xorm:"Text"`
  48. // Bandwidth int `json:"bandwidth" xorm:"Int"` // 带宽限制
  49. // Auth map[string]interface{} `json:"auth" xorm:"not null default '{}' varchar(255)"` // 认证方式
  50. // Status int8 `json:"status" xorm:"Int"` // 1正常
  51. // CreatedAt time.Time `json:"created_at" xorm:"DateTime created"`
  52. // UpdatedAt time.Time `json:"updated_at" xorm:"DateTime updated"`
  53. // }
  54. func GetGroupNames() []string {
  55. var datas []Group
  56. err := Find(&datas, 0, 0)
  57. if err != nil {
  58. base.Error(err)
  59. return nil
  60. }
  61. var names []string
  62. for _, v := range datas {
  63. names = append(names, v.Name)
  64. }
  65. return names
  66. }
  67. func GetGroupNamesNormal() []string {
  68. var datas []Group
  69. err := FindWhere(&datas, 0, 0, "status=1")
  70. if err != nil {
  71. base.Error(err)
  72. return nil
  73. }
  74. var names []string
  75. for _, v := range datas {
  76. names = append(names, v.Name)
  77. }
  78. return names
  79. }
  80. func GetGroupNamesIds() []GroupNameId {
  81. var datas []Group
  82. err := Find(&datas, 0, 0)
  83. if err != nil {
  84. base.Error(err)
  85. return nil
  86. }
  87. var names []GroupNameId
  88. for _, v := range datas {
  89. names = append(names, GroupNameId{Id: v.Id, Name: v.Name})
  90. }
  91. return names
  92. }
  93. func SetGroup(g *Group) error {
  94. var err error
  95. if g.Name == "" {
  96. return errors.New("用户组名错误")
  97. }
  98. // 判断数据
  99. routeInclude := []ValData{}
  100. for _, v := range g.RouteInclude {
  101. if v.Val != "" {
  102. if v.Val == All {
  103. routeInclude = append(routeInclude, v)
  104. continue
  105. }
  106. ipMask, ipNet, err := parseIpNet(v.Val)
  107. if err != nil {
  108. return errors.New("RouteInclude 错误" + err.Error())
  109. }
  110. // 给Mac系统下发路由时,必须是标准的网络地址
  111. if strings.Split(ipMask, "/")[0] != ipNet.IP.String() {
  112. errMsg := fmt.Sprintf("RouteInclude 错误: 网络地址错误,建议: %s 改为 %s", v.Val, ipNet)
  113. return errors.New(errMsg)
  114. }
  115. v.IpMask = ipMask
  116. routeInclude = append(routeInclude, v)
  117. }
  118. }
  119. g.RouteInclude = routeInclude
  120. routeExclude := []ValData{}
  121. for _, v := range g.RouteExclude {
  122. if v.Val != "" {
  123. ipMask, ipNet, err := parseIpNet(v.Val)
  124. if err != nil {
  125. return errors.New("RouteExclude 错误" + err.Error())
  126. }
  127. if strings.Split(ipMask, "/")[0] != ipNet.IP.String() {
  128. errMsg := fmt.Sprintf("RouteInclude 错误: 网络地址错误,建议: %s 改为 %s", v.Val, ipNet)
  129. return errors.New(errMsg)
  130. }
  131. v.IpMask = ipMask
  132. routeExclude = append(routeExclude, v)
  133. }
  134. }
  135. g.RouteExclude = routeExclude
  136. // 转换数据
  137. linkAcl := []GroupLinkAcl{}
  138. for _, v := range g.LinkAcl {
  139. if v.Val != "" {
  140. _, ipNet, err := parseIpNet(v.Val)
  141. if err != nil {
  142. return errors.New("GroupLinkAcl 错误" + err.Error())
  143. }
  144. v.IpNet = ipNet
  145. linkAcl = append(linkAcl, v)
  146. }
  147. }
  148. g.LinkAcl = linkAcl
  149. // DNS 判断
  150. clientDns := []ValData{}
  151. for _, v := range g.ClientDns {
  152. if v.Val != "" {
  153. ip := net.ParseIP(v.Val)
  154. if ip.String() != v.Val {
  155. return errors.New("DNS IP 错误")
  156. }
  157. clientDns = append(clientDns, v)
  158. }
  159. }
  160. // 是否默认路由
  161. isDefRoute := len(routeInclude) == 0 || (len(routeInclude) == 1 && routeInclude[0].Val == "all")
  162. if isDefRoute && len(clientDns) == 0 {
  163. return errors.New("默认路由,必须设置一个DNS")
  164. }
  165. g.ClientDns = clientDns
  166. // 域名拆分隧道,不能同时填写
  167. g.DsIncludeDomains = strings.TrimSpace(g.DsIncludeDomains)
  168. g.DsExcludeDomains = strings.TrimSpace(g.DsExcludeDomains)
  169. if g.DsIncludeDomains != "" && g.DsExcludeDomains != "" {
  170. return errors.New("包含/排除域名不能同时填写")
  171. }
  172. // 校验包含域名的格式
  173. err = CheckDomainNames(g.DsIncludeDomains)
  174. if err != nil {
  175. return errors.New("包含域名有误:" + err.Error())
  176. }
  177. // 校验排除域名的格式
  178. err = CheckDomainNames(g.DsExcludeDomains)
  179. if err != nil {
  180. return errors.New("排除域名有误:" + err.Error())
  181. }
  182. if isDefRoute && g.DsIncludeDomains != "" {
  183. return errors.New("默认路由, 不允许设置\"包含域名\", 请重新配置")
  184. }
  185. // 处理登入方式的逻辑
  186. defAuth := map[string]interface{}{
  187. "type": "local",
  188. }
  189. if len(g.Auth) == 0 {
  190. g.Auth = defAuth
  191. }
  192. authType := g.Auth["type"].(string)
  193. if authType == "local" {
  194. g.Auth = defAuth
  195. } else {
  196. if _, ok := authRegistry[authType]; !ok {
  197. return errors.New("未知的认证方式: " + authType)
  198. }
  199. auth := makeInstance(authType).(IUserAuth)
  200. err = auth.checkData(g.Auth)
  201. if err != nil {
  202. return err
  203. }
  204. // 重置Auth, 删除多余的key
  205. g.Auth = map[string]interface{}{
  206. "type": authType,
  207. authType: g.Auth[authType],
  208. }
  209. }
  210. g.UpdatedAt = time.Now()
  211. if g.Id > 0 {
  212. err = Set(g)
  213. } else {
  214. err = Add(g)
  215. }
  216. return err
  217. }
  218. func GroupAuthLogin(name, pwd string, authData map[string]interface{}) error {
  219. g := &Group{Auth: authData}
  220. authType := g.Auth["type"].(string)
  221. if _, ok := authRegistry[authType]; !ok {
  222. return errors.New("未知的认证方式: " + authType)
  223. }
  224. auth := makeInstance(authType).(IUserAuth)
  225. err := auth.checkData(g.Auth)
  226. if err != nil {
  227. return err
  228. }
  229. err = auth.checkUser(name, pwd, g)
  230. return err
  231. }
  232. func parseIpNet(s string) (string, *net.IPNet, error) {
  233. ip, ipNet, err := net.ParseCIDR(s)
  234. if err != nil {
  235. return "", nil, err
  236. }
  237. mask := net.IP(ipNet.Mask)
  238. ipMask := fmt.Sprintf("%s/%s", ip, mask)
  239. return ipMask, ipNet, nil
  240. }
  241. func CheckDomainNames(domains string) error {
  242. if domains == "" {
  243. return nil
  244. }
  245. strLen := 0
  246. str_slice := strings.Split(domains, ",")
  247. for _, val := range str_slice {
  248. if val == "" {
  249. return errors.New(val + " 请以逗号分隔域名")
  250. }
  251. if !ValidateDomainName(val) {
  252. return errors.New(val + " 域名有误")
  253. }
  254. strLen += len(val)
  255. }
  256. if strLen > DsMaxLen {
  257. p := message.NewPrinter(language.English)
  258. return fmt.Errorf("字符长度超出限制,最大%s个(不包含逗号), 请删减一些域名", p.Sprintf("%d", DsMaxLen))
  259. }
  260. return nil
  261. }
  262. func ValidateDomainName(domain string) bool {
  263. RegExp := regexp.MustCompile(`^([a-zA-Z0-9][-a-zA-Z0-9]{0,62}\.)+[A-Za-z]{2,18}$`)
  264. return RegExp.MatchString(domain)
  265. }