token_test.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "strconv"
  9. "strings"
  10. "testing"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/gin-gonic/gin"
  14. "github.com/glebarez/sqlite"
  15. "gorm.io/gorm"
  16. )
  17. type tokenAPIResponse struct {
  18. Success bool `json:"success"`
  19. Message string `json:"message"`
  20. Data json.RawMessage `json:"data"`
  21. }
  22. type tokenPageResponse struct {
  23. Items []tokenResponseItem `json:"items"`
  24. }
  25. type tokenResponseItem struct {
  26. ID int `json:"id"`
  27. Name string `json:"name"`
  28. Key string `json:"key"`
  29. Status int `json:"status"`
  30. }
  31. type tokenKeyResponse struct {
  32. Key string `json:"key"`
  33. }
  34. func setupTokenControllerTestDB(t *testing.T) *gorm.DB {
  35. t.Helper()
  36. gin.SetMode(gin.TestMode)
  37. common.UsingSQLite = true
  38. common.UsingMySQL = false
  39. common.UsingPostgreSQL = false
  40. common.RedisEnabled = false
  41. dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
  42. db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
  43. if err != nil {
  44. t.Fatalf("failed to open sqlite db: %v", err)
  45. }
  46. model.DB = db
  47. model.LOG_DB = db
  48. if err := db.AutoMigrate(&model.Token{}); err != nil {
  49. t.Fatalf("failed to migrate token table: %v", err)
  50. }
  51. t.Cleanup(func() {
  52. sqlDB, err := db.DB()
  53. if err == nil {
  54. _ = sqlDB.Close()
  55. }
  56. })
  57. return db
  58. }
  59. func seedToken(t *testing.T, db *gorm.DB, userID int, name string, rawKey string) *model.Token {
  60. t.Helper()
  61. token := &model.Token{
  62. UserId: userID,
  63. Name: name,
  64. Key: rawKey,
  65. Status: common.TokenStatusEnabled,
  66. CreatedTime: 1,
  67. AccessedTime: 1,
  68. ExpiredTime: -1,
  69. RemainQuota: 100,
  70. UnlimitedQuota: true,
  71. Group: "default",
  72. }
  73. if err := db.Create(token).Error; err != nil {
  74. t.Fatalf("failed to create token: %v", err)
  75. }
  76. return token
  77. }
  78. func newAuthenticatedContext(t *testing.T, method string, target string, body any, userID int) (*gin.Context, *httptest.ResponseRecorder) {
  79. t.Helper()
  80. var requestBody *bytes.Reader
  81. if body != nil {
  82. payload, err := common.Marshal(body)
  83. if err != nil {
  84. t.Fatalf("failed to marshal request body: %v", err)
  85. }
  86. requestBody = bytes.NewReader(payload)
  87. } else {
  88. requestBody = bytes.NewReader(nil)
  89. }
  90. recorder := httptest.NewRecorder()
  91. ctx, _ := gin.CreateTestContext(recorder)
  92. ctx.Request = httptest.NewRequest(method, target, requestBody)
  93. if body != nil {
  94. ctx.Request.Header.Set("Content-Type", "application/json")
  95. }
  96. ctx.Set("id", userID)
  97. return ctx, recorder
  98. }
  99. func decodeAPIResponse(t *testing.T, recorder *httptest.ResponseRecorder) tokenAPIResponse {
  100. t.Helper()
  101. var response tokenAPIResponse
  102. if err := common.Unmarshal(recorder.Body.Bytes(), &response); err != nil {
  103. t.Fatalf("failed to decode api response: %v", err)
  104. }
  105. return response
  106. }
  107. func TestGetAllTokensMasksKeyInResponse(t *testing.T) {
  108. db := setupTokenControllerTestDB(t)
  109. token := seedToken(t, db, 1, "list-token", "abcd1234efgh5678")
  110. seedToken(t, db, 2, "other-user-token", "zzzz1234yyyy5678")
  111. ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/?p=1&size=10", nil, 1)
  112. GetAllTokens(ctx)
  113. response := decodeAPIResponse(t, recorder)
  114. if !response.Success {
  115. t.Fatalf("expected success response, got message: %s", response.Message)
  116. }
  117. var page tokenPageResponse
  118. if err := common.Unmarshal(response.Data, &page); err != nil {
  119. t.Fatalf("failed to decode token page response: %v", err)
  120. }
  121. if len(page.Items) != 1 {
  122. t.Fatalf("expected exactly one token, got %d", len(page.Items))
  123. }
  124. if page.Items[0].Key != token.GetMaskedKey() {
  125. t.Fatalf("expected masked key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
  126. }
  127. if strings.Contains(recorder.Body.String(), token.Key) {
  128. t.Fatalf("list response leaked raw token key: %s", recorder.Body.String())
  129. }
  130. }
  131. func TestSearchTokensMasksKeyInResponse(t *testing.T) {
  132. db := setupTokenControllerTestDB(t)
  133. token := seedToken(t, db, 1, "searchable-token", "ijkl1234mnop5678")
  134. ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/search?keyword=searchable-token&p=1&size=10", nil, 1)
  135. SearchTokens(ctx)
  136. response := decodeAPIResponse(t, recorder)
  137. if !response.Success {
  138. t.Fatalf("expected success response, got message: %s", response.Message)
  139. }
  140. var page tokenPageResponse
  141. if err := common.Unmarshal(response.Data, &page); err != nil {
  142. t.Fatalf("failed to decode search response: %v", err)
  143. }
  144. if len(page.Items) != 1 {
  145. t.Fatalf("expected exactly one search result, got %d", len(page.Items))
  146. }
  147. if page.Items[0].Key != token.GetMaskedKey() {
  148. t.Fatalf("expected masked search key %q, got %q", token.GetMaskedKey(), page.Items[0].Key)
  149. }
  150. if strings.Contains(recorder.Body.String(), token.Key) {
  151. t.Fatalf("search response leaked raw token key: %s", recorder.Body.String())
  152. }
  153. }
  154. func TestGetTokenMasksKeyInResponse(t *testing.T) {
  155. db := setupTokenControllerTestDB(t)
  156. token := seedToken(t, db, 1, "detail-token", "qrst1234uvwx5678")
  157. ctx, recorder := newAuthenticatedContext(t, http.MethodGet, "/api/token/"+strconv.Itoa(token.Id), nil, 1)
  158. ctx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
  159. GetToken(ctx)
  160. response := decodeAPIResponse(t, recorder)
  161. if !response.Success {
  162. t.Fatalf("expected success response, got message: %s", response.Message)
  163. }
  164. var detail tokenResponseItem
  165. if err := common.Unmarshal(response.Data, &detail); err != nil {
  166. t.Fatalf("failed to decode token detail response: %v", err)
  167. }
  168. if detail.Key != token.GetMaskedKey() {
  169. t.Fatalf("expected masked detail key %q, got %q", token.GetMaskedKey(), detail.Key)
  170. }
  171. if strings.Contains(recorder.Body.String(), token.Key) {
  172. t.Fatalf("detail response leaked raw token key: %s", recorder.Body.String())
  173. }
  174. }
  175. func TestUpdateTokenMasksKeyInResponse(t *testing.T) {
  176. db := setupTokenControllerTestDB(t)
  177. token := seedToken(t, db, 1, "editable-token", "yzab1234cdef5678")
  178. body := map[string]any{
  179. "id": token.Id,
  180. "name": "updated-token",
  181. "expired_time": -1,
  182. "remain_quota": 100,
  183. "unlimited_quota": true,
  184. "model_limits_enabled": false,
  185. "model_limits": "",
  186. "group": "default",
  187. "cross_group_retry": false,
  188. }
  189. ctx, recorder := newAuthenticatedContext(t, http.MethodPut, "/api/token/", body, 1)
  190. UpdateToken(ctx)
  191. response := decodeAPIResponse(t, recorder)
  192. if !response.Success {
  193. t.Fatalf("expected success response, got message: %s", response.Message)
  194. }
  195. var detail tokenResponseItem
  196. if err := common.Unmarshal(response.Data, &detail); err != nil {
  197. t.Fatalf("failed to decode token update response: %v", err)
  198. }
  199. if detail.Key != token.GetMaskedKey() {
  200. t.Fatalf("expected masked update key %q, got %q", token.GetMaskedKey(), detail.Key)
  201. }
  202. if strings.Contains(recorder.Body.String(), token.Key) {
  203. t.Fatalf("update response leaked raw token key: %s", recorder.Body.String())
  204. }
  205. }
  206. func TestGetTokenKeyRequiresOwnershipAndReturnsFullKey(t *testing.T) {
  207. db := setupTokenControllerTestDB(t)
  208. token := seedToken(t, db, 1, "owned-token", "owner1234token5678")
  209. authorizedCtx, authorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 1)
  210. authorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
  211. GetTokenKey(authorizedCtx)
  212. authorizedResponse := decodeAPIResponse(t, authorizedRecorder)
  213. if !authorizedResponse.Success {
  214. t.Fatalf("expected authorized key fetch to succeed, got message: %s", authorizedResponse.Message)
  215. }
  216. var keyData tokenKeyResponse
  217. if err := common.Unmarshal(authorizedResponse.Data, &keyData); err != nil {
  218. t.Fatalf("failed to decode token key response: %v", err)
  219. }
  220. if keyData.Key != token.GetFullKey() {
  221. t.Fatalf("expected full key %q, got %q", token.GetFullKey(), keyData.Key)
  222. }
  223. unauthorizedCtx, unauthorizedRecorder := newAuthenticatedContext(t, http.MethodPost, "/api/token/"+strconv.Itoa(token.Id)+"/key", nil, 2)
  224. unauthorizedCtx.Params = gin.Params{{Key: "id", Value: strconv.Itoa(token.Id)}}
  225. GetTokenKey(unauthorizedCtx)
  226. unauthorizedResponse := decodeAPIResponse(t, unauthorizedRecorder)
  227. if unauthorizedResponse.Success {
  228. t.Fatalf("expected unauthorized key fetch to fail")
  229. }
  230. if strings.Contains(unauthorizedRecorder.Body.String(), token.Key) {
  231. t.Fatalf("unauthorized key response leaked raw token key: %s", unauthorizedRecorder.Body.String())
  232. }
  233. }