url_validator_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package common
  2. import (
  3. "testing"
  4. "github.com/QuantumNous/new-api/constant"
  5. )
  6. func TestValidateRedirectURL(t *testing.T) {
  7. // Save original trusted domains and restore after test
  8. originalDomains := constant.TrustedRedirectDomains
  9. defer func() {
  10. constant.TrustedRedirectDomains = originalDomains
  11. }()
  12. tests := []struct {
  13. name string
  14. url string
  15. trustedDomains []string
  16. wantErr bool
  17. errContains string
  18. }{
  19. // Valid cases
  20. {
  21. name: "exact domain match with https",
  22. url: "https://example.com/success",
  23. trustedDomains: []string{"example.com"},
  24. wantErr: false,
  25. },
  26. {
  27. name: "exact domain match with http",
  28. url: "http://example.com/callback",
  29. trustedDomains: []string{"example.com"},
  30. wantErr: false,
  31. },
  32. {
  33. name: "subdomain match",
  34. url: "https://sub.example.com/success",
  35. trustedDomains: []string{"example.com"},
  36. wantErr: false,
  37. },
  38. {
  39. name: "case insensitive domain",
  40. url: "https://EXAMPLE.COM/success",
  41. trustedDomains: []string{"example.com"},
  42. wantErr: false,
  43. },
  44. // Invalid cases - untrusted domain
  45. {
  46. name: "untrusted domain",
  47. url: "https://evil.com/phishing",
  48. trustedDomains: []string{"example.com"},
  49. wantErr: true,
  50. errContains: "not in the trusted domains list",
  51. },
  52. {
  53. name: "suffix attack - fakeexample.com",
  54. url: "https://fakeexample.com/success",
  55. trustedDomains: []string{"example.com"},
  56. wantErr: true,
  57. errContains: "not in the trusted domains list",
  58. },
  59. {
  60. name: "empty trusted domains list",
  61. url: "https://example.com/success",
  62. trustedDomains: []string{},
  63. wantErr: true,
  64. errContains: "not in the trusted domains list",
  65. },
  66. // Invalid cases - scheme
  67. {
  68. name: "javascript scheme",
  69. url: "javascript:alert('xss')",
  70. trustedDomains: []string{"example.com"},
  71. wantErr: true,
  72. errContains: "invalid URL scheme",
  73. },
  74. {
  75. name: "data scheme",
  76. url: "data:text/html,<script>alert('xss')</script>",
  77. trustedDomains: []string{"example.com"},
  78. wantErr: true,
  79. errContains: "invalid URL scheme",
  80. },
  81. // Edge cases
  82. {
  83. name: "empty URL",
  84. url: "",
  85. trustedDomains: []string{"example.com"},
  86. wantErr: true,
  87. errContains: "invalid URL scheme",
  88. },
  89. }
  90. for _, tt := range tests {
  91. t.Run(tt.name, func(t *testing.T) {
  92. // Set up trusted domains for this test case
  93. constant.TrustedRedirectDomains = tt.trustedDomains
  94. err := ValidateRedirectURL(tt.url)
  95. if tt.wantErr {
  96. if err == nil {
  97. t.Errorf("ValidateRedirectURL(%q) expected error containing %q, got nil", tt.url, tt.errContains)
  98. return
  99. }
  100. if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
  101. t.Errorf("ValidateRedirectURL(%q) error = %q, want error containing %q", tt.url, err.Error(), tt.errContains)
  102. }
  103. } else {
  104. if err != nil {
  105. t.Errorf("ValidateRedirectURL(%q) unexpected error: %v", tt.url, err)
  106. }
  107. }
  108. })
  109. }
  110. }
  111. func contains(s, substr string) bool {
  112. return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
  113. (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr)))
  114. }
  115. func findSubstring(s, substr string) bool {
  116. for i := 0; i <= len(s)-len(substr); i++ {
  117. if s[i:i+len(substr)] == substr {
  118. return true
  119. }
  120. }
  121. return false
  122. }