| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- package dbdata
- import (
- "errors"
- "fmt"
- "net"
- "strings"
- "time"
- )
- func GetPolicy(Username string) *Policy {
- policyData := &Policy{}
- err := One("Username", Username, policyData)
- if err != nil {
- return policyData
- }
- return policyData
- }
- func SetPolicy(p *Policy) error {
- var err error
- if p.Username == "" {
- return errors.New("用户名错误")
- }
- // 包含路由
- routeInclude := []ValData{}
- for _, v := range p.RouteInclude {
- if v.Val != "" {
- if v.Val == All {
- routeInclude = append(routeInclude, v)
- continue
- }
- ipMask, ipNet, err := parseIpNet(v.Val)
- if err != nil {
- return errors.New("RouteInclude 错误" + err.Error())
- }
- if strings.Split(ipMask, "/")[0] != ipNet.IP.String() {
- errMsg := fmt.Sprintf("RouteInclude 错误: 网络地址错误,建议: %s 改为 %s", v.Val, ipNet)
- return errors.New(errMsg)
- }
- v.IpMask = ipMask
- routeInclude = append(routeInclude, v)
- }
- }
- p.RouteInclude = routeInclude
- // 排除路由
- routeExclude := []ValData{}
- for _, v := range p.RouteExclude {
- if v.Val != "" {
- ipMask, ipNet, err := parseIpNet(v.Val)
- if err != nil {
- return errors.New("RouteExclude 错误" + err.Error())
- }
- if strings.Split(ipMask, "/")[0] != ipNet.IP.String() {
- errMsg := fmt.Sprintf("RouteInclude 错误: 网络地址错误,建议: %s 改为 %s", v.Val, ipNet)
- return errors.New(errMsg)
- }
- v.IpMask = ipMask
- routeExclude = append(routeExclude, v)
- }
- }
- p.RouteExclude = routeExclude
- // DNS 判断
- clientDns := []ValData{}
- for _, v := range p.ClientDns {
- if v.Val != "" {
- ip := net.ParseIP(v.Val)
- if ip.String() != v.Val {
- return errors.New("DNS IP 错误")
- }
- clientDns = append(clientDns, v)
- }
- }
- if len(routeInclude) == 0 || (len(routeInclude) == 1 && routeInclude[0].Val == "all") {
- if len(clientDns) == 0 {
- return errors.New("默认路由,必须设置一个DNS")
- }
- }
- p.ClientDns = clientDns
- // 域名拆分隧道,不能同时填写
- p.DsIncludeDomains = strings.TrimSpace(p.DsIncludeDomains)
- p.DsExcludeDomains = strings.TrimSpace(p.DsExcludeDomains)
- if p.DsIncludeDomains != "" && p.DsExcludeDomains != "" {
- return errors.New("包含/排除域名不能同时填写")
- }
- // 校验包含域名的格式
- err = CheckDomainNames(p.DsIncludeDomains)
- if err != nil {
- return errors.New("包含域名有误:" + err.Error())
- }
- // 校验排除域名的格式
- err = CheckDomainNames(p.DsExcludeDomains)
- if err != nil {
- return errors.New("排除域名有误:" + err.Error())
- }
- p.UpdatedAt = time.Now()
- if p.Id > 0 {
- err = Set(p)
- } else {
- err = Add(p)
- }
- return err
- }
|