generic.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. package oauth
  2. import (
  3. "context"
  4. "encoding/base64"
  5. stdjson "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "net/url"
  11. "regexp"
  12. "strconv"
  13. "strings"
  14. "time"
  15. "github.com/QuantumNous/new-api/common"
  16. "github.com/QuantumNous/new-api/i18n"
  17. "github.com/QuantumNous/new-api/logger"
  18. "github.com/QuantumNous/new-api/model"
  19. "github.com/QuantumNous/new-api/setting/system_setting"
  20. "github.com/gin-gonic/gin"
  21. "github.com/samber/lo"
  22. "github.com/tidwall/gjson"
  23. )
  24. // AuthStyle defines how to send client credentials
  25. const (
  26. AuthStyleAutoDetect = 0 // Auto-detect based on server response
  27. AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
  28. AuthStyleInHeader = 2 // Send as Basic Auth header
  29. )
  30. // GenericOAuthProvider implements OAuth for custom/generic OAuth providers
  31. type GenericOAuthProvider struct {
  32. config *model.CustomOAuthProvider
  33. }
  34. type accessPolicy struct {
  35. Logic string `json:"logic"`
  36. Conditions []accessCondition `json:"conditions"`
  37. Groups []accessPolicy `json:"groups"`
  38. }
  39. type accessCondition struct {
  40. Field string `json:"field"`
  41. Op string `json:"op"`
  42. Value any `json:"value"`
  43. }
  44. type accessPolicyFailure struct {
  45. Field string
  46. Op string
  47. Expected any
  48. Current any
  49. }
  50. var supportedAccessPolicyOps = []string{
  51. "eq",
  52. "ne",
  53. "gt",
  54. "gte",
  55. "lt",
  56. "lte",
  57. "in",
  58. "not_in",
  59. "contains",
  60. "not_contains",
  61. "exists",
  62. "not_exists",
  63. }
  64. // NewGenericOAuthProvider creates a new generic OAuth provider from config
  65. func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
  66. return &GenericOAuthProvider{config: config}
  67. }
  68. func (p *GenericOAuthProvider) GetName() string {
  69. return p.config.Name
  70. }
  71. func (p *GenericOAuthProvider) IsEnabled() bool {
  72. return p.config.Enabled
  73. }
  74. func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
  75. return p.config
  76. }
  77. func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
  78. if code == "" {
  79. return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
  80. }
  81. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
  82. redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
  83. values := url.Values{}
  84. values.Set("grant_type", "authorization_code")
  85. values.Set("code", code)
  86. values.Set("redirect_uri", redirectUri)
  87. // Determine auth style
  88. authStyle := p.config.AuthStyle
  89. if authStyle == AuthStyleAutoDetect {
  90. // Default to params style for most OAuth servers
  91. authStyle = AuthStyleInParams
  92. }
  93. var req *http.Request
  94. var err error
  95. if authStyle == AuthStyleInParams {
  96. values.Set("client_id", p.config.ClientId)
  97. values.Set("client_secret", p.config.ClientSecret)
  98. }
  99. req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
  100. if err != nil {
  101. return nil, err
  102. }
  103. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  104. req.Header.Set("Accept", "application/json")
  105. if authStyle == AuthStyleInHeader {
  106. // Basic Auth
  107. credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
  108. req.Header.Set("Authorization", "Basic "+credentials)
  109. }
  110. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
  111. p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
  112. client := http.Client{
  113. Timeout: 20 * time.Second,
  114. }
  115. res, err := client.Do(req)
  116. if err != nil {
  117. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
  118. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
  119. }
  120. defer res.Body.Close()
  121. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
  122. body, err := io.ReadAll(res.Body)
  123. if err != nil {
  124. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
  125. return nil, err
  126. }
  127. bodyStr := string(body)
  128. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
  129. // Try to parse as JSON first
  130. var tokenResponse struct {
  131. AccessToken string `json:"access_token"`
  132. TokenType string `json:"token_type"`
  133. RefreshToken string `json:"refresh_token"`
  134. ExpiresIn int `json:"expires_in"`
  135. Scope string `json:"scope"`
  136. IDToken string `json:"id_token"`
  137. Error string `json:"error"`
  138. ErrorDesc string `json:"error_description"`
  139. }
  140. if err := common.Unmarshal(body, &tokenResponse); err != nil {
  141. // Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
  142. parsedValues, parseErr := url.ParseQuery(bodyStr)
  143. if parseErr != nil {
  144. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
  145. return nil, err
  146. }
  147. tokenResponse.AccessToken = parsedValues.Get("access_token")
  148. tokenResponse.TokenType = parsedValues.Get("token_type")
  149. tokenResponse.Scope = parsedValues.Get("scope")
  150. }
  151. if tokenResponse.Error != "" {
  152. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
  153. p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
  154. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
  155. }
  156. if tokenResponse.AccessToken == "" {
  157. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
  158. return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
  159. }
  160. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
  161. return &OAuthToken{
  162. AccessToken: tokenResponse.AccessToken,
  163. TokenType: tokenResponse.TokenType,
  164. RefreshToken: tokenResponse.RefreshToken,
  165. ExpiresIn: tokenResponse.ExpiresIn,
  166. Scope: tokenResponse.Scope,
  167. IDToken: tokenResponse.IDToken,
  168. }, nil
  169. }
  170. func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
  171. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
  172. req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
  173. if err != nil {
  174. return nil, err
  175. }
  176. // Set authorization header
  177. tokenType := normalizeAuthorizationTokenType(token.TokenType)
  178. req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
  179. req.Header.Set("Accept", "application/json")
  180. client := http.Client{
  181. Timeout: 20 * time.Second,
  182. }
  183. res, err := client.Do(req)
  184. if err != nil {
  185. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
  186. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
  187. }
  188. defer res.Body.Close()
  189. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
  190. if res.StatusCode != http.StatusOK {
  191. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
  192. return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
  193. }
  194. body, err := io.ReadAll(res.Body)
  195. if err != nil {
  196. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
  197. return nil, err
  198. }
  199. bodyStr := string(body)
  200. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
  201. // Extract fields using gjson (supports JSONPath-like syntax)
  202. userId := gjson.Get(bodyStr, p.config.UserIdField).String()
  203. username := gjson.Get(bodyStr, p.config.UsernameField).String()
  204. displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
  205. email := gjson.Get(bodyStr, p.config.EmailField).String()
  206. // If user ID field returns a number, convert it
  207. if userId == "" {
  208. // Try to get as number
  209. userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
  210. if userIdNum.Exists() {
  211. userId = userIdNum.Raw
  212. // Remove quotes if present
  213. userId = strings.Trim(userId, "\"")
  214. }
  215. }
  216. if userId == "" {
  217. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
  218. return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
  219. }
  220. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
  221. p.config.Slug, userId, username, displayName, email)
  222. policyRaw := strings.TrimSpace(p.config.AccessPolicy)
  223. if policyRaw != "" {
  224. policy, err := parseAccessPolicy(policyRaw)
  225. if err != nil {
  226. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] invalid access policy: %s", p.config.Slug, err.Error()))
  227. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, nil, "invalid access policy configuration")
  228. }
  229. allowed, failure := evaluateAccessPolicy(bodyStr, policy)
  230. if !allowed {
  231. message := renderAccessDeniedMessage(p.config.AccessDeniedMessage, p.config.Name, bodyStr, failure)
  232. logger.LogWarn(ctx, fmt.Sprintf("[OAuth-Generic-%s] access denied by policy: field=%s op=%s expected=%v current=%v",
  233. p.config.Slug, failure.Field, failure.Op, failure.Expected, failure.Current))
  234. return nil, &AccessDeniedError{Message: message}
  235. }
  236. }
  237. return &OAuthUser{
  238. ProviderUserID: userId,
  239. Username: username,
  240. DisplayName: displayName,
  241. Email: email,
  242. Extra: map[string]any{
  243. "provider": p.config.Slug,
  244. },
  245. }, nil
  246. }
  247. func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
  248. return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
  249. }
  250. func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
  251. foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
  252. if err != nil {
  253. return err
  254. }
  255. *user = *foundUser
  256. return nil
  257. }
  258. func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
  259. // For generic providers, we store the binding in user_oauth_bindings table
  260. // This is handled separately in the OAuth controller
  261. }
  262. func (p *GenericOAuthProvider) GetProviderPrefix() string {
  263. return p.config.Slug + "_"
  264. }
  265. // GetProviderId returns the provider ID for binding purposes
  266. func (p *GenericOAuthProvider) GetProviderId() int {
  267. return p.config.Id
  268. }
  269. func normalizeAuthorizationTokenType(tokenType string) string {
  270. tokenType = strings.TrimSpace(tokenType)
  271. if tokenType == "" || strings.EqualFold(tokenType, "Bearer") {
  272. return "Bearer"
  273. }
  274. return tokenType
  275. }
  276. // IsGenericProvider returns true for generic providers
  277. func (p *GenericOAuthProvider) IsGenericProvider() bool {
  278. return true
  279. }
  280. func parseAccessPolicy(raw string) (*accessPolicy, error) {
  281. var policy accessPolicy
  282. if err := common.UnmarshalJsonStr(raw, &policy); err != nil {
  283. return nil, err
  284. }
  285. if err := validateAccessPolicy(&policy); err != nil {
  286. return nil, err
  287. }
  288. return &policy, nil
  289. }
  290. func validateAccessPolicy(policy *accessPolicy) error {
  291. if policy == nil {
  292. return errors.New("policy is nil")
  293. }
  294. logic := strings.ToLower(strings.TrimSpace(policy.Logic))
  295. if logic == "" {
  296. logic = "and"
  297. }
  298. if !lo.Contains([]string{"and", "or"}, logic) {
  299. return fmt.Errorf("unsupported policy logic: %s", logic)
  300. }
  301. policy.Logic = logic
  302. if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
  303. return errors.New("policy requires at least one condition or group")
  304. }
  305. for index := range policy.Conditions {
  306. if err := validateAccessCondition(&policy.Conditions[index], index); err != nil {
  307. return err
  308. }
  309. }
  310. for index := range policy.Groups {
  311. if err := validateAccessPolicy(&policy.Groups[index]); err != nil {
  312. return fmt.Errorf("invalid policy group[%d]: %w", index, err)
  313. }
  314. }
  315. return nil
  316. }
  317. func validateAccessCondition(condition *accessCondition, index int) error {
  318. if condition == nil {
  319. return fmt.Errorf("condition[%d] is nil", index)
  320. }
  321. condition.Field = strings.TrimSpace(condition.Field)
  322. if condition.Field == "" {
  323. return fmt.Errorf("condition[%d].field is required", index)
  324. }
  325. condition.Op = normalizePolicyOp(condition.Op)
  326. if !lo.Contains(supportedAccessPolicyOps, condition.Op) {
  327. return fmt.Errorf("condition[%d].op is unsupported: %s", index, condition.Op)
  328. }
  329. if lo.Contains([]string{"in", "not_in"}, condition.Op) {
  330. if _, ok := condition.Value.([]any); !ok {
  331. return fmt.Errorf("condition[%d].value must be an array for op %s", index, condition.Op)
  332. }
  333. }
  334. return nil
  335. }
  336. func evaluateAccessPolicy(body string, policy *accessPolicy) (bool, *accessPolicyFailure) {
  337. if policy == nil {
  338. return true, nil
  339. }
  340. logic := strings.ToLower(strings.TrimSpace(policy.Logic))
  341. if logic == "" {
  342. logic = "and"
  343. }
  344. hasAny := len(policy.Conditions) > 0 || len(policy.Groups) > 0
  345. if !hasAny {
  346. return true, nil
  347. }
  348. if logic == "or" {
  349. var firstFailure *accessPolicyFailure
  350. for _, cond := range policy.Conditions {
  351. ok, failure := evaluateAccessCondition(body, cond)
  352. if ok {
  353. return true, nil
  354. }
  355. if firstFailure == nil {
  356. firstFailure = failure
  357. }
  358. }
  359. for _, group := range policy.Groups {
  360. ok, failure := evaluateAccessPolicy(body, &group)
  361. if ok {
  362. return true, nil
  363. }
  364. if firstFailure == nil {
  365. firstFailure = failure
  366. }
  367. }
  368. return false, firstFailure
  369. }
  370. for _, cond := range policy.Conditions {
  371. ok, failure := evaluateAccessCondition(body, cond)
  372. if !ok {
  373. return false, failure
  374. }
  375. }
  376. for _, group := range policy.Groups {
  377. ok, failure := evaluateAccessPolicy(body, &group)
  378. if !ok {
  379. return false, failure
  380. }
  381. }
  382. return true, nil
  383. }
  384. func evaluateAccessCondition(body string, cond accessCondition) (bool, *accessPolicyFailure) {
  385. path := cond.Field
  386. op := cond.Op
  387. result := gjson.Get(body, path)
  388. current := gjsonResultToValue(result)
  389. failure := &accessPolicyFailure{
  390. Field: path,
  391. Op: op,
  392. Expected: cond.Value,
  393. Current: current,
  394. }
  395. switch op {
  396. case "exists":
  397. return result.Exists(), failure
  398. case "not_exists":
  399. return !result.Exists(), failure
  400. case "eq":
  401. return compareAny(current, cond.Value) == 0, failure
  402. case "ne":
  403. return compareAny(current, cond.Value) != 0, failure
  404. case "gt":
  405. return compareAny(current, cond.Value) > 0, failure
  406. case "gte":
  407. return compareAny(current, cond.Value) >= 0, failure
  408. case "lt":
  409. return compareAny(current, cond.Value) < 0, failure
  410. case "lte":
  411. return compareAny(current, cond.Value) <= 0, failure
  412. case "in":
  413. return valueInSlice(current, cond.Value), failure
  414. case "not_in":
  415. return !valueInSlice(current, cond.Value), failure
  416. case "contains":
  417. return containsValue(current, cond.Value), failure
  418. case "not_contains":
  419. return !containsValue(current, cond.Value), failure
  420. default:
  421. return false, failure
  422. }
  423. }
  424. func normalizePolicyOp(op string) string {
  425. return strings.ToLower(strings.TrimSpace(op))
  426. }
  427. func gjsonResultToValue(result gjson.Result) any {
  428. if !result.Exists() {
  429. return nil
  430. }
  431. if result.IsArray() {
  432. arr := result.Array()
  433. values := make([]any, 0, len(arr))
  434. for _, item := range arr {
  435. values = append(values, gjsonResultToValue(item))
  436. }
  437. return values
  438. }
  439. switch result.Type {
  440. case gjson.Null:
  441. return nil
  442. case gjson.True:
  443. return true
  444. case gjson.False:
  445. return false
  446. case gjson.Number:
  447. return result.Num
  448. case gjson.String:
  449. return result.String()
  450. case gjson.JSON:
  451. var data any
  452. if err := common.UnmarshalJsonStr(result.Raw, &data); err == nil {
  453. return data
  454. }
  455. return result.Raw
  456. default:
  457. return result.Value()
  458. }
  459. }
  460. func compareAny(left any, right any) int {
  461. if lf, ok := toFloat(left); ok {
  462. if rf, ok2 := toFloat(right); ok2 {
  463. switch {
  464. case lf < rf:
  465. return -1
  466. case lf > rf:
  467. return 1
  468. default:
  469. return 0
  470. }
  471. }
  472. }
  473. ls := strings.TrimSpace(fmt.Sprint(left))
  474. rs := strings.TrimSpace(fmt.Sprint(right))
  475. switch {
  476. case ls < rs:
  477. return -1
  478. case ls > rs:
  479. return 1
  480. default:
  481. return 0
  482. }
  483. }
  484. func toFloat(v any) (float64, bool) {
  485. switch value := v.(type) {
  486. case float64:
  487. return value, true
  488. case float32:
  489. return float64(value), true
  490. case int:
  491. return float64(value), true
  492. case int8:
  493. return float64(value), true
  494. case int16:
  495. return float64(value), true
  496. case int32:
  497. return float64(value), true
  498. case int64:
  499. return float64(value), true
  500. case uint:
  501. return float64(value), true
  502. case uint8:
  503. return float64(value), true
  504. case uint16:
  505. return float64(value), true
  506. case uint32:
  507. return float64(value), true
  508. case uint64:
  509. return float64(value), true
  510. case stdjson.Number:
  511. n, err := value.Float64()
  512. if err == nil {
  513. return n, true
  514. }
  515. case string:
  516. n, err := strconv.ParseFloat(strings.TrimSpace(value), 64)
  517. if err == nil {
  518. return n, true
  519. }
  520. }
  521. return 0, false
  522. }
  523. func valueInSlice(current any, expected any) bool {
  524. list, ok := expected.([]any)
  525. if !ok {
  526. return false
  527. }
  528. return lo.ContainsBy(list, func(item any) bool {
  529. return compareAny(current, item) == 0
  530. })
  531. }
  532. func containsValue(current any, expected any) bool {
  533. switch value := current.(type) {
  534. case string:
  535. target := strings.TrimSpace(fmt.Sprint(expected))
  536. return strings.Contains(value, target)
  537. case []any:
  538. return lo.ContainsBy(value, func(item any) bool {
  539. return compareAny(item, expected) == 0
  540. })
  541. }
  542. return false
  543. }
  544. func renderAccessDeniedMessage(template string, providerName string, body string, failure *accessPolicyFailure) string {
  545. defaultMessage := "Access denied: your account does not meet this provider's access requirements."
  546. message := strings.TrimSpace(template)
  547. if message == "" {
  548. return defaultMessage
  549. }
  550. if failure == nil {
  551. failure = &accessPolicyFailure{}
  552. }
  553. replacements := map[string]string{
  554. "{{provider}}": providerName,
  555. "{{field}}": failure.Field,
  556. "{{op}}": failure.Op,
  557. "{{required}}": fmt.Sprint(failure.Expected),
  558. "{{current}}": fmt.Sprint(failure.Current),
  559. }
  560. for key, value := range replacements {
  561. message = strings.ReplaceAll(message, key, value)
  562. }
  563. currentPattern := regexp.MustCompile(`\{\{current\.([^}]+)\}\}`)
  564. message = currentPattern.ReplaceAllStringFunc(message, func(token string) string {
  565. match := currentPattern.FindStringSubmatch(token)
  566. if len(match) != 2 {
  567. return ""
  568. }
  569. path := strings.TrimSpace(match[1])
  570. if path == "" {
  571. return ""
  572. }
  573. return strings.TrimSpace(gjson.Get(body, path).String())
  574. })
  575. requiredPattern := regexp.MustCompile(`\{\{required\.([^}]+)\}\}`)
  576. message = requiredPattern.ReplaceAllStringFunc(message, func(token string) string {
  577. match := requiredPattern.FindStringSubmatch(token)
  578. if len(match) != 2 {
  579. return ""
  580. }
  581. path := strings.TrimSpace(match[1])
  582. if failure.Field == path {
  583. return fmt.Sprint(failure.Expected)
  584. }
  585. return ""
  586. })
  587. return strings.TrimSpace(message)
  588. }