Jelajahi Sumber

project init

project init 如readme.md所写
aiprodcoder 3 bulan lalu
induk
melakukan
02c1fe7ae3
100 mengubah file dengan 10676 tambahan dan 0 penghapusan
  1. 7 0
      .dockerignore
  2. 75 0
      .env.example
  3. 12 0
      .github/FUNDING.yml
  4. 26 0
      .github/ISSUE_TEMPLATE/bug_report.md
  5. 5 0
      .github/ISSUE_TEMPLATE/config.yml
  6. 21 0
      .github/ISSUE_TEMPLATE/feature_request.md
  7. 19 0
      .github/PULL_REQUEST_TEMPLATE/pull_request_template.md
  8. 62 0
      .github/workflows/docker-image-alpha.yml
  9. 56 0
      .github/workflows/docker-image-arm64.yml
  10. 59 0
      .github/workflows/linux-release.yml
  11. 51 0
      .github/workflows/macos-release.yml
  12. 21 0
      .github/workflows/pr-target-branch-check.yml
  13. 53 0
      .github/workflows/windows-release.yml
  14. 13 0
      .gitignore
  15. 35 0
      Dockerfile.backup
  16. 210 0
      README.md
  17. 146 0
      bin/README_migration.md
  18. 149 0
      bin/add_token_rate_limit_fields.sql
  19. 27 0
      bin/add_token_usage_count_fields.sql
  20. 11 0
      bin/add_token_usage_fields_mysql.sql
  21. 15 0
      bin/add_token_usage_fields_postgresql.sql
  22. 10 0
      bin/add_token_usage_fields_sqlite.sql
  23. 23 0
      bin/add_user_input_field.sql
  24. 90 0
      bin/create_token_usage_logs_table.sql
  25. 93 0
      bin/create_usage_statistics_table.sql
  26. 143 0
      bin/migration_token_usage_count.sql
  27. 6 0
      bin/migration_v0.2-v0.3.sql
  28. 17 0
      bin/migration_v0.3-v0.4.sql
  29. 51 0
      bin/test_usage_statistics_data.sql
  30. 40 0
      bin/time_test.sh
  31. 73 0
      common/api_type.go
  32. 201 0
      common/constants.go
  33. 31 0
      common/crypto.go
  34. 82 0
      common/custom-event.go
  35. 15 0
      common/database.go
  36. 40 0
      common/email-outlook-auth.go
  37. 90 0
      common/email.go
  38. 32 0
      common/embed-file-system.go
  39. 41 0
      common/endpoint_type.go
  40. 38 0
      common/env.go
  41. 111 0
      common/gin.go
  42. 53 0
      common/go-channel.go
  43. 24 0
      common/gopool.go
  44. 34 0
      common/hash.go
  45. 57 0
      common/http.go
  46. 120 0
      common/init.go
  47. 22 0
      common/json.go
  48. 89 0
      common/limiter/limiter.go
  49. 44 0
      common/limiter/lua/rate_limit.lua
  50. 123 0
      common/logger.go
  51. 42 0
      common/model.go
  52. 82 0
      common/page_info.go
  53. 44 0
      common/pprof.go
  54. 70 0
      common/rate-limit.go
  55. 327 0
      common/redis.go
  56. 97 0
      common/str.go
  57. 33 0
      common/topup-ratio.go
  58. 304 0
      common/utils.go
  59. 9 0
      common/validate.go
  60. 77 0
      common/verification.go
  61. 26 0
      constant/README.md
  62. 35 0
      constant/api_type.go
  63. 5 0
      constant/azure.go
  64. 14 0
      constant/cache_key.go
  65. 109 0
      constant/channel.go
  66. 44 0
      constant/context_key.go
  67. 16 0
      constant/endpoint_type.go
  68. 15 0
      constant/env.go
  69. 9 0
      constant/finish_reason.go
  70. 48 0
      constant/midjourney.go
  71. 8 0
      constant/multi_key_mode.go
  72. 3 0
      constant/setup.go
  73. 23 0
      constant/task.go
  74. 92 0
      controller/billing.go
  75. 492 0
      controller/channel-billing.go
  76. 465 0
      controller/channel-test.go
  77. 916 0
      controller/channel.go
  78. 103 0
      controller/console_migrate.go
  79. 239 0
      controller/github.go
  80. 50 0
      controller/group.go
  81. 9 0
      controller/image.go
  82. 259 0
      controller/linuxdo.go
  83. 168 0
      controller/log.go
  84. 263 0
      controller/midjourney.go
  85. 302 0
      controller/misc.go
  86. 216 0
      controller/model.go
  87. 228 0
      controller/oidc.go
  88. 171 0
      controller/option.go
  89. 84 0
      controller/playground.go
  90. 71 0
      controller/pricing.go
  91. 24 0
      controller/ratio_config.go
  92. 474 0
      controller/ratio_sync.go
  93. 193 0
      controller/redemption.go
  94. 478 0
      controller/relay.go
  95. 181 0
      controller/setup.go
  96. 116 0
      controller/swag_video.go
  97. 273 0
      controller/task.go
  98. 138 0
      controller/task_video.go
  99. 124 0
      controller/telegram.go
  100. 241 0
      controller/token.go

+ 7 - 0
.dockerignore

@@ -0,0 +1,7 @@
+.github
+.git
+*.md
+.vscode
+.gitignore
+Makefile
+docs

+ 75 - 0
.env.example

@@ -0,0 +1,75 @@
+# 端口号
+# PORT=3000
+# 前端基础URL
+# FRONTEND_BASE_URL=https://your-frontend-url.com
+
+
+# 调试相关配置
+# 启用pprof
+# ENABLE_PPROF=true
+# 启用调试模式
+# DEBUG=true
+
+# 数据库相关配置
+# 数据库连接字符串
+# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
+# 日志数据库连接字符串
+# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
+# SQLite数据库路径
+# SQLITE_PATH=/path/to/sqlite.db
+# 数据库最大空闲连接数
+# SQL_MAX_IDLE_CONNS=100
+# 数据库最大打开连接数
+# SQL_MAX_OPEN_CONNS=1000
+# 数据库连接最大生命周期(秒)
+# SQL_MAX_LIFETIME=60
+
+
+# 缓存相关配置
+# Redis连接字符串
+# REDIS_CONN_STRING=redis://user:password@localhost:6379/0
+# 同步频率(单位:秒)
+# SYNC_FREQUENCY=60
+# 内存缓存启用
+# MEMORY_CACHE_ENABLED=true
+# 渠道更新频率(单位:秒)
+# CHANNEL_UPDATE_FREQUENCY=30
+# 批量更新启用
+# BATCH_UPDATE_ENABLED=true
+# 批量更新间隔(单位:秒)
+# BATCH_UPDATE_INTERVAL=5
+
+# 任务和功能配置
+# 更新任务启用
+# UPDATE_TASK=true
+
+# 对话超时设置
+# 所有请求超时时间,单位秒,默认为0,表示不限制
+# RELAY_TIMEOUT=0
+# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
+# STREAMING_TIMEOUT=120
+
+# Gemini 识别图片 最大图片数量
+# GEMINI_VISION_MAX_IMAGE_NUM=16
+
+# 会话密钥
+# SESSION_SECRET=random_string
+
+# 其他配置
+# 渠道测试频率(单位:秒)
+# CHANNEL_TEST_FREQUENCY=10
+# 生成默认token
+# GENERATE_DEFAULT_TOKEN=false
+# Cohere 安全设置
+# COHERE_SAFETY_SETTING=NONE
+# 是否统计图片token
+# GET_MEDIA_TOKEN=true
+# 是否在非流(stream=false)情况下统计图片token
+# GET_MEDIA_TOKEN_NOT_STREAM=true
+# 设置 Dify 渠道是否输出工作流和节点信息到客户端
+# DIFY_DEBUG=true
+
+
+# 节点类型
+# 如果是主节点则为master
+# NODE_TYPE=master

+ 12 - 0
.github/FUNDING.yml

@@ -0,0 +1,12 @@
+# These are supported funding model platforms
+
+github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
+patreon: # Replace with a single Patreon username
+open_collective: # Replace with a single Open Collective username
+ko_fi: # Replace with a single Ko-fi username
+tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
+community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
+liberapay: # Replace with a single Liberapay username
+issuehunt: # Replace with a single IssueHunt username
+otechie: # Replace with a single Otechie username
+custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

+ 26 - 0
.github/ISSUE_TEMPLATE/bug_report.md

@@ -0,0 +1,26 @@
+---
+name: 报告问题
+about: 使用简练详细的语言描述你遇到的问题
+title: ''
+labels: bug
+assignees: ''
+
+---
+
+**例行检查**
+
+[//]: # (方框内删除已有的空格,填 x 号)
++ [ ] 我已确认目前没有类似 issue
++ [ ] 我已确认我已升级到最新版本
++ [ ] 我已完整查看过项目 README,尤其是常见问题部分
++ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 
++ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
+
+**问题描述**
+
+**复现步骤**
+
+**预期结果**
+
+**相关截图**
+如果没有的话,请删除此节。

+ 5 - 0
.github/ISSUE_TEMPLATE/config.yml

@@ -0,0 +1,5 @@
+blank_issues_enabled: false
+contact_links:
+  - name: 项目群聊
+    url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
+    about: QQ 群:629454374

+ 21 - 0
.github/ISSUE_TEMPLATE/feature_request.md

@@ -0,0 +1,21 @@
+---
+name: 功能请求
+about: 使用简练详细的语言描述希望加入的新功能
+title: ''
+labels: enhancement
+assignees: ''
+
+---
+
+**例行检查**
+
+[//]: # (方框内删除已有的空格,填 x 号)
++ [ ] 我已确认目前没有类似 issue
++ [ ] 我已确认我已升级到最新版本
++ [ ] 我已完整查看过项目 README,已确定现有版本无法满足需求
++ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈
++ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
+
+**功能描述**
+
+**应用场景**

+ 19 - 0
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md

@@ -0,0 +1,19 @@
+### PR 类型
+
+- [ ] Bug 修复
+- [ ] 新功能
+- [ ] 文档更新
+- [ ] 其他
+
+### PR 是否包含破坏性更新?
+
+- [ ] 是
+- [ ] 否
+
+### PR 描述
+
+**请在下方详细描述您的 PR,包括目的、实现细节等。**
+
+### **重要提示**
+
+**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**

+ 62 - 0
.github/workflows/docker-image-alpha.yml

@@ -0,0 +1,62 @@
+name: Publish Docker image (alpha)
+
+on:
+  push:
+    branches:
+      - alpha
+  workflow_dispatch:
+    inputs:
+      name:
+        description: "reason"
+        required: false
+
+jobs:
+  push_to_registries:
+    name: Push Docker image to multiple registries
+    runs-on: ubuntu-latest
+    permissions:
+      packages: write
+      contents: read
+    steps:
+      - name: Check out the repo
+        uses: actions/checkout@v4
+
+      - name: Save version info
+        run: |
+          echo "alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)" > VERSION
+
+      - name: Log in to Docker Hub
+        uses: docker/login-action@v3
+        with:
+          username: ${{ secrets.DOCKERHUB_USERNAME }}
+          password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+      - name: Log in to the Container registry
+        uses: docker/login-action@v3
+        with:
+          registry: ghcr.io
+          username: ${{ github.actor }}
+          password: ${{ secrets.GITHUB_TOKEN }}
+
+      - name: Set up Docker Buildx
+        uses: docker/setup-buildx-action@v3
+
+      - name: Extract metadata (tags, labels) for Docker
+        id: meta
+        uses: docker/metadata-action@v5
+        with:
+          images: |
+            calciumion/new-api
+            ghcr.io/${{ github.repository }}
+          tags: |
+            type=raw,value=alpha
+            type=raw,value=alpha-{{date 'YYYYMMDD'}}-{{sha}}
+
+      - name: Build and push Docker images
+        uses: docker/build-push-action@v5
+        with:
+          context: .
+          platforms: linux/amd64,linux/arm64
+          push: true
+          tags: ${{ steps.meta.outputs.tags }}
+          labels: ${{ steps.meta.outputs.labels }}

+ 56 - 0
.github/workflows/docker-image-arm64.yml

@@ -0,0 +1,56 @@
+name: Publish Docker image (Multi Registries)
+
+on:
+  push:
+    tags:
+      - '*'
+jobs:
+  push_to_registries:
+    name: Push Docker image to multiple registries
+    runs-on: ubuntu-latest
+    permissions:
+      packages: write
+      contents: read
+    steps:
+      - name: Check out the repo
+        uses: actions/checkout@v4
+
+      - name: Save version info
+        run: |
+          git describe --tags > VERSION 
+
+      - name: Set up QEMU
+        uses: docker/setup-qemu-action@v3
+
+      - name: Set up Docker Buildx
+        uses: docker/setup-buildx-action@v3
+
+      - name: Log in to Docker Hub
+        uses: docker/login-action@v3
+        with:
+          username: ${{ secrets.DOCKERHUB_USERNAME }}
+          password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+      - name: Log in to the Container registry
+        uses: docker/login-action@v3
+        with:
+          registry: ghcr.io
+          username: ${{ github.actor }}
+          password: ${{ secrets.GITHUB_TOKEN }}
+
+      - name: Extract metadata (tags, labels) for Docker
+        id: meta
+        uses: docker/metadata-action@v5
+        with:
+          images: |
+            calciumion/new-api
+            ghcr.io/${{ github.repository }}
+
+      - name: Build and push Docker images
+        uses: docker/build-push-action@v5
+        with:
+          context: .
+          platforms: linux/amd64,linux/arm64
+          push: true
+          tags: ${{ steps.meta.outputs.tags }}
+          labels: ${{ steps.meta.outputs.labels }}

+ 59 - 0
.github/workflows/linux-release.yml

@@ -0,0 +1,59 @@
+name: Linux Release
+permissions:
+  contents: write
+
+on:
+  workflow_dispatch:
+    inputs:
+      name:
+        description: 'reason'
+        required: false
+  push:
+    tags:
+      - '*'
+      - '!*-alpha*'
+jobs:
+  release:
+    runs-on: ubuntu-latest
+    steps:
+      - name: Checkout
+        uses: actions/checkout@v3
+        with:
+          fetch-depth: 0
+      - uses: oven-sh/setup-bun@v2
+        with:
+          bun-version: latest
+      - name: Build Frontend
+        env:
+          CI: ""
+        run: |
+          cd web
+          bun install
+          DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
+          cd ..
+      - name: Set up Go
+        uses: actions/setup-go@v3
+        with:
+          go-version: '>=1.18.0'
+      - name: Build Backend (amd64)
+        run: |
+          go mod download
+          go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
+
+      - name: Build Backend (arm64)
+        run: |
+          sudo apt-get update
+          DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
+          CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64
+
+      - name: Release
+        uses: softprops/action-gh-release@v1
+        if: startsWith(github.ref, 'refs/tags/')
+        with:
+          files: |
+            one-api
+            one-api-arm64
+          draft: true
+          generate_release_notes: true
+        env:
+          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

+ 51 - 0
.github/workflows/macos-release.yml

@@ -0,0 +1,51 @@
+name: macOS Release
+permissions:
+  contents: write
+
+on:
+  workflow_dispatch:
+    inputs:
+      name:
+        description: 'reason'
+        required: false
+  push:
+    tags:
+      - '*'
+      - '!*-alpha*'
+jobs:
+  release:
+    runs-on: macos-latest
+    steps:
+      - name: Checkout
+        uses: actions/checkout@v3
+        with:
+          fetch-depth: 0
+      - uses: oven-sh/setup-bun@v2
+        with:
+          bun-version: latest
+      - name: Build Frontend
+        env:
+          CI: ""
+          NODE_OPTIONS: "--max-old-space-size=4096"
+        run: |
+          cd web
+          bun install
+          DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
+          cd ..
+      - name: Set up Go
+        uses: actions/setup-go@v3
+        with:
+          go-version: '>=1.18.0'
+      - name: Build Backend
+        run: |
+          go mod download
+          go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos
+      - name: Release
+        uses: softprops/action-gh-release@v1
+        if: startsWith(github.ref, 'refs/tags/')
+        with:
+          files: one-api-macos
+          draft: true
+          generate_release_notes: true
+        env:
+          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

+ 21 - 0
.github/workflows/pr-target-branch-check.yml

@@ -0,0 +1,21 @@
+name: Check PR Branching Strategy
+on:
+  pull_request:
+    types: [opened, synchronize, reopened, edited]
+
+jobs:
+  check-branching-strategy:
+    runs-on: ubuntu-latest
+    steps:
+      - name: Enforce branching strategy
+        run: |
+          if [[ "${{ github.base_ref }}" == "main" ]]; then
+            if [[ "${{ github.head_ref }}" != "alpha" ]]; then
+              echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
+              exit 1
+            fi
+          elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
+            echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
+            exit 1
+          fi
+          echo "Branching strategy check passed."

+ 53 - 0
.github/workflows/windows-release.yml

@@ -0,0 +1,53 @@
+name: Windows Release
+permissions:
+  contents: write
+
+on:
+  workflow_dispatch:
+    inputs:
+      name:
+        description: 'reason'
+        required: false
+  push:
+    tags:
+      - '*'
+      - '!*-alpha*'
+jobs:
+  release:
+    runs-on: windows-latest
+    defaults:
+      run:
+        shell: bash
+    steps:
+      - name: Checkout
+        uses: actions/checkout@v3
+        with:
+          fetch-depth: 0
+      - uses: oven-sh/setup-bun@v2
+        with:
+          bun-version: latest
+      - name: Build Frontend
+        env:
+          CI: ""
+        run: |
+          cd web
+          bun install
+          DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
+          cd ..
+      - name: Set up Go
+        uses: actions/setup-go@v3
+        with:
+          go-version: '>=1.18.0'
+      - name: Build Backend
+        run: |
+          go mod download
+          go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe
+      - name: Release
+        uses: softprops/action-gh-release@v1
+        if: startsWith(github.ref, 'refs/tags/')
+        with:
+          files: one-api.exe
+          draft: true
+          generate_release_notes: true
+        env:
+          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

+ 13 - 0
.gitignore

@@ -0,0 +1,13 @@
+.idea
+.vscode
+upload
+*.exe
+*.db
+build
+*.db-journal
+logs
+web/dist
+.env
+one-api
+.DS_Store
+tiktoken_cache

+ 35 - 0
Dockerfile.backup

@@ -0,0 +1,35 @@
+FROM oven/bun:latest AS builder
+
+WORKDIR /build
+COPY web/package.json .
+COPY web/bun.lock .
+RUN bun install
+COPY ./web .
+COPY ./VERSION .
+RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
+
+FROM golang:alpine AS builder2
+
+ENV GO111MODULE=on \
+    CGO_ENABLED=0 \
+    GOOS=linux
+
+WORKDIR /build
+
+ADD go.mod go.sum ./
+RUN go mod download
+
+COPY . .
+COPY --from=builder /build/dist ./web/dist
+RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-api
+
+FROM alpine
+
+RUN apk upgrade --no-cache \
+    && apk add --no-cache ca-certificates tzdata ffmpeg \
+    && update-ca-certificates
+
+COPY --from=builder2 /build/one-api /
+EXPOSE 3000
+WORKDIR /data
+ENTRYPOINT ["/one-api"]

+ 210 - 0
README.md

@@ -0,0 +1,210 @@
+<p align="right">
+   <strong>中文</strong> | <a href="./README.en.md">English</a>
+</p>
+<div align="center">
+
+![mixapi](/web/public/logo.png)
+
+# MIXAPI
+
+🍥新一代大模型网关与AI资产管理系统
+
+<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
+
+<p align="center">
+  <a href="https://raw.githubusercontent.com/Calcium-Ion/new-api/main/LICENSE">
+    <img src="https://img.shields.io/github/license/Calcium-Ion/new-api?color=brightgreen" alt="license">
+  </a>
+  <a href="https://github.com/Calcium-Ion/new-api/releases/latest">
+    <img src="https://img.shields.io/github/v/release/Calcium-Ion/new-api?color=brightgreen&include_prereleases" alt="release">
+  </a>
+  <a href="https://github.com/users/Calcium-Ion/packages/container/package/new-api">
+    <img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker">
+  </a>
+  <a href="https://hub.docker.com/r/CalciumIon/new-api">
+    <img src="https://img.shields.io/badge/docker-dockerHub-blue" alt="docker">
+  </a>
+  <a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
+    <img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
+  </a>
+</p>
+</div>
+
+## 📝 项目说明
+
+> [!NOTE]  
+> 本项目为开源项目,由[New API](https://github.com/Calcium-Ion/new-api)二次开发而来
+
+> [!IMPORTANT]  
+> - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
+> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
+> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
+
+<h2>🤝 我们信任的合作伙伴</h2>
+<p id="premium-sponsors">&nbsp;</p>
+<p align="center"><strong>排名不分先后</strong></p>
+<p align="center">
+  <a href="https://www.cherry-ai.com/" target=_blank><img
+    src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
+  /></a>
+  <a href="https://bda.pku.edu.cn/" target=_blank><img
+    src="./docs/images/pku.png" alt="北京大学" height="120"
+  /></a>
+  <a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
+    src="./docs/images/ucloud.png" alt="UCloud 优刻得" height="120"
+  /></a>
+  <a href="https://www.aliyun.com/" target=_blank><img
+    src="./docs/images/aliyun.png" alt="阿里云" height="120"
+  /></a>
+  <a href="https://io.net/" target=_blank><img
+    src="./docs/images/io-net.png" alt="IO.NET" height="120"
+  /></a>
+</p>
+<p>&nbsp;</p>
+
+## 📚 文档
+
+详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
+
+也可访问AI生成的DeepWiki:
+[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/QuantumNous/new-api)
+
+## ✨ 主要特性
+
+MIXAPI提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction):
+
+1. 🎨 全新的UI界面
+2. 🌍 多语言支持
+3. 💰 支持在线充值功能(易支付)
+4. 🔍 支持用key查询使用额度(配合[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
+5. 🔄 兼容原版One API的数据库
+6. 💵 支持模型按次数收费
+7. ⚖️ 支持渠道加权随机
+8. 📈 数据看板(控制台)
+9. 🔒 令牌分组、模型限制
+10. 🤖 支持更多授权登陆方式(LinuxDO,Telegram、OIDC)
+11. 🔄 支持Rerank模型(Cohere和Jina),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
+12. ⚡ 支持OpenAI Realtime API(包括Azure渠道),[接口文档](https://docs.newapi.pro/api/openai-realtime)
+13. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
+14. 支持使用路由/chat2link进入聊天界面
+15. 🧠 支持通过模型名称后缀设置 reasoning effort:
+    1. OpenAI o系列模型
+        - 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
+        - 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
+        - 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
+    2. Claude 思考模型
+        - 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
+16. 🔄 思考转内容功能
+17. 🔄 针对用户的模型限流功能
+18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
+    1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
+    2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
+    3. 支持的渠道:
+        - [x] OpenAI
+        - [x] Azure
+        - [x] DeepSeek
+        - [x] Claude
+19. 🔄 新增对token令牌的控制,可控制分钟请求次数限制和日请求次数限制
+20. 📊 新增用量日统计
+21. 📊 新增用量月统计
+22. 📋 新增令牌管理显示该令牌的今日次数和总次数
+23. 📝 新增通过令牌请求的内容记录显示
+24. 📝 支持通过令牌查询余额
+
+## 模型支持
+
+此版本支持多种模型,详情请参考[接口文档-中继接口](https://docs.newapi.pro/api):
+
+1. 第三方模型 **gpts** (gpt-4-gizmo-*)
+2. 第三方渠道[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[接口文档](https://docs.newapi.pro/api/midjourney-proxy-image)
+3. 第三方渠道[Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music)
+4. 自定义渠道,支持填入完整调用地址
+5. Rerank模型([Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
+6. Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
+7. Dify,当前仅支持chatflow
+
+## 环境变量配置
+
+详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
+
+- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
+- `STREAMING_TIMEOUT`:流式回复超时时间,默认300秒
+- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
+- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
+- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
+- `GET_MEDIA_TOKEN_NOT_STREAM`:非流情况下是否统计图片token,默认 `true`
+- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认 `true`
+- `COHERE_SAFETY_SETTING`:Cohere模型安全设置,可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认 `NONE`
+- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16`
+- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20`
+- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容
+- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview`
+- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
+- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
+- `ERROR_LOG_ENABLED=true`: 是否记录并显示错误日志,默认`false`
+
+## 部署
+
+详细部署指南请参考[安装指南-部署方式](https://docs.newapi.pro/installation):
+
+ 
+
+### 多机部署注意事项
+- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致
+- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取
+
+### 部署要求
+- 本地数据库(默认):SQLite(Docker部署必须挂载`/data`目录)
+- 远程数据库:MySQL版本 >= 5.7.8,PgSQL版本 >= 9.6
+
+### 部署方式
+ 
+#### 本地运行方式
+```shell
+go run main.go
+```
+#### 构造并使用Docker镜像
+```shell
+# 使用SQLite
+docker run --name mixapi -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/mixapi:/data 打包好的镜像名称:latest
+
+# 使用MySQL
+docker run --name mixapi -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/mixapi:/data 打包好的镜像名称:latest
+```
+
+## 渠道重试与缓存
+渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
+
+### 缓存设置方法
+1. `REDIS_CONN_STRING`:设置Redis作为缓存
+2. `MEMORY_CACHE_ENABLED`:启用内存缓存(设置了Redis则无需手动设置)
+
+## 接口文档
+
+详细接口文档请参考[接口文档](https://docs.newapi.pro/api):
+
+- [聊天接口(Chat)](https://docs.newapi.pro/api/openai-chat)
+- [图像接口(Image)](https://docs.newapi.pro/api/openai-image)
+- [重排序接口(Rerank)](https://docs.newapi.pro/api/jinaai-rerank)
+- [实时对话接口(Realtime)](https://docs.newapi.pro/api/openai-realtime)
+- [Claude聊天接口(messages)](https://docs.newapi.pro/api/anthropic-chat)
+
+## 相关项目
+- [One API](https://github.com/songquanpeng/one-api):原版项目
+- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持
+- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代AI一站式B/C端解决方案
+- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool):用key查询使用额度
+
+其他基于New API的项目:
+- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
+
+## 帮助支持
+
+如有问题,请参考[帮助支持](https://docs.newapi.pro/support):
+- [社区交流](https://docs.newapi.pro/support/community-interaction)
+- [反馈问题](https://docs.newapi.pro/support/feedback-issues)
+- [常见问题](https://docs.newapi.pro/support/faq)
+
+## 🌟 Star History
+
+[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date)

+ 146 - 0
bin/README_migration.md

@@ -0,0 +1,146 @@
+# Token使用次数统计字段 - SQL迁移脚本使用说明
+
+## 文件列表
+
+本目录包含了以下SQL迁移脚本文件:
+
+### 1. 完整迁移脚本
+- **`migration_token_usage_count.sql`** - 完整的迁移脚本,包含所有数据库类型的语法和回滚操作
+
+### 2. 数据库专用脚本
+- **`add_token_usage_fields_mysql.sql`** - MySQL专用简化脚本
+- **`add_token_usage_fields_sqlite.sql`** - SQLite专用简化脚本  
+- **`add_token_usage_fields_postgresql.sql`** - PostgreSQL专用简化脚本
+
+### 3. 通用脚本
+- **`add_token_usage_count_fields.sql`** - 包含所有数据库语法的通用脚本
+
+## 新增字段说明
+
+| 字段名 | 数据类型 | 默认值 | 说明 |
+|--------|----------|--------|------|
+| `daily_usage_count` | INT/INTEGER | 0 | 今日使用次数,每日自动重置 |
+| `total_usage_count` | INT/INTEGER | 0 | 总使用次数,只增不减 |
+| `last_usage_date` | VARCHAR(10)/TEXT | '' | 最后使用日期(YYYY-MM-DD格式) |
+
+## 使用方法
+
+### 自动迁移(推荐)
+如果您的项目使用了GORM自动迁移,则无需手动执行SQL脚本:
+```bash
+# 启动程序时会自动迁移
+go run main.go
+```
+
+### 手动迁移
+如果需要手动执行SQL迁移,请根据您的数据库类型选择对应脚本:
+
+#### MySQL数据库
+```bash
+mysql -u root -p your_database < add_token_usage_fields_mysql.sql
+```
+
+#### SQLite数据库  
+```bash
+sqlite3 your_database.db < add_token_usage_fields_sqlite.sql
+```
+
+#### PostgreSQL数据库
+```bash
+psql -U username -d your_database -f add_token_usage_fields_postgresql.sql
+```
+
+## 迁移前注意事项
+
+1. **备份数据库**: 执行迁移前请务必备份您的数据库
+2. **检查权限**: 确保数据库用户有ALTER TABLE权限
+3. **停止服务**: 建议在迁移期间停止API服务
+4. **测试环境**: 建议先在测试环境验证迁移脚本
+
+## 验证迁移结果
+
+### MySQL验证
+```sql
+DESCRIBE tokens;
+-- 或
+SHOW COLUMNS FROM tokens LIKE '%usage%';
+```
+
+### SQLite验证  
+```sql
+PRAGMA table_info(tokens);
+```
+
+### PostgreSQL验证
+```sql
+\d tokens
+-- 或
+SELECT column_name, data_type, is_nullable, column_default 
+FROM information_schema.columns 
+WHERE table_name = 'tokens' 
+  AND column_name LIKE '%usage%';
+```
+
+## 回滚操作
+
+如果需要回滚迁移,请参考 `migration_token_usage_count.sql` 文件中的回滚部分,或执行以下SQL:
+
+### MySQL/PostgreSQL回滚
+```sql
+ALTER TABLE tokens DROP COLUMN daily_usage_count;
+ALTER TABLE tokens DROP COLUMN total_usage_count;  
+ALTER TABLE tokens DROP COLUMN last_usage_date;
+```
+
+### SQLite回滚
+SQLite不支持DROP COLUMN,需要重建表:
+```sql
+-- 1. 创建临时表包含原有字段
+CREATE TABLE tokens_temp AS 
+SELECT id, user_id, key, status, name, created_time, accessed_time, 
+       expired_time, remain_quota, unlimited_quota, model_limits_enabled, 
+       model_limits, allow_ips, used_quota, `group`, deleted_at 
+FROM tokens;
+
+-- 2. 删除原表
+DROP TABLE tokens;
+
+-- 3. 重命名临时表
+ALTER TABLE tokens_temp RENAME TO tokens;
+
+-- 4. 重建索引和约束 (根据实际情况调整)
+```
+
+## 故障排除
+
+### 常见错误
+
+1. **字段已存在错误**
+   - 错误信息: `Duplicate column name` 或 `column already exists`
+   - 解决方案: 字段已存在,无需重复添加
+
+2. **权限不足错误**
+   - 错误信息: `Access denied` 或 `permission denied`
+   - 解决方案: 检查数据库用户权限
+
+3. **表不存在错误**
+   - 错误信息: `Table doesn't exist`
+   - 解决方案: 确认表名正确,检查数据库选择
+
+### 检查脚本执行状态
+```sql
+-- 检查字段是否添加成功
+SELECT COUNT(*) as field_count
+FROM information_schema.columns 
+WHERE table_name = 'tokens' 
+  AND column_name IN ('daily_usage_count', 'total_usage_count', 'last_usage_date');
+-- 应该返回 3
+```
+
+## 技术支持
+
+如果在迁移过程中遇到问题:
+1. 查看数据库错误日志
+2. 确认SQL语法与数据库版本兼容
+3. 验证数据库连接和权限
+4. 参考项目文档或提交issue

+ 149 - 0
bin/add_token_rate_limit_fields.sql

@@ -0,0 +1,149 @@
+-- ====================================================================
+-- Token访问频率限制功能数据库迁移脚本
+-- ====================================================================
+-- 版本: v1.0
+-- 创建日期: 2025-08-25
+-- 描述: 为tokens表添加访问频率限制功能的相关字段
+-- 
+-- 新增字段:
+-- - rate_limit_per_minute: 每分钟访问次数限制,0表示不限制
+-- - rate_limit_per_day: 每日访问次数限制,0表示不限制
+-- - last_rate_limit_reset: 最后重置时间,用于重置计数器
+-- ====================================================================
+
+-- ====================================================================
+-- 向前迁移 (UP) - 添加字段
+-- ====================================================================
+
+-- MySQL 数据库
+-- 检查字段是否已存在,避免重复添加
+SET @sql_rate_minute = 'ALTER TABLE tokens ADD COLUMN rate_limit_per_minute INT NOT NULL DEFAULT 0 COMMENT ''每分钟访问次数限制,0表示不限制''';
+SET @sql_rate_day = 'ALTER TABLE tokens ADD COLUMN rate_limit_per_day INT NOT NULL DEFAULT 0 COMMENT ''每日访问次数限制,0表示不限制''';
+SET @sql_reset_time = 'ALTER TABLE tokens ADD COLUMN last_rate_limit_reset BIGINT NOT NULL DEFAULT 0 COMMENT ''最后重置时间戳''';
+
+-- 执行添加字段的SQL (仅在字段不存在时执行)
+PREPARE stmt FROM @sql_rate_minute;
+EXECUTE stmt;
+DEALLOCATE PREPARE stmt;
+
+PREPARE stmt FROM @sql_rate_day;
+EXECUTE stmt;
+DEALLOCATE PREPARE stmt;
+
+PREPARE stmt FROM @sql_reset_time;
+EXECUTE stmt;
+DEALLOCATE PREPARE stmt;
+
+-- 为新字段添加索引 (可选,提升查询性能)
+-- CREATE INDEX idx_tokens_rate_limit_reset ON tokens(last_rate_limit_reset);
+
+-- ====================================================================
+-- PostgreSQL 数据库 (如果使用PostgreSQL,取消下面注释)
+-- ====================================================================
+/*
+-- 添加字段
+ALTER TABLE tokens ADD COLUMN IF NOT EXISTS rate_limit_per_minute INTEGER NOT NULL DEFAULT 0;
+ALTER TABLE tokens ADD COLUMN IF NOT EXISTS rate_limit_per_day INTEGER NOT NULL DEFAULT 0;
+ALTER TABLE tokens ADD COLUMN IF NOT EXISTS last_rate_limit_reset BIGINT NOT NULL DEFAULT 0;
+
+-- 添加注释
+COMMENT ON COLUMN tokens.rate_limit_per_minute IS '每分钟访问次数限制,0表示不限制';
+COMMENT ON COLUMN tokens.rate_limit_per_day IS '每日访问次数限制,0表示不限制';
+COMMENT ON COLUMN tokens.last_rate_limit_reset IS '最后重置时间戳';
+
+-- 添加索引 (可选)
+-- CREATE INDEX IF NOT EXISTS idx_tokens_rate_limit_reset ON tokens(last_rate_limit_reset);
+*/
+
+-- ====================================================================
+-- SQLite 数据库 (如果使用SQLite,取消下面注释)
+-- ====================================================================
+/*
+-- SQLite不支持IF NOT EXISTS语法,需要手动检查
+-- 添加字段
+ALTER TABLE tokens ADD COLUMN rate_limit_per_minute INTEGER NOT NULL DEFAULT 0;
+ALTER TABLE tokens ADD COLUMN rate_limit_per_day INTEGER NOT NULL DEFAULT 0;
+ALTER TABLE tokens ADD COLUMN last_rate_limit_reset INTEGER NOT NULL DEFAULT 0;
+
+-- 添加索引 (可选)
+-- CREATE INDEX IF NOT EXISTS idx_tokens_rate_limit_reset ON tokens(last_rate_limit_reset);
+*/
+
+-- ====================================================================
+-- 向后迁移 (DOWN) - 删除字段 (回滚操作)
+-- ====================================================================
+/*
+-- 如果需要回滚,请执行以下SQL语句
+
+-- MySQL 回滚
+ALTER TABLE tokens DROP COLUMN IF EXISTS rate_limit_per_minute;
+ALTER TABLE tokens DROP COLUMN IF EXISTS rate_limit_per_day;
+ALTER TABLE tokens DROP COLUMN IF EXISTS last_rate_limit_reset;
+
+-- PostgreSQL 回滚
+-- ALTER TABLE tokens DROP COLUMN IF EXISTS rate_limit_per_minute;
+-- ALTER TABLE tokens DROP COLUMN IF EXISTS rate_limit_per_day;
+-- ALTER TABLE tokens DROP COLUMN IF EXISTS last_rate_limit_reset;
+
+-- SQLite 回滚
+-- SQLite不支持DROP COLUMN,需要重建表:
+-- 1. 创建临时表包含原有字段
+-- CREATE TABLE tokens_temp AS 
+-- SELECT id, user_id, key, status, name, created_time, accessed_time, 
+--        expired_time, remain_quota, unlimited_quota, model_limits_enabled, 
+--        model_limits, allow_ips, used_quota, `group`, deleted_at,
+--        daily_usage_count, total_usage_count, last_usage_date
+-- FROM tokens;
+
+-- 2. 删除原表
+-- DROP TABLE tokens;
+
+-- 3. 重命名临时表
+-- ALTER TABLE tokens_temp RENAME TO tokens;
+
+-- 4. 重建索引和约束 (根据实际情况调整)
+*/
+
+-- ====================================================================
+-- 验证脚本 - 检查字段是否成功添加
+-- ====================================================================
+/*
+-- 检查字段是否存在 (MySQL)
+SELECT 
+    COLUMN_NAME,
+    DATA_TYPE,
+    IS_NULLABLE,
+    COLUMN_DEFAULT,
+    COLUMN_COMMENT
+FROM INFORMATION_SCHEMA.COLUMNS 
+WHERE TABLE_SCHEMA = DATABASE() 
+  AND TABLE_NAME = 'tokens' 
+  AND COLUMN_NAME IN ('rate_limit_per_minute', 'rate_limit_per_day', 'last_rate_limit_reset');
+
+-- 检查字段是否存在 (PostgreSQL)
+-- SELECT column_name, data_type, is_nullable, column_default 
+-- FROM information_schema.columns 
+-- WHERE table_name = 'tokens' 
+--   AND column_name IN ('rate_limit_per_minute', 'rate_limit_per_day', 'last_rate_limit_reset');
+
+-- 检查字段是否存在 (SQLite)
+-- PRAGMA table_info(tokens);
+*/
+
+-- ====================================================================
+-- 使用说明
+-- ====================================================================
+/*
+1. 根据您的数据库类型,选择对应的SQL语句执行
+2. 默认启用MySQL语法,如使用其他数据库请取消相应注释
+3. 建议在执行前备份数据库
+4. 执行后可运行验证脚本确认字段添加成功
+5. 如需回滚,请执行"向后迁移"部分的SQL语句
+
+新字段说明:
+- rate_limit_per_minute: 每分钟访问次数限制,0表示不限制
+- rate_limit_per_day: 每日访问次数限制,0表示不限制
+- last_rate_limit_reset: 最后重置时间戳,用于计算是否需要重置计数器
+
+这些字段会在Token验证时被检查,实现访问频率限制功能。
+*/

+ 27 - 0
bin/add_token_usage_count_fields.sql

@@ -0,0 +1,27 @@
+-- Token使用次数统计字段迁移脚本
+-- 添加日期: 2025-08-25
+-- 功能: 为tokens表添加今日次数、总次数和最后使用日期字段
+
+-- MySQL 语法
+-- 添加今日使用次数字段
+ALTER TABLE tokens ADD COLUMN daily_usage_count INT NOT NULL DEFAULT 0 COMMENT '今日使用次数';
+
+-- 添加总使用次数字段  
+ALTER TABLE tokens ADD COLUMN total_usage_count INT NOT NULL DEFAULT 0 COMMENT '总使用次数';
+
+-- 添加最后使用日期字段
+ALTER TABLE tokens ADD COLUMN last_usage_date VARCHAR(10) NOT NULL DEFAULT '' COMMENT '最后使用日期(YYYY-MM-DD)';
+
+-- PostgreSQL 语法 (如果使用PostgreSQL数据库)
+-- ALTER TABLE tokens ADD COLUMN daily_usage_count INTEGER NOT NULL DEFAULT 0;
+-- ALTER TABLE tokens ADD COLUMN total_usage_count INTEGER NOT NULL DEFAULT 0; 
+-- ALTER TABLE tokens ADD COLUMN last_usage_date VARCHAR(10) NOT NULL DEFAULT '';
+
+-- COMMENT ON COLUMN tokens.daily_usage_count IS '今日使用次数';
+-- COMMENT ON COLUMN tokens.total_usage_count IS '总使用次数';
+-- COMMENT ON COLUMN tokens.last_usage_date IS '最后使用日期(YYYY-MM-DD)';
+
+-- SQLite 语法 (如果使用SQLite数据库)
+-- ALTER TABLE tokens ADD COLUMN daily_usage_count INTEGER NOT NULL DEFAULT 0;
+-- ALTER TABLE tokens ADD COLUMN total_usage_count INTEGER NOT NULL DEFAULT 0;
+-- ALTER TABLE tokens ADD COLUMN last_usage_date TEXT NOT NULL DEFAULT '';

+ 11 - 0
bin/add_token_usage_fields_mysql.sql

@@ -0,0 +1,11 @@
+-- 简化版Token使用次数统计字段迁移脚本
+-- 适用于MySQL数据库,可直接执行
+
+-- 添加今日使用次数字段
+ALTER TABLE tokens ADD COLUMN daily_usage_count INT NOT NULL DEFAULT 0 COMMENT '今日使用次数';
+
+-- 添加总使用次数字段  
+ALTER TABLE tokens ADD COLUMN total_usage_count INT NOT NULL DEFAULT 0 COMMENT '总使用次数';
+
+-- 添加最后使用日期字段
+ALTER TABLE tokens ADD COLUMN last_usage_date VARCHAR(10) NOT NULL DEFAULT '' COMMENT '最后使用日期(YYYY-MM-DD)';

+ 15 - 0
bin/add_token_usage_fields_postgresql.sql

@@ -0,0 +1,15 @@
+-- PostgreSQL数据库Token使用次数统计字段迁移脚本
+
+-- 添加今日使用次数字段
+ALTER TABLE tokens ADD COLUMN IF NOT EXISTS daily_usage_count INTEGER NOT NULL DEFAULT 0;
+
+-- 添加总使用次数字段
+ALTER TABLE tokens ADD COLUMN IF NOT EXISTS total_usage_count INTEGER NOT NULL DEFAULT 0;
+
+-- 添加最后使用日期字段
+ALTER TABLE tokens ADD COLUMN IF NOT EXISTS last_usage_date VARCHAR(10) NOT NULL DEFAULT '';
+
+-- 添加字段注释
+COMMENT ON COLUMN tokens.daily_usage_count IS '今日使用次数';
+COMMENT ON COLUMN tokens.total_usage_count IS '总使用次数';
+COMMENT ON COLUMN tokens.last_usage_date IS '最后使用日期(YYYY-MM-DD)';

+ 10 - 0
bin/add_token_usage_fields_sqlite.sql

@@ -0,0 +1,10 @@
+-- SQLite数据库Token使用次数统计字段迁移脚本
+
+-- 添加今日使用次数字段
+ALTER TABLE tokens ADD COLUMN daily_usage_count INTEGER NOT NULL DEFAULT 0;
+
+-- 添加总使用次数字段
+ALTER TABLE tokens ADD COLUMN total_usage_count INTEGER NOT NULL DEFAULT 0;
+
+-- 添加最后使用日期字段
+ALTER TABLE tokens ADD COLUMN last_usage_date TEXT NOT NULL DEFAULT '';

+ 23 - 0
bin/add_user_input_field.sql

@@ -0,0 +1,23 @@
+-- 为logs表添加user_input字段来记录用户输入内容
+-- 适用于 MySQL、SQLite、PostgreSQL 数据库
+
+-- MySQL 语法
+-- ALTER TABLE logs ADD COLUMN user_input TEXT COMMENT '用户输入内容';
+
+-- SQLite 语法
+-- ALTER TABLE logs ADD COLUMN user_input TEXT;
+
+-- PostgreSQL 语法
+-- ALTER TABLE logs ADD COLUMN user_input TEXT;
+
+-- 通用语法(兼容多数据库)
+ALTER TABLE logs ADD COLUMN user_input TEXT;
+
+-- 为新字段添加索引(可选,如果需要搜索用户输入内容的话)
+-- CREATE INDEX idx_logs_user_input ON logs(user_input(100));
+
+-- 说明:
+-- 1. user_input 字段用于存储用户通过API发送的实际输入内容(如messages内容)
+-- 2. 原有的 content 字段继续用于存储计费相关信息
+-- 3. 该字段主要记录消费类型(type=2)和错误类型(type=5)日志的用户输入
+-- 4. 字段类型使用TEXT以支持长文本内容

+ 90 - 0
bin/create_token_usage_logs_table.sql

@@ -0,0 +1,90 @@
+-- ====================================================================
+-- Token使用日志表迁移脚本 (用于访问频率限制)
+-- ====================================================================
+-- 版本: v1.0
+-- 创建日期: 2025-08-25
+-- 描述: 创建token_usage_logs表,用于记录令牌使用情况以支持访问频率限制功能
+-- ====================================================================
+
+-- ====================================================================
+-- MySQL 语法
+-- ====================================================================
+CREATE TABLE IF NOT EXISTS token_usage_logs (
+    id INT AUTO_INCREMENT PRIMARY KEY,
+    token_id INT NOT NULL COMMENT '令牌ID',
+    created_at BIGINT NOT NULL COMMENT '创建时间戳',
+    INDEX idx_token_created (token_id, created_at),
+    INDEX idx_created (created_at)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='Token使用日志表(用于访问频率限制)';
+
+-- ====================================================================
+-- PostgreSQL 语法 (如果使用PostgreSQL,取消下面注释)
+-- ====================================================================
+/*
+CREATE TABLE IF NOT EXISTS token_usage_logs (
+    id SERIAL PRIMARY KEY,
+    token_id INTEGER NOT NULL,
+    created_at BIGINT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS idx_token_usage_logs_token_created ON token_usage_logs(token_id, created_at);
+CREATE INDEX IF NOT EXISTS idx_token_usage_logs_created ON token_usage_logs(created_at);
+
+COMMENT ON TABLE token_usage_logs IS 'Token使用日志表(用于访问频率限制)';
+COMMENT ON COLUMN token_usage_logs.token_id IS '令牌ID';
+COMMENT ON COLUMN token_usage_logs.created_at IS '创建时间戳';
+*/
+
+-- ====================================================================
+-- SQLite 语法 (如果使用SQLite,取消下面注释)
+-- ====================================================================
+/*
+CREATE TABLE IF NOT EXISTS token_usage_logs (
+    id INTEGER PRIMARY KEY AUTOINCREMENT,
+    token_id INTEGER NOT NULL,
+    created_at INTEGER NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS idx_token_usage_logs_token_created ON token_usage_logs(token_id, created_at);
+CREATE INDEX IF NOT EXISTS idx_token_usage_logs_created ON token_usage_logs(created_at);
+*/
+
+-- ====================================================================
+-- 清理旧数据的脚本 (可选,定期执行以减少存储空间)
+-- ====================================================================
+/*
+-- 删除30天前的记录 (建议设置定时任务执行)
+DELETE FROM token_usage_logs WHERE created_at < UNIX_TIMESTAMP(DATE_SUB(NOW(), INTERVAL 30 DAY));
+
+-- PostgreSQL版本
+-- DELETE FROM token_usage_logs WHERE created_at < EXTRACT(EPOCH FROM NOW() - INTERVAL '30 days');
+
+-- SQLite版本
+-- DELETE FROM token_usage_logs WHERE created_at < strftime('%s', 'now', '-30 days');
+*/
+
+-- ====================================================================
+-- 使用说明
+-- ====================================================================
+/*
+此表用于记录令牌的每次使用,支持访问频率限制功能:
+
+1. 表结构:
+   - id: 主键,自增
+   - token_id: 令牌ID,关联tokens表
+   - created_at: 创建时间戳
+
+2. 索引:
+   - idx_token_created: 复合索引,用于快速查询某个令牌在特定时间范围内的使用次数
+   - idx_created: 时间索引,用于定期清理旧数据
+
+3. 使用场景:
+   - 分钟级限制:查询当前分钟内的记录数
+   - 日级限制:查询当天内的记录数
+   - 定期清理:删除过期记录以节省存储空间
+
+4. 性能考虑:
+   - 该表会频繁插入,但查询相对较少
+   - 索引设计优化了查询性能
+   - 建议定期清理旧数据
+*/

+ 93 - 0
bin/create_usage_statistics_table.sql

@@ -0,0 +1,93 @@
+-- 用量日统计汇总表迁移脚本
+-- 创建日期: 2025-08-25
+-- 功能: 创建用量统计汇总表,用于存储按日期、令牌、模型分组的统计数据
+
+-- MySQL 语法
+CREATE TABLE IF NOT EXISTS usage_statistics (
+    id INT AUTO_INCREMENT PRIMARY KEY,
+    date VARCHAR(10) NOT NULL COMMENT '统计日期(YYYY-MM-DD)',
+    token_id INT NOT NULL COMMENT '令牌ID',
+    token_name VARCHAR(255) NOT NULL DEFAULT '' COMMENT '令牌名称',
+    model_name VARCHAR(255) NOT NULL COMMENT '模型名称',
+    total_requests INT NOT NULL DEFAULT 0 COMMENT '总请求次数',
+    successful_requests INT NOT NULL DEFAULT 0 COMMENT '成功请求次数',
+    failed_requests INT NOT NULL DEFAULT 0 COMMENT '失败请求次数',
+    total_tokens INT NOT NULL DEFAULT 0 COMMENT '总Token消耗',
+    prompt_tokens INT NOT NULL DEFAULT 0 COMMENT '提示Token数',
+    completion_tokens INT NOT NULL DEFAULT 0 COMMENT '完成Token数',
+    total_quota INT NOT NULL DEFAULT 0 COMMENT '总额度消耗',
+    created_time BIGINT NOT NULL COMMENT '创建时间戳',
+    updated_time BIGINT NOT NULL COMMENT '更新时间戳',
+    INDEX idx_date (date),
+    INDEX idx_token_id (token_id),
+    INDEX idx_model_name (model_name),
+    INDEX idx_date_token_model (date, token_id, model_name),
+    UNIQUE KEY uk_date_token_model (date, token_id, model_name)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='用量统计汇总表';
+
+-- PostgreSQL 语法 (如果使用PostgreSQL)
+/*
+CREATE TABLE IF NOT EXISTS usage_statistics (
+    id SERIAL PRIMARY KEY,
+    date VARCHAR(10) NOT NULL,
+    token_id INTEGER NOT NULL,
+    token_name VARCHAR(255) NOT NULL DEFAULT '',
+    model_name VARCHAR(255) NOT NULL,
+    total_requests INTEGER NOT NULL DEFAULT 0,
+    successful_requests INTEGER NOT NULL DEFAULT 0,
+    failed_requests INTEGER NOT NULL DEFAULT 0,
+    total_tokens INTEGER NOT NULL DEFAULT 0,
+    prompt_tokens INTEGER NOT NULL DEFAULT 0,
+    completion_tokens INTEGER NOT NULL DEFAULT 0,
+    total_quota INTEGER NOT NULL DEFAULT 0,
+    created_time BIGINT NOT NULL,
+    updated_time BIGINT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS idx_usage_statistics_date ON usage_statistics(date);
+CREATE INDEX IF NOT EXISTS idx_usage_statistics_token_id ON usage_statistics(token_id);
+CREATE INDEX IF NOT EXISTS idx_usage_statistics_model_name ON usage_statistics(model_name);
+CREATE INDEX IF NOT EXISTS idx_usage_statistics_date_token_model ON usage_statistics(date, token_id, model_name);
+CREATE UNIQUE INDEX IF NOT EXISTS uk_usage_statistics_date_token_model ON usage_statistics(date, token_id, model_name);
+
+COMMENT ON TABLE usage_statistics IS '用量统计汇总表';
+COMMENT ON COLUMN usage_statistics.date IS '统计日期(YYYY-MM-DD)';
+COMMENT ON COLUMN usage_statistics.token_id IS '令牌ID';
+COMMENT ON COLUMN usage_statistics.token_name IS '令牌名称';
+COMMENT ON COLUMN usage_statistics.model_name IS '模型名称';
+COMMENT ON COLUMN usage_statistics.total_requests IS '总请求次数';
+COMMENT ON COLUMN usage_statistics.successful_requests IS '成功请求次数';
+COMMENT ON COLUMN usage_statistics.failed_requests IS '失败请求次数';
+COMMENT ON COLUMN usage_statistics.total_tokens IS '总Token消耗';
+COMMENT ON COLUMN usage_statistics.prompt_tokens IS '提示Token数';
+COMMENT ON COLUMN usage_statistics.completion_tokens IS '完成Token数';
+COMMENT ON COLUMN usage_statistics.total_quota IS '总额度消耗';
+COMMENT ON COLUMN usage_statistics.created_time IS '创建时间戳';
+COMMENT ON COLUMN usage_statistics.updated_time IS '更新时间戳';
+*/
+
+-- SQLite 语法 (如果使用SQLite)
+/*
+CREATE TABLE IF NOT EXISTS usage_statistics (
+    id INTEGER PRIMARY KEY AUTOINCREMENT,
+    date TEXT NOT NULL,
+    token_id INTEGER NOT NULL,
+    token_name TEXT NOT NULL DEFAULT '',
+    model_name TEXT NOT NULL,
+    total_requests INTEGER NOT NULL DEFAULT 0,
+    successful_requests INTEGER NOT NULL DEFAULT 0,
+    failed_requests INTEGER NOT NULL DEFAULT 0,
+    total_tokens INTEGER NOT NULL DEFAULT 0,
+    prompt_tokens INTEGER NOT NULL DEFAULT 0,
+    completion_tokens INTEGER NOT NULL DEFAULT 0,
+    total_quota INTEGER NOT NULL DEFAULT 0,
+    created_time INTEGER NOT NULL,
+    updated_time INTEGER NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS idx_usage_statistics_date ON usage_statistics(date);
+CREATE INDEX IF NOT EXISTS idx_usage_statistics_token_id ON usage_statistics(token_id);
+CREATE INDEX IF NOT EXISTS idx_usage_statistics_model_name ON usage_statistics(model_name);
+CREATE INDEX IF NOT EXISTS idx_usage_statistics_date_token_model ON usage_statistics(date, token_id, model_name);
+CREATE UNIQUE INDEX IF NOT EXISTS uk_usage_statistics_date_token_model ON usage_statistics(date, token_id, model_name);
+*/

+ 143 - 0
bin/migration_token_usage_count.sql

@@ -0,0 +1,143 @@
+-- ====================================================================
+-- Token使用次数统计功能数据库迁移脚本
+-- ====================================================================
+-- 版本: v1.0
+-- 创建日期: 2025-08-25
+-- 作者: Assistant
+-- 描述: 为tokens表添加使用次数统计功能的相关字段
+-- 
+-- 新增字段:
+-- - daily_usage_count: 今日使用次数,每日自动重置
+-- - total_usage_count: 总使用次数,只增不减
+-- - last_usage_date: 最后使用日期,用于判断是否需要重置今日次数
+-- ====================================================================
+
+-- ====================================================================
+-- 向前迁移 (UP) - 添加字段
+-- ====================================================================
+
+-- MySQL 数据库
+-- 检查字段是否已存在,避免重复添加
+SET @sql_daily = 'ALTER TABLE tokens ADD COLUMN daily_usage_count INT NOT NULL DEFAULT 0 COMMENT ''今日使用次数''';
+SET @sql_total = 'ALTER TABLE tokens ADD COLUMN total_usage_count INT NOT NULL DEFAULT 0 COMMENT ''总使用次数''';
+SET @sql_date = 'ALTER TABLE tokens ADD COLUMN last_usage_date VARCHAR(10) NOT NULL DEFAULT '''' COMMENT ''最后使用日期(YYYY-MM-DD)''';
+
+-- 执行添加字段的SQL (仅在字段不存在时执行)
+PREPARE stmt FROM @sql_daily;
+EXECUTE stmt;
+DEALLOCATE PREPARE stmt;
+
+PREPARE stmt FROM @sql_total;
+EXECUTE stmt;
+DEALLOCATE PREPARE stmt;
+
+PREPARE stmt FROM @sql_date;
+EXECUTE stmt;
+DEALLOCATE PREPARE stmt;
+
+-- 为新字段添加索引 (可选,提升查询性能)
+-- CREATE INDEX idx_tokens_last_usage_date ON tokens(last_usage_date);
+-- CREATE INDEX idx_tokens_total_usage_count ON tokens(total_usage_count);
+
+-- ====================================================================
+-- PostgreSQL 数据库 (如果使用PostgreSQL,取消下面注释)
+-- ====================================================================
+/*
+-- 添加字段
+ALTER TABLE tokens ADD COLUMN IF NOT EXISTS daily_usage_count INTEGER NOT NULL DEFAULT 0;
+ALTER TABLE tokens ADD COLUMN IF NOT EXISTS total_usage_count INTEGER NOT NULL DEFAULT 0;
+ALTER TABLE tokens ADD COLUMN IF NOT EXISTS last_usage_date VARCHAR(10) NOT NULL DEFAULT '';
+
+-- 添加注释
+COMMENT ON COLUMN tokens.daily_usage_count IS '今日使用次数';
+COMMENT ON COLUMN tokens.total_usage_count IS '总使用次数';
+COMMENT ON COLUMN tokens.last_usage_date IS '最后使用日期(YYYY-MM-DD)';
+
+-- 添加索引 (可选)
+-- CREATE INDEX IF NOT EXISTS idx_tokens_last_usage_date ON tokens(last_usage_date);
+-- CREATE INDEX IF NOT EXISTS idx_tokens_total_usage_count ON tokens(total_usage_count);
+*/
+
+-- ====================================================================
+-- SQLite 数据库 (如果使用SQLite,取消下面注释)
+-- ====================================================================
+/*
+-- SQLite不支持IF NOT EXISTS语法,需要手动检查
+-- 添加字段
+ALTER TABLE tokens ADD COLUMN daily_usage_count INTEGER NOT NULL DEFAULT 0;
+ALTER TABLE tokens ADD COLUMN total_usage_count INTEGER NOT NULL DEFAULT 0;
+ALTER TABLE tokens ADD COLUMN last_usage_date TEXT NOT NULL DEFAULT '';
+
+-- 添加索引 (可选)
+-- CREATE INDEX IF NOT EXISTS idx_tokens_last_usage_date ON tokens(last_usage_date);
+-- CREATE INDEX IF NOT EXISTS idx_tokens_total_usage_count ON tokens(total_usage_count);
+*/
+
+-- ====================================================================
+-- 向后迁移 (DOWN) - 删除字段 (回滚操作)
+-- ====================================================================
+/*
+-- 如果需要回滚,请执行以下SQL语句
+
+-- MySQL 回滚
+ALTER TABLE tokens DROP COLUMN IF EXISTS daily_usage_count;
+ALTER TABLE tokens DROP COLUMN IF EXISTS total_usage_count;
+ALTER TABLE tokens DROP COLUMN IF EXISTS last_usage_date;
+
+-- PostgreSQL 回滚
+-- ALTER TABLE tokens DROP COLUMN IF EXISTS daily_usage_count;
+-- ALTER TABLE tokens DROP COLUMN IF EXISTS total_usage_count;
+-- ALTER TABLE tokens DROP COLUMN IF EXISTS last_usage_date;
+
+-- SQLite 回滚 (SQLite不支持DROP COLUMN,需要重建表)
+-- 1. 创建备份表
+-- CREATE TABLE tokens_backup AS SELECT id, user_id, key, status, name, created_time, accessed_time, expired_time, remain_quota, unlimited_quota, model_limits_enabled, model_limits, allow_ips, used_quota, group FROM tokens;
+-- 2. 删除原表
+-- DROP TABLE tokens;
+-- 3. 重命名备份表
+-- ALTER TABLE tokens_backup RENAME TO tokens;
+*/
+
+-- ====================================================================
+-- 验证脚本 - 检查字段是否成功添加
+-- ====================================================================
+/*
+-- 检查字段是否存在 (MySQL)
+SELECT 
+    COLUMN_NAME,
+    DATA_TYPE,
+    IS_NULLABLE,
+    COLUMN_DEFAULT,
+    COLUMN_COMMENT
+FROM INFORMATION_SCHEMA.COLUMNS 
+WHERE TABLE_SCHEMA = DATABASE() 
+  AND TABLE_NAME = 'tokens' 
+  AND COLUMN_NAME IN ('daily_usage_count', 'total_usage_count', 'last_usage_date');
+
+-- 检查字段是否存在 (PostgreSQL)
+-- SELECT column_name, data_type, is_nullable, column_default 
+-- FROM information_schema.columns 
+-- WHERE table_name = 'tokens' 
+--   AND column_name IN ('daily_usage_count', 'total_usage_count', 'last_usage_date');
+
+-- 检查字段是否存在 (SQLite)
+-- PRAGMA table_info(tokens);
+*/
+
+-- ====================================================================
+-- 使用说明
+-- ====================================================================
+/*
+1. 根据您的数据库类型,选择对应的SQL语句执行
+2. 默认启用MySQL语法,如使用其他数据库请取消相应注释
+3. 建议在执行前备份数据库
+4. 执行后可运行验证脚本确认字段添加成功
+5. 如需回滚,请执行"向后迁移"部分的SQL语句
+
+新字段说明:
+- daily_usage_count: 记录Token今日使用次数,每天首次使用时重置
+- total_usage_count: 记录Token总使用次数,永不重置
+- last_usage_date: 记录最后使用日期,格式为YYYY-MM-DD,用于判断是否跨天
+
+这些字段会在Token被使用时自动更新,无需手动维护。
+*/

+ 6 - 0
bin/migration_v0.2-v0.3.sql

@@ -0,0 +1,6 @@
+UPDATE users
+SET quota = quota + (
+    SELECT SUM(remain_quota)
+    FROM tokens
+    WHERE tokens.user_id = users.id
+)

+ 17 - 0
bin/migration_v0.3-v0.4.sql

@@ -0,0 +1,17 @@
+INSERT INTO abilities (`group`, model, channel_id, enabled)
+SELECT c.`group`, m.model, c.id, 1
+FROM channels c
+CROSS JOIN (
+    SELECT 'gpt-3.5-turbo' AS model UNION ALL
+    SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL
+    SELECT 'gpt-4' AS model UNION ALL
+    SELECT 'gpt-4-0314' AS model
+) AS m
+WHERE c.status = 1
+  AND NOT EXISTS (
+    SELECT 1
+    FROM abilities a
+    WHERE a.`group` = c.`group`
+      AND a.model = m.model
+      AND a.channel_id = c.id
+);

+ 51 - 0
bin/test_usage_statistics_data.sql

@@ -0,0 +1,51 @@
+-- 测试用量统计功能的SQL脚本
+-- 执行此脚本可以插入一些测试数据来验证用量统计功能
+
+-- 插入测试数据到usage_statistics表
+INSERT INTO usage_statistics (
+    date, token_id, token_name, model_name, 
+    total_requests, successful_requests, failed_requests,
+    total_tokens, prompt_tokens, completion_tokens, total_quota,
+    created_time, updated_time
+) VALUES 
+-- 今天的数据
+('2025-08-25', 1, 'test-token-1', 'gpt-3.5-turbo', 10, 8, 2, 5000, 3000, 2000, 1000, 1724568000, 1724568000),
+('2025-08-25', 1, 'test-token-1', 'gpt-4', 5, 5, 0, 8000, 5000, 3000, 2000, 1724568000, 1724568000),
+('2025-08-25', 2, 'test-token-2', 'gpt-3.5-turbo', 15, 12, 3, 7500, 4500, 3000, 1500, 1724568000, 1724568000),
+
+-- 昨天的数据  
+('2025-08-24', 1, 'test-token-1', 'gpt-3.5-turbo', 20, 18, 2, 10000, 6000, 4000, 2000, 1724481600, 1724481600),
+('2025-08-24', 1, 'test-token-1', 'gpt-4', 8, 7, 1, 12000, 7000, 5000, 3000, 1724481600, 1724481600),
+('2025-08-24', 2, 'test-token-2', 'gpt-3.5-turbo', 12, 10, 2, 6000, 3600, 2400, 1200, 1724481600, 1724481600),
+
+-- 前天的数据
+('2025-08-23', 1, 'test-token-1', 'gpt-3.5-turbo', 25, 22, 3, 12500, 7500, 5000, 2500, 1724395200, 1724395200),
+('2025-08-23', 2, 'test-token-2', 'gpt-4', 6, 5, 1, 9000, 5400, 3600, 1800, 1724395200, 1724395200),
+
+-- 一周前的数据
+('2025-08-18', 1, 'test-token-1', 'gpt-3.5-turbo', 30, 28, 2, 15000, 9000, 6000, 3000, 1723939200, 1723939200),
+('2025-08-18', 2, 'test-token-2', 'gpt-3.5-turbo', 18, 15, 3, 9000, 5400, 3600, 1800, 1723939200, 1723939200);
+
+-- 查询插入的数据验证
+SELECT 
+    date,
+    token_name,
+    model_name,
+    total_requests,
+    successful_requests,
+    failed_requests,
+    total_tokens,
+    total_quota
+FROM usage_statistics 
+ORDER BY date DESC, token_id ASC, model_name ASC;
+
+-- 查询汇总统计
+SELECT 
+    SUM(total_requests) as total_requests,
+    SUM(successful_requests) as successful_requests,
+    SUM(failed_requests) as failed_requests,
+    ROUND(SUM(successful_requests) * 100.0 / SUM(total_requests), 2) as success_rate,
+    SUM(total_tokens) as total_tokens,
+    SUM(total_quota) as total_quota
+FROM usage_statistics 
+WHERE date >= '2025-08-18';

+ 40 - 0
bin/time_test.sh

@@ -0,0 +1,40 @@
+#!/bin/bash
+
+if [ $# -lt 3 ]; then
+  echo "Usage: time_test.sh <domain> <key> <count> [<model>]"
+  exit 1
+fi
+
+domain=$1
+key=$2
+count=$3
+model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo
+
+total_time=0
+times=()
+
+for ((i=1; i<=count; i++)); do
+  result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \
+           https://"$domain"/v1/chat/completions \
+           -H "Content-Type: application/json" \
+           -H "Authorization: Bearer $key" \
+           -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}')
+  http_code=$(echo "$result" | awk '{print $1}')
+  time=$(echo "$result" | awk '{print $2}')
+  echo "HTTP status code: $http_code, Time taken: $time"
+  total_time=$(bc <<< "$total_time + $time")
+  times+=("$time")
+done
+
+average_time=$(echo "scale=4; $total_time / $count" | bc)
+
+sum_of_squares=0
+for time in "${times[@]}"; do
+  difference=$(echo "scale=4; $time - $average_time" | bc)
+  square=$(echo "scale=4; $difference * $difference" | bc)
+  sum_of_squares=$(echo "scale=4; $sum_of_squares + $square" | bc)
+done
+
+standard_deviation=$(echo "scale=4; sqrt($sum_of_squares / $count)" | bc)
+
+echo "Average time: $average_time±$standard_deviation"

+ 73 - 0
common/api_type.go

@@ -0,0 +1,73 @@
+package common
+
+import "one-api/constant"
+
+func ChannelType2APIType(channelType int) (int, bool) {
+	apiType := -1
+	switch channelType {
+	case constant.ChannelTypeOpenAI:
+		apiType = constant.APITypeOpenAI
+	case constant.ChannelTypeAnthropic:
+		apiType = constant.APITypeAnthropic
+	case constant.ChannelTypeBaidu:
+		apiType = constant.APITypeBaidu
+	case constant.ChannelTypePaLM:
+		apiType = constant.APITypePaLM
+	case constant.ChannelTypeZhipu:
+		apiType = constant.APITypeZhipu
+	case constant.ChannelTypeAli:
+		apiType = constant.APITypeAli
+	case constant.ChannelTypeXunfei:
+		apiType = constant.APITypeXunfei
+	case constant.ChannelTypeAIProxyLibrary:
+		apiType = constant.APITypeAIProxyLibrary
+	case constant.ChannelTypeTencent:
+		apiType = constant.APITypeTencent
+	case constant.ChannelTypeGemini:
+		apiType = constant.APITypeGemini
+	case constant.ChannelTypeZhipu_v4:
+		apiType = constant.APITypeZhipuV4
+	case constant.ChannelTypeOllama:
+		apiType = constant.APITypeOllama
+	case constant.ChannelTypePerplexity:
+		apiType = constant.APITypePerplexity
+	case constant.ChannelTypeAws:
+		apiType = constant.APITypeAws
+	case constant.ChannelTypeCohere:
+		apiType = constant.APITypeCohere
+	case constant.ChannelTypeDify:
+		apiType = constant.APITypeDify
+	case constant.ChannelTypeJina:
+		apiType = constant.APITypeJina
+	case constant.ChannelCloudflare:
+		apiType = constant.APITypeCloudflare
+	case constant.ChannelTypeSiliconFlow:
+		apiType = constant.APITypeSiliconFlow
+	case constant.ChannelTypeVertexAi:
+		apiType = constant.APITypeVertexAi
+	case constant.ChannelTypeMistral:
+		apiType = constant.APITypeMistral
+	case constant.ChannelTypeDeepSeek:
+		apiType = constant.APITypeDeepSeek
+	case constant.ChannelTypeMokaAI:
+		apiType = constant.APITypeMokaAI
+	case constant.ChannelTypeVolcEngine:
+		apiType = constant.APITypeVolcEngine
+	case constant.ChannelTypeBaiduV2:
+		apiType = constant.APITypeBaiduV2
+	case constant.ChannelTypeOpenRouter:
+		apiType = constant.APITypeOpenRouter
+	case constant.ChannelTypeXinference:
+		apiType = constant.APITypeXinference
+	case constant.ChannelTypeXai:
+		apiType = constant.APITypeXai
+	case constant.ChannelTypeCoze:
+		apiType = constant.APITypeCoze
+	case constant.ChannelTypeJimeng:
+		apiType = constant.APITypeJimeng
+	}
+	if apiType == -1 {
+		return constant.APITypeOpenAI, false
+	}
+	return apiType, true
+}

+ 201 - 0
common/constants.go

@@ -0,0 +1,201 @@
+package common
+
+import (
+	//"os"
+	//"strconv"
+	"sync"
+	"time"
+
+	"github.com/google/uuid"
+)
+
+var StartTime = time.Now().Unix() // unit: second
+var Version = "v0.0.0"            // this hard coding will be replaced automatically when building, no need to manually change
+var SystemName = "New API"
+var Footer = ""
+var Logo = ""
+var TopUpLink = ""
+
+// var ChatLink = ""
+// var ChatLink2 = ""
+var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
+var DisplayInCurrencyEnabled = true
+var DisplayTokenStatEnabled = true
+var DrawingEnabled = true
+var TaskEnabled = true
+var DataExportEnabled = true
+var DataExportInterval = 5         // unit: minute
+var DataExportDefaultTime = "hour" // unit: minute
+var DefaultCollapseSidebar = false // default value of collapse sidebar
+
+// Any options with "Secret", "Token" in its key won't be return by GetOptions
+
+var SessionSecret = uuid.New().String()
+var CryptoSecret = uuid.New().String()
+
+var OptionMap map[string]string
+var OptionMapRWMutex sync.RWMutex
+
+var ItemsPerPage = 10
+var MaxRecentItems = 100
+
+var PasswordLoginEnabled = true
+var PasswordRegisterEnabled = true
+var EmailVerificationEnabled = false
+var GitHubOAuthEnabled = false
+var LinuxDOOAuthEnabled = false
+var WeChatAuthEnabled = false
+var TelegramOAuthEnabled = false
+var TurnstileCheckEnabled = false
+var RegisterEnabled = true
+
+var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
+var EmailAliasRestrictionEnabled = false  // 是否启用邮箱别名限制
+var EmailDomainWhitelist = []string{
+	"gmail.com",
+	"163.com",
+	"126.com",
+	"qq.com",
+	"outlook.com",
+	"hotmail.com",
+	"icloud.com",
+	"yahoo.com",
+	"foxmail.com",
+}
+var EmailLoginAuthServerList = []string{
+	"smtp.sendcloud.net",
+	"smtp.azurecomm.net",
+}
+
+var DebugEnabled bool
+var MemoryCacheEnabled bool
+
+var LogConsumeEnabled = true
+
+var SMTPServer = ""
+var SMTPPort = 587
+var SMTPSSLEnabled = false
+var SMTPAccount = ""
+var SMTPFrom = ""
+var SMTPToken = ""
+
+var GitHubClientId = ""
+var GitHubClientSecret = ""
+var LinuxDOClientId = ""
+var LinuxDOClientSecret = ""
+
+var WeChatServerAddress = ""
+var WeChatServerToken = ""
+var WeChatAccountQRCodeImageURL = ""
+
+var TurnstileSiteKey = ""
+var TurnstileSecretKey = ""
+
+var TelegramBotToken = ""
+var TelegramBotName = ""
+
+var QuotaForNewUser = 0
+var QuotaForInviter = 0
+var QuotaForInvitee = 0
+var ChannelDisableThreshold = 5.0
+var AutomaticDisableChannelEnabled = false
+var AutomaticEnableChannelEnabled = false
+var QuotaRemindThreshold = 1000
+var PreConsumedQuota = 500
+
+var RetryTimes = 0
+
+//var RootUserEmail = ""
+
+var IsMasterNode bool
+
+var requestInterval int
+var RequestInterval time.Duration
+
+var SyncFrequency int // unit is second
+
+var BatchUpdateEnabled = false
+var BatchUpdateInterval int
+
+var RelayTimeout int // unit is second
+
+var GeminiSafetySetting string
+
+// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
+var CohereSafetySetting string
+
+const (
+	RequestIdKey = "X-Oneapi-Request-Id"
+)
+
+const (
+	RoleGuestUser  = 0
+	RoleCommonUser = 1
+	RoleAdminUser  = 10
+	RoleRootUser   = 100
+)
+
+func IsValidateRole(role int) bool {
+	return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
+}
+
+var (
+	FileUploadPermission    = RoleGuestUser
+	FileDownloadPermission  = RoleGuestUser
+	ImageUploadPermission   = RoleGuestUser
+	ImageDownloadPermission = RoleGuestUser
+)
+
+// All duration's unit is seconds
+// Shouldn't larger then RateLimitKeyExpirationDuration
+var (
+	GlobalApiRateLimitEnable   bool
+	GlobalApiRateLimitNum      int
+	GlobalApiRateLimitDuration int64
+
+	GlobalWebRateLimitEnable   bool
+	GlobalWebRateLimitNum      int
+	GlobalWebRateLimitDuration int64
+
+	UploadRateLimitNum            = 10
+	UploadRateLimitDuration int64 = 60
+
+	DownloadRateLimitNum            = 10
+	DownloadRateLimitDuration int64 = 60
+
+	CriticalRateLimitNum            = 20
+	CriticalRateLimitDuration int64 = 20 * 60
+)
+
+var RateLimitKeyExpirationDuration = 20 * time.Minute
+
+const (
+	UserStatusEnabled  = 1 // don't use 0, 0 is the default value!
+	UserStatusDisabled = 2 // also don't use 0
+)
+
+const (
+	TokenStatusEnabled   = 1 // don't use 0, 0 is the default value!
+	TokenStatusDisabled  = 2 // also don't use 0
+	TokenStatusExpired   = 3
+	TokenStatusExhausted = 4
+)
+
+const (
+	RedemptionCodeStatusEnabled  = 1 // don't use 0, 0 is the default value!
+	RedemptionCodeStatusDisabled = 2 // also don't use 0
+	RedemptionCodeStatusUsed     = 3 // also don't use 0
+)
+
+const (
+	ChannelStatusUnknown          = 0
+	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value!
+	ChannelStatusManuallyDisabled = 2 // also don't use 0
+	ChannelStatusAutoDisabled     = 3
+)
+
+const (
+	TopUpStatusPending = "pending"
+	TopUpStatusSuccess = "success"
+	TopUpStatusExpired = "expired"
+)

+ 31 - 0
common/crypto.go

@@ -0,0 +1,31 @@
+package common
+
+import (
+	"crypto/hmac"
+	"crypto/sha256"
+	"encoding/hex"
+	"golang.org/x/crypto/bcrypt"
+)
+
+func GenerateHMACWithKey(key []byte, data string) string {
+	h := hmac.New(sha256.New, key)
+	h.Write([]byte(data))
+	return hex.EncodeToString(h.Sum(nil))
+}
+
+func GenerateHMAC(data string) string {
+	h := hmac.New(sha256.New, []byte(CryptoSecret))
+	h.Write([]byte(data))
+	return hex.EncodeToString(h.Sum(nil))
+}
+
+func Password2Hash(password string) (string, error) {
+	passwordBytes := []byte(password)
+	hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
+	return string(hashedPassword), err
+}
+
+func ValidatePasswordAndHash(password string, hash string) bool {
+	err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
+	return err == nil
+}

+ 82 - 0
common/custom-event.go

@@ -0,0 +1,82 @@
+// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package common
+
+import (
+	"fmt"
+	"io"
+	"net/http"
+	"strings"
+)
+
+type stringWriter interface {
+	io.Writer
+	writeString(string) (int, error)
+}
+
+type stringWrapper struct {
+	io.Writer
+}
+
+func (w stringWrapper) writeString(str string) (int, error) {
+	return w.Writer.Write([]byte(str))
+}
+
+func checkWriter(writer io.Writer) stringWriter {
+	if w, ok := writer.(stringWriter); ok {
+		return w
+	} else {
+		return stringWrapper{writer}
+	}
+}
+
+// Server-Sent Events
+// W3C Working Draft 29 October 2009
+// http://www.w3.org/TR/2009/WD-eventsource-20091029/
+
+var contentType = []string{"text/event-stream"}
+var noCache = []string{"no-cache"}
+
+var fieldReplacer = strings.NewReplacer(
+	"\n", "\\n",
+	"\r", "\\r")
+
+var dataReplacer = strings.NewReplacer(
+	"\n", "\n",
+	"\r", "\\r")
+
+type CustomEvent struct {
+	Event string
+	Id    string
+	Retry uint
+	Data  interface{}
+}
+
+func encode(writer io.Writer, event CustomEvent) error {
+	w := checkWriter(writer)
+	return writeData(w, event.Data)
+}
+
+func writeData(w stringWriter, data interface{}) error {
+	dataReplacer.WriteString(w, fmt.Sprint(data))
+	if strings.HasPrefix(data.(string), "data") {
+		w.writeString("\n\n")
+	}
+	return nil
+}
+
+func (r CustomEvent) Render(w http.ResponseWriter) error {
+	r.WriteContentType(w)
+	return encode(w, r)
+}
+
+func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
+	header := w.Header()
+	header["Content-Type"] = contentType
+
+	if _, exist := header["Cache-Control"]; !exist {
+		header["Cache-Control"] = noCache
+	}
+}

+ 15 - 0
common/database.go

@@ -0,0 +1,15 @@
+package common
+
+const (
+	DatabaseTypeMySQL      = "mysql"
+	DatabaseTypeSQLite     = "sqlite"
+	DatabaseTypePostgreSQL = "postgres"
+)
+
+var UsingSQLite = false
+var UsingPostgreSQL = false
+var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
+var UsingMySQL = false
+var UsingClickHouse = false
+
+var SQLitePath = "one-api.db?_busy_timeout=5000"

+ 40 - 0
common/email-outlook-auth.go

@@ -0,0 +1,40 @@
+package common
+
+import (
+	"errors"
+	"net/smtp"
+	"strings"
+)
+
+type outlookAuth struct {
+	username, password string
+}
+
+func LoginAuth(username, password string) smtp.Auth {
+	return &outlookAuth{username, password}
+}
+
+func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
+	return "LOGIN", []byte{}, nil
+}
+
+func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
+	if more {
+		switch string(fromServer) {
+		case "Username:":
+			return []byte(a.username), nil
+		case "Password:":
+			return []byte(a.password), nil
+		default:
+			return nil, errors.New("unknown fromServer")
+		}
+	}
+	return nil, nil
+}
+
+func isOutlookServer(server string) bool {
+	// 兼容多地区的outlook邮箱和ofb邮箱
+	// 其实应该加一个Option来区分是否用LOGIN的方式登录
+	// 先临时兼容一下
+	return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
+}

+ 90 - 0
common/email.go

@@ -0,0 +1,90 @@
+package common
+
+import (
+	"crypto/tls"
+	"encoding/base64"
+	"fmt"
+	"net/smtp"
+	"slices"
+	"strings"
+	"time"
+)
+
+func generateMessageID() (string, error) {
+	split := strings.Split(SMTPFrom, "@")
+	if len(split) < 2 {
+		return "", fmt.Errorf("invalid SMTP account")
+	}
+	domain := strings.Split(SMTPFrom, "@")[1]
+	return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain), nil
+}
+
+func SendEmail(subject string, receiver string, content string) error {
+	if SMTPFrom == "" { // for compatibility
+		SMTPFrom = SMTPAccount
+	}
+	id, err2 := generateMessageID()
+	if err2 != nil {
+		return err2
+	}
+	if SMTPServer == "" && SMTPAccount == "" {
+		return fmt.Errorf("SMTP 服务器未配置")
+	}
+	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
+	mail := []byte(fmt.Sprintf("To: %s\r\n"+
+		"From: %s<%s>\r\n"+
+		"Subject: %s\r\n"+
+		"Date: %s\r\n"+
+		"Message-ID: %s\r\n"+ // 添加 Message-ID 头
+		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
+		receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), id, content))
+	auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
+	addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
+	to := strings.Split(receiver, ";")
+	var err error
+	if SMTPPort == 465 || SMTPSSLEnabled {
+		tlsConfig := &tls.Config{
+			InsecureSkipVerify: true,
+			ServerName:         SMTPServer,
+		}
+		conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
+		if err != nil {
+			return err
+		}
+		client, err := smtp.NewClient(conn, SMTPServer)
+		if err != nil {
+			return err
+		}
+		defer client.Close()
+		if err = client.Auth(auth); err != nil {
+			return err
+		}
+		if err = client.Mail(SMTPFrom); err != nil {
+			return err
+		}
+		receiverEmails := strings.Split(receiver, ";")
+		for _, receiver := range receiverEmails {
+			if err = client.Rcpt(receiver); err != nil {
+				return err
+			}
+		}
+		w, err := client.Data()
+		if err != nil {
+			return err
+		}
+		_, err = w.Write(mail)
+		if err != nil {
+			return err
+		}
+		err = w.Close()
+		if err != nil {
+			return err
+		}
+	} else if isOutlookServer(SMTPAccount) || slices.Contains(EmailLoginAuthServerList, SMTPServer) {
+		auth = LoginAuth(SMTPAccount, SMTPToken)
+		err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
+	} else {
+		err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
+	}
+	return err
+}

+ 32 - 0
common/embed-file-system.go

@@ -0,0 +1,32 @@
+package common
+
+import (
+	"embed"
+	"github.com/gin-contrib/static"
+	"io/fs"
+	"net/http"
+)
+
+// Credit: https://github.com/gin-contrib/static/issues/19
+
+type embedFileSystem struct {
+	http.FileSystem
+}
+
+func (e embedFileSystem) Exists(prefix string, path string) bool {
+	_, err := e.Open(path)
+	if err != nil {
+		return false
+	}
+	return true
+}
+
+func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
+	efs, err := fs.Sub(fsEmbed, targetPath)
+	if err != nil {
+		panic(err)
+	}
+	return embedFileSystem{
+		FileSystem: http.FS(efs),
+	}
+}

+ 41 - 0
common/endpoint_type.go

@@ -0,0 +1,41 @@
+package common
+
+import "one-api/constant"
+
+// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
+func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
+	var endpointTypes []constant.EndpointType
+	switch channelType {
+	case constant.ChannelTypeJina:
+		endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
+	//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
+	//	endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
+	//case constant.ChannelTypeSunoAPI:
+	//	endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
+	//case constant.ChannelTypeKling:
+	//	endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
+	//case constant.ChannelTypeJimeng:
+	//	endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
+	case constant.ChannelTypeAws:
+		fallthrough
+	case constant.ChannelTypeAnthropic:
+		endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
+	case constant.ChannelTypeVertexAi:
+		fallthrough
+	case constant.ChannelTypeGemini:
+		endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
+	case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
+		endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
+	default:
+		if IsOpenAIResponseOnlyModel(modelName) {
+			endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
+		} else {
+			endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
+		}
+	}
+	if IsImageGenerationModel(modelName) {
+		// add to first
+		endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
+	}
+	return endpointTypes
+}

+ 38 - 0
common/env.go

@@ -0,0 +1,38 @@
+package common
+
+import (
+	"fmt"
+	"os"
+	"strconv"
+)
+
+func GetEnvOrDefault(env string, defaultValue int) int {
+	if env == "" || os.Getenv(env) == "" {
+		return defaultValue
+	}
+	num, err := strconv.Atoi(os.Getenv(env))
+	if err != nil {
+		SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
+		return defaultValue
+	}
+	return num
+}
+
+func GetEnvOrDefaultString(env string, defaultValue string) string {
+	if env == "" || os.Getenv(env) == "" {
+		return defaultValue
+	}
+	return os.Getenv(env)
+}
+
+func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
+	if env == "" || os.Getenv(env) == "" {
+		return defaultValue
+	}
+	b, err := strconv.ParseBool(os.Getenv(env))
+	if err != nil {
+		SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
+		return defaultValue
+	}
+	return b
+}

+ 111 - 0
common/gin.go

@@ -0,0 +1,111 @@
+package common
+
+import (
+	"bytes"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/constant"
+	"strings"
+	"time"
+)
+
+const KeyRequestBody = "key_request_body"
+
+func GetRequestBody(c *gin.Context) ([]byte, error) {
+	requestBody, _ := c.Get(KeyRequestBody)
+	if requestBody != nil {
+		return requestBody.([]byte), nil
+	}
+	requestBody, err := io.ReadAll(c.Request.Body)
+	if err != nil {
+		return nil, err
+	}
+	_ = c.Request.Body.Close()
+	c.Set(KeyRequestBody, requestBody)
+	return requestBody.([]byte), nil
+}
+
+func UnmarshalBodyReusable(c *gin.Context, v any) error {
+	requestBody, err := GetRequestBody(c)
+	if err != nil {
+		return err
+	}
+	contentType := c.Request.Header.Get("Content-Type")
+	if strings.HasPrefix(contentType, "application/json") {
+		err = Unmarshal(requestBody, &v)
+	} else {
+		// skip for now
+		// TODO: someday non json request have variant model, we will need to implementation this
+	}
+	if err != nil {
+		return err
+	}
+	// Reset request body
+	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	return nil
+}
+
+func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
+	c.Set(string(key), value)
+}
+
+func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
+	return c.Get(string(key))
+}
+
+func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
+	return c.GetString(string(key))
+}
+
+func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
+	return c.GetInt(string(key))
+}
+
+func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
+	return c.GetBool(string(key))
+}
+
+func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
+	return c.GetStringSlice(string(key))
+}
+
+func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
+	return c.GetStringMap(string(key))
+}
+
+func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
+	return c.GetTime(string(key))
+}
+
+func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
+	if value, ok := c.Get(string(key)); ok {
+		if v, ok := value.(T); ok {
+			return v, true
+		}
+	}
+	var t T
+	return t, false
+}
+
+func ApiError(c *gin.Context, err error) {
+	c.JSON(http.StatusOK, gin.H{
+		"success": false,
+		"message": err.Error(),
+	})
+}
+
+func ApiErrorMsg(c *gin.Context, msg string) {
+	c.JSON(http.StatusOK, gin.H{
+		"success": false,
+		"message": msg,
+	})
+}
+
+func ApiSuccess(c *gin.Context, data any) {
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    data,
+	})
+}

+ 53 - 0
common/go-channel.go

@@ -0,0 +1,53 @@
+package common
+
+import (
+	"time"
+)
+
+func SafeSendBool(ch chan bool, value bool) (closed bool) {
+	defer func() {
+		// Recover from panic if one occured. A panic would mean the channel was closed.
+		if recover() != nil {
+			closed = true
+		}
+	}()
+
+	// This will panic if the channel is closed.
+	ch <- value
+
+	// If the code reaches here, then the channel was not closed.
+	return false
+}
+
+func SafeSendString(ch chan string, value string) (closed bool) {
+	defer func() {
+		// Recover from panic if one occured. A panic would mean the channel was closed.
+		if recover() != nil {
+			closed = true
+		}
+	}()
+
+	// This will panic if the channel is closed.
+	ch <- value
+
+	// If the code reaches here, then the channel was not closed.
+	return false
+}
+
+// SafeSendStringTimeout send, return true, else return false
+func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
+	defer func() {
+		// Recover from panic if one occured. A panic would mean the channel was closed.
+		if recover() != nil {
+			closed = false
+		}
+	}()
+
+	// This will panic if the channel is closed.
+	select {
+	case ch <- value:
+		return true
+	case <-time.After(time.Duration(timeout) * time.Second):
+		return false
+	}
+}

+ 24 - 0
common/gopool.go

@@ -0,0 +1,24 @@
+package common
+
+import (
+	"context"
+	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
+	"math"
+)
+
+var relayGoPool gopool.Pool
+
+func init() {
+	relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
+	relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
+		if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
+			SafeSendBool(stopChan, true)
+		}
+		SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i))
+	})
+}
+
+func RelayCtxGo(ctx context.Context, f func()) {
+	relayGoPool.CtxGo(ctx, f)
+}

+ 34 - 0
common/hash.go

@@ -0,0 +1,34 @@
+package common
+
+import (
+	"crypto/hmac"
+	"crypto/sha1"
+	"crypto/sha256"
+	"encoding/hex"
+)
+
+func Sha256Raw(data []byte) []byte {
+	h := sha256.New()
+	h.Write(data)
+	return h.Sum(nil)
+}
+
+func Sha1Raw(data []byte) []byte {
+	h := sha1.New()
+	h.Write(data)
+	return h.Sum(nil)
+}
+
+func Sha1(data []byte) string {
+	return hex.EncodeToString(Sha1Raw(data))
+}
+
+func HmacSha256Raw(message, key []byte) []byte {
+	h := hmac.New(sha256.New, key)
+	h.Write(message)
+	return h.Sum(nil)
+}
+
+func HmacSha256(message, key string) string {
+	return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
+}

+ 57 - 0
common/http.go

@@ -0,0 +1,57 @@
+package common
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"net/http"
+
+	"github.com/gin-gonic/gin"
+)
+
+func CloseResponseBodyGracefully(httpResponse *http.Response) {
+	if httpResponse == nil || httpResponse.Body == nil {
+		return
+	}
+	err := httpResponse.Body.Close()
+	if err != nil {
+		SysError("failed to close response body: " + err.Error())
+	}
+}
+
+func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
+	if c.Writer == nil {
+		return
+	}
+
+	body := io.NopCloser(bytes.NewBuffer(data))
+
+	// We shouldn't set the header before we parse the response body, because the parse part may fail.
+	// And then we will have to send an error response, but in this case, the header has already been set.
+	// So the httpClient will be confused by the response.
+	// For example, Postman will report error, and we cannot check the response at all.
+	if src != nil {
+		for k, v := range src.Header {
+			// avoid setting Content-Length
+			if k == "Content-Length" {
+				continue
+			}
+			c.Writer.Header().Set(k, v[0])
+		}
+	}
+
+	// set Content-Length header manually BEFORE calling WriteHeader
+	c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
+
+	// Write header with status code (this sends the headers)
+	if src != nil {
+		c.Writer.WriteHeader(src.StatusCode)
+	} else {
+		c.Writer.WriteHeader(http.StatusOK)
+	}
+
+	_, err := io.Copy(c.Writer, body)
+	if err != nil {
+		LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
+	}
+}

+ 120 - 0
common/init.go

@@ -0,0 +1,120 @@
+package common
+
+import (
+	"flag"
+	"fmt"
+	"log"
+	"one-api/constant"
+	"os"
+	"path/filepath"
+	"strconv"
+	"time"
+)
+
+var (
+	Port         = flag.Int("port", 3000, "the listening port")
+	PrintVersion = flag.Bool("version", false, "print version and exit")
+	PrintHelp    = flag.Bool("help", false, "print help and exit")
+	LogDir       = flag.String("log-dir", "./logs", "specify the log directory")
+)
+
+func printHelp() {
+	fmt.Println("New API " + Version + " - All in one API service for OpenAI API.")
+	fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
+	fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
+	fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
+}
+
+func InitEnv() {
+	flag.Parse()
+
+	if *PrintVersion {
+		fmt.Println(Version)
+		os.Exit(0)
+	}
+
+	if *PrintHelp {
+		printHelp()
+		os.Exit(0)
+	}
+
+	if os.Getenv("SESSION_SECRET") != "" {
+		ss := os.Getenv("SESSION_SECRET")
+		if ss == "random_string" {
+			log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
+			log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
+			log.Fatal("Please set SESSION_SECRET to a random string.")
+		} else {
+			SessionSecret = ss
+		}
+	}
+	if os.Getenv("CRYPTO_SECRET") != "" {
+		CryptoSecret = os.Getenv("CRYPTO_SECRET")
+	} else {
+		CryptoSecret = SessionSecret
+	}
+	if os.Getenv("SQLITE_PATH") != "" {
+		SQLitePath = os.Getenv("SQLITE_PATH")
+	}
+	if *LogDir != "" {
+		var err error
+		*LogDir, err = filepath.Abs(*LogDir)
+		if err != nil {
+			log.Fatal(err)
+		}
+		if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
+			err = os.Mkdir(*LogDir, 0777)
+			if err != nil {
+				log.Fatal(err)
+			}
+		}
+	}
+
+	// Initialize variables from constants.go that were using environment variables
+	DebugEnabled = os.Getenv("DEBUG") == "true"
+	MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
+	IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
+
+	// Parse requestInterval and set RequestInterval
+	requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
+	RequestInterval = time.Duration(requestInterval) * time.Second
+
+	// Initialize variables with GetEnvOrDefault
+	SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
+	BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
+	RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
+
+	// Initialize string variables with GetEnvOrDefaultString
+	GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
+	CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
+
+	// Initialize rate limit variables
+	GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
+	GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
+	GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
+
+	GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
+	GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
+	GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
+
+	initConstantEnv()
+}
+
+func initConstantEnv() {
+	constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
+	constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
+	constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
+	// ForceStreamOption 覆盖请求参数,强制返回usage信息
+	constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
+	constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
+	constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
+	constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
+	constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
+	constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
+	constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
+	constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
+	// GenerateDefaultToken 是否生成初始令牌,默认关闭。
+	constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
+	// 是否启用错误日志
+	constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
+}

+ 22 - 0
common/json.go

@@ -0,0 +1,22 @@
+package common
+
+import (
+	"bytes"
+	"encoding/json"
+)
+
+func Unmarshal(data []byte, v any) error {
+	return json.Unmarshal(data, v)
+}
+
+func UnmarshalJsonStr(data string, v any) error {
+	return json.Unmarshal(StringToByteSlice(data), v)
+}
+
+func DecodeJson(reader *bytes.Reader, v any) error {
+	return json.NewDecoder(reader).Decode(v)
+}
+
+func Marshal(v any) ([]byte, error) {
+	return json.Marshal(v)
+}

+ 89 - 0
common/limiter/limiter.go

@@ -0,0 +1,89 @@
+package limiter
+
+import (
+	"context"
+	_ "embed"
+	"fmt"
+	"github.com/go-redis/redis/v8"
+	"one-api/common"
+	"sync"
+)
+
+//go:embed lua/rate_limit.lua
+var rateLimitScript string
+
+type RedisLimiter struct {
+	client         *redis.Client
+	limitScriptSHA string
+}
+
+var (
+	instance *RedisLimiter
+	once     sync.Once
+)
+
+func New(ctx context.Context, r *redis.Client) *RedisLimiter {
+	once.Do(func() {
+		// 预加载脚本
+		limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
+		if err != nil {
+			common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
+		}
+		instance = &RedisLimiter{
+			client:         r,
+			limitScriptSHA: limitSHA,
+		}
+	})
+
+	return instance
+}
+
+func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
+	// 默认配置
+	config := &Config{
+		Capacity:  10,
+		Rate:      1,
+		Requested: 1,
+	}
+
+	// 应用选项模式
+	for _, opt := range opts {
+		opt(config)
+	}
+
+	// 执行限流
+	result, err := rl.client.EvalSha(
+		ctx,
+		rl.limitScriptSHA,
+		[]string{key},
+		config.Requested,
+		config.Rate,
+		config.Capacity,
+	).Int()
+
+	if err != nil {
+		return false, fmt.Errorf("rate limit failed: %w", err)
+	}
+	return result == 1, nil
+}
+
+// Config 配置选项模式
+type Config struct {
+	Capacity  int64
+	Rate      int64
+	Requested int64
+}
+
+type Option func(*Config)
+
+func WithCapacity(c int64) Option {
+	return func(cfg *Config) { cfg.Capacity = c }
+}
+
+func WithRate(r int64) Option {
+	return func(cfg *Config) { cfg.Rate = r }
+}
+
+func WithRequested(n int64) Option {
+	return func(cfg *Config) { cfg.Requested = n }
+}

+ 44 - 0
common/limiter/lua/rate_limit.lua

@@ -0,0 +1,44 @@
+-- 令牌桶限流器
+-- KEYS[1]: 限流器唯一标识
+-- ARGV[1]: 请求令牌数 (通常为1)
+-- ARGV[2]: 令牌生成速率 (每秒)
+-- ARGV[3]: 桶容量
+
+local key = KEYS[1]
+local requested = tonumber(ARGV[1])
+local rate = tonumber(ARGV[2])
+local capacity = tonumber(ARGV[3])
+
+-- 获取当前时间(Redis服务器时间)
+local now = redis.call('TIME')
+local nowInSeconds = tonumber(now[1])
+
+-- 获取桶状态
+local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
+local tokens = tonumber(bucket[1])
+local last_time = tonumber(bucket[2])
+
+-- 初始化桶(首次请求或过期)
+if not tokens or not last_time then
+    tokens = capacity
+    last_time = nowInSeconds
+else
+    -- 计算新增令牌
+    local elapsed = nowInSeconds - last_time
+    local add_tokens = elapsed * rate
+    tokens = math.min(capacity, tokens + add_tokens)
+    last_time = nowInSeconds
+end
+
+-- 判断是否允许请求
+local allowed = false
+if tokens >= requested then
+    tokens = tokens - requested
+    allowed = true
+end
+
+---- 更新桶状态并设置过期时间
+redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
+--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
+
+return allowed and 1 or 0

+ 123 - 0
common/logger.go

@@ -0,0 +1,123 @@
+package common
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/gin-gonic/gin"
+	"io"
+	"log"
+	"os"
+	"path/filepath"
+	"sync"
+	"time"
+)
+
+const (
+	loggerINFO  = "INFO"
+	loggerWarn  = "WARN"
+	loggerError = "ERR"
+)
+
+const maxLogCount = 1000000
+
+var logCount int
+var setupLogLock sync.Mutex
+var setupLogWorking bool
+
+func SetupLogger() {
+	if *LogDir != "" {
+		ok := setupLogLock.TryLock()
+		if !ok {
+			log.Println("setup log is already working")
+			return
+		}
+		defer func() {
+			setupLogLock.Unlock()
+			setupLogWorking = false
+		}()
+		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
+		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
+		if err != nil {
+			log.Fatal("failed to open log file")
+		}
+		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
+		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
+	}
+}
+
+func SysLog(s string) {
+	t := time.Now()
+	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
+}
+
+func SysError(s string) {
+	t := time.Now()
+	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
+}
+
+func LogInfo(ctx context.Context, msg string) {
+	logHelper(ctx, loggerINFO, msg)
+}
+
+func LogWarn(ctx context.Context, msg string) {
+	logHelper(ctx, loggerWarn, msg)
+}
+
+func LogError(ctx context.Context, msg string) {
+	logHelper(ctx, loggerError, msg)
+}
+
+func logHelper(ctx context.Context, level string, msg string) {
+	writer := gin.DefaultErrorWriter
+	if level == loggerINFO {
+		writer = gin.DefaultWriter
+	}
+	id := ctx.Value(RequestIdKey)
+	if id == nil {
+		id = "SYSTEM"
+	}
+	now := time.Now()
+	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
+	logCount++ // we don't need accurate count, so no lock here
+	if logCount > maxLogCount && !setupLogWorking {
+		logCount = 0
+		setupLogWorking = true
+		gopool.Go(func() {
+			SetupLogger()
+		})
+	}
+}
+
+func FatalLog(v ...any) {
+	t := time.Now()
+	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
+	os.Exit(1)
+}
+
+func LogQuota(quota int) string {
+	if DisplayInCurrencyEnabled {
+		return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
+	} else {
+		return fmt.Sprintf("%d 点额度", quota)
+	}
+}
+
+func FormatQuota(quota int) string {
+	if DisplayInCurrencyEnabled {
+		return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
+	} else {
+		return fmt.Sprintf("%d", quota)
+	}
+}
+
+// LogJson 仅供测试使用 only for test
+func LogJson(ctx context.Context, msg string, obj any) {
+	jsonStr, err := json.Marshal(obj)
+	if err != nil {
+		LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
+		return
+	}
+	LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
+}

+ 42 - 0
common/model.go

@@ -0,0 +1,42 @@
+package common
+
+import "strings"
+
+var (
+	// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
+	OpenAIResponseOnlyModels = []string{
+		"o3-pro",
+		"o3-deep-research",
+		"o4-mini-deep-research",
+	}
+	ImageGenerationModels = []string{
+		"dall-e-3",
+		"dall-e-2",
+		"gpt-image-1",
+		"prefix:imagen-",
+		"flux-",
+		"flux.1-",
+	}
+)
+
+func IsOpenAIResponseOnlyModel(modelName string) bool {
+	for _, m := range OpenAIResponseOnlyModels {
+		if strings.Contains(modelName, m) {
+			return true
+		}
+	}
+	return false
+}
+
+func IsImageGenerationModel(modelName string) bool {
+	modelName = strings.ToLower(modelName)
+	for _, m := range ImageGenerationModels {
+		if strings.Contains(modelName, m) {
+			return true
+		}
+		if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
+			return true
+		}
+	}
+	return false
+}

+ 82 - 0
common/page_info.go

@@ -0,0 +1,82 @@
+package common
+
+import (
+	"strconv"
+
+	"github.com/gin-gonic/gin"
+)
+
+type PageInfo struct {
+	Page     int `json:"page"`      // page num 页码
+	PageSize int `json:"page_size"` // page size 页大小
+
+	Total int `json:"total"` // 总条数,后设置
+	Items any `json:"items"` // 数据,后设置
+}
+
+func (p *PageInfo) GetStartIdx() int {
+	return (p.Page - 1) * p.PageSize
+}
+
+func (p *PageInfo) GetEndIdx() int {
+	return p.Page * p.PageSize
+}
+
+func (p *PageInfo) GetPageSize() int {
+	return p.PageSize
+}
+
+func (p *PageInfo) GetPage() int {
+	return p.Page
+}
+
+func (p *PageInfo) SetTotal(total int) {
+	p.Total = total
+}
+
+func (p *PageInfo) SetItems(items any) {
+	p.Items = items
+}
+
+func GetPageQuery(c *gin.Context) *PageInfo {
+	pageInfo := &PageInfo{}
+	// 手动获取并处理每个参数
+	if page, err := strconv.Atoi(c.Query("page")); err == nil {
+		pageInfo.Page = page
+	}
+	if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
+		pageInfo.PageSize = pageSize
+	}
+	if pageInfo.Page < 1 {
+		// 兼容
+		page, _ := strconv.Atoi(c.Query("p"))
+		if page != 0 {
+			pageInfo.Page = page
+		} else {
+			pageInfo.Page = 1
+		}
+	}
+
+	if pageInfo.PageSize == 0 {
+		// 兼容
+		pageSize, _ := strconv.Atoi(c.Query("ps"))
+		if pageSize != 0 {
+			pageInfo.PageSize = pageSize
+		}
+		if pageInfo.PageSize == 0 {
+			pageSize, _ = strconv.Atoi(c.Query("size")) // token page
+			if pageSize != 0 {
+				pageInfo.PageSize = pageSize
+			}
+		}
+		if pageInfo.PageSize == 0 {
+			pageInfo.PageSize = ItemsPerPage
+		}
+	}
+
+	if pageInfo.PageSize > 100 {
+		pageInfo.PageSize = 100
+	}
+
+	return pageInfo
+}

+ 44 - 0
common/pprof.go

@@ -0,0 +1,44 @@
+package common
+
+import (
+	"fmt"
+	"github.com/shirou/gopsutil/cpu"
+	"os"
+	"runtime/pprof"
+	"time"
+)
+
+// Monitor 定时监控cpu使用率,超过阈值输出pprof文件
+func Monitor() {
+	for {
+		percent, err := cpu.Percent(time.Second, false)
+		if err != nil {
+			panic(err)
+		}
+		if percent[0] > 80 {
+			fmt.Println("cpu usage too high")
+			// write pprof file
+			if _, err := os.Stat("./pprof"); os.IsNotExist(err) {
+				err := os.Mkdir("./pprof", os.ModePerm)
+				if err != nil {
+					SysLog("创建pprof文件夹失败 " + err.Error())
+					continue
+				}
+			}
+			f, err := os.Create("./pprof/" + fmt.Sprintf("cpu-%s.pprof", time.Now().Format("20060102150405")))
+			if err != nil {
+				SysLog("创建pprof文件失败 " + err.Error())
+				continue
+			}
+			err = pprof.StartCPUProfile(f)
+			if err != nil {
+				SysLog("启动pprof失败 " + err.Error())
+				continue
+			}
+			time.Sleep(10 * time.Second) // profile for 30 seconds
+			pprof.StopCPUProfile()
+			f.Close()
+		}
+		time.Sleep(30 * time.Second)
+	}
+}

+ 70 - 0
common/rate-limit.go

@@ -0,0 +1,70 @@
+package common
+
+import (
+	"sync"
+	"time"
+)
+
+type InMemoryRateLimiter struct {
+	store              map[string]*[]int64
+	mutex              sync.Mutex
+	expirationDuration time.Duration
+}
+
+func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) {
+	if l.store == nil {
+		l.mutex.Lock()
+		if l.store == nil {
+			l.store = make(map[string]*[]int64)
+			l.expirationDuration = expirationDuration
+			if expirationDuration > 0 {
+				go l.clearExpiredItems()
+			}
+		}
+		l.mutex.Unlock()
+	}
+}
+
+func (l *InMemoryRateLimiter) clearExpiredItems() {
+	for {
+		time.Sleep(l.expirationDuration)
+		l.mutex.Lock()
+		now := time.Now().Unix()
+		for key := range l.store {
+			queue := l.store[key]
+			size := len(*queue)
+			if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) {
+				delete(l.store, key)
+			}
+		}
+		l.mutex.Unlock()
+	}
+}
+
+// Request parameter duration's unit is seconds
+func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool {
+	l.mutex.Lock()
+	defer l.mutex.Unlock()
+	// [old <-- new]
+	queue, ok := l.store[key]
+	now := time.Now().Unix()
+	if ok {
+		if len(*queue) < maxRequestNum {
+			*queue = append(*queue, now)
+			return true
+		} else {
+			if now-(*queue)[0] >= duration {
+				*queue = (*queue)[1:]
+				*queue = append(*queue, now)
+				return true
+			} else {
+				return false
+			}
+		}
+	} else {
+		s := make([]int64, 0, maxRequestNum)
+		l.store[key] = &s
+		*(l.store[key]) = append(*(l.store[key]), now)
+	}
+	return true
+}

+ 327 - 0
common/redis.go

@@ -0,0 +1,327 @@
+package common
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"os"
+	"reflect"
+	"strconv"
+	"time"
+
+	"github.com/go-redis/redis/v8"
+	"gorm.io/gorm"
+)
+
+var RDB *redis.Client
+var RedisEnabled = true
+
+func RedisKeyCacheSeconds() int {
+	return SyncFrequency
+}
+
+// InitRedisClient This function is called after init()
+func InitRedisClient() (err error) {
+	if os.Getenv("REDIS_CONN_STRING") == "" {
+		RedisEnabled = false
+		SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
+		return nil
+	}
+	if os.Getenv("SYNC_FREQUENCY") == "" {
+		SysLog("SYNC_FREQUENCY not set, use default value 60")
+		SyncFrequency = 60
+	}
+	SysLog("Redis is enabled")
+	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
+	if err != nil {
+		FatalLog("failed to parse Redis connection string: " + err.Error())
+	}
+	opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
+	RDB = redis.NewClient(opt)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+	defer cancel()
+
+	_, err = RDB.Ping(ctx).Result()
+	if err != nil {
+		FatalLog("Redis ping test failed: " + err.Error())
+	}
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
+		SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
+	}
+	return err
+}
+
+func ParseRedisOption() *redis.Options {
+	opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
+	if err != nil {
+		FatalLog("failed to parse Redis connection string: " + err.Error())
+	}
+	return opt
+}
+
+func RedisSet(key string, value string, expiration time.Duration) error {
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
+	}
+	ctx := context.Background()
+	return RDB.Set(ctx, key, value, expiration).Err()
+}
+
+func RedisGet(key string) (string, error) {
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis GET: key=%s", key))
+	}
+	ctx := context.Background()
+	val, err := RDB.Get(ctx, key).Result()
+	return val, err
+}
+
+//func RedisExpire(key string, expiration time.Duration) error {
+//	ctx := context.Background()
+//	return RDB.Expire(ctx, key, expiration).Err()
+//}
+//
+//func RedisGetEx(key string, expiration time.Duration) (string, error) {
+//	ctx := context.Background()
+//	return RDB.GetSet(ctx, key, expiration).Result()
+//}
+
+func RedisDel(key string) error {
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
+	}
+	ctx := context.Background()
+	return RDB.Del(ctx, key).Err()
+}
+
+func RedisDelKey(key string) error {
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
+	}
+	ctx := context.Background()
+	return RDB.Del(ctx, key).Err()
+}
+
+func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
+	}
+	ctx := context.Background()
+
+	data := make(map[string]interface{})
+
+	// 使用反射遍历结构体字段
+	v := reflect.ValueOf(obj).Elem()
+	t := v.Type()
+	for i := 0; i < v.NumField(); i++ {
+		field := t.Field(i)
+		value := v.Field(i)
+
+		// Skip DeletedAt field
+		if field.Type.String() == "gorm.DeletedAt" {
+			continue
+		}
+
+		// 处理指针类型
+		if value.Kind() == reflect.Ptr {
+			if value.IsNil() {
+				data[field.Name] = ""
+				continue
+			}
+			value = value.Elem()
+		}
+
+		// 处理布尔类型
+		if value.Kind() == reflect.Bool {
+			data[field.Name] = strconv.FormatBool(value.Bool())
+			continue
+		}
+
+		// 其他类型直接转换为字符串
+		data[field.Name] = fmt.Sprintf("%v", value.Interface())
+	}
+
+	txn := RDB.TxPipeline()
+	txn.HSet(ctx, key, data)
+
+	// 只有在 expiration 大于 0 时才设置过期时间
+	if expiration > 0 {
+		txn.Expire(ctx, key, expiration)
+	}
+
+	_, err := txn.Exec(ctx)
+	if err != nil {
+		return fmt.Errorf("failed to execute transaction: %w", err)
+	}
+	return nil
+}
+
+func RedisHGetObj(key string, obj interface{}) error {
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
+	}
+	ctx := context.Background()
+
+	result, err := RDB.HGetAll(ctx, key).Result()
+	if err != nil {
+		return fmt.Errorf("failed to load hash from Redis: %w", err)
+	}
+
+	if len(result) == 0 {
+		return fmt.Errorf("key %s not found in Redis", key)
+	}
+
+	// Handle both pointer and non-pointer values
+	val := reflect.ValueOf(obj)
+	if val.Kind() != reflect.Ptr {
+		return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
+	}
+
+	v := val.Elem()
+	if v.Kind() != reflect.Struct {
+		return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
+	}
+
+	t := v.Type()
+	for i := 0; i < v.NumField(); i++ {
+		field := t.Field(i)
+		fieldName := field.Name
+		if value, ok := result[fieldName]; ok {
+			fieldValue := v.Field(i)
+
+			// Handle pointer types
+			if fieldValue.Kind() == reflect.Ptr {
+				if value == "" {
+					continue
+				}
+				if fieldValue.IsNil() {
+					fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
+				}
+				fieldValue = fieldValue.Elem()
+			}
+
+			// Enhanced type handling for Token struct
+			switch fieldValue.Kind() {
+			case reflect.String:
+				fieldValue.SetString(value)
+			case reflect.Int, reflect.Int64:
+				intValue, err := strconv.ParseInt(value, 10, 64)
+				if err != nil {
+					return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
+				}
+				fieldValue.SetInt(intValue)
+			case reflect.Bool:
+				boolValue, err := strconv.ParseBool(value)
+				if err != nil {
+					return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
+				}
+				fieldValue.SetBool(boolValue)
+			case reflect.Struct:
+				// Special handling for gorm.DeletedAt
+				if fieldValue.Type().String() == "gorm.DeletedAt" {
+					if value != "" {
+						timeValue, err := time.Parse(time.RFC3339, value)
+						if err != nil {
+							return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
+						}
+						fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
+					}
+				}
+			default:
+				return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
+			}
+		}
+	}
+
+	return nil
+}
+
+// RedisIncr Add this function to handle atomic increments
+func RedisIncr(key string, delta int64) error {
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
+	}
+	// 检查键的剩余生存时间
+	ttlCmd := RDB.TTL(context.Background(), key)
+	ttl, err := ttlCmd.Result()
+	if err != nil && !errors.Is(err, redis.Nil) {
+		return fmt.Errorf("failed to get TTL: %w", err)
+	}
+
+	// 只有在 key 存在且有 TTL 时才需要特殊处理
+	if ttl > 0 {
+		ctx := context.Background()
+		// 开始一个Redis事务
+		txn := RDB.TxPipeline()
+
+		// 减少余额
+		decrCmd := txn.IncrBy(ctx, key, delta)
+		if err := decrCmd.Err(); err != nil {
+			return err // 如果减少失败,则直接返回错误
+		}
+
+		// 重新设置过期时间,使用原来的过期时间
+		txn.Expire(ctx, key, ttl)
+
+		// 执行事务
+		_, err = txn.Exec(ctx)
+		return err
+	}
+	return nil
+}
+
+func RedisHIncrBy(key, field string, delta int64) error {
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
+	}
+	ttlCmd := RDB.TTL(context.Background(), key)
+	ttl, err := ttlCmd.Result()
+	if err != nil && !errors.Is(err, redis.Nil) {
+		return fmt.Errorf("failed to get TTL: %w", err)
+	}
+
+	if ttl > 0 {
+		ctx := context.Background()
+		txn := RDB.TxPipeline()
+
+		incrCmd := txn.HIncrBy(ctx, key, field, delta)
+		if err := incrCmd.Err(); err != nil {
+			return err
+		}
+
+		txn.Expire(ctx, key, ttl)
+
+		_, err = txn.Exec(ctx)
+		return err
+	}
+	return nil
+}
+
+func RedisHSetField(key, field string, value interface{}) error {
+	if DebugEnabled {
+		SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
+	}
+	ttlCmd := RDB.TTL(context.Background(), key)
+	ttl, err := ttlCmd.Result()
+	if err != nil && !errors.Is(err, redis.Nil) {
+		return fmt.Errorf("failed to get TTL: %w", err)
+	}
+
+	if ttl > 0 {
+		ctx := context.Background()
+		txn := RDB.TxPipeline()
+
+		hsetCmd := txn.HSet(ctx, key, field, value)
+		if err := hsetCmd.Err(); err != nil {
+			return err
+		}
+
+		txn.Expire(ctx, key, ttl)
+
+		_, err = txn.Exec(ctx)
+		return err
+	}
+	return nil
+}

+ 97 - 0
common/str.go

@@ -0,0 +1,97 @@
+package common
+
+import (
+	"encoding/base64"
+	"encoding/json"
+	"math/rand"
+	"strconv"
+	"unsafe"
+)
+
+func GetStringIfEmpty(str string, defaultValue string) string {
+	if str == "" {
+		return defaultValue
+	}
+	return str
+}
+
+func GetRandomString(length int) string {
+	//rand.Seed(time.Now().UnixNano())
+	key := make([]byte, length)
+	for i := 0; i < length; i++ {
+		key[i] = keyChars[rand.Intn(len(keyChars))]
+	}
+	return string(key)
+}
+
+func MapToJsonStr(m map[string]interface{}) string {
+	bytes, err := json.Marshal(m)
+	if err != nil {
+		return ""
+	}
+	return string(bytes)
+}
+
+func StrToMap(str string) (map[string]interface{}, error) {
+	m := make(map[string]interface{})
+	err := Unmarshal([]byte(str), &m)
+	if err != nil {
+		return nil, err
+	}
+	return m, nil
+}
+
+func StrToJsonArray(str string) ([]interface{}, error) {
+	var js []interface{}
+	err := json.Unmarshal([]byte(str), &js)
+	if err != nil {
+		return nil, err
+	}
+	return js, nil
+}
+
+func IsJsonArray(str string) bool {
+	var js []interface{}
+	return json.Unmarshal([]byte(str), &js) == nil
+}
+
+func IsJsonObject(str string) bool {
+	var js map[string]interface{}
+	return json.Unmarshal([]byte(str), &js) == nil
+}
+
+func String2Int(str string) int {
+	num, err := strconv.Atoi(str)
+	if err != nil {
+		return 0
+	}
+	return num
+}
+
+func StringsContains(strs []string, str string) bool {
+	for _, s := range strs {
+		if s == str {
+			return true
+		}
+	}
+	return false
+}
+
+// StringToByteSlice []byte only read, panic on append
+func StringToByteSlice(s string) []byte {
+	tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
+	tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
+	return *(*[]byte)(unsafe.Pointer(&tmp2))
+}
+
+func EncodeBase64(str string) string {
+	return base64.StdEncoding.EncodeToString([]byte(str))
+}
+
+func GetJsonString(data any) string {
+	if data == nil {
+		return ""
+	}
+	b, _ := json.Marshal(data)
+	return string(b)
+}

+ 33 - 0
common/topup-ratio.go

@@ -0,0 +1,33 @@
+package common
+
+import (
+	"encoding/json"
+)
+
+var TopupGroupRatio = map[string]float64{
+	"default": 1,
+	"vip":     1,
+	"svip":    1,
+}
+
+func TopupGroupRatio2JSONString() string {
+	jsonBytes, err := json.Marshal(TopupGroupRatio)
+	if err != nil {
+		SysError("error marshalling model ratio: " + err.Error())
+	}
+	return string(jsonBytes)
+}
+
+func UpdateTopupGroupRatioByJSONString(jsonStr string) error {
+	TopupGroupRatio = make(map[string]float64)
+	return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio)
+}
+
+func GetTopupGroupRatio(name string) float64 {
+	ratio, ok := TopupGroupRatio[name]
+	if !ok {
+		SysError("topup group ratio not found: " + name)
+		return 1
+	}
+	return ratio
+}

+ 304 - 0
common/utils.go

@@ -0,0 +1,304 @@
+package common
+
+import (
+	"bytes"
+	"context"
+	crand "crypto/rand"
+	"encoding/base64"
+	"encoding/json"
+	"fmt"
+	"html/template"
+	"io"
+	"log"
+	"math/big"
+	"math/rand"
+	"net"
+	"net/url"
+	"os"
+	"os/exec"
+	"runtime"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/google/uuid"
+	"github.com/pkg/errors"
+)
+
+func OpenBrowser(url string) {
+	var err error
+
+	switch runtime.GOOS {
+	case "linux":
+		err = exec.Command("xdg-open", url).Start()
+	case "windows":
+		err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
+	case "darwin":
+		err = exec.Command("open", url).Start()
+	}
+	if err != nil {
+		log.Println(err)
+	}
+}
+
+func GetIp() (ip string) {
+	ips, err := net.InterfaceAddrs()
+	if err != nil {
+		log.Println(err)
+		return ip
+	}
+
+	for _, a := range ips {
+		if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
+			if ipNet.IP.To4() != nil {
+				ip = ipNet.IP.String()
+				if strings.HasPrefix(ip, "10") {
+					return
+				}
+				if strings.HasPrefix(ip, "172") {
+					return
+				}
+				if strings.HasPrefix(ip, "192.168") {
+					return
+				}
+				ip = ""
+			}
+		}
+	}
+	return
+}
+
+var sizeKB = 1024
+var sizeMB = sizeKB * 1024
+var sizeGB = sizeMB * 1024
+
+func Bytes2Size(num int64) string {
+	numStr := ""
+	unit := "B"
+	if num/int64(sizeGB) > 1 {
+		numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
+		unit = "GB"
+	} else if num/int64(sizeMB) > 1 {
+		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
+		unit = "MB"
+	} else if num/int64(sizeKB) > 1 {
+		numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
+		unit = "KB"
+	} else {
+		numStr = fmt.Sprintf("%d", num)
+	}
+	return numStr + " " + unit
+}
+
+func Seconds2Time(num int) (time string) {
+	if num/31104000 > 0 {
+		time += strconv.Itoa(num/31104000) + " 年 "
+		num %= 31104000
+	}
+	if num/2592000 > 0 {
+		time += strconv.Itoa(num/2592000) + " 个月 "
+		num %= 2592000
+	}
+	if num/86400 > 0 {
+		time += strconv.Itoa(num/86400) + " 天 "
+		num %= 86400
+	}
+	if num/3600 > 0 {
+		time += strconv.Itoa(num/3600) + " 小时 "
+		num %= 3600
+	}
+	if num/60 > 0 {
+		time += strconv.Itoa(num/60) + " 分钟 "
+		num %= 60
+	}
+	time += strconv.Itoa(num) + " 秒"
+	return
+}
+
+func Interface2String(inter interface{}) string {
+	switch inter.(type) {
+	case string:
+		return inter.(string)
+	case int:
+		return fmt.Sprintf("%d", inter.(int))
+	case float64:
+		return fmt.Sprintf("%f", inter.(float64))
+	}
+	return "Not Implemented"
+}
+
+func UnescapeHTML(x string) interface{} {
+	return template.HTML(x)
+}
+
+func IntMax(a int, b int) int {
+	if a >= b {
+		return a
+	} else {
+		return b
+	}
+}
+
+func IsIP(s string) bool {
+	ip := net.ParseIP(s)
+	return ip != nil
+}
+
+func GetUUID() string {
+	code := uuid.New().String()
+	code = strings.Replace(code, "-", "", -1)
+	return code
+}
+
+const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+
+func init() {
+	rand.New(rand.NewSource(time.Now().UnixNano()))
+}
+
+func GenerateRandomCharsKey(length int) (string, error) {
+	b := make([]byte, length)
+	maxI := big.NewInt(int64(len(keyChars)))
+
+	for i := range b {
+		n, err := crand.Int(crand.Reader, maxI)
+		if err != nil {
+			return "", err
+		}
+		b[i] = keyChars[n.Int64()]
+	}
+
+	return string(b), nil
+}
+
+func GenerateRandomKey(length int) (string, error) {
+	bytes := make([]byte, length*3/4) // 对于48位的输出,这里应该是36
+	if _, err := crand.Read(bytes); err != nil {
+		return "", err
+	}
+	return base64.StdEncoding.EncodeToString(bytes), nil
+}
+
+func GenerateKey() (string, error) {
+	//rand.Seed(time.Now().UnixNano())
+	return GenerateRandomCharsKey(48)
+}
+
+func GetRandomInt(max int) int {
+	//rand.Seed(time.Now().UnixNano())
+	return rand.Intn(max)
+}
+
+func GetTimestamp() int64 {
+	return time.Now().Unix()
+}
+
+func GetTimeString() string {
+	now := time.Now()
+	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
+}
+
+func Max(a int, b int) int {
+	if a >= b {
+		return a
+	} else {
+		return b
+	}
+}
+
+func MessageWithRequestId(message string, id string) string {
+	return fmt.Sprintf("%s (request id: %s)", message, id)
+}
+
+func RandomSleep() {
+	// Sleep for 0-3000 ms
+	time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
+}
+
+func GetPointer[T any](v T) *T {
+	return &v
+}
+
+func Any2Type[T any](data any) (T, error) {
+	var zero T
+	bytes, err := json.Marshal(data)
+	if err != nil {
+		return zero, err
+	}
+	var res T
+	err = json.Unmarshal(bytes, &res)
+	if err != nil {
+		return zero, err
+	}
+	return res, nil
+}
+
+// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
+func SaveTmpFile(filename string, data io.Reader) (string, error) {
+	f, err := os.CreateTemp(os.TempDir(), filename)
+	if err != nil {
+		return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
+	}
+	defer f.Close()
+
+	_, err = io.Copy(f, data)
+	if err != nil {
+		return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
+	}
+
+	return f.Name(), nil
+}
+
+// GetAudioDuration returns the duration of an audio file in seconds.
+func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
+	// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
+	c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
+	output, err := c.Output()
+	if err != nil {
+		return 0, errors.Wrap(err, "failed to get audio duration")
+	}
+  durationStr := string(bytes.TrimSpace(output))
+  if durationStr == "N/A" {
+    // Create a temporary output file name
+    tmpFp, err := os.CreateTemp("", "audio-*"+ext)
+    if err != nil {
+      return 0, errors.Wrap(err, "failed to create temporary file")
+    }
+    tmpName := tmpFp.Name()
+    // Close immediately so ffmpeg can open the file on Windows.
+    _ = tmpFp.Close()
+    defer os.Remove(tmpName)
+
+    // ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
+    ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
+    if err := ffmpegCmd.Run(); err != nil {
+      return 0, errors.Wrap(err, "failed to run ffmpeg")
+    }
+
+    // Recalculate the duration of the new file
+    c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
+    output, err := c.Output()
+    if err != nil {
+      return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
+    }
+    durationStr = string(bytes.TrimSpace(output))
+  }
+	return strconv.ParseFloat(durationStr, 64)
+}
+
+// BuildURL concatenates base and endpoint, returns the complete url string
+func BuildURL(base string, endpoint string) string {
+	u, err := url.Parse(base)
+	if err != nil {
+		return base + endpoint
+	}
+	end := endpoint
+	if end == "" {
+		end = "/"
+	}
+	ref, err := url.Parse(end)
+	if err != nil {
+		return base + endpoint
+	}
+	return u.ResolveReference(ref).String()
+}

+ 9 - 0
common/validate.go

@@ -0,0 +1,9 @@
+package common
+
+import "github.com/go-playground/validator/v10"
+
+var Validate *validator.Validate
+
+func init() {
+	Validate = validator.New()
+}

+ 77 - 0
common/verification.go

@@ -0,0 +1,77 @@
+package common
+
+import (
+	"github.com/google/uuid"
+	"strings"
+	"sync"
+	"time"
+)
+
+type verificationValue struct {
+	code string
+	time time.Time
+}
+
+const (
+	EmailVerificationPurpose = "v"
+	PasswordResetPurpose     = "r"
+)
+
+var verificationMutex sync.Mutex
+var verificationMap map[string]verificationValue
+var verificationMapMaxSize = 10
+var VerificationValidMinutes = 10
+
+func GenerateVerificationCode(length int) string {
+	code := uuid.New().String()
+	code = strings.Replace(code, "-", "", -1)
+	if length == 0 {
+		return code
+	}
+	return code[:length]
+}
+
+func RegisterVerificationCodeWithKey(key string, code string, purpose string) {
+	verificationMutex.Lock()
+	defer verificationMutex.Unlock()
+	verificationMap[purpose+key] = verificationValue{
+		code: code,
+		time: time.Now(),
+	}
+	if len(verificationMap) > verificationMapMaxSize {
+		removeExpiredPairs()
+	}
+}
+
+func VerifyCodeWithKey(key string, code string, purpose string) bool {
+	verificationMutex.Lock()
+	defer verificationMutex.Unlock()
+	value, okay := verificationMap[purpose+key]
+	now := time.Now()
+	if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 {
+		return false
+	}
+	return code == value.code
+}
+
+func DeleteKey(key string, purpose string) {
+	verificationMutex.Lock()
+	defer verificationMutex.Unlock()
+	delete(verificationMap, purpose+key)
+}
+
+// no lock inside, so the caller must lock the verificationMap before calling!
+func removeExpiredPairs() {
+	now := time.Now()
+	for key := range verificationMap {
+		if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 {
+			delete(verificationMap, key)
+		}
+	}
+}
+
+func init() {
+	verificationMutex.Lock()
+	defer verificationMutex.Unlock()
+	verificationMap = make(map[string]verificationValue)
+}

+ 26 - 0
constant/README.md

@@ -0,0 +1,26 @@
+# constant 包 (`/constant`)
+
+该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
+
+## 当前文件
+
+| 文件                   | 说明                                                                  |
+|----------------------|---------------------------------------------------------------------|
+| `azure.go`           | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。         |
+| `cache_key.go`       | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。                                    |
+| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。                          |
+| `context_key.go`     | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
+| `env.go`             | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。                                     |
+| `finish_reason.go`   | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。                           |
+| `midjourney.go`      | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。                            |
+| `setup.go`           | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。                                       |
+| `task.go`            | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。                     |
+| `user_setting.go`    | 用户设置相关键常量以及通知类型(Email/Webhook)等。                                    |
+
+## 使用约定
+
+1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
+2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
+3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
+
+> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。

+ 35 - 0
constant/api_type.go

@@ -0,0 +1,35 @@
+package constant
+
+const (
+	APITypeOpenAI = iota
+	APITypeAnthropic
+	APITypePaLM
+	APITypeBaidu
+	APITypeZhipu
+	APITypeAli
+	APITypeXunfei
+	APITypeAIProxyLibrary
+	APITypeTencent
+	APITypeGemini
+	APITypeZhipuV4
+	APITypeOllama
+	APITypePerplexity
+	APITypeAws
+	APITypeCohere
+	APITypeDify
+	APITypeJina
+	APITypeCloudflare
+	APITypeSiliconFlow
+	APITypeVertexAi
+	APITypeMistral
+	APITypeDeepSeek
+	APITypeMokaAI
+	APITypeVolcEngine
+	APITypeBaiduV2
+	APITypeOpenRouter
+	APITypeXinference
+	APITypeXai
+	APITypeCoze
+	APITypeJimeng
+	APITypeDummy // this one is only for count, do not add any channel after this
+)

+ 5 - 0
constant/azure.go

@@ -0,0 +1,5 @@
+package constant
+
+import "time"
+
+var AzureNoRemoveDotTime = time.Date(2025, time.May, 10, 0, 0, 0, 0, time.UTC).Unix()

+ 14 - 0
constant/cache_key.go

@@ -0,0 +1,14 @@
+package constant
+
+// Cache keys
+const (
+	UserGroupKeyFmt    = "user_group:%d"
+	UserQuotaKeyFmt    = "user_quota:%d"
+	UserEnabledKeyFmt  = "user_enabled:%d"
+	UserUsernameKeyFmt = "user_name:%d"
+)
+
+const (
+	TokenFiledRemainQuota = "RemainQuota"
+	TokenFieldGroup       = "Group"
+)

+ 109 - 0
constant/channel.go

@@ -0,0 +1,109 @@
+package constant
+
+const (
+	ChannelTypeUnknown        = 0
+	ChannelTypeOpenAI         = 1
+	ChannelTypeMidjourney     = 2
+	ChannelTypeAzure          = 3
+	ChannelTypeOllama         = 4
+	ChannelTypeMidjourneyPlus = 5
+	ChannelTypeOpenAIMax      = 6
+	ChannelTypeOhMyGPT        = 7
+	ChannelTypeCustom         = 8
+	ChannelTypeAILS           = 9
+	ChannelTypeAIProxy        = 10
+	ChannelTypePaLM           = 11
+	ChannelTypeAPI2GPT        = 12
+	ChannelTypeAIGC2D         = 13
+	ChannelTypeAnthropic      = 14
+	ChannelTypeBaidu          = 15
+	ChannelTypeZhipu          = 16
+	ChannelTypeAli            = 17
+	ChannelTypeXunfei         = 18
+	ChannelType360            = 19
+	ChannelTypeOpenRouter     = 20
+	ChannelTypeAIProxyLibrary = 21
+	ChannelTypeFastGPT        = 22
+	ChannelTypeTencent        = 23
+	ChannelTypeGemini         = 24
+	ChannelTypeMoonshot       = 25
+	ChannelTypeZhipu_v4       = 26
+	ChannelTypePerplexity     = 27
+	ChannelTypeLingYiWanWu    = 31
+	ChannelTypeAws            = 33
+	ChannelTypeCohere         = 34
+	ChannelTypeMiniMax        = 35
+	ChannelTypeSunoAPI        = 36
+	ChannelTypeDify           = 37
+	ChannelTypeJina           = 38
+	ChannelCloudflare         = 39
+	ChannelTypeSiliconFlow    = 40
+	ChannelTypeVertexAi       = 41
+	ChannelTypeMistral        = 42
+	ChannelTypeDeepSeek       = 43
+	ChannelTypeMokaAI         = 44
+	ChannelTypeVolcEngine     = 45
+	ChannelTypeBaiduV2        = 46
+	ChannelTypeXinference     = 47
+	ChannelTypeXai            = 48
+	ChannelTypeCoze           = 49
+	ChannelTypeKling          = 50
+	ChannelTypeJimeng         = 51
+	ChannelTypeDummy          // this one is only for count, do not add any channel after this
+
+)
+
+var ChannelBaseURLs = []string{
+	"",                                    // 0
+	"https://api.openai.com",              // 1
+	"https://oa.api2d.net",                // 2
+	"",                                    // 3
+	"http://localhost:11434",              // 4
+	"https://api.openai-sb.com",           // 5
+	"https://api.openaimax.com",           // 6
+	"https://api.ohmygpt.com",             // 7
+	"",                                    // 8
+	"https://api.caipacity.com",           // 9
+	"https://api.aiproxy.io",              // 10
+	"",                                    // 11
+	"https://api.api2gpt.com",             // 12
+	"https://api.aigc2d.com",              // 13
+	"https://api.anthropic.com",           // 14
+	"https://aip.baidubce.com",            // 15
+	"https://open.bigmodel.cn",            // 16
+	"https://dashscope.aliyuncs.com",      // 17
+	"",                                    // 18
+	"https://api.360.cn",                  // 19
+	"https://openrouter.ai/api",           // 20
+	"https://api.aiproxy.io",              // 21
+	"https://fastgpt.run/api/openapi",     // 22
+	"https://hunyuan.tencentcloudapi.com", //23
+	"https://generativelanguage.googleapis.com", //24
+	"https://api.moonshot.cn",                   //25
+	"https://open.bigmodel.cn",                  //26
+	"https://api.perplexity.ai",                 //27
+	"",                                          //28
+	"",                                          //29
+	"",                                          //30
+	"https://api.lingyiwanwu.com",               //31
+	"",                                          //32
+	"",                                          //33
+	"https://api.cohere.ai",                     //34
+	"https://api.minimax.chat",                  //35
+	"",                                          //36
+	"https://api.dify.ai",                       //37
+	"https://api.jina.ai",                       //38
+	"https://api.cloudflare.com",                //39
+	"https://api.siliconflow.cn",                //40
+	"",                                          //41
+	"https://api.mistral.ai",                    //42
+	"https://api.deepseek.com",                  //43
+	"https://api.moka.ai",                       //44
+	"https://ark.cn-beijing.volces.com",         //45
+	"https://qianfan.baidubce.com",              //46
+	"",                                          //47
+	"https://api.x.ai",                          //48
+	"https://api.coze.cn",                       //49
+	"https://api.klingai.com",                   //50
+	"https://visual.volcengineapi.com",          //51
+}

+ 44 - 0
constant/context_key.go

@@ -0,0 +1,44 @@
+package constant
+
+type ContextKey string
+
+const (
+	ContextKeyOriginalModel    ContextKey = "original_model"
+	ContextKeyRequestStartTime ContextKey = "request_start_time"
+
+	/* token related keys */
+	ContextKeyTokenUnlimited         ContextKey = "token_unlimited_quota"
+	ContextKeyTokenKey               ContextKey = "token_key"
+	ContextKeyTokenId                ContextKey = "token_id"
+	ContextKeyTokenGroup             ContextKey = "token_group"
+	ContextKeyTokenAllowIps          ContextKey = "allow_ips"
+	ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
+	ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
+	ContextKeyTokenModelLimit        ContextKey = "token_model_limit"
+
+	/* channel related keys */
+	ContextKeyChannelId                ContextKey = "channel_id"
+	ContextKeyChannelName              ContextKey = "channel_name"
+	ContextKeyChannelCreateTime        ContextKey = "channel_create_time"
+	ContextKeyChannelBaseUrl           ContextKey = "base_url"
+	ContextKeyChannelType              ContextKey = "channel_type"
+	ContextKeyChannelSetting           ContextKey = "channel_setting"
+	ContextKeyChannelParamOverride     ContextKey = "param_override"
+	ContextKeyChannelOrganization      ContextKey = "channel_organization"
+	ContextKeyChannelAutoBan           ContextKey = "auto_ban"
+	ContextKeyChannelModelMapping      ContextKey = "model_mapping"
+	ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
+	ContextKeyChannelIsMultiKey        ContextKey = "channel_is_multi_key"
+	ContextKeyChannelMultiKeyIndex     ContextKey = "channel_multi_key_index"
+	ContextKeyChannelKey               ContextKey = "channel_key"
+
+	/* user related keys */
+	ContextKeyUserId      ContextKey = "id"
+	ContextKeyUserSetting ContextKey = "user_setting"
+	ContextKeyUserQuota   ContextKey = "user_quota"
+	ContextKeyUserStatus  ContextKey = "user_status"
+	ContextKeyUserEmail   ContextKey = "user_email"
+	ContextKeyUserGroup   ContextKey = "user_group"
+	ContextKeyUsingGroup  ContextKey = "group"
+	ContextKeyUserName    ContextKey = "username"
+)

+ 16 - 0
constant/endpoint_type.go

@@ -0,0 +1,16 @@
+package constant
+
+type EndpointType string
+
+const (
+	EndpointTypeOpenAI          EndpointType = "openai"
+	EndpointTypeOpenAIResponse  EndpointType = "openai-response"
+	EndpointTypeAnthropic       EndpointType = "anthropic"
+	EndpointTypeGemini          EndpointType = "gemini"
+	EndpointTypeJinaRerank      EndpointType = "jina-rerank"
+	EndpointTypeImageGeneration EndpointType = "image-generation"
+	//EndpointTypeMidjourney     EndpointType = "midjourney-proxy"
+	//EndpointTypeSuno           EndpointType = "suno-proxy"
+	//EndpointTypeKling          EndpointType = "kling"
+	//EndpointTypeJimeng         EndpointType = "jimeng"
+)

+ 15 - 0
constant/env.go

@@ -0,0 +1,15 @@
+package constant
+
+var StreamingTimeout int
+var DifyDebug bool
+var MaxFileDownloadMB int
+var ForceStreamOption bool
+var GetMediaToken bool
+var GetMediaTokenNotStream bool
+var UpdateTask bool
+var AzureDefaultAPIVersion string
+var GeminiVisionMaxImageNum int
+var NotifyLimitCount int
+var NotificationLimitDurationMinute int
+var GenerateDefaultToken bool
+var ErrorLogEnabled bool

+ 9 - 0
constant/finish_reason.go

@@ -0,0 +1,9 @@
+package constant
+
+var (
+	FinishReasonStop          = "stop"
+	FinishReasonToolCalls     = "tool_calls"
+	FinishReasonLength        = "length"
+	FinishReasonFunctionCall  = "function_call"
+	FinishReasonContentFilter = "content_filter"
+)

+ 48 - 0
constant/midjourney.go

@@ -0,0 +1,48 @@
+package constant
+
+const (
+	MjErrorUnknown = 5
+	MjRequestError = 4
+)
+
+const (
+	MjActionImagine       = "IMAGINE"
+	MjActionDescribe      = "DESCRIBE"
+	MjActionBlend         = "BLEND"
+	MjActionUpscale       = "UPSCALE"
+	MjActionVariation     = "VARIATION"
+	MjActionReRoll        = "REROLL"
+	MjActionInPaint       = "INPAINT"
+	MjActionModal         = "MODAL"
+	MjActionZoom          = "ZOOM"
+	MjActionCustomZoom    = "CUSTOM_ZOOM"
+	MjActionShorten       = "SHORTEN"
+	MjActionHighVariation = "HIGH_VARIATION"
+	MjActionLowVariation  = "LOW_VARIATION"
+	MjActionPan           = "PAN"
+	MjActionSwapFace      = "SWAP_FACE"
+	MjActionUpload        = "UPLOAD"
+	MjActionVideo         = "VIDEO"
+	MjActionEdits         = "EDITS"
+)
+
+var MidjourneyModel2Action = map[string]string{
+	"mj_imagine":        MjActionImagine,
+	"mj_describe":       MjActionDescribe,
+	"mj_blend":          MjActionBlend,
+	"mj_upscale":        MjActionUpscale,
+	"mj_variation":      MjActionVariation,
+	"mj_reroll":         MjActionReRoll,
+	"mj_modal":          MjActionModal,
+	"mj_inpaint":        MjActionInPaint,
+	"mj_zoom":           MjActionZoom,
+	"mj_custom_zoom":    MjActionCustomZoom,
+	"mj_shorten":        MjActionShorten,
+	"mj_high_variation": MjActionHighVariation,
+	"mj_low_variation":  MjActionLowVariation,
+	"mj_pan":            MjActionPan,
+	"swap_face":         MjActionSwapFace,
+	"mj_upload":         MjActionUpload,
+	"mj_video":          MjActionVideo,
+	"mj_edits":          MjActionEdits,
+}

+ 8 - 0
constant/multi_key_mode.go

@@ -0,0 +1,8 @@
+package constant
+
+type MultiKeyMode string
+
+const (
+	MultiKeyModeRandom  MultiKeyMode = "random"  // 随机
+	MultiKeyModePolling MultiKeyMode = "polling" // 轮询
+)

+ 3 - 0
constant/setup.go

@@ -0,0 +1,3 @@
+package constant
+
+var Setup = false

+ 23 - 0
constant/task.go

@@ -0,0 +1,23 @@
+package constant
+
+type TaskPlatform string
+
+const (
+	TaskPlatformSuno       TaskPlatform = "suno"
+	TaskPlatformMidjourney              = "mj"
+	TaskPlatformKling      TaskPlatform = "kling"
+	TaskPlatformJimeng     TaskPlatform = "jimeng"
+)
+
+const (
+	SunoActionMusic  = "MUSIC"
+	SunoActionLyrics = "LYRICS"
+
+	TaskActionGenerate     = "generate"
+	TaskActionTextGenerate = "textGenerate"
+)
+
+var SunoModel2Action = map[string]string{
+	"suno_music":  SunoActionMusic,
+	"suno_lyrics": SunoActionLyrics,
+}

+ 92 - 0
controller/billing.go

@@ -0,0 +1,92 @@
+package controller
+
+import (
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+	"one-api/dto"
+	"one-api/model"
+)
+
+func GetSubscription(c *gin.Context) {
+	var remainQuota int
+	var usedQuota int
+	var err error
+	var token *model.Token
+	var expiredTime int64
+	if common.DisplayTokenStatEnabled {
+		tokenId := c.GetInt("token_id")
+		token, err = model.GetTokenById(tokenId)
+		expiredTime = token.ExpiredTime
+		remainQuota = token.RemainQuota
+		usedQuota = token.UsedQuota
+	} else {
+		userId := c.GetInt("id")
+		remainQuota, err = model.GetUserQuota(userId, false)
+		usedQuota, err = model.GetUserUsedQuota(userId)
+	}
+	if expiredTime <= 0 {
+		expiredTime = 0
+	}
+	if err != nil {
+		openAIError := dto.OpenAIError{
+			Message: err.Error(),
+			Type:    "upstream_error",
+		}
+		c.JSON(200, gin.H{
+			"error": openAIError,
+		})
+		return
+	}
+	quota := remainQuota + usedQuota
+	amount := float64(quota)
+	if common.DisplayInCurrencyEnabled {
+		amount /= common.QuotaPerUnit
+	}
+	if token != nil && token.UnlimitedQuota {
+		amount = 100000000
+	}
+	subscription := OpenAISubscriptionResponse{
+		Object:             "billing_subscription",
+		HasPaymentMethod:   true,
+		SoftLimitUSD:       amount,
+		HardLimitUSD:       amount,
+		SystemHardLimitUSD: amount,
+		AccessUntil:        expiredTime,
+	}
+	c.JSON(200, subscription)
+	return
+}
+
+func GetUsage(c *gin.Context) {
+	var quota int
+	var err error
+	var token *model.Token
+	if common.DisplayTokenStatEnabled {
+		tokenId := c.GetInt("token_id")
+		token, err = model.GetTokenById(tokenId)
+		quota = token.UsedQuota
+	} else {
+		userId := c.GetInt("id")
+		quota, err = model.GetUserUsedQuota(userId)
+	}
+	if err != nil {
+		openAIError := dto.OpenAIError{
+			Message: err.Error(),
+			Type:    "new_api_error",
+		}
+		c.JSON(200, gin.H{
+			"error": openAIError,
+		})
+		return
+	}
+	amount := float64(quota)
+	if common.DisplayInCurrencyEnabled {
+		amount /= common.QuotaPerUnit
+	}
+	usage := OpenAIUsageResponse{
+		Object:     "list",
+		TotalUsage: amount * 100,
+	}
+	c.JSON(200, usage)
+	return
+}

+ 492 - 0
controller/channel-billing.go

@@ -0,0 +1,492 @@
+package controller
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/model"
+	"one-api/service"
+	"one-api/setting"
+	"one-api/types"
+	"strconv"
+	"time"
+
+	"github.com/shopspring/decimal"
+
+	"github.com/gin-gonic/gin"
+)
+
+// https://github.com/songquanpeng/one-api/issues/79
+
+type OpenAISubscriptionResponse struct {
+	Object             string  `json:"object"`
+	HasPaymentMethod   bool    `json:"has_payment_method"`
+	SoftLimitUSD       float64 `json:"soft_limit_usd"`
+	HardLimitUSD       float64 `json:"hard_limit_usd"`
+	SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
+	AccessUntil        int64   `json:"access_until"`
+}
+
+type OpenAIUsageDailyCost struct {
+	Timestamp float64 `json:"timestamp"`
+	LineItems []struct {
+		Name string  `json:"name"`
+		Cost float64 `json:"cost"`
+	}
+}
+
+type OpenAICreditGrants struct {
+	Object         string  `json:"object"`
+	TotalGranted   float64 `json:"total_granted"`
+	TotalUsed      float64 `json:"total_used"`
+	TotalAvailable float64 `json:"total_available"`
+}
+
+type OpenAIUsageResponse struct {
+	Object string `json:"object"`
+	//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
+	TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
+}
+
+type OpenAISBUsageResponse struct {
+	Msg  string `json:"msg"`
+	Data *struct {
+		Credit string `json:"credit"`
+	} `json:"data"`
+}
+
+type AIProxyUserOverviewResponse struct {
+	Success   bool   `json:"success"`
+	Message   string `json:"message"`
+	ErrorCode int    `json:"error_code"`
+	Data      struct {
+		TotalPoints float64 `json:"totalPoints"`
+	} `json:"data"`
+}
+
+type API2GPTUsageResponse struct {
+	Object         string  `json:"object"`
+	TotalGranted   float64 `json:"total_granted"`
+	TotalUsed      float64 `json:"total_used"`
+	TotalRemaining float64 `json:"total_remaining"`
+}
+
+type APGC2DGPTUsageResponse struct {
+	//Grants         interface{} `json:"grants"`
+	Object         string  `json:"object"`
+	TotalAvailable float64 `json:"total_available"`
+	TotalGranted   float64 `json:"total_granted"`
+	TotalUsed      float64 `json:"total_used"`
+}
+
+type SiliconFlowUsageResponse struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+	Status  bool   `json:"status"`
+	Data    struct {
+		ID            string `json:"id"`
+		Name          string `json:"name"`
+		Image         string `json:"image"`
+		Email         string `json:"email"`
+		IsAdmin       bool   `json:"isAdmin"`
+		Balance       string `json:"balance"`
+		Status        string `json:"status"`
+		Introduction  string `json:"introduction"`
+		Role          string `json:"role"`
+		ChargeBalance string `json:"chargeBalance"`
+		TotalBalance  string `json:"totalBalance"`
+		Category      string `json:"category"`
+	} `json:"data"`
+}
+
+type DeepSeekUsageResponse struct {
+	IsAvailable  bool `json:"is_available"`
+	BalanceInfos []struct {
+		Currency        string `json:"currency"`
+		TotalBalance    string `json:"total_balance"`
+		GrantedBalance  string `json:"granted_balance"`
+		ToppedUpBalance string `json:"topped_up_balance"`
+	} `json:"balance_infos"`
+}
+
+type OpenRouterCreditResponse struct {
+	Data struct {
+		TotalCredits float64 `json:"total_credits"`
+		TotalUsage   float64 `json:"total_usage"`
+	} `json:"data"`
+}
+
+// GetAuthHeader get auth header
+func GetAuthHeader(token string) http.Header {
+	h := http.Header{}
+	h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
+	return h
+}
+
+func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
+	req, err := http.NewRequest(method, url, nil)
+	if err != nil {
+		return nil, err
+	}
+	for k := range headers {
+		req.Header.Add(k, headers.Get(k))
+	}
+	res, err := service.GetHttpClient().Do(req)
+	if err != nil {
+		return nil, err
+	}
+	if res.StatusCode != http.StatusOK {
+		return nil, fmt.Errorf("status code: %d", res.StatusCode)
+	}
+	body, err := io.ReadAll(res.Body)
+	if err != nil {
+		return nil, err
+	}
+	err = res.Body.Close()
+	if err != nil {
+		return nil, err
+	}
+	return body, nil
+}
+
+func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
+	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+
+	if err != nil {
+		return 0, err
+	}
+	response := OpenAICreditGrants{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	channel.UpdateBalance(response.TotalAvailable)
+	return response.TotalAvailable, nil
+}
+
+func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
+	url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		return 0, err
+	}
+	response := OpenAISBUsageResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	if response.Data == nil {
+		return 0, errors.New(response.Msg)
+	}
+	balance, err := strconv.ParseFloat(response.Data.Credit, 64)
+	if err != nil {
+		return 0, err
+	}
+	channel.UpdateBalance(balance)
+	return balance, nil
+}
+
+func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
+	url := "https://aiproxy.io/api/report/getUserOverview"
+	headers := http.Header{}
+	headers.Add("Api-Key", channel.Key)
+	body, err := GetResponseBody("GET", url, channel, headers)
+	if err != nil {
+		return 0, err
+	}
+	response := AIProxyUserOverviewResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	if !response.Success {
+		return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
+	}
+	channel.UpdateBalance(response.Data.TotalPoints)
+	return response.Data.TotalPoints, nil
+}
+
+func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
+	url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+
+	if err != nil {
+		return 0, err
+	}
+	response := API2GPTUsageResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	channel.UpdateBalance(response.TotalRemaining)
+	return response.TotalRemaining, nil
+}
+
+func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
+	url := "https://api.siliconflow.cn/v1/user/info"
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		return 0, err
+	}
+	response := SiliconFlowUsageResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	if response.Code != 20000 {
+		return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
+	}
+	balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
+	if err != nil {
+		return 0, err
+	}
+	channel.UpdateBalance(balance)
+	return balance, nil
+}
+
+func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
+	url := "https://api.deepseek.com/user/balance"
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		return 0, err
+	}
+	response := DeepSeekUsageResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	index := -1
+	for i, balanceInfo := range response.BalanceInfos {
+		if balanceInfo.Currency == "CNY" {
+			index = i
+			break
+		}
+	}
+	if index == -1 {
+		return 0, errors.New("currency CNY not found")
+	}
+	balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
+	if err != nil {
+		return 0, err
+	}
+	channel.UpdateBalance(balance)
+	return balance, nil
+}
+
+func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
+	url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		return 0, err
+	}
+	response := APGC2DGPTUsageResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	channel.UpdateBalance(response.TotalAvailable)
+	return response.TotalAvailable, nil
+}
+
+func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
+	url := "https://openrouter.ai/api/v1/credits"
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		return 0, err
+	}
+	response := OpenRouterCreditResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	balance := response.Data.TotalCredits - response.Data.TotalUsage
+	channel.UpdateBalance(balance)
+	return balance, nil
+}
+
+func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
+	url := "https://api.moonshot.cn/v1/users/me/balance"
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		return 0, err
+	}
+
+	type MoonshotBalanceData struct {
+		AvailableBalance float64 `json:"available_balance"`
+		VoucherBalance   float64 `json:"voucher_balance"`
+		CashBalance      float64 `json:"cash_balance"`
+	}
+
+	type MoonshotBalanceResponse struct {
+		Code   int                 `json:"code"`
+		Data   MoonshotBalanceData `json:"data"`
+		Scode  string              `json:"scode"`
+		Status bool                `json:"status"`
+	}
+
+	response := MoonshotBalanceResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	if !response.Status || response.Code != 0 {
+		return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
+	}
+	availableBalanceCny := response.Data.AvailableBalance
+	availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
+	channel.UpdateBalance(availableBalanceUsd)
+	return availableBalanceUsd, nil
+}
+
+func updateChannelBalance(channel *model.Channel) (float64, error) {
+	baseURL := constant.ChannelBaseURLs[channel.Type]
+	if channel.GetBaseURL() == "" {
+		channel.BaseURL = &baseURL
+	}
+	switch channel.Type {
+	case constant.ChannelTypeOpenAI:
+		if channel.GetBaseURL() != "" {
+			baseURL = channel.GetBaseURL()
+		}
+	case constant.ChannelTypeAzure:
+		return 0, errors.New("尚未实现")
+	case constant.ChannelTypeCustom:
+		baseURL = channel.GetBaseURL()
+	//case common.ChannelTypeOpenAISB:
+	//	return updateChannelOpenAISBBalance(channel)
+	case constant.ChannelTypeAIProxy:
+		return updateChannelAIProxyBalance(channel)
+	case constant.ChannelTypeAPI2GPT:
+		return updateChannelAPI2GPTBalance(channel)
+	case constant.ChannelTypeAIGC2D:
+		return updateChannelAIGC2DBalance(channel)
+	case constant.ChannelTypeSiliconFlow:
+		return updateChannelSiliconFlowBalance(channel)
+	case constant.ChannelTypeDeepSeek:
+		return updateChannelDeepSeekBalance(channel)
+	case constant.ChannelTypeOpenRouter:
+		return updateChannelOpenRouterBalance(channel)
+	case constant.ChannelTypeMoonshot:
+		return updateChannelMoonshotBalance(channel)
+	default:
+		return 0, errors.New("尚未实现")
+	}
+	url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
+
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		return 0, err
+	}
+	subscription := OpenAISubscriptionResponse{}
+	err = json.Unmarshal(body, &subscription)
+	if err != nil {
+		return 0, err
+	}
+	now := time.Now()
+	startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
+	endDate := now.Format("2006-01-02")
+	if !subscription.HasPaymentMethod {
+		startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
+	}
+	url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
+	body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		return 0, err
+	}
+	usage := OpenAIUsageResponse{}
+	err = json.Unmarshal(body, &usage)
+	if err != nil {
+		return 0, err
+	}
+	balance := subscription.HardLimitUSD - usage.TotalUsage/100
+	channel.UpdateBalance(balance)
+	return balance, nil
+}
+
+func UpdateChannelBalance(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	channel, err := model.CacheGetChannel(id)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	if channel.ChannelInfo.IsMultiKey {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "多密钥渠道不支持余额查询",
+		})
+		return
+	}
+	balance, err := updateChannelBalance(channel)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"balance": balance,
+	})
+}
+
+func updateAllChannelsBalance() error {
+	channels, err := model.GetAllChannels(0, 0, true, false)
+	if err != nil {
+		return err
+	}
+	for _, channel := range channels {
+		if channel.Status != common.ChannelStatusEnabled {
+			continue
+		}
+		if channel.ChannelInfo.IsMultiKey {
+			continue // skip multi-key channels
+		}
+		// TODO: support Azure
+		//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
+		//	continue
+		//}
+		balance, err := updateChannelBalance(channel)
+		if err != nil {
+			continue
+		} else {
+			// err is nil & balance <= 0 means quota is used up
+			if balance <= 0 {
+				service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
+			}
+		}
+		time.Sleep(common.RequestInterval)
+	}
+	return nil
+}
+
+func UpdateAllChannelsBalance(c *gin.Context) {
+	// TODO: make it async
+	err := updateAllChannelsBalance()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func AutomaticallyUpdateChannels(frequency int) {
+	for {
+		time.Sleep(time.Duration(frequency) * time.Minute)
+		common.SysLog("updating all channels")
+		_ = updateAllChannelsBalance()
+		common.SysLog("channels update done")
+	}
+}

+ 465 - 0
controller/channel-test.go

@@ -0,0 +1,465 @@
+package controller
+
+import (
+	"bytes"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"math"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/middleware"
+	"one-api/model"
+	"one-api/relay"
+	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
+	"one-api/relay/helper"
+	"one-api/service"
+	"one-api/types"
+	"strconv"
+	"strings"
+	"sync"
+	"time"
+
+	"github.com/bytedance/gopkg/util/gopool"
+
+	"github.com/gin-gonic/gin"
+)
+
+type testResult struct {
+	context     *gin.Context
+	localErr    error
+	newAPIError *types.NewAPIError
+}
+
+func testChannel(channel *model.Channel, testModel string) testResult {
+	tik := time.Now()
+	if channel.Type == constant.ChannelTypeMidjourney {
+		return testResult{
+			localErr:    errors.New("midjourney channel test is not supported"),
+			newAPIError: nil,
+		}
+	}
+	if channel.Type == constant.ChannelTypeMidjourneyPlus {
+		return testResult{
+			localErr:    errors.New("midjourney plus channel test is not supported"),
+			newAPIError: nil,
+		}
+	}
+	if channel.Type == constant.ChannelTypeSunoAPI {
+		return testResult{
+			localErr:    errors.New("suno channel test is not supported"),
+			newAPIError: nil,
+		}
+	}
+	if channel.Type == constant.ChannelTypeKling {
+		return testResult{
+			localErr:    errors.New("kling channel test is not supported"),
+			newAPIError: nil,
+		}
+	}
+	if channel.Type == constant.ChannelTypeJimeng {
+		return testResult{
+			localErr:    errors.New("jimeng channel test is not supported"),
+			newAPIError: nil,
+		}
+	}
+	w := httptest.NewRecorder()
+	c, _ := gin.CreateTestContext(w)
+
+	requestPath := "/v1/chat/completions"
+
+	// 先判断是否为 Embedding 模型
+	if strings.Contains(strings.ToLower(testModel), "embedding") ||
+		strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
+		strings.Contains(testModel, "bge-") || // bge 系列模型
+		strings.Contains(testModel, "embed") ||
+		channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
+		requestPath = "/v1/embeddings" // 修改请求路径
+	}
+
+	c.Request = &http.Request{
+		Method: "POST",
+		URL:    &url.URL{Path: requestPath}, // 使用动态路径
+		Body:   nil,
+		Header: make(http.Header),
+	}
+
+	if testModel == "" {
+		if channel.TestModel != nil && *channel.TestModel != "" {
+			testModel = *channel.TestModel
+		} else {
+			if len(channel.GetModels()) > 0 {
+				testModel = channel.GetModels()[0]
+			} else {
+				testModel = "gpt-4o-mini"
+			}
+		}
+	}
+
+	cache, err := model.GetUserCache(1)
+	if err != nil {
+		return testResult{
+			localErr:    err,
+			newAPIError: nil,
+		}
+	}
+	cache.WriteContext(c)
+
+	//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
+	c.Request.Header.Set("Content-Type", "application/json")
+	c.Set("channel", channel.Type)
+	c.Set("base_url", channel.GetBaseURL())
+	group, _ := model.GetUserGroup(1, false)
+	c.Set("group", group)
+
+	newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
+	if newAPIError != nil {
+		return testResult{
+			context:     c,
+			localErr:    newAPIError,
+			newAPIError: newAPIError,
+		}
+	}
+
+	info := relaycommon.GenRelayInfo(c)
+
+	err = helper.ModelMappedHelper(c, info, nil)
+	if err != nil {
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
+		}
+	}
+	testModel = info.UpstreamModelName
+
+	apiType, _ := common.ChannelType2APIType(channel.Type)
+	adaptor := relay.GetAdaptor(apiType)
+	if adaptor == nil {
+		return testResult{
+			context:     c,
+			localErr:    fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
+			newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
+		}
+	}
+
+	request := buildTestRequest(testModel)
+	// 创建一个用于日志的 info 副本,移除 ApiKey
+	logInfo := *info
+	logInfo.ApiKey = ""
+	common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
+
+	priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
+	if err != nil {
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
+		}
+	}
+
+	adaptor.Init(info)
+
+	var convertedRequest any
+	// 根据 RelayMode 选择正确的转换函数
+	if info.RelayMode == relayconstant.RelayModeEmbeddings {
+		// 创建一个 EmbeddingRequest
+		embeddingRequest := dto.EmbeddingRequest{
+			Input: request.Input,
+			Model: request.Model,
+		}
+		// 调用专门用于 Embedding 的转换函数
+		convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
+	} else {
+		// 对其他所有请求类型(如 Chat),保持原有逻辑
+		convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
+	}
+
+	if err != nil {
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
+		}
+	}
+	jsonData, err := json.Marshal(convertedRequest)
+	if err != nil {
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
+		}
+	}
+	requestBody := bytes.NewBuffer(jsonData)
+	c.Request.Body = io.NopCloser(requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
+	if err != nil {
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
+		}
+	}
+	var httpResp *http.Response
+	if resp != nil {
+		httpResp = resp.(*http.Response)
+		if httpResp.StatusCode != http.StatusOK {
+			err := service.RelayErrorHandler(httpResp, true)
+			return testResult{
+				context:     c,
+				localErr:    err,
+				newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
+			}
+		}
+	}
+	usageA, respErr := adaptor.DoResponse(c, httpResp, info)
+	if respErr != nil {
+		return testResult{
+			context:     c,
+			localErr:    respErr,
+			newAPIError: respErr,
+		}
+	}
+	if usageA == nil {
+		return testResult{
+			context:     c,
+			localErr:    errors.New("usage is nil"),
+			newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
+		}
+	}
+	usage := usageA.(*dto.Usage)
+	result := w.Result()
+	respBody, err := io.ReadAll(result.Body)
+	if err != nil {
+		return testResult{
+			context:     c,
+			localErr:    err,
+			newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
+		}
+	}
+	info.PromptTokens = usage.PromptTokens
+
+	quota := 0
+	if !priceData.UsePrice {
+		quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
+		quota = int(math.Round(float64(quota) * priceData.ModelRatio))
+		if priceData.ModelRatio != 0 && quota <= 0 {
+			quota = 1
+		}
+	} else {
+		quota = int(priceData.ModelPrice * common.QuotaPerUnit)
+	}
+	tok := time.Now()
+	milliseconds := tok.Sub(tik).Milliseconds()
+	consumedTime := float64(milliseconds) / 1000.0
+	other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
+		usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+	model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
+		ChannelId:        channel.Id,
+		PromptTokens:     usage.PromptTokens,
+		CompletionTokens: usage.CompletionTokens,
+		ModelName:        info.OriginModelName,
+		TokenName:        "模型测试",
+		Quota:            quota,
+		Content:          "模型测试",
+		UserInput:        "", // 测试请求不记录用户输入
+		UseTimeSeconds:   int(consumedTime),
+		IsStream:         false,
+		Group:            info.UsingGroup,
+		Other:            other,
+	})
+	common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
+	return testResult{
+		context:     c,
+		localErr:    nil,
+		newAPIError: nil,
+	}
+}
+
+func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
+	testRequest := &dto.GeneralOpenAIRequest{
+		Model:  "", // this will be set later
+		Stream: false,
+	}
+
+	// 先判断是否为 Embedding 模型
+	if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
+		strings.HasPrefix(model, "m3e") || // m3e 系列模型
+		strings.Contains(model, "bge-") {
+		testRequest.Model = model
+		// Embedding 请求
+		testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
+		return testRequest
+	}
+	// 并非Embedding 模型
+	if strings.HasPrefix(model, "o") {
+		testRequest.MaxCompletionTokens = 10
+	} else if strings.Contains(model, "thinking") {
+		if !strings.Contains(model, "claude") {
+			testRequest.MaxTokens = 50
+		}
+	} else if strings.Contains(model, "gemini") {
+		testRequest.MaxTokens = 3000
+	} else {
+		testRequest.MaxTokens = 10
+	}
+
+	testMessage := dto.Message{
+		Role:    "user",
+		Content: "hi",
+	}
+	testRequest.Model = model
+	testRequest.Messages = append(testRequest.Messages, testMessage)
+	return testRequest
+}
+
+func TestChannel(c *gin.Context) {
+	channelId, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	channel, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	//defer func() {
+	//	if channel.ChannelInfo.IsMultiKey {
+	//		go func() { _ = channel.SaveChannelInfo() }()
+	//	}
+	//}()
+	testModel := c.Query("model")
+	tik := time.Now()
+	result := testChannel(channel, testModel)
+	if result.localErr != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": result.localErr.Error(),
+			"time":    0.0,
+		})
+		return
+	}
+	tok := time.Now()
+	milliseconds := tok.Sub(tik).Milliseconds()
+	go channel.UpdateResponseTime(milliseconds)
+	consumedTime := float64(milliseconds) / 1000.0
+	if result.newAPIError != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": result.newAPIError.Error(),
+			"time":    consumedTime,
+		})
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"time":    consumedTime,
+	})
+	return
+}
+
+var testAllChannelsLock sync.Mutex
+var testAllChannelsRunning bool = false
+
+func testAllChannels(notify bool) error {
+
+	testAllChannelsLock.Lock()
+	if testAllChannelsRunning {
+		testAllChannelsLock.Unlock()
+		return errors.New("测试已在运行中")
+	}
+	testAllChannelsRunning = true
+	testAllChannelsLock.Unlock()
+	channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
+	if getChannelErr != nil {
+		return getChannelErr
+	}
+	var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
+	if disableThreshold == 0 {
+		disableThreshold = 10000000 // a impossible value
+	}
+	gopool.Go(func() {
+		// 使用 defer 确保无论如何都会重置运行状态,防止死锁
+		defer func() {
+			testAllChannelsLock.Lock()
+			testAllChannelsRunning = false
+			testAllChannelsLock.Unlock()
+		}()
+
+		for _, channel := range channels {
+			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
+			tik := time.Now()
+			result := testChannel(channel, "")
+			tok := time.Now()
+			milliseconds := tok.Sub(tik).Milliseconds()
+
+			shouldBanChannel := false
+			newAPIError := result.newAPIError
+			// request error disables the channel
+			if newAPIError != nil {
+				shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
+			}
+
+			// 当错误检查通过,才检查响应时间
+			if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
+				if milliseconds > disableThreshold {
+					err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+					newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
+					shouldBanChannel = true
+				}
+			}
+
+			// disable channel
+			if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
+				go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+			}
+
+			// enable channel
+			if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
+				service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
+			}
+
+			channel.UpdateResponseTime(milliseconds)
+			time.Sleep(common.RequestInterval)
+		}
+
+		if notify {
+			service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
+		}
+	})
+	return nil
+}
+
+func TestAllChannels(c *gin.Context) {
+	err := testAllChannels(true)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func AutomaticallyTestChannels(frequency int) {
+	if frequency <= 0 {
+		common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
+		return
+	}
+	for {
+		time.Sleep(time.Duration(frequency) * time.Minute)
+		common.SysLog("testing all channels")
+		_ = testAllChannels(false)
+		common.SysLog("channel test finished")
+	}
+}

+ 916 - 0
controller/channel.go

@@ -0,0 +1,916 @@
+package controller
+
+import (
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/model"
+	"strconv"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+)
+
+type OpenAIModel struct {
+	ID         string `json:"id"`
+	Object     string `json:"object"`
+	Created    int64  `json:"created"`
+	OwnedBy    string `json:"owned_by"`
+	Permission []struct {
+		ID                 string `json:"id"`
+		Object             string `json:"object"`
+		Created            int64  `json:"created"`
+		AllowCreateEngine  bool   `json:"allow_create_engine"`
+		AllowSampling      bool   `json:"allow_sampling"`
+		AllowLogprobs      bool   `json:"allow_logprobs"`
+		AllowSearchIndices bool   `json:"allow_search_indices"`
+		AllowView          bool   `json:"allow_view"`
+		AllowFineTuning    bool   `json:"allow_fine_tuning"`
+		Organization       string `json:"organization"`
+		Group              string `json:"group"`
+		IsBlocking         bool   `json:"is_blocking"`
+	} `json:"permission"`
+	Root   string `json:"root"`
+	Parent string `json:"parent"`
+}
+
+type OpenAIModelsResponse struct {
+	Data    []OpenAIModel `json:"data"`
+	Success bool          `json:"success"`
+}
+
+func parseStatusFilter(statusParam string) int {
+	switch strings.ToLower(statusParam) {
+	case "enabled", "1":
+		return common.ChannelStatusEnabled
+	case "disabled", "0":
+		return 0
+	default:
+		return -1
+	}
+}
+
+func GetAllChannels(c *gin.Context) {
+	pageInfo := common.GetPageQuery(c)
+	channelData := make([]*model.Channel, 0)
+	idSort, _ := strconv.ParseBool(c.Query("id_sort"))
+	enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
+	statusParam := c.Query("status")
+	// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
+	statusFilter := parseStatusFilter(statusParam)
+	// type filter
+	typeStr := c.Query("type")
+	typeFilter := -1
+	if typeStr != "" {
+		if t, err := strconv.Atoi(typeStr); err == nil {
+			typeFilter = t
+		}
+	}
+
+	var total int64
+
+	if enableTagMode {
+		tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+			return
+		}
+		for _, tag := range tags {
+			if tag == nil || *tag == "" {
+				continue
+			}
+			tagChannels, err := model.GetChannelsByTag(*tag, idSort)
+			if err != nil {
+				continue
+			}
+			filtered := make([]*model.Channel, 0)
+			for _, ch := range tagChannels {
+				if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
+					continue
+				}
+				if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
+					continue
+				}
+				if typeFilter >= 0 && ch.Type != typeFilter {
+					continue
+				}
+				filtered = append(filtered, ch)
+			}
+			channelData = append(channelData, filtered...)
+		}
+		total, _ = model.CountAllTags()
+	} else {
+		baseQuery := model.DB.Model(&model.Channel{})
+		if typeFilter >= 0 {
+			baseQuery = baseQuery.Where("type = ?", typeFilter)
+		}
+		if statusFilter == common.ChannelStatusEnabled {
+			baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
+		} else if statusFilter == 0 {
+			baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
+		}
+
+		baseQuery.Count(&total)
+
+		order := "priority desc"
+		if idSort {
+			order = "id desc"
+		}
+
+		err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+			return
+		}
+	}
+
+	countQuery := model.DB.Model(&model.Channel{})
+	if statusFilter == common.ChannelStatusEnabled {
+		countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
+	} else if statusFilter == 0 {
+		countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
+	}
+	var results []struct {
+		Type  int64
+		Count int64
+	}
+	_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
+	typeCounts := make(map[int64]int64)
+	for _, r := range results {
+		typeCounts[r.Type] = r.Count
+	}
+	common.ApiSuccess(c, gin.H{
+		"items":       channelData,
+		"total":       total,
+		"page":        pageInfo.GetPage(),
+		"page_size":   pageInfo.GetPageSize(),
+		"type_counts": typeCounts,
+	})
+	return
+}
+
+func FetchUpstreamModels(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	channel, err := model.GetChannelById(id, true)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	baseURL := constant.ChannelBaseURLs[channel.Type]
+	if channel.GetBaseURL() != "" {
+		baseURL = channel.GetBaseURL()
+	}
+	url := fmt.Sprintf("%s/v1/models", baseURL)
+	switch channel.Type {
+	case constant.ChannelTypeGemini:
+		url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
+	case constant.ChannelTypeAli:
+		url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
+	}
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	var result OpenAIModelsResponse
+	if err = json.Unmarshal(body, &result); err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
+		})
+		return
+	}
+
+	var ids []string
+	for _, model := range result.Data {
+		id := model.ID
+		if channel.Type == constant.ChannelTypeGemini {
+			id = strings.TrimPrefix(id, "models/")
+		}
+		ids = append(ids, id)
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    ids,
+	})
+}
+
+func FixChannelsAbilities(c *gin.Context) {
+	success, fails, err := model.FixAbility()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data": gin.H{
+			"success": success,
+			"fails":   fails,
+		},
+	})
+}
+
+func SearchChannels(c *gin.Context) {
+	keyword := c.Query("keyword")
+	group := c.Query("group")
+	modelKeyword := c.Query("model")
+	statusParam := c.Query("status")
+	statusFilter := parseStatusFilter(statusParam)
+	idSort, _ := strconv.ParseBool(c.Query("id_sort"))
+	enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
+	channelData := make([]*model.Channel, 0)
+	if enableTagMode {
+		tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+		for _, tag := range tags {
+			if tag != nil && *tag != "" {
+				tagChannel, err := model.GetChannelsByTag(*tag, idSort)
+				if err == nil {
+					channelData = append(channelData, tagChannel...)
+				}
+			}
+		}
+	} else {
+		channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+		channelData = channels
+	}
+
+	if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
+		filtered := make([]*model.Channel, 0, len(channelData))
+		for _, ch := range channelData {
+			if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
+				continue
+			}
+			if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
+				continue
+			}
+			filtered = append(filtered, ch)
+		}
+		channelData = filtered
+	}
+
+	// calculate type counts for search results
+	typeCounts := make(map[int64]int64)
+	for _, channel := range channelData {
+		typeCounts[int64(channel.Type)]++
+	}
+
+	typeParam := c.Query("type")
+	typeFilter := -1
+	if typeParam != "" {
+		if tp, err := strconv.Atoi(typeParam); err == nil {
+			typeFilter = tp
+		}
+	}
+
+	if typeFilter >= 0 {
+		filtered := make([]*model.Channel, 0, len(channelData))
+		for _, ch := range channelData {
+			if ch.Type == typeFilter {
+				filtered = append(filtered, ch)
+			}
+		}
+		channelData = filtered
+	}
+
+	page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
+	pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
+	if page < 1 {
+		page = 1
+	}
+	if pageSize <= 0 {
+		pageSize = 20
+	}
+
+	total := len(channelData)
+	startIdx := (page - 1) * pageSize
+	if startIdx > total {
+		startIdx = total
+	}
+	endIdx := startIdx + pageSize
+	if endIdx > total {
+		endIdx = total
+	}
+
+	pagedData := channelData[startIdx:endIdx]
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data": gin.H{
+			"items":       pagedData,
+			"total":       total,
+			"type_counts": typeCounts,
+		},
+	})
+	return
+}
+
+func GetChannel(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	channel, err := model.GetChannelById(id, false)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    channel,
+	})
+	return
+}
+
+// validateChannel 通用的渠道校验函数
+func validateChannel(channel *model.Channel, isAdd bool) error {
+	// 校验 channel settings
+	if err := channel.ValidateSettings(); err != nil {
+		return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
+	}
+
+	// 如果是添加操作,检查 channel 和 key 是否为空
+	if isAdd {
+		if channel == nil || channel.Key == "" {
+			return fmt.Errorf("channel cannot be empty")
+		}
+
+		// 检查模型名称长度是否超过 255
+		for _, m := range channel.GetModels() {
+			if len(m) > 255 {
+				return fmt.Errorf("模型名称过长: %s", m)
+			}
+		}
+	}
+
+	// VertexAI 特殊校验
+	if channel.Type == constant.ChannelTypeVertexAi {
+		if channel.Other == "" {
+			return fmt.Errorf("部署地区不能为空")
+		}
+
+		regionMap, err := common.StrToMap(channel.Other)
+		if err != nil {
+			return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
+		}
+
+		if regionMap["default"] == nil {
+			return fmt.Errorf("部署地区必须包含default字段")
+		}
+	}
+
+	return nil
+}
+
+type AddChannelRequest struct {
+	Mode         string                `json:"mode"`
+	MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
+	Channel      *model.Channel        `json:"channel"`
+}
+
+func getVertexArrayKeys(keys string) ([]string, error) {
+	if keys == "" {
+		return nil, nil
+	}
+	var keyArray []interface{}
+	err := common.Unmarshal([]byte(keys), &keyArray)
+	if err != nil {
+		return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
+	}
+	cleanKeys := make([]string, 0, len(keyArray))
+	for _, key := range keyArray {
+		var keyStr string
+		switch v := key.(type) {
+		case string:
+			keyStr = strings.TrimSpace(v)
+		default:
+			bytes, err := json.Marshal(v)
+			if err != nil {
+				return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
+			}
+			keyStr = string(bytes)
+		}
+		if keyStr != "" {
+			cleanKeys = append(cleanKeys, keyStr)
+		}
+	}
+	if len(cleanKeys) == 0 {
+		return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
+	}
+	return cleanKeys, nil
+}
+
+func AddChannel(c *gin.Context) {
+	addChannelRequest := AddChannelRequest{}
+	err := c.ShouldBindJSON(&addChannelRequest)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	// 使用统一的校验函数
+	if err := validateChannel(addChannelRequest.Channel, true); err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
+	keys := make([]string, 0)
+	switch addChannelRequest.Mode {
+	case "multi_to_single":
+		addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
+		addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
+		if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+			array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
+			if err != nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": err.Error(),
+				})
+				return
+			}
+			addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
+			addChannelRequest.Channel.Key = strings.Join(array, "\n")
+		} else {
+			cleanKeys := make([]string, 0)
+			for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
+				if key == "" {
+					continue
+				}
+				key = strings.TrimSpace(key)
+				cleanKeys = append(cleanKeys, key)
+			}
+			addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
+			addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
+		}
+		keys = []string{addChannelRequest.Channel.Key}
+	case "batch":
+		if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+			// multi json
+			keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
+			if err != nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": err.Error(),
+				})
+				return
+			}
+		} else {
+			keys = strings.Split(addChannelRequest.Channel.Key, "\n")
+		}
+	case "single":
+		keys = []string{addChannelRequest.Channel.Key}
+	default:
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "不支持的添加模式",
+		})
+		return
+	}
+
+	channels := make([]model.Channel, 0, len(keys))
+	for _, key := range keys {
+		if key == "" {
+			continue
+		}
+		localChannel := addChannelRequest.Channel
+		localChannel.Key = key
+		channels = append(channels, *localChannel)
+	}
+	err = model.BatchInsertChannels(channels)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func DeleteChannel(c *gin.Context) {
+	id, _ := strconv.Atoi(c.Param("id"))
+	channel := model.Channel{Id: id}
+	err := channel.Delete()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	model.InitChannelCache()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func DeleteDisabledChannel(c *gin.Context) {
+	rows, err := model.DeleteDisabledChannel()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	model.InitChannelCache()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    rows,
+	})
+	return
+}
+
+type ChannelTag struct {
+	Tag          string  `json:"tag"`
+	NewTag       *string `json:"new_tag"`
+	Priority     *int64  `json:"priority"`
+	Weight       *uint   `json:"weight"`
+	ModelMapping *string `json:"model_mapping"`
+	Models       *string `json:"models"`
+	Groups       *string `json:"groups"`
+}
+
+func DisableTagChannels(c *gin.Context) {
+	channelTag := ChannelTag{}
+	err := c.ShouldBindJSON(&channelTag)
+	if err != nil || channelTag.Tag == "" {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "参数错误",
+		})
+		return
+	}
+	err = model.DisableChannelByTag(channelTag.Tag)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	model.InitChannelCache()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func EnableTagChannels(c *gin.Context) {
+	channelTag := ChannelTag{}
+	err := c.ShouldBindJSON(&channelTag)
+	if err != nil || channelTag.Tag == "" {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "参数错误",
+		})
+		return
+	}
+	err = model.EnableChannelByTag(channelTag.Tag)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	model.InitChannelCache()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func EditTagChannels(c *gin.Context) {
+	channelTag := ChannelTag{}
+	err := c.ShouldBindJSON(&channelTag)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "参数错误",
+		})
+		return
+	}
+	if channelTag.Tag == "" {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "tag不能为空",
+		})
+		return
+	}
+	err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	model.InitChannelCache()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+type ChannelBatch struct {
+	Ids []int   `json:"ids"`
+	Tag *string `json:"tag"`
+}
+
+func DeleteChannelBatch(c *gin.Context) {
+	channelBatch := ChannelBatch{}
+	err := c.ShouldBindJSON(&channelBatch)
+	if err != nil || len(channelBatch.Ids) == 0 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "参数错误",
+		})
+		return
+	}
+	err = model.BatchDeleteChannels(channelBatch.Ids)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	model.InitChannelCache()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    len(channelBatch.Ids),
+	})
+	return
+}
+
+type PatchChannel struct {
+	model.Channel
+	MultiKeyMode *string `json:"multi_key_mode"`
+}
+
+func UpdateChannel(c *gin.Context) {
+	channel := PatchChannel{}
+	err := c.ShouldBindJSON(&channel)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	// 使用统一的校验函数
+	if err := validateChannel(&channel.Channel, false); err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
+	originChannel, err := model.GetChannelById(channel.Id, false)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	// Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
+	channel.ChannelInfo = originChannel.ChannelInfo
+
+	// If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
+	if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
+		channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
+	}
+	err = channel.Update()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	model.InitChannelCache()
+	channel.Key = ""
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    channel,
+	})
+	return
+}
+
+func FetchModels(c *gin.Context) {
+	var req struct {
+		BaseURL string `json:"base_url"`
+		Type    int    `json:"type"`
+		Key     string `json:"key"`
+	}
+
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{
+			"success": false,
+			"message": "Invalid request",
+		})
+		return
+	}
+
+	baseURL := req.BaseURL
+	if baseURL == "" {
+		baseURL = constant.ChannelBaseURLs[req.Type]
+	}
+
+	client := &http.Client{}
+	url := fmt.Sprintf("%s/v1/models", baseURL)
+
+	request, err := http.NewRequest("GET", url, nil)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	// remove line breaks and extra spaces.
+	key := strings.TrimSpace(req.Key)
+	// If the key contains a line break, only take the first part.
+	key = strings.Split(key, "\n")[0]
+	request.Header.Set("Authorization", "Bearer "+key)
+
+	response, err := client.Do(request)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	//check status code
+	if response.StatusCode != http.StatusOK {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": "Failed to fetch models",
+		})
+		return
+	}
+	defer response.Body.Close()
+
+	var result struct {
+		Data []struct {
+			ID string `json:"id"`
+		} `json:"data"`
+	}
+
+	if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	var models []string
+	for _, model := range result.Data {
+		models = append(models, model.ID)
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"data":    models,
+	})
+}
+
+func BatchSetChannelTag(c *gin.Context) {
+	channelBatch := ChannelBatch{}
+	err := c.ShouldBindJSON(&channelBatch)
+	if err != nil || len(channelBatch.Ids) == 0 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "参数错误",
+		})
+		return
+	}
+	err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	model.InitChannelCache()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    len(channelBatch.Ids),
+	})
+	return
+}
+
+func GetTagModels(c *gin.Context) {
+	tag := c.Query("tag")
+	if tag == "" {
+		c.JSON(http.StatusBadRequest, gin.H{
+			"success": false,
+			"message": "tag不能为空",
+		})
+		return
+	}
+
+	channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	var longestModels string
+	maxLength := 0
+
+	// Find the longest models string among all channels with the given tag
+	for _, channel := range channels {
+		if channel.Models != "" {
+			currentModels := strings.Split(channel.Models, ",")
+			if len(currentModels) > maxLength {
+				maxLength = len(currentModels)
+				longestModels = channel.Models
+			}
+		}
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    longestModels,
+	})
+	return
+}
+
+// CopyChannel handles cloning an existing channel with its key.
+// POST /api/channel/copy/:id
+// Optional query params:
+//
+//	suffix         - string appended to the original name (default "_复制")
+//	reset_balance  - bool, when true will reset balance & used_quota to 0 (default true)
+func CopyChannel(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
+		return
+	}
+
+	suffix := c.DefaultQuery("suffix", "_复制")
+	resetBalance := true
+	if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
+		if v, err := strconv.ParseBool(rbStr); err == nil {
+			resetBalance = v
+		}
+	}
+
+	// fetch original channel with key
+	origin, err := model.GetChannelById(id, true)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+		return
+	}
+
+	// clone channel
+	clone := *origin // shallow copy is sufficient as we will overwrite primitives
+	clone.Id = 0     // let DB auto-generate
+	clone.CreatedTime = common.GetTimestamp()
+	clone.Name = origin.Name + suffix
+	clone.TestTime = 0
+	clone.ResponseTime = 0
+	if resetBalance {
+		clone.Balance = 0
+		clone.UsedQuota = 0
+	}
+
+	// insert
+	if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
+		c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+		return
+	}
+	model.InitChannelCache()
+	// success
+	c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
+}

+ 103 - 0
controller/console_migrate.go

@@ -0,0 +1,103 @@
+// 用于迁移检测的旧键,该文件下个版本会删除
+
+package controller
+
+import (
+    "encoding/json"
+    "net/http"
+    "one-api/common"
+    "one-api/model"
+    "github.com/gin-gonic/gin"
+)
+
+// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
+func MigrateConsoleSetting(c *gin.Context) {
+    // 读取全部 option
+    opts, err := model.AllOption()
+    if err != nil {
+        c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+        return
+    }
+    // 建立 map
+    valMap := map[string]string{}
+    for _, o := range opts {
+        valMap[o.Key] = o.Value
+    }
+
+    // 处理 APIInfo
+    if v := valMap["ApiInfo"]; v != "" {
+        var arr []map[string]interface{}
+        if err := json.Unmarshal([]byte(v), &arr); err == nil {
+            if len(arr) > 50 {
+                arr = arr[:50]
+            }
+            bytes, _ := json.Marshal(arr)
+            model.UpdateOption("console_setting.api_info", string(bytes))
+        }
+        model.UpdateOption("ApiInfo", "")
+    }
+    // Announcements 直接搬
+    if v := valMap["Announcements"]; v != "" {
+        model.UpdateOption("console_setting.announcements", v)
+        model.UpdateOption("Announcements", "")
+    }
+    // FAQ 转换
+    if v := valMap["FAQ"]; v != "" {
+        var arr []map[string]interface{}
+        if err := json.Unmarshal([]byte(v), &arr); err == nil {
+            out := []map[string]interface{}{}
+            for _, item := range arr {
+                q, _ := item["question"].(string)
+                if q == "" {
+                    q, _ = item["title"].(string)
+                }
+                a, _ := item["answer"].(string)
+                if a == "" {
+                    a, _ = item["content"].(string)
+                }
+                if q != "" && a != "" {
+                    out = append(out, map[string]interface{}{"question": q, "answer": a})
+                }
+            }
+            if len(out) > 50 {
+                out = out[:50]
+            }
+            bytes, _ := json.Marshal(out)
+            model.UpdateOption("console_setting.faq", string(bytes))
+        }
+        model.UpdateOption("FAQ", "")
+    }
+    // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
+    url := valMap["UptimeKumaUrl"]
+    slug := valMap["UptimeKumaSlug"]
+    if url != "" && slug != "" {
+        // 仅当同时存在 URL 与 Slug 时才进行迁移
+        groups := []map[string]interface{}{
+            {
+                "id":           1,
+                "categoryName": "old",
+                "url":          url,
+                "slug":         slug,
+                "description":  "",
+            },
+        }
+        bytes, _ := json.Marshal(groups)
+        model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
+    }
+    // 清空旧键内容
+    if url != "" {
+        model.UpdateOption("UptimeKumaUrl", "")
+    }
+    if slug != "" {
+        model.UpdateOption("UptimeKumaSlug", "")
+    }
+
+    // 删除旧键记录
+    oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
+    model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
+
+    // 重新加载 OptionMap
+    model.InitOptionMap()
+    common.SysLog("console setting migrated")
+    c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
+} 

+ 239 - 0
controller/github.go

@@ -0,0 +1,239 @@
+package controller
+
+import (
+	"bytes"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+	"strconv"
+	"time"
+
+	"github.com/gin-contrib/sessions"
+	"github.com/gin-gonic/gin"
+)
+
+type GitHubOAuthResponse struct {
+	AccessToken string `json:"access_token"`
+	Scope       string `json:"scope"`
+	TokenType   string `json:"token_type"`
+}
+
+type GitHubUser struct {
+	Login string `json:"login"`
+	Name  string `json:"name"`
+	Email string `json:"email"`
+}
+
+func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
+	if code == "" {
+		return nil, errors.New("无效的参数")
+	}
+	values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
+	jsonData, err := json.Marshal(values)
+	if err != nil {
+		return nil, err
+	}
+	req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Accept", "application/json")
+	client := http.Client{
+		Timeout: 5 * time.Second,
+	}
+	res, err := client.Do(req)
+	if err != nil {
+		common.SysLog(err.Error())
+		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
+	}
+	defer res.Body.Close()
+	var oAuthResponse GitHubOAuthResponse
+	err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
+	if err != nil {
+		return nil, err
+	}
+	req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
+	res2, err := client.Do(req)
+	if err != nil {
+		common.SysLog(err.Error())
+		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
+	}
+	defer res2.Body.Close()
+	var githubUser GitHubUser
+	err = json.NewDecoder(res2.Body).Decode(&githubUser)
+	if err != nil {
+		return nil, err
+	}
+	if githubUser.Login == "" {
+		return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
+	}
+	return &githubUser, nil
+}
+
+func GitHubOAuth(c *gin.Context) {
+	session := sessions.Default(c)
+	state := c.Query("state")
+	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
+		c.JSON(http.StatusForbidden, gin.H{
+			"success": false,
+			"message": "state is empty or not same",
+		})
+		return
+	}
+	username := session.Get("username")
+	if username != nil {
+		GitHubBind(c)
+		return
+	}
+
+	if !common.GitHubOAuthEnabled {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "管理员未开启通过 GitHub 登录以及注册",
+		})
+		return
+	}
+	code := c.Query("code")
+	githubUser, err := getGitHubUserInfoByCode(code)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	user := model.User{
+		GitHubId: githubUser.Login,
+	}
+	// IsGitHubIdAlreadyTaken is unscoped
+	if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
+		// FillUserByGitHubId is scoped
+		err := user.FillUserByGitHubId()
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+		// if user.Id == 0 , user has been deleted
+		if user.Id == 0 {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "用户已注销",
+			})
+			return
+		}
+	} else {
+		if common.RegisterEnabled {
+			user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
+			if githubUser.Name != "" {
+				user.DisplayName = githubUser.Name
+			} else {
+				user.DisplayName = "GitHub User"
+			}
+			user.Email = githubUser.Email
+			user.Role = common.RoleCommonUser
+			user.Status = common.UserStatusEnabled
+			affCode := session.Get("aff")
+			inviterId := 0
+			if affCode != nil {
+				inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
+			}
+
+			if err := user.Insert(inviterId); err != nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": err.Error(),
+				})
+				return
+			}
+		} else {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "管理员关闭了新用户注册",
+			})
+			return
+		}
+	}
+
+	if user.Status != common.UserStatusEnabled {
+		c.JSON(http.StatusOK, gin.H{
+			"message": "用户已被封禁",
+			"success": false,
+		})
+		return
+	}
+	setupLogin(&user, c)
+}
+
+func GitHubBind(c *gin.Context) {
+	if !common.GitHubOAuthEnabled {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "管理员未开启通过 GitHub 登录以及注册",
+		})
+		return
+	}
+	code := c.Query("code")
+	githubUser, err := getGitHubUserInfoByCode(code)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	user := model.User{
+		GitHubId: githubUser.Login,
+	}
+	if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "该 GitHub 账户已被绑定",
+		})
+		return
+	}
+	session := sessions.Default(c)
+	id := session.Get("id")
+	// id := c.GetInt("id")  // critical bug!
+	user.Id = id.(int)
+	err = user.FillUserById()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	user.GitHubId = githubUser.Login
+	err = user.Update(false)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "bind",
+	})
+	return
+}
+
+func GenerateOAuthCode(c *gin.Context) {
+	session := sessions.Default(c)
+	state := common.GetRandomString(12)
+	affCode := c.Query("aff")
+	if affCode != "" {
+		session.Set("aff", affCode)
+	}
+	session.Set("oauth_state", state)
+	err := session.Save()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    state,
+	})
+}

+ 50 - 0
controller/group.go

@@ -0,0 +1,50 @@
+package controller
+
+import (
+	"net/http"
+	"one-api/model"
+	"one-api/setting"
+	"one-api/setting/ratio_setting"
+
+	"github.com/gin-gonic/gin"
+)
+
+func GetGroups(c *gin.Context) {
+	groupNames := make([]string, 0)
+	for groupName := range ratio_setting.GetGroupRatioCopy() {
+		groupNames = append(groupNames, groupName)
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    groupNames,
+	})
+}
+
+func GetUserGroups(c *gin.Context) {
+	usableGroups := make(map[string]map[string]interface{})
+	userGroup := ""
+	userId := c.GetInt("id")
+	userGroup, _ = model.GetUserGroup(userId, false)
+	for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
+		// UserUsableGroups contains the groups that the user can use
+		userUsableGroups := setting.GetUserUsableGroups(userGroup)
+		if desc, ok := userUsableGroups[groupName]; ok {
+			usableGroups[groupName] = map[string]interface{}{
+				"ratio": ratio,
+				"desc":  desc,
+			}
+		}
+	}
+	if setting.GroupInUserUsableGroups("auto") {
+		usableGroups["auto"] = map[string]interface{}{
+			"ratio": "自动",
+			"desc":  setting.GetUsableGroupDescription("auto"),
+		}
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    usableGroups,
+	})
+}

+ 9 - 0
controller/image.go

@@ -0,0 +1,9 @@
+package controller
+
+import (
+	"github.com/gin-gonic/gin"
+)
+
+func GetImage(c *gin.Context) {
+
+}

+ 259 - 0
controller/linuxdo.go

@@ -0,0 +1,259 @@
+package controller
+
+import (
+	"encoding/base64"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"net/http"
+	"net/url"
+	"one-api/common"
+	"one-api/model"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/gin-contrib/sessions"
+	"github.com/gin-gonic/gin"
+)
+
+type LinuxdoUser struct {
+	Id         int    `json:"id"`
+	Username   string `json:"username"`
+	Name       string `json:"name"`
+	Active     bool   `json:"active"`
+	TrustLevel int    `json:"trust_level"`
+	Silenced   bool   `json:"silenced"`
+}
+
+func LinuxDoBind(c *gin.Context) {
+	if !common.LinuxDOOAuthEnabled {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "管理员未开启通过 Linux DO 登录以及注册",
+		})
+		return
+	}
+
+	code := c.Query("code")
+	linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	user := model.User{
+		LinuxDOId: strconv.Itoa(linuxdoUser.Id),
+	}
+
+	if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "该 Linux DO 账户已被绑定",
+		})
+		return
+	}
+
+	session := sessions.Default(c)
+	id := session.Get("id")
+	user.Id = id.(int)
+
+	err = user.FillUserById()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
+	err = user.Update(false)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "bind",
+	})
+}
+
+func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
+	if code == "" {
+		return nil, errors.New("invalid code")
+	}
+
+	// Get access token using Basic auth
+	tokenEndpoint := "https://connect.linux.do/oauth2/token"
+	credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
+	basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
+
+	// Get redirect URI from request
+	scheme := "http"
+	if c.Request.TLS != nil {
+		scheme = "https"
+	}
+	redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
+
+	data := url.Values{}
+	data.Set("grant_type", "authorization_code")
+	data.Set("code", code)
+	data.Set("redirect_uri", redirectURI)
+
+	req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
+	if err != nil {
+		return nil, err
+	}
+
+	req.Header.Set("Authorization", basicAuth)
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	req.Header.Set("Accept", "application/json")
+
+	client := http.Client{Timeout: 5 * time.Second}
+	res, err := client.Do(req)
+	if err != nil {
+		return nil, errors.New("failed to connect to Linux DO server")
+	}
+	defer res.Body.Close()
+
+	var tokenRes struct {
+		AccessToken string `json:"access_token"`
+		Message     string `json:"message"`
+	}
+	if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
+		return nil, err
+	}
+
+	if tokenRes.AccessToken == "" {
+		return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
+	}
+
+	// Get user info
+	userEndpoint := "https://connect.linux.do/api/user"
+	req, err = http.NewRequest("GET", userEndpoint, nil)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
+	req.Header.Set("Accept", "application/json")
+
+	res2, err := client.Do(req)
+	if err != nil {
+		return nil, errors.New("failed to get user info from Linux DO")
+	}
+	defer res2.Body.Close()
+
+	var linuxdoUser LinuxdoUser
+	if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
+		return nil, err
+	}
+
+	if linuxdoUser.Id == 0 {
+		return nil, errors.New("invalid user info returned")
+	}
+
+	return &linuxdoUser, nil
+}
+
+func LinuxdoOAuth(c *gin.Context) {
+	session := sessions.Default(c)
+
+	errorCode := c.Query("error")
+	if errorCode != "" {
+		errorDescription := c.Query("error_description")
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": errorDescription,
+		})
+		return
+	}
+
+	state := c.Query("state")
+	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
+		c.JSON(http.StatusForbidden, gin.H{
+			"success": false,
+			"message": "state is empty or not same",
+		})
+		return
+	}
+
+	username := session.Get("username")
+	if username != nil {
+		LinuxDoBind(c)
+		return
+	}
+
+	if !common.LinuxDOOAuthEnabled {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "管理员未开启通过 Linux DO 登录以及注册",
+		})
+		return
+	}
+
+	code := c.Query("code")
+	linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	user := model.User{
+		LinuxDOId: strconv.Itoa(linuxdoUser.Id),
+	}
+
+	// Check if user exists
+	if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
+		err := user.FillUserByLinuxDOId()
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+		if user.Id == 0 {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "用户已注销",
+			})
+			return
+		}
+	} else {
+		if common.RegisterEnabled {
+			user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
+			user.DisplayName = linuxdoUser.Name
+			user.Role = common.RoleCommonUser
+			user.Status = common.UserStatusEnabled
+
+			affCode := session.Get("aff")
+			inviterId := 0
+			if affCode != nil {
+				inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
+			}
+
+			if err := user.Insert(inviterId); err != nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": err.Error(),
+				})
+				return
+			}
+		} else {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "管理员关闭了新用户注册",
+			})
+			return
+		}
+	}
+
+	if user.Status != common.UserStatusEnabled {
+		c.JSON(http.StatusOK, gin.H{
+			"message": "用户已被封禁",
+			"success": false,
+		})
+		return
+	}
+
+	setupLogin(&user, c)
+}

+ 168 - 0
controller/log.go

@@ -0,0 +1,168 @@
+package controller
+
+import (
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+	"strconv"
+
+	"github.com/gin-gonic/gin"
+)
+
+func GetAllLogs(c *gin.Context) {
+	pageInfo := common.GetPageQuery(c)
+	logType, _ := strconv.Atoi(c.Query("type"))
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+	username := c.Query("username")
+	tokenName := c.Query("token_name")
+	modelName := c.Query("model_name")
+	channel, _ := strconv.Atoi(c.Query("channel"))
+	group := c.Query("group")
+	logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(logs)
+	common.ApiSuccess(c, pageInfo)
+	return
+}
+
+func GetUserLogs(c *gin.Context) {
+	pageInfo := common.GetPageQuery(c)
+	userId := c.GetInt("id")
+	logType, _ := strconv.Atoi(c.Query("type"))
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+	tokenName := c.Query("token_name")
+	modelName := c.Query("model_name")
+	group := c.Query("group")
+	logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(logs)
+	common.ApiSuccess(c, pageInfo)
+	return
+}
+
+func SearchAllLogs(c *gin.Context) {
+	keyword := c.Query("keyword")
+	logs, err := model.SearchAllLogs(keyword)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    logs,
+	})
+	return
+}
+
+func SearchUserLogs(c *gin.Context) {
+	keyword := c.Query("keyword")
+	userId := c.GetInt("id")
+	logs, err := model.SearchUserLogs(userId, keyword)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    logs,
+	})
+	return
+}
+
+func GetLogByKey(c *gin.Context) {
+	key := c.Query("key")
+	logs, err := model.GetLogByKey(key)
+	if err != nil {
+		c.JSON(200, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	c.JSON(200, gin.H{
+		"success": true,
+		"message": "",
+		"data":    logs,
+	})
+}
+
+func GetLogsStat(c *gin.Context) {
+	logType, _ := strconv.Atoi(c.Query("type"))
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+	tokenName := c.Query("token_name")
+	username := c.Query("username")
+	modelName := c.Query("model_name")
+	channel, _ := strconv.Atoi(c.Query("channel"))
+	group := c.Query("group")
+	stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
+	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data": gin.H{
+			"quota": stat.Quota,
+			"rpm":   stat.Rpm,
+			"tpm":   stat.Tpm,
+		},
+	})
+	return
+}
+
+func GetLogsSelfStat(c *gin.Context) {
+	username := c.GetString("username")
+	logType, _ := strconv.Atoi(c.Query("type"))
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+	tokenName := c.Query("token_name")
+	modelName := c.Query("model_name")
+	channel, _ := strconv.Atoi(c.Query("channel"))
+	group := c.Query("group")
+	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
+	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
+	c.JSON(200, gin.H{
+		"success": true,
+		"message": "",
+		"data": gin.H{
+			"quota": quotaNum.Quota,
+			"rpm":   quotaNum.Rpm,
+			"tpm":   quotaNum.Tpm,
+			//"token": tokenNum,
+		},
+	})
+	return
+}
+
+func DeleteHistoryLogs(c *gin.Context) {
+	targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
+	if targetTimestamp == 0 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "target timestamp is required",
+		})
+		return
+	}
+	count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    count,
+	})
+	return
+}

+ 263 - 0
controller/midjourney.go

@@ -0,0 +1,263 @@
+package controller
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	"one-api/model"
+	"one-api/service"
+	"one-api/setting"
+	"time"
+
+	"github.com/gin-gonic/gin"
+)
+
+func UpdateMidjourneyTaskBulk() {
+	//imageModel := "midjourney"
+	ctx := context.TODO()
+	for {
+		time.Sleep(time.Duration(15) * time.Second)
+
+		tasks := model.GetAllUnFinishTasks()
+		if len(tasks) == 0 {
+			continue
+		}
+
+		common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
+		taskChannelM := make(map[int][]string)
+		taskM := make(map[string]*model.Midjourney)
+		nullTaskIds := make([]int, 0)
+		for _, task := range tasks {
+			if task.MjId == "" {
+				// 统计失败的未完成任务
+				nullTaskIds = append(nullTaskIds, task.Id)
+				continue
+			}
+			taskM[task.MjId] = task
+			taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
+		}
+		if len(nullTaskIds) > 0 {
+			err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
+				"status":   "FAILURE",
+				"progress": "100%",
+			})
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
+			} else {
+				common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
+			}
+		}
+		if len(taskChannelM) == 0 {
+			continue
+		}
+
+		for channelId, taskIds := range taskChannelM {
+			common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+			if len(taskIds) == 0 {
+				continue
+			}
+			midjourneyChannel, err := model.CacheGetChannel(channelId)
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
+				err := model.MjBulkUpdate(taskIds, map[string]any{
+					"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+					"status":      "FAILURE",
+					"progress":    "100%",
+				})
+				if err != nil {
+					common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
+				}
+				continue
+			}
+			requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
+
+			body, _ := json.Marshal(map[string]any{
+				"ids": taskIds,
+			})
+			req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
+				continue
+			}
+			// 设置超时时间
+			timeout := time.Second * 15
+			ctx, cancel := context.WithTimeout(context.Background(), timeout)
+			// 使用带有超时的 context 创建新的请求
+			req = req.WithContext(ctx)
+			req.Header.Set("Content-Type", "application/json")
+			req.Header.Set("mj-api-secret", midjourneyChannel.Key)
+			resp, err := service.GetHttpClient().Do(req)
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
+				continue
+			}
+			if resp.StatusCode != http.StatusOK {
+				common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+				continue
+			}
+			responseBody, err := io.ReadAll(resp.Body)
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
+				continue
+			}
+			var responseItems []dto.MidjourneyDto
+			err = json.Unmarshal(responseBody, &responseItems)
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+				continue
+			}
+			resp.Body.Close()
+			req.Body.Close()
+			cancel()
+
+			for _, responseItem := range responseItems {
+				task := taskM[responseItem.MjId]
+
+				useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
+				// 如果时间超过一小时,且进度不是100%,则认为任务失败
+				if useTime > 3600000 && task.Progress != "100%" {
+					responseItem.FailReason = "上游任务超时(超过1小时)"
+					responseItem.Status = "FAILURE"
+				}
+				if !checkMjTaskNeedUpdate(task, responseItem) {
+					continue
+				}
+				task.Code = 1
+				task.Progress = responseItem.Progress
+				task.PromptEn = responseItem.PromptEn
+				task.State = responseItem.State
+				task.SubmitTime = responseItem.SubmitTime
+				task.StartTime = responseItem.StartTime
+				task.FinishTime = responseItem.FinishTime
+				task.ImageUrl = responseItem.ImageUrl
+				task.Status = responseItem.Status
+				task.FailReason = responseItem.FailReason
+				if responseItem.Properties != nil {
+					propertiesStr, _ := json.Marshal(responseItem.Properties)
+					task.Properties = string(propertiesStr)
+				}
+				if responseItem.Buttons != nil {
+					buttonStr, _ := json.Marshal(responseItem.Buttons)
+					task.Buttons = string(buttonStr)
+				}
+				shouldReturnQuota := false
+				if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
+					common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
+					task.Progress = "100%"
+					if task.Quota != 0 {
+						shouldReturnQuota = true
+					}
+				}
+				err = task.Update()
+				if err != nil {
+					common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
+				} else {
+					if shouldReturnQuota {
+						err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
+						if err != nil {
+							common.LogError(ctx, "fail to increase user quota: "+err.Error())
+						}
+						logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
+						model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+					}
+				}
+			}
+		}
+	}
+}
+
+func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
+	if oldTask.Code != 1 {
+		return true
+	}
+	if oldTask.Progress != newTask.Progress {
+		return true
+	}
+	if oldTask.PromptEn != newTask.PromptEn {
+		return true
+	}
+	if oldTask.State != newTask.State {
+		return true
+	}
+	if oldTask.SubmitTime != newTask.SubmitTime {
+		return true
+	}
+	if oldTask.StartTime != newTask.StartTime {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if oldTask.ImageUrl != newTask.ImageUrl {
+		return true
+	}
+	if oldTask.Status != newTask.Status {
+		return true
+	}
+	if oldTask.FailReason != newTask.FailReason {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if oldTask.Progress != "100%" && newTask.FailReason != "" {
+		return true
+	}
+
+	return false
+}
+
+func GetAllMidjourney(c *gin.Context) {
+	pageInfo := common.GetPageQuery(c)
+
+	// 解析其他查询参数
+	queryParams := model.TaskQueryParams{
+		ChannelID:      c.Query("channel_id"),
+		MjID:           c.Query("mj_id"),
+		StartTimestamp: c.Query("start_timestamp"),
+		EndTimestamp:   c.Query("end_timestamp"),
+	}
+
+	items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+	total := model.CountAllTasks(queryParams)
+
+	if setting.MjForwardUrlEnabled {
+		for i, midjourney := range items {
+			midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
+			items[i] = midjourney
+		}
+	}
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(items)
+	common.ApiSuccess(c, pageInfo)
+}
+
+func GetUserMidjourney(c *gin.Context) {
+	pageInfo := common.GetPageQuery(c)
+
+	userId := c.GetInt("id")
+
+	queryParams := model.TaskQueryParams{
+		MjID:           c.Query("mj_id"),
+		StartTimestamp: c.Query("start_timestamp"),
+		EndTimestamp:   c.Query("end_timestamp"),
+	}
+
+	items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+	total := model.CountAllUserTask(userId, queryParams)
+
+	if setting.MjForwardUrlEnabled {
+		for i, midjourney := range items {
+			midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
+			items[i] = midjourney
+		}
+	}
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(items)
+	common.ApiSuccess(c, pageInfo)
+}

+ 302 - 0
controller/misc.go

@@ -0,0 +1,302 @@
+package controller
+
+import (
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/middleware"
+	"one-api/model"
+	"one-api/setting"
+	"one-api/setting/console_setting"
+	"one-api/setting/operation_setting"
+	"one-api/setting/system_setting"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+)
+
+func TestStatus(c *gin.Context) {
+	err := model.PingDB()
+	if err != nil {
+		c.JSON(http.StatusServiceUnavailable, gin.H{
+			"success": false,
+			"message": "数据库连接失败",
+		})
+		return
+	}
+	// 获取HTTP统计信息
+	httpStats := middleware.GetStats()
+	c.JSON(http.StatusOK, gin.H{
+		"success":    true,
+		"message":    "Server is running",
+		"http_stats": httpStats,
+	})
+	return
+}
+
+func GetStatus(c *gin.Context) {
+
+	cs := console_setting.GetConsoleSetting()
+
+	data := gin.H{
+		"version":                  common.Version,
+		"start_time":               common.StartTime,
+		"email_verification":       common.EmailVerificationEnabled,
+		"github_oauth":             common.GitHubOAuthEnabled,
+		"github_client_id":         common.GitHubClientId,
+		"linuxdo_oauth":            common.LinuxDOOAuthEnabled,
+		"linuxdo_client_id":        common.LinuxDOClientId,
+		"telegram_oauth":           common.TelegramOAuthEnabled,
+		"telegram_bot_name":        common.TelegramBotName,
+		"system_name":              common.SystemName,
+		"logo":                     common.Logo,
+		"footer_html":              common.Footer,
+		"wechat_qrcode":            common.WeChatAccountQRCodeImageURL,
+		"wechat_login":             common.WeChatAuthEnabled,
+		"server_address":           setting.ServerAddress,
+		"price":                    setting.Price,
+		"stripe_unit_price":        setting.StripeUnitPrice,
+		"min_topup":                setting.MinTopUp,
+		"stripe_min_topup":         setting.StripeMinTopUp,
+		"turnstile_check":          common.TurnstileCheckEnabled,
+		"turnstile_site_key":       common.TurnstileSiteKey,
+		"top_up_link":              common.TopUpLink,
+		"docs_link":                operation_setting.GetGeneralSetting().DocsLink,
+		"quota_per_unit":           common.QuotaPerUnit,
+		"display_in_currency":      common.DisplayInCurrencyEnabled,
+		"enable_batch_update":      common.BatchUpdateEnabled,
+		"enable_drawing":           common.DrawingEnabled,
+		"enable_task":              common.TaskEnabled,
+		"enable_data_export":       common.DataExportEnabled,
+		"data_export_default_time": common.DataExportDefaultTime,
+		"default_collapse_sidebar": common.DefaultCollapseSidebar,
+		"enable_online_topup":      setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
+		"enable_stripe_topup":      setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
+		"mj_notify_enabled":        setting.MjNotifyEnabled,
+		"chats":                    setting.Chats,
+		"demo_site_enabled":        operation_setting.DemoSiteEnabled,
+		"self_use_mode_enabled":    operation_setting.SelfUseModeEnabled,
+		"default_use_auto_group":   setting.DefaultUseAutoGroup,
+		"pay_methods":              setting.PayMethods,
+		"usd_exchange_rate":        setting.USDExchangeRate,
+
+		// 面板启用开关
+		"api_info_enabled":      cs.ApiInfoEnabled,
+		"uptime_kuma_enabled":   cs.UptimeKumaEnabled,
+		"announcements_enabled": cs.AnnouncementsEnabled,
+		"faq_enabled":           cs.FAQEnabled,
+
+		"oidc_enabled":                system_setting.GetOIDCSettings().Enabled,
+		"oidc_client_id":              system_setting.GetOIDCSettings().ClientId,
+		"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
+		"setup":                       constant.Setup,
+	}
+
+	// 根据启用状态注入可选内容
+	if cs.ApiInfoEnabled {
+		data["api_info"] = console_setting.GetApiInfo()
+	}
+	if cs.AnnouncementsEnabled {
+		data["announcements"] = console_setting.GetAnnouncements()
+	}
+	if cs.FAQEnabled {
+		data["faq"] = console_setting.GetFAQ()
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    data,
+	})
+	return
+}
+
+func GetNotice(c *gin.Context) {
+	common.OptionMapRWMutex.RLock()
+	defer common.OptionMapRWMutex.RUnlock()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    common.OptionMap["Notice"],
+	})
+	return
+}
+
+func GetAbout(c *gin.Context) {
+	common.OptionMapRWMutex.RLock()
+	defer common.OptionMapRWMutex.RUnlock()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    common.OptionMap["About"],
+	})
+	return
+}
+
+func GetMidjourney(c *gin.Context) {
+	common.OptionMapRWMutex.RLock()
+	defer common.OptionMapRWMutex.RUnlock()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    common.OptionMap["Midjourney"],
+	})
+	return
+}
+
+func GetHomePageContent(c *gin.Context) {
+	common.OptionMapRWMutex.RLock()
+	defer common.OptionMapRWMutex.RUnlock()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    common.OptionMap["HomePageContent"],
+	})
+	return
+}
+
+func SendEmailVerification(c *gin.Context) {
+	email := c.Query("email")
+	if err := common.Validate.Var(email, "required,email"); err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "无效的参数",
+		})
+		return
+	}
+	parts := strings.Split(email, "@")
+	if len(parts) != 2 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "无效的邮箱地址",
+		})
+		return
+	}
+	localPart := parts[0]
+	domainPart := parts[1]
+	if common.EmailDomainRestrictionEnabled {
+		allowed := false
+		for _, domain := range common.EmailDomainWhitelist {
+			if domainPart == domain {
+				allowed = true
+				break
+			}
+		}
+		if !allowed {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
+			})
+			return
+		}
+	}
+	if common.EmailAliasRestrictionEnabled {
+		containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".")
+		if containsSpecialSymbols {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
+			})
+			return
+		}
+	}
+
+	if model.IsEmailAlreadyTaken(email) {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "邮箱地址已被占用",
+		})
+		return
+	}
+	code := common.GenerateVerificationCode(6)
+	common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
+	subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
+	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+
+		"<p>您的验证码为: <strong>%s</strong></p>"+
+		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
+	err := common.SendEmail(subject, email, content)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func SendPasswordResetEmail(c *gin.Context) {
+	email := c.Query("email")
+	if err := common.Validate.Var(email, "required,email"); err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "无效的参数",
+		})
+		return
+	}
+	if !model.IsEmailAlreadyTaken(email) {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "该邮箱地址未注册",
+		})
+		return
+	}
+	code := common.GenerateVerificationCode(0)
+	common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
+	link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
+	subject := fmt.Sprintf("%s密码重置", common.SystemName)
+	content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
+		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
+		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
+		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
+	err := common.SendEmail(subject, email, content)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+type PasswordResetRequest struct {
+	Email string `json:"email"`
+	Token string `json:"token"`
+}
+
+func ResetPassword(c *gin.Context) {
+	var req PasswordResetRequest
+	err := json.NewDecoder(c.Request.Body).Decode(&req)
+	if req.Email == "" || req.Token == "" {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "无效的参数",
+		})
+		return
+	}
+	if !common.VerifyCodeWithKey(req.Email, req.Token, common.PasswordResetPurpose) {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "重置链接非法或已过期",
+		})
+		return
+	}
+	password := common.GenerateVerificationCode(12)
+	err = model.ResetUserPasswordByEmail(req.Email, password)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	common.DeleteKey(req.Email, common.PasswordResetPurpose)
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    password,
+	})
+	return
+}

+ 216 - 0
controller/model.go

@@ -0,0 +1,216 @@
+package controller
+
+import (
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"github.com/samber/lo"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/model"
+	"one-api/relay"
+	"one-api/relay/channel/ai360"
+	"one-api/relay/channel/lingyiwanwu"
+	"one-api/relay/channel/minimax"
+	"one-api/relay/channel/moonshot"
+	relaycommon "one-api/relay/common"
+	"one-api/setting"
+)
+
+// https://platform.openai.com/docs/api-reference/models/list
+
+var openAIModels []dto.OpenAIModels
+var openAIModelsMap map[string]dto.OpenAIModels
+var channelId2Models map[int][]string
+
+func init() {
+	// https://platform.openai.com/docs/models/model-endpoint-compatibility
+	for i := 0; i < constant.APITypeDummy; i++ {
+		if i == constant.APITypeAIProxyLibrary {
+			continue
+		}
+		adaptor := relay.GetAdaptor(i)
+		channelName := adaptor.GetChannelName()
+		modelNames := adaptor.GetModelList()
+		for _, modelName := range modelNames {
+			openAIModels = append(openAIModels, dto.OpenAIModels{
+				Id:      modelName,
+				Object:  "model",
+				Created: 1626777600,
+				OwnedBy: channelName,
+			})
+		}
+	}
+	for _, modelName := range ai360.ModelList {
+		openAIModels = append(openAIModels, dto.OpenAIModels{
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: ai360.ChannelName,
+		})
+	}
+	for _, modelName := range moonshot.ModelList {
+		openAIModels = append(openAIModels, dto.OpenAIModels{
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: moonshot.ChannelName,
+		})
+	}
+	for _, modelName := range lingyiwanwu.ModelList {
+		openAIModels = append(openAIModels, dto.OpenAIModels{
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: lingyiwanwu.ChannelName,
+		})
+	}
+	for _, modelName := range minimax.ModelList {
+		openAIModels = append(openAIModels, dto.OpenAIModels{
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: minimax.ChannelName,
+		})
+	}
+	for modelName, _ := range constant.MidjourneyModel2Action {
+		openAIModels = append(openAIModels, dto.OpenAIModels{
+			Id:      modelName,
+			Object:  "model",
+			Created: 1626777600,
+			OwnedBy: "midjourney",
+		})
+	}
+	openAIModelsMap = make(map[string]dto.OpenAIModels)
+	for _, aiModel := range openAIModels {
+		openAIModelsMap[aiModel.Id] = aiModel
+	}
+	channelId2Models = make(map[int][]string)
+	for i := 1; i <= constant.ChannelTypeDummy; i++ {
+		apiType, success := common.ChannelType2APIType(i)
+		if !success || apiType == constant.APITypeAIProxyLibrary {
+			continue
+		}
+		meta := &relaycommon.RelayInfo{ChannelType: i}
+		adaptor := relay.GetAdaptor(apiType)
+		adaptor.Init(meta)
+		channelId2Models[i] = adaptor.GetModelList()
+	}
+	openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
+		return m.Id
+	})
+}
+
+func ListModels(c *gin.Context) {
+	userOpenAiModels := make([]dto.OpenAIModels, 0)
+
+	modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
+	if modelLimitEnable {
+		s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
+		var tokenModelLimit map[string]bool
+		if ok {
+			tokenModelLimit = s.(map[string]bool)
+		} else {
+			tokenModelLimit = map[string]bool{}
+		}
+		for allowModel, _ := range tokenModelLimit {
+			if oaiModel, ok := openAIModelsMap[allowModel]; ok {
+				oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
+				userOpenAiModels = append(userOpenAiModels, oaiModel)
+			} else {
+				userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
+					Id:                     allowModel,
+					Object:                 "model",
+					Created:                1626777600,
+					OwnedBy:                "custom",
+					SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
+				})
+			}
+		}
+	} else {
+		userId := c.GetInt("id")
+		userGroup, err := model.GetUserGroup(userId, false)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "get user group failed",
+			})
+			return
+		}
+		group := userGroup
+		tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
+		if tokenGroup != "" {
+			group = tokenGroup
+		}
+		var models []string
+		if tokenGroup == "auto" {
+			for _, autoGroup := range setting.AutoGroups {
+				groupModels := model.GetGroupEnabledModels(autoGroup)
+				for _, g := range groupModels {
+					if !common.StringsContains(models, g) {
+						models = append(models, g)
+					}
+				}
+			}
+		} else {
+			models = model.GetGroupEnabledModels(group)
+		}
+		for _, modelName := range models {
+			if oaiModel, ok := openAIModelsMap[modelName]; ok {
+				oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
+				userOpenAiModels = append(userOpenAiModels, oaiModel)
+			} else {
+				userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
+					Id:                     modelName,
+					Object:                 "model",
+					Created:                1626777600,
+					OwnedBy:                "custom",
+					SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
+				})
+			}
+		}
+	}
+	c.JSON(200, gin.H{
+		"success": true,
+		"data":    userOpenAiModels,
+	})
+}
+
+func ChannelListModels(c *gin.Context) {
+	c.JSON(200, gin.H{
+		"success": true,
+		"data":    openAIModels,
+	})
+}
+
+func DashboardListModels(c *gin.Context) {
+	c.JSON(200, gin.H{
+		"success": true,
+		"data":    channelId2Models,
+	})
+}
+
+func EnabledListModels(c *gin.Context) {
+	c.JSON(200, gin.H{
+		"success": true,
+		"data":    model.GetEnabledModels(),
+	})
+}
+
+func RetrieveModel(c *gin.Context) {
+	modelId := c.Param("model")
+	if aiModel, ok := openAIModelsMap[modelId]; ok {
+		c.JSON(200, aiModel)
+	} else {
+		openAIError := dto.OpenAIError{
+			Message: fmt.Sprintf("The model '%s' does not exist", modelId),
+			Type:    "invalid_request_error",
+			Param:   "model",
+			Code:    "model_not_found",
+		}
+		c.JSON(200, gin.H{
+			"error": openAIError,
+		})
+	}
+}

+ 228 - 0
controller/oidc.go

@@ -0,0 +1,228 @@
+package controller
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"net/http"
+	"net/url"
+	"one-api/common"
+	"one-api/model"
+	"one-api/setting"
+	"one-api/setting/system_setting"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/gin-contrib/sessions"
+	"github.com/gin-gonic/gin"
+)
+
+type OidcResponse struct {
+	AccessToken  string `json:"access_token"`
+	IDToken      string `json:"id_token"`
+	RefreshToken string `json:"refresh_token"`
+	TokenType    string `json:"token_type"`
+	ExpiresIn    int    `json:"expires_in"`
+	Scope        string `json:"scope"`
+}
+
+type OidcUser struct {
+	OpenID            string `json:"sub"`
+	Email             string `json:"email"`
+	Name              string `json:"name"`
+	PreferredUsername string `json:"preferred_username"`
+	Picture           string `json:"picture"`
+}
+
+func getOidcUserInfoByCode(code string) (*OidcUser, error) {
+	if code == "" {
+		return nil, errors.New("无效的参数")
+	}
+
+	values := url.Values{}
+	values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
+	values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
+	values.Set("code", code)
+	values.Set("grant_type", "authorization_code")
+	values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
+	formData := values.Encode()
+	req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+	req.Header.Set("Accept", "application/json")
+	client := http.Client{
+		Timeout: 5 * time.Second,
+	}
+	res, err := client.Do(req)
+	if err != nil {
+		common.SysLog(err.Error())
+		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
+	}
+	defer res.Body.Close()
+	var oidcResponse OidcResponse
+	err = json.NewDecoder(res.Body).Decode(&oidcResponse)
+	if err != nil {
+		return nil, err
+	}
+
+	if oidcResponse.AccessToken == "" {
+		common.SysError("OIDC 获取 Token 失败,请检查设置!")
+		return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
+	}
+
+	req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
+	res2, err := client.Do(req)
+	if err != nil {
+		common.SysLog(err.Error())
+		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
+	}
+	defer res2.Body.Close()
+	if res2.StatusCode != http.StatusOK {
+		common.SysError("OIDC 获取用户信息失败!请检查设置!")
+		return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
+	}
+
+	var oidcUser OidcUser
+	err = json.NewDecoder(res2.Body).Decode(&oidcUser)
+	if err != nil {
+		return nil, err
+	}
+	if oidcUser.OpenID == "" || oidcUser.Email == "" {
+		common.SysError("OIDC 获取用户信息为空!请检查设置!")
+		return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
+	}
+	return &oidcUser, nil
+}
+
+func OidcAuth(c *gin.Context) {
+	session := sessions.Default(c)
+	state := c.Query("state")
+	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
+		c.JSON(http.StatusForbidden, gin.H{
+			"success": false,
+			"message": "state is empty or not same",
+		})
+		return
+	}
+	username := session.Get("username")
+	if username != nil {
+		OidcBind(c)
+		return
+	}
+	if !system_setting.GetOIDCSettings().Enabled {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "管理员未开启通过 OIDC 登录以及注册",
+		})
+		return
+	}
+	code := c.Query("code")
+	oidcUser, err := getOidcUserInfoByCode(code)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	user := model.User{
+		OidcId: oidcUser.OpenID,
+	}
+	if model.IsOidcIdAlreadyTaken(user.OidcId) {
+		err := user.FillUserByOidcId()
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	} else {
+		if common.RegisterEnabled {
+			user.Email = oidcUser.Email
+			if oidcUser.PreferredUsername != "" {
+				user.Username = oidcUser.PreferredUsername
+			} else {
+				user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
+			}
+			if oidcUser.Name != "" {
+				user.DisplayName = oidcUser.Name
+			} else {
+				user.DisplayName = "OIDC User"
+			}
+			err := user.Insert(0)
+			if err != nil {
+				c.JSON(http.StatusOK, gin.H{
+					"success": false,
+					"message": err.Error(),
+				})
+				return
+			}
+		} else {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "管理员关闭了新用户注册",
+			})
+			return
+		}
+	}
+
+	if user.Status != common.UserStatusEnabled {
+		c.JSON(http.StatusOK, gin.H{
+			"message": "用户已被封禁",
+			"success": false,
+		})
+		return
+	}
+	setupLogin(&user, c)
+}
+
+func OidcBind(c *gin.Context) {
+	if !system_setting.GetOIDCSettings().Enabled {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "管理员未开启通过 OIDC 登录以及注册",
+		})
+		return
+	}
+	code := c.Query("code")
+	oidcUser, err := getOidcUserInfoByCode(code)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	user := model.User{
+		OidcId: oidcUser.OpenID,
+	}
+	if model.IsOidcIdAlreadyTaken(user.OidcId) {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "该 OIDC 账户已被绑定",
+		})
+		return
+	}
+	session := sessions.Default(c)
+	id := session.Get("id")
+	// id := c.GetInt("id")  // critical bug!
+	user.Id = id.(int)
+	err = user.FillUserById()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	user.OidcId = oidcUser.OpenID
+	err = user.Update(false)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "bind",
+	})
+	return
+}

+ 171 - 0
controller/option.go

@@ -0,0 +1,171 @@
+package controller
+
+import (
+	"encoding/json"
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+	"one-api/setting"
+	"one-api/setting/console_setting"
+	"one-api/setting/ratio_setting"
+	"one-api/setting/system_setting"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+)
+
+func GetOptions(c *gin.Context) {
+	var options []*model.Option
+	common.OptionMapRWMutex.Lock()
+	for k, v := range common.OptionMap {
+		if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
+			continue
+		}
+		options = append(options, &model.Option{
+			Key:   k,
+			Value: common.Interface2String(v),
+		})
+	}
+	common.OptionMapRWMutex.Unlock()
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    options,
+	})
+	return
+}
+
+func UpdateOption(c *gin.Context) {
+	var option model.Option
+	err := json.NewDecoder(c.Request.Body).Decode(&option)
+	if err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{
+			"success": false,
+			"message": "无效的参数",
+		})
+		return
+	}
+	switch option.Key {
+	case "GitHubOAuthEnabled":
+		if option.Value == "true" && common.GitHubClientId == "" {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
+			})
+			return
+		}
+	case "oidc.enabled":
+		if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret!",
+			})
+			return
+		}
+	case "LinuxDOOAuthEnabled":
+		if option.Value == "true" && common.LinuxDOClientId == "" {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "无法启用 LinuxDO OAuth,请先填入 LinuxDO Client Id 以及 LinuxDO Client Secret!",
+			})
+			return
+		}
+	case "EmailDomainRestrictionEnabled":
+		if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
+			})
+			return
+		}
+	case "WeChatAuthEnabled":
+		if option.Value == "true" && common.WeChatServerAddress == "" {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
+			})
+			return
+		}
+	case "TurnstileCheckEnabled":
+		if option.Value == "true" && common.TurnstileSiteKey == "" {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
+			})
+
+			return
+		}
+	case "TelegramOAuthEnabled":
+		if option.Value == "true" && common.TelegramBotToken == "" {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "无法启用 Telegram OAuth,请先填入 Telegram Bot Token!",
+			})
+			return
+		}
+	case "GroupRatio":
+		err = ratio_setting.CheckGroupRatio(option.Value)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	case "ModelRequestRateLimitGroup":
+		err = setting.CheckModelRequestRateLimitGroup(option.Value)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	case "console_setting.api_info":
+		err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	case "console_setting.announcements":
+		err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	case "console_setting.faq":
+		err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	case "console_setting.uptime_kuma_groups":
+		err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	}
+	err = model.UpdateOption(option.Key, option.Value)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}

+ 84 - 0
controller/playground.go

@@ -0,0 +1,84 @@
+package controller
+
+import (
+	"errors"
+	"fmt"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/middleware"
+	"one-api/model"
+	"one-api/setting"
+	"one-api/types"
+	"time"
+
+	"github.com/gin-gonic/gin"
+)
+
+func Playground(c *gin.Context) {
+	var newAPIError *types.NewAPIError
+
+	defer func() {
+		if newAPIError != nil {
+			c.JSON(newAPIError.StatusCode, gin.H{
+				"error": newAPIError.ToOpenAIError(),
+			})
+		}
+	}()
+
+	useAccessToken := c.GetBool("use_access_token")
+	if useAccessToken {
+		newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
+		return
+	}
+
+	playgroundRequest := &dto.PlayGroundRequest{}
+	err := common.UnmarshalBodyReusable(c, playgroundRequest)
+	if err != nil {
+		newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
+		return
+	}
+
+	if playgroundRequest.Model == "" {
+		newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
+		return
+	}
+	c.Set("original_model", playgroundRequest.Model)
+	group := playgroundRequest.Group
+	userGroup := c.GetString("group")
+
+	if group == "" {
+		group = userGroup
+	} else {
+		if !setting.GroupInUserUsableGroups(group) && group != userGroup {
+			newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
+			return
+		}
+		c.Set("group", group)
+	}
+
+	userId := c.GetInt("id")
+
+	// Write user context to ensure acceptUnsetRatio is available
+	userCache, err := model.GetUserCache(userId)
+	if err != nil {
+		newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
+		return
+	}
+	userCache.WriteContext(c)
+
+	tempToken := &model.Token{
+		UserId: userId,
+		Name:   fmt.Sprintf("playground-%s", group),
+		Group:  group,
+	}
+	_ = middleware.SetupContextForToken(c, tempToken)
+	_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
+	if newAPIError != nil {
+		return
+	}
+	//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
+	common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
+
+	Relay(c)
+}

+ 71 - 0
controller/pricing.go

@@ -0,0 +1,71 @@
+package controller
+
+import (
+	"one-api/model"
+	"one-api/setting"
+	"one-api/setting/ratio_setting"
+
+	"github.com/gin-gonic/gin"
+)
+
+func GetPricing(c *gin.Context) {
+	pricing := model.GetPricing()
+	userId, exists := c.Get("id")
+	usableGroup := map[string]string{}
+	groupRatio := map[string]float64{}
+	for s, f := range ratio_setting.GetGroupRatioCopy() {
+		groupRatio[s] = f
+	}
+	var group string
+	if exists {
+		user, err := model.GetUserCache(userId.(int))
+		if err == nil {
+			group = user.Group
+			for g := range groupRatio {
+				ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
+				if ok {
+					groupRatio[g] = ratio
+				}
+			}
+		}
+	}
+
+	usableGroup = setting.GetUserUsableGroups(group)
+	// check groupRatio contains usableGroup
+	for group := range ratio_setting.GetGroupRatioCopy() {
+		if _, ok := usableGroup[group]; !ok {
+			delete(groupRatio, group)
+		}
+	}
+
+	c.JSON(200, gin.H{
+		"success":      true,
+		"data":         pricing,
+		"group_ratio":  groupRatio,
+		"usable_group": usableGroup,
+	})
+}
+
+func ResetModelRatio(c *gin.Context) {
+	defaultStr := ratio_setting.DefaultModelRatio2JSONString()
+	err := model.UpdateOption("ModelRatio", defaultStr)
+	if err != nil {
+		c.JSON(200, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
+	if err != nil {
+		c.JSON(200, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	c.JSON(200, gin.H{
+		"success": true,
+		"message": "重置模型倍率成功",
+	})
+}

+ 24 - 0
controller/ratio_config.go

@@ -0,0 +1,24 @@
+package controller
+
+import (
+    "net/http"
+    "one-api/setting/ratio_setting"
+
+    "github.com/gin-gonic/gin"
+)
+
+func GetRatioConfig(c *gin.Context) {
+    if !ratio_setting.IsExposeRatioEnabled() {
+        c.JSON(http.StatusForbidden, gin.H{
+            "success": false,
+            "message": "倍率配置接口未启用",
+        })
+        return
+    }
+
+    c.JSON(http.StatusOK, gin.H{
+        "success": true,
+        "message": "",
+        "data":    ratio_setting.GetExposedData(),
+    })
+} 

+ 474 - 0
controller/ratio_sync.go

@@ -0,0 +1,474 @@
+package controller
+
+import (
+    "context"
+    "encoding/json"
+    "fmt"
+    "net/http"
+    "strings"
+    "sync"
+    "time"
+
+    "one-api/common"
+    "one-api/dto"
+    "one-api/model"
+    "one-api/setting/ratio_setting"
+
+    "github.com/gin-gonic/gin"
+)
+
+const (
+    defaultTimeoutSeconds  = 10
+    defaultEndpoint        = "/api/ratio_config"
+    maxConcurrentFetches   = 8
+)
+
+var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
+
+type upstreamResult struct {
+    Name string                 `json:"name"`
+    Data map[string]any         `json:"data,omitempty"`
+    Err  string                 `json:"err,omitempty"`
+}
+
+func FetchUpstreamRatios(c *gin.Context) {
+    var req dto.UpstreamRequest
+    if err := c.ShouldBindJSON(&req); err != nil {
+        c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
+        return
+    }
+
+    if req.Timeout <= 0 {
+        req.Timeout = defaultTimeoutSeconds
+    }
+
+    var upstreams []dto.UpstreamDTO
+
+    if len(req.Upstreams) > 0 {
+        for _, u := range req.Upstreams {
+            if strings.HasPrefix(u.BaseURL, "http") {
+                if u.Endpoint == "" {
+                    u.Endpoint = defaultEndpoint
+                }
+                u.BaseURL = strings.TrimRight(u.BaseURL, "/")
+                upstreams = append(upstreams, u)
+            }
+        }
+    } else if len(req.ChannelIDs) > 0 {
+        intIds := make([]int, 0, len(req.ChannelIDs))
+        for _, id64 := range req.ChannelIDs {
+            intIds = append(intIds, int(id64))
+        }
+        dbChannels, err := model.GetChannelsByIds(intIds)
+        if err != nil {
+            common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
+            c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
+            return
+        }
+        for _, ch := range dbChannels {
+            if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
+                upstreams = append(upstreams, dto.UpstreamDTO{
+                    ID:       ch.Id,
+                    Name:     ch.Name,
+                    BaseURL:  strings.TrimRight(base, "/"),
+                    Endpoint: "",
+                })
+            }
+        }
+    }
+
+    if len(upstreams) == 0 {
+        c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
+        return
+    }
+
+    var wg sync.WaitGroup
+    ch := make(chan upstreamResult, len(upstreams))
+
+    sem := make(chan struct{}, maxConcurrentFetches)
+
+    client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
+
+    for _, chn := range upstreams {
+        wg.Add(1)
+        go func(chItem dto.UpstreamDTO) {
+            defer wg.Done()
+
+            sem <- struct{}{}
+            defer func() { <-sem }()
+
+            endpoint := chItem.Endpoint
+            if endpoint == "" {
+                endpoint = defaultEndpoint
+            } else if !strings.HasPrefix(endpoint, "/") {
+                endpoint = "/" + endpoint
+            }
+            fullURL := chItem.BaseURL + endpoint
+
+            uniqueName := chItem.Name
+            if chItem.ID != 0 {
+                uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
+            }
+
+            ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
+            defer cancel()
+
+            httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
+            if err != nil {
+                common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
+                ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+                return
+            }
+
+            resp, err := client.Do(httpReq)
+            if err != nil {
+                common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
+                ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+                return
+            }
+            defer resp.Body.Close()
+            if resp.StatusCode != http.StatusOK {
+                common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
+                ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
+                return
+            }
+            // 兼容两种上游接口格式:
+            //  type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
+            //  type2: /api/pricing      -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
+            var body struct {
+                Success bool            `json:"success"`
+                Data    json.RawMessage `json:"data"`
+                Message string          `json:"message"`
+            }
+
+            if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+                common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
+                ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+                return
+            }
+
+            if !body.Success {
+                ch <- upstreamResult{Name: uniqueName, Err: body.Message}
+                return
+            }
+
+            // 尝试按 type1 解析
+            var type1Data map[string]any
+            if err := json.Unmarshal(body.Data, &type1Data); err == nil {
+                // 如果包含至少一个 ratioTypes 字段,则认为是 type1
+                isType1 := false
+                for _, rt := range ratioTypes {
+                    if _, ok := type1Data[rt]; ok {
+                        isType1 = true
+                        break
+                    }
+                }
+                if isType1 {
+                    ch <- upstreamResult{Name: uniqueName, Data: type1Data}
+                    return
+                }
+            }
+
+            // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
+            var pricingItems []struct {
+                ModelName       string  `json:"model_name"`
+                QuotaType       int     `json:"quota_type"`
+                ModelRatio      float64 `json:"model_ratio"`
+                ModelPrice      float64 `json:"model_price"`
+                CompletionRatio float64 `json:"completion_ratio"`
+            }
+            if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
+                common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
+                ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
+                return
+            }
+
+            modelRatioMap := make(map[string]float64)
+            completionRatioMap := make(map[string]float64)
+            modelPriceMap := make(map[string]float64)
+
+            for _, item := range pricingItems {
+                if item.QuotaType == 1 {
+                    modelPriceMap[item.ModelName] = item.ModelPrice
+                } else {
+                    modelRatioMap[item.ModelName] = item.ModelRatio
+                    // completionRatio 可能为 0,此时也直接赋值,保持与上游一致
+                    completionRatioMap[item.ModelName] = item.CompletionRatio
+                }
+            }
+
+            converted := make(map[string]any)
+
+            if len(modelRatioMap) > 0 {
+                ratioAny := make(map[string]any, len(modelRatioMap))
+                for k, v := range modelRatioMap {
+                    ratioAny[k] = v
+                }
+                converted["model_ratio"] = ratioAny
+            }
+
+            if len(completionRatioMap) > 0 {
+                compAny := make(map[string]any, len(completionRatioMap))
+                for k, v := range completionRatioMap {
+                    compAny[k] = v
+                }
+                converted["completion_ratio"] = compAny
+            }
+
+            if len(modelPriceMap) > 0 {
+                priceAny := make(map[string]any, len(modelPriceMap))
+                for k, v := range modelPriceMap {
+                    priceAny[k] = v
+                }
+                converted["model_price"] = priceAny
+            }
+
+            ch <- upstreamResult{Name: uniqueName, Data: converted}
+        }(chn)
+    }
+
+    wg.Wait()
+    close(ch)
+
+    localData := ratio_setting.GetExposedData()
+
+    var testResults []dto.TestResult
+    var successfulChannels []struct {
+        name string
+        data map[string]any
+    }
+
+    for r := range ch {
+        if r.Err != "" {
+            testResults = append(testResults, dto.TestResult{
+                Name:   r.Name,
+                Status: "error",
+                Error:  r.Err,
+            })
+        } else {
+            testResults = append(testResults, dto.TestResult{
+                Name:   r.Name,
+                Status: "success",
+            })
+            successfulChannels = append(successfulChannels, struct {
+                name string
+                data map[string]any
+            }{name: r.Name, data: r.Data})
+        }
+    }
+
+    differences := buildDifferences(localData, successfulChannels)
+
+    c.JSON(http.StatusOK, gin.H{
+        "success": true,
+        "data": gin.H{
+            "differences":  differences,
+            "test_results": testResults,
+        },
+    })
+}
+
+func buildDifferences(localData map[string]any, successfulChannels []struct {
+    name string
+    data map[string]any
+}) map[string]map[string]dto.DifferenceItem {
+    differences := make(map[string]map[string]dto.DifferenceItem)
+
+    allModels := make(map[string]struct{})
+    
+    for _, ratioType := range ratioTypes {
+        if localRatioAny, ok := localData[ratioType]; ok {
+            if localRatio, ok := localRatioAny.(map[string]float64); ok {
+                for modelName := range localRatio {
+                    allModels[modelName] = struct{}{}
+                }
+            }
+        }
+    }
+    
+    for _, channel := range successfulChannels {
+        for _, ratioType := range ratioTypes {
+            if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+                for modelName := range upstreamRatio {
+                    allModels[modelName] = struct{}{}
+                }
+            }
+        }
+    }
+
+    confidenceMap := make(map[string]map[string]bool)
+    
+    // 预处理阶段:检查pricing接口的可信度
+    for _, channel := range successfulChannels {
+        confidenceMap[channel.name] = make(map[string]bool)
+        
+        modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
+        completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
+        
+        if hasModelRatio && hasCompletionRatio {
+            // 遍历所有模型,检查是否满足不可信条件
+            for modelName := range allModels {
+                // 默认为可信
+                confidenceMap[channel.name][modelName] = true
+                
+                // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
+                if modelRatioVal, ok := modelRatios[modelName]; ok {
+                    if completionRatioVal, ok := completionRatios[modelName]; ok {
+                        // 转换为float64进行比较
+                        if modelRatioFloat, ok := modelRatioVal.(float64); ok {
+                            if completionRatioFloat, ok := completionRatioVal.(float64); ok {
+                                if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
+                                    confidenceMap[channel.name][modelName] = false
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        } else {
+            // 如果不是从pricing接口获取的数据,则全部标记为可信
+            for modelName := range allModels {
+                confidenceMap[channel.name][modelName] = true
+            }
+        }
+    }
+
+    for modelName := range allModels {
+        for _, ratioType := range ratioTypes {
+            var localValue interface{} = nil
+            if localRatioAny, ok := localData[ratioType]; ok {
+                if localRatio, ok := localRatioAny.(map[string]float64); ok {
+                    if val, exists := localRatio[modelName]; exists {
+                        localValue = val
+                    }
+                }
+            }
+
+            upstreamValues := make(map[string]interface{})
+            confidenceValues := make(map[string]bool)
+            hasUpstreamValue := false
+            hasDifference := false
+
+            for _, channel := range successfulChannels {
+                var upstreamValue interface{} = nil
+                
+                if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+                    if val, exists := upstreamRatio[modelName]; exists {
+                        upstreamValue = val
+                        hasUpstreamValue = true
+                        
+                        if localValue != nil && localValue != val {
+                            hasDifference = true
+                        } else if localValue == val {
+                            upstreamValue = "same"
+                        }
+                    }
+                }
+                if upstreamValue == nil && localValue == nil {
+                    upstreamValue = "same"
+                }
+                
+                if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
+                    hasDifference = true
+                }
+                
+                upstreamValues[channel.name] = upstreamValue
+                
+                confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
+            }
+
+            shouldInclude := false
+            
+            if localValue != nil {
+                if hasDifference {
+                    shouldInclude = true
+                }
+            } else {
+                if hasUpstreamValue {
+                    shouldInclude = true
+                }
+            }
+
+            if shouldInclude {
+                if differences[modelName] == nil {
+                    differences[modelName] = make(map[string]dto.DifferenceItem)
+                }
+                differences[modelName][ratioType] = dto.DifferenceItem{
+                    Current:   localValue,
+                    Upstreams: upstreamValues,
+                    Confidence: confidenceValues,
+                }
+            }
+        }
+    }
+
+    channelHasDiff := make(map[string]bool)
+    for _, ratioMap := range differences {
+        for _, item := range ratioMap {
+            for chName, val := range item.Upstreams {
+                if val != nil && val != "same" {
+                    channelHasDiff[chName] = true
+                }
+            }
+        }
+    }
+
+    for modelName, ratioMap := range differences {
+        for ratioType, item := range ratioMap {
+            for chName := range item.Upstreams {
+                if !channelHasDiff[chName] {
+                    delete(item.Upstreams, chName)
+                    delete(item.Confidence, chName)
+                }
+            }
+
+            allSame := true
+            for _, v := range item.Upstreams {
+                if v != "same" {
+                    allSame = false
+                    break
+                }
+            }
+            if len(item.Upstreams) == 0 || allSame {
+                delete(ratioMap, ratioType)
+            } else {
+                differences[modelName][ratioType] = item
+            }
+        }
+
+        if len(ratioMap) == 0 {
+            delete(differences, modelName)
+        }
+    }
+
+    return differences
+}
+
+func GetSyncableChannels(c *gin.Context) {
+    channels, err := model.GetAllChannels(0, 0, true, false)
+    if err != nil {
+        c.JSON(http.StatusOK, gin.H{
+            "success": false,
+            "message": err.Error(),
+        })
+        return
+    }
+
+    var syncableChannels []dto.SyncableChannel
+    for _, channel := range channels {
+        if channel.GetBaseURL() != "" {
+            syncableChannels = append(syncableChannels, dto.SyncableChannel{
+                ID:      channel.Id,
+                Name:    channel.Name,
+                BaseURL: channel.GetBaseURL(),
+                Status:  channel.Status,
+            })
+        }
+    }
+
+    c.JSON(http.StatusOK, gin.H{
+        "success": true,
+        "message": "",
+        "data":    syncableChannels,
+    })
+} 

+ 193 - 0
controller/redemption.go

@@ -0,0 +1,193 @@
+package controller
+
+import (
+	"errors"
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+	"strconv"
+
+	"github.com/gin-gonic/gin"
+)
+
+func GetAllRedemptions(c *gin.Context) {
+	pageInfo := common.GetPageQuery(c)
+	redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(redemptions)
+	common.ApiSuccess(c, pageInfo)
+	return
+}
+
+func SearchRedemptions(c *gin.Context) {
+	keyword := c.Query("keyword")
+	pageInfo := common.GetPageQuery(c)
+	redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(redemptions)
+	common.ApiSuccess(c, pageInfo)
+	return
+}
+
+func GetRedemption(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	redemption, err := model.GetRedemptionById(id)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    redemption,
+	})
+	return
+}
+
+func AddRedemption(c *gin.Context) {
+	redemption := model.Redemption{}
+	err := c.ShouldBindJSON(&redemption)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "兑换码名称长度必须在1-20之间",
+		})
+		return
+	}
+	if redemption.Count <= 0 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "兑换码个数必须大于0",
+		})
+		return
+	}
+	if redemption.Count > 100 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "一次兑换码批量生成的个数不能大于 100",
+		})
+		return
+	}
+	if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+		c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+		return
+	}
+	var keys []string
+	for i := 0; i < redemption.Count; i++ {
+		key := common.GetUUID()
+		cleanRedemption := model.Redemption{
+			UserId:      c.GetInt("id"),
+			Name:        redemption.Name,
+			Key:         key,
+			CreatedTime: common.GetTimestamp(),
+			Quota:       redemption.Quota,
+			ExpiredTime: redemption.ExpiredTime,
+		}
+		err = cleanRedemption.Insert()
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+				"data":    keys,
+			})
+			return
+		}
+		keys = append(keys, key)
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    keys,
+	})
+	return
+}
+
+func DeleteRedemption(c *gin.Context) {
+	id, _ := strconv.Atoi(c.Param("id"))
+	err := model.DeleteRedemptionById(id)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func UpdateRedemption(c *gin.Context) {
+	statusOnly := c.Query("status_only")
+	redemption := model.Redemption{}
+	err := c.ShouldBindJSON(&redemption)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	cleanRedemption, err := model.GetRedemptionById(redemption.Id)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	if statusOnly == "" {
+		if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+			c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+			return
+		}
+		// If you add more fields, please also update redemption.Update()
+		cleanRedemption.Name = redemption.Name
+		cleanRedemption.Quota = redemption.Quota
+		cleanRedemption.ExpiredTime = redemption.ExpiredTime
+	}
+	if statusOnly != "" {
+		cleanRedemption.Status = redemption.Status
+	}
+	err = cleanRedemption.Update()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    cleanRedemption,
+	})
+	return
+}
+
+func DeleteInvalidRedemption(c *gin.Context) {
+	rows, err := model.DeleteInvalidRedemptions()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    rows,
+	})
+	return
+}
+
+func validateExpiredTime(expired int64) error {
+	if expired != 0 && expired < common.GetTimestamp() {
+		return errors.New("过期时间不能早于当前时间")
+	}
+	return nil
+}

+ 478 - 0
controller/relay.go

@@ -0,0 +1,478 @@
+package controller
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"io"
+	"log"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	constant2 "one-api/constant"
+	"one-api/dto"
+	"one-api/middleware"
+	"one-api/model"
+	"one-api/relay"
+	relayconstant "one-api/relay/constant"
+	"one-api/relay/helper"
+	"one-api/service"
+	"one-api/types"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
+)
+
+func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
+	var err *types.NewAPIError
+	switch relayMode {
+	case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
+		err = relay.ImageHelper(c)
+	case relayconstant.RelayModeAudioSpeech:
+		fallthrough
+	case relayconstant.RelayModeAudioTranslation:
+		fallthrough
+	case relayconstant.RelayModeAudioTranscription:
+		err = relay.AudioHelper(c)
+	case relayconstant.RelayModeRerank:
+		err = relay.RerankHelper(c, relayMode)
+	case relayconstant.RelayModeEmbeddings:
+		err = relay.EmbeddingHelper(c)
+	case relayconstant.RelayModeResponses:
+		err = relay.ResponsesHelper(c)
+	case relayconstant.RelayModeGemini:
+		err = relay.GeminiHelper(c)
+	default:
+		err = relay.TextHelper(c)
+	}
+
+	if constant2.ErrorLogEnabled && err != nil {
+		// 保存错误日志到mysql中
+		userId := c.GetInt("id")
+		tokenName := c.GetString("token_name")
+		modelName := c.GetString("original_model")
+		tokenId := c.GetInt("token_id")
+		userGroup := c.GetString("group")
+		channelId := c.GetInt("channel_id")
+		other := make(map[string]interface{})
+		other["error_type"] = err.ErrorType
+		other["error_code"] = err.GetErrorCode()
+		other["status_code"] = err.StatusCode
+		other["channel_id"] = channelId
+		other["channel_name"] = c.GetString("channel_name")
+		other["channel_type"] = c.GetInt("channel_type")
+
+		model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
+	}
+
+	return err
+}
+
+func Relay(c *gin.Context) {
+	log.Println("===========Relay==========", c)
+	relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
+	requestId := c.GetString(common.RequestIdKey)
+	group := c.GetString("group")
+	originalModel := c.GetString("original_model")
+	var newAPIError *types.NewAPIError
+
+	for i := 0; i <= common.RetryTimes; i++ {
+		channel, err := getChannel(c, group, originalModel, i)
+		if err != nil {
+			common.LogError(c, err.Error())
+			newAPIError = err
+			break
+		}
+
+		newAPIError = relayRequest(c, relayMode, channel)
+
+		if newAPIError == nil {
+			return // 成功处理请求,直接返回
+		}
+
+		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+
+		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
+			break
+		}
+	}
+	useChannel := c.GetStringSlice("use_channel")
+	if len(useChannel) > 1 {
+		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+		common.LogInfo(c, retryLogStr)
+	}
+
+	if newAPIError != nil {
+		//if newAPIError.StatusCode == http.StatusTooManyRequests {
+		//	common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
+		//	newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
+		//}
+		newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
+		c.JSON(newAPIError.StatusCode, gin.H{
+			"error": newAPIError.ToOpenAIError(),
+		})
+	}
+}
+
+var upgrader = websocket.Upgrader{
+	Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
+	CheckOrigin: func(r *http.Request) bool {
+		return true // 允许跨域
+	},
+}
+
+func WssRelay(c *gin.Context) {
+	// 将 HTTP 连接升级为 WebSocket 连接
+
+	ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
+	defer ws.Close()
+
+	if err != nil {
+		helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
+		return
+	}
+
+	relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
+	requestId := c.GetString(common.RequestIdKey)
+	group := c.GetString("group")
+	//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
+	originalModel := c.GetString("original_model")
+	var newAPIError *types.NewAPIError
+
+	for i := 0; i <= common.RetryTimes; i++ {
+		channel, err := getChannel(c, group, originalModel, i)
+		if err != nil {
+			common.LogError(c, err.Error())
+			newAPIError = err
+			break
+		}
+
+		newAPIError = wssRequest(c, ws, relayMode, channel)
+
+		if newAPIError == nil {
+			return // 成功处理请求,直接返回
+		}
+
+		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+
+		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
+			break
+		}
+	}
+	useChannel := c.GetStringSlice("use_channel")
+	if len(useChannel) > 1 {
+		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+		common.LogInfo(c, retryLogStr)
+	}
+
+	if newAPIError != nil {
+		//if newAPIError.StatusCode == http.StatusTooManyRequests {
+		//	newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
+		//}
+		newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
+		helper.WssError(c, ws, newAPIError.ToOpenAIError())
+	}
+}
+
+func RelayClaude(c *gin.Context) {
+	//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+	log.Println("===========RelayClaude==========", c)
+	requestId := c.GetString(common.RequestIdKey)
+	group := c.GetString("group")
+	originalModel := c.GetString("original_model")
+	var newAPIError *types.NewAPIError
+
+	for i := 0; i <= common.RetryTimes; i++ {
+		channel, err := getChannel(c, group, originalModel, i)
+		if err != nil {
+			common.LogError(c, err.Error())
+			newAPIError = err
+			break
+		}
+
+		newAPIError = claudeRequest(c, channel)
+
+		if newAPIError == nil {
+			return // 成功处理请求,直接返回
+		}
+
+		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+
+		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
+			break
+		}
+	}
+	useChannel := c.GetStringSlice("use_channel")
+	if len(useChannel) > 1 {
+		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+		common.LogInfo(c, retryLogStr)
+	}
+
+	if newAPIError != nil {
+		newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
+		c.JSON(newAPIError.StatusCode, gin.H{
+			"type":  "error",
+			"error": newAPIError.ToClaudeError(),
+		})
+	}
+}
+
+func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
+	addUsedChannel(c, channel.Id)
+	requestBody, _ := common.GetRequestBody(c)
+	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	return relayHandler(c, relayMode)
+}
+
+func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
+	addUsedChannel(c, channel.Id)
+	requestBody, _ := common.GetRequestBody(c)
+	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	return relay.WssHelper(c, ws)
+}
+
+func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
+	addUsedChannel(c, channel.Id)
+	requestBody, _ := common.GetRequestBody(c)
+	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	return relay.ClaudeHelper(c)
+}
+
+func addUsedChannel(c *gin.Context, channelId int) {
+	useChannel := c.GetStringSlice("use_channel")
+	useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
+	c.Set("use_channel", useChannel)
+}
+
+func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
+	if retryCount == 0 {
+		autoBan := c.GetBool("auto_ban")
+		autoBanInt := 1
+		if !autoBan {
+			autoBanInt = 0
+		}
+		return &model.Channel{
+			Id:      c.GetInt("channel_id"),
+			Type:    c.GetInt("channel_type"),
+			Name:    c.GetString("channel_name"),
+			AutoBan: &autoBanInt,
+		}, nil
+	}
+	channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
+	if err != nil {
+		if group == "auto" {
+			return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
+		}
+		return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
+	}
+	newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+	if newAPIError != nil {
+		return nil, newAPIError
+	}
+	return channel, nil
+}
+
+func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
+	if openaiErr == nil {
+		return false
+	}
+	if types.IsChannelError(openaiErr) {
+		return true
+	}
+	if types.IsLocalError(openaiErr) {
+		return false
+	}
+	if retryTimes <= 0 {
+		return false
+	}
+	if _, ok := c.Get("specific_channel_id"); ok {
+		return false
+	}
+	if openaiErr.StatusCode == http.StatusTooManyRequests {
+		return true
+	}
+	if openaiErr.StatusCode == 307 {
+		return true
+	}
+	if openaiErr.StatusCode/100 == 5 {
+		// 超时不重试
+		if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
+			return false
+		}
+		return true
+	}
+	if openaiErr.StatusCode == http.StatusBadRequest {
+		channelType := c.GetInt("channel_type")
+		if channelType == constant.ChannelTypeAnthropic {
+			return true
+		}
+		return false
+	}
+	if openaiErr.StatusCode == 408 {
+		// azure处理超时不重试
+		return false
+	}
+	if openaiErr.StatusCode/100 == 2 {
+		return false
+	}
+	return true
+}
+
+func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
+	// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
+	// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
+	common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
+	if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
+		service.DisableChannel(channelError, err.Error())
+	}
+}
+
+func RelayMidjourney(c *gin.Context) {
+	relayMode := c.GetInt("relay_mode")
+	var err *dto.MidjourneyResponse
+	switch relayMode {
+	case relayconstant.RelayModeMidjourneyNotify:
+		err = relay.RelayMidjourneyNotify(c)
+	case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
+		err = relay.RelayMidjourneyTask(c, relayMode)
+	case relayconstant.RelayModeMidjourneyTaskImageSeed:
+		err = relay.RelayMidjourneyTaskImageSeed(c)
+	case relayconstant.RelayModeSwapFace:
+		err = relay.RelaySwapFace(c)
+	default:
+		err = relay.RelayMidjourneySubmit(c, relayMode)
+	}
+	//err = relayMidjourneySubmit(c, relayMode)
+	log.Println(err)
+	if err != nil {
+		statusCode := http.StatusBadRequest
+		if err.Code == 30 {
+			err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
+			statusCode = http.StatusTooManyRequests
+		}
+		c.JSON(statusCode, gin.H{
+			"description": fmt.Sprintf("%s %s", err.Description, err.Result),
+			"type":        "upstream_error",
+			"code":        err.Code,
+		})
+		channelId := c.GetInt("channel_id")
+		common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
+	}
+}
+
+func RelayNotImplemented(c *gin.Context) {
+	err := dto.OpenAIError{
+		Message: "API not implemented",
+		Type:    "new_api_error",
+		Param:   "",
+		Code:    "api_not_implemented",
+	}
+	c.JSON(http.StatusNotImplemented, gin.H{
+		"error": err,
+	})
+}
+
+func RelayNotFound(c *gin.Context) {
+	err := dto.OpenAIError{
+		Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
+		Type:    "invalid_request_error",
+		Param:   "",
+		Code:    "",
+	}
+	c.JSON(http.StatusNotFound, gin.H{
+		"error": err,
+	})
+}
+
+func RelayTask(c *gin.Context) {
+	retryTimes := common.RetryTimes
+	channelId := c.GetInt("channel_id")
+	relayMode := c.GetInt("relay_mode")
+	group := c.GetString("group")
+	originalModel := c.GetString("original_model")
+	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
+	taskErr := taskRelayHandler(c, relayMode)
+	if taskErr == nil {
+		retryTimes = 0
+	}
+	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
+		channel, newAPIError := getChannel(c, group, originalModel, i)
+		if newAPIError != nil {
+			common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
+			taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
+			break
+		}
+		channelId = channel.Id
+		useChannel := c.GetStringSlice("use_channel")
+		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
+		c.Set("use_channel", useChannel)
+		common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+
+		requestBody, _ := common.GetRequestBody(c)
+		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+		taskErr = taskRelayHandler(c, relayMode)
+	}
+	useChannel := c.GetStringSlice("use_channel")
+	if len(useChannel) > 1 {
+		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+		common.LogInfo(c, retryLogStr)
+	}
+	if taskErr != nil {
+		if taskErr.StatusCode == http.StatusTooManyRequests {
+			taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
+		}
+		c.JSON(taskErr.StatusCode, taskErr)
+	}
+}
+
+func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
+	var err *dto.TaskError
+	switch relayMode {
+	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
+		err = relay.RelayTaskFetch(c, relayMode)
+	default:
+		err = relay.RelayTaskSubmit(c, relayMode)
+	}
+	return err
+}
+
+func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
+	if taskErr == nil {
+		return false
+	}
+	if retryTimes <= 0 {
+		return false
+	}
+	if _, ok := c.Get("specific_channel_id"); ok {
+		return false
+	}
+	if taskErr.StatusCode == http.StatusTooManyRequests {
+		return true
+	}
+	if taskErr.StatusCode == 307 {
+		return true
+	}
+	if taskErr.StatusCode/100 == 5 {
+		// 超时不重试
+		if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
+			return false
+		}
+		return true
+	}
+	if taskErr.StatusCode == http.StatusBadRequest {
+		return false
+	}
+	if taskErr.StatusCode == 408 {
+		// azure处理超时不重试
+		return false
+	}
+	if taskErr.LocalError {
+		return false
+	}
+	if taskErr.StatusCode/100 == 2 {
+		return false
+	}
+	return true
+}

+ 181 - 0
controller/setup.go

@@ -0,0 +1,181 @@
+package controller
+
+import (
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/model"
+	"one-api/setting/operation_setting"
+	"time"
+)
+
+type Setup struct {
+	Status       bool   `json:"status"`
+	RootInit     bool   `json:"root_init"`
+	DatabaseType string `json:"database_type"`
+}
+
+type SetupRequest struct {
+	Username           string `json:"username"`
+	Password           string `json:"password"`
+	ConfirmPassword    string `json:"confirmPassword"`
+	SelfUseModeEnabled bool   `json:"SelfUseModeEnabled"`
+	DemoSiteEnabled    bool   `json:"DemoSiteEnabled"`
+}
+
+func GetSetup(c *gin.Context) {
+	setup := Setup{
+		Status: constant.Setup,
+	}
+	if constant.Setup {
+		c.JSON(200, gin.H{
+			"success": true,
+			"data":    setup,
+		})
+		return
+	}
+	setup.RootInit = model.RootUserExists()
+	if common.UsingMySQL {
+		setup.DatabaseType = "mysql"
+	}
+	if common.UsingPostgreSQL {
+		setup.DatabaseType = "postgres"
+	}
+	if common.UsingSQLite {
+		setup.DatabaseType = "sqlite"
+	}
+	c.JSON(200, gin.H{
+		"success": true,
+		"data":    setup,
+	})
+}
+
+func PostSetup(c *gin.Context) {
+	// Check if setup is already completed
+	if constant.Setup {
+		c.JSON(400, gin.H{
+			"success": false,
+			"message": "系统已经初始化完成",
+		})
+		return
+	}
+
+	// Check if root user already exists
+	rootExists := model.RootUserExists()
+
+	var req SetupRequest
+	err := c.ShouldBindJSON(&req)
+	if err != nil {
+		c.JSON(400, gin.H{
+			"success": false,
+			"message": "请求参数有误",
+		})
+		return
+	}
+
+	// If root doesn't exist, validate and create admin account
+	if !rootExists {
+		// Validate username length: max 12 characters to align with model.User validation
+		if len(req.Username) > 12 {
+			c.JSON(400, gin.H{
+				"success": false,
+				"message": "用户名长度不能超过12个字符",
+			})
+			return
+		}
+		// Validate password
+		if req.Password != req.ConfirmPassword {
+			c.JSON(400, gin.H{
+				"success": false,
+				"message": "两次输入的密码不一致",
+			})
+			return
+		}
+
+		if len(req.Password) < 8 {
+			c.JSON(400, gin.H{
+				"success": false,
+				"message": "密码长度至少为8个字符",
+			})
+			return
+		}
+
+		// Create root user
+		hashedPassword, err := common.Password2Hash(req.Password)
+		if err != nil {
+			c.JSON(500, gin.H{
+				"success": false,
+				"message": "系统错误: " + err.Error(),
+			})
+			return
+		}
+		rootUser := model.User{
+			Username:    req.Username,
+			Password:    hashedPassword,
+			Role:        common.RoleRootUser,
+			Status:      common.UserStatusEnabled,
+			DisplayName: "Root User",
+			AccessToken: nil,
+			Quota:       100000000,
+		}
+		err = model.DB.Create(&rootUser).Error
+		if err != nil {
+			c.JSON(500, gin.H{
+				"success": false,
+				"message": "创建管理员账号失败: " + err.Error(),
+			})
+			return
+		}
+	}
+
+	// Set operation modes
+	operation_setting.SelfUseModeEnabled = req.SelfUseModeEnabled
+	operation_setting.DemoSiteEnabled = req.DemoSiteEnabled
+
+	// Save operation modes to database for persistence
+	err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
+	if err != nil {
+		c.JSON(500, gin.H{
+			"success": false,
+			"message": "保存自用模式设置失败: " + err.Error(),
+		})
+		return
+	}
+
+	err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
+	if err != nil {
+		c.JSON(500, gin.H{
+			"success": false,
+			"message": "保存演示站点模式设置失败: " + err.Error(),
+		})
+		return
+	}
+
+	// Update setup status
+	constant.Setup = true
+
+	setup := model.Setup{
+		Version:       common.Version,
+		InitializedAt: time.Now().Unix(),
+	}
+	err = model.DB.Create(&setup).Error
+	if err != nil {
+		c.JSON(500, gin.H{
+			"success": false,
+			"message": "系统初始化失败: " + err.Error(),
+		})
+		return
+	}
+
+	c.JSON(200, gin.H{
+		"success": true,
+		"message": "系统初始化成功",
+	})
+}
+
+func boolToString(b bool) string {
+	if b {
+		return "true"
+	}
+	return "false"
+}

+ 116 - 0
controller/swag_video.go

@@ -0,0 +1,116 @@
+package controller
+
+import (
+	"github.com/gin-gonic/gin"
+)
+
+// VideoGenerations
+// @Summary 生成视频
+// @Description 调用视频生成接口生成视频
+// @Description 支持多种视频生成服务:
+// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
+// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body dto.VideoRequest true "视频生成请求参数"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /v1/video/generations [post]
+func VideoGenerations(c *gin.Context) {
+}
+
+// VideoGenerationsTaskId
+// @Summary 查询视频
+// @Description 根据任务ID查询视频生成任务的状态和结果
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Security BearerAuth
+// @Param task_id path string true "Task ID"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /v1/video/generations/{task_id} [get]
+func VideoGenerationsTaskId(c *gin.Context) {
+}
+
+// KlingText2VideoGenerations
+// @Summary 可灵文生视频
+// @Description 调用可灵AI文生视频接口,生成视频内容
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body KlingText2VideoRequest true "视频生成请求参数"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /kling/v1/videos/text2video [post]
+func KlingText2VideoGenerations(c *gin.Context) {
+}
+
+type KlingText2VideoRequest struct {
+	ModelName      string              `json:"model_name,omitempty" example:"kling-v1"`
+	Prompt         string              `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
+	NegativePrompt string              `json:"negative_prompt,omitempty" example:"blurry, low quality"`
+	CfgScale       float64             `json:"cfg_scale,omitempty" example:"0.7"`
+	Mode           string              `json:"mode,omitempty" example:"std"`
+	CameraControl  *KlingCameraControl `json:"camera_control,omitempty"`
+	AspectRatio    string              `json:"aspect_ratio,omitempty" example:"16:9"`
+	Duration       string              `json:"duration,omitempty" example:"5"`
+	CallbackURL    string              `json:"callback_url,omitempty" example:"https://your.domain/callback"`
+	ExternalTaskId string              `json:"external_task_id,omitempty" example:"custom-task-001"`
+}
+
+type KlingCameraControl struct {
+	Type   string             `json:"type,omitempty" example:"simple"`
+	Config *KlingCameraConfig `json:"config,omitempty"`
+}
+
+type KlingCameraConfig struct {
+	Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
+	Vertical   float64 `json:"vertical,omitempty" example:"0"`
+	Pan        float64 `json:"pan,omitempty" example:"0"`
+	Tilt       float64 `json:"tilt,omitempty" example:"0"`
+	Roll       float64 `json:"roll,omitempty" example:"0"`
+	Zoom       float64 `json:"zoom,omitempty" example:"0"`
+}
+
+// KlingImage2VideoGenerations
+// @Summary 可灵官方-图生视频
+// @Description 调用可灵AI图生视频接口,生成视频内容
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /kling/v1/videos/image2video [post]
+func KlingImage2VideoGenerations(c *gin.Context) {
+}
+
+type KlingImage2VideoRequest struct {
+	ModelName      string              `json:"model_name,omitempty" example:"kling-v2-master"`
+	Image          string              `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
+	Prompt         string              `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
+	NegativePrompt string              `json:"negative_prompt,omitempty" example:"blurry, low quality"`
+	CfgScale       float64             `json:"cfg_scale,omitempty" example:"0.7"`
+	Mode           string              `json:"mode,omitempty" example:"std"`
+	CameraControl  *KlingCameraControl `json:"camera_control,omitempty"`
+	AspectRatio    string              `json:"aspect_ratio,omitempty" example:"16:9"`
+	Duration       string              `json:"duration,omitempty" example:"5"`
+	CallbackURL    string              `json:"callback_url,omitempty" example:"https://your.domain/callback"`
+	ExternalTaskId string              `json:"external_task_id,omitempty" example:"custom-task-002"`
+}

+ 273 - 0
controller/task.go

@@ -0,0 +1,273 @@
+package controller
+
+import (
+	"context"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/model"
+	"one-api/relay"
+	"sort"
+	"strconv"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/samber/lo"
+)
+
+func UpdateTaskBulk() {
+	//revocer
+	//imageModel := "midjourney"
+	for {
+		time.Sleep(time.Duration(15) * time.Second)
+		common.SysLog("任务进度轮询开始")
+		ctx := context.TODO()
+		allTasks := model.GetAllUnFinishSyncTasks(500)
+		platformTask := make(map[constant.TaskPlatform][]*model.Task)
+		for _, t := range allTasks {
+			platformTask[t.Platform] = append(platformTask[t.Platform], t)
+		}
+		for platform, tasks := range platformTask {
+			if len(tasks) == 0 {
+				continue
+			}
+			taskChannelM := make(map[int][]string)
+			taskM := make(map[string]*model.Task)
+			nullTaskIds := make([]int64, 0)
+			for _, task := range tasks {
+				if task.TaskID == "" {
+					// 统计失败的未完成任务
+					nullTaskIds = append(nullTaskIds, task.ID)
+					continue
+				}
+				taskM[task.TaskID] = task
+				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
+			}
+			if len(nullTaskIds) > 0 {
+				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
+					"status":   "FAILURE",
+					"progress": "100%",
+				})
+				if err != nil {
+					common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+				} else {
+					common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+				}
+			}
+			if len(taskChannelM) == 0 {
+				continue
+			}
+
+			UpdateTaskByPlatform(platform, taskChannelM, taskM)
+		}
+		common.SysLog("任务进度轮询完成")
+	}
+}
+
+func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
+	switch platform {
+	case constant.TaskPlatformMidjourney:
+		//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
+	case constant.TaskPlatformSuno:
+		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
+	case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
+		_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
+	default:
+		common.SysLog("未知平台")
+	}
+}
+
+func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
+		if err != nil {
+			common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	channel, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+		err = model.TaskBulkUpdate(taskIds, map[string]any{
+			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if err != nil {
+			common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
+		}
+		return err
+	}
+	adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
+	if adaptor == nil {
+		return errors.New("adaptor not found")
+	}
+	resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
+		"ids": taskIds,
+	})
+	if err != nil {
+		common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
+		return err
+	}
+	if resp.StatusCode != http.StatusOK {
+		common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+		return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
+		return err
+	}
+	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
+	err = json.Unmarshal(responseBody, &responseItems)
+	if err != nil {
+		common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+		return err
+	}
+	if !responseItems.IsSuccess() {
+		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
+		return err
+	}
+
+	for _, responseItem := range responseItems.Data {
+		task := taskM[responseItem.TaskID]
+		if !checkTaskNeedUpdate(task, responseItem) {
+			continue
+		}
+
+		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
+		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
+		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
+		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
+		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
+		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
+			common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+			task.Progress = "100%"
+			//err = model.CacheUpdateUserQuota(task.UserId) ?
+			if err != nil {
+				common.LogError(ctx, "error update user quota cache: "+err.Error())
+			} else {
+				quota := task.Quota
+				if quota != 0 {
+					err = model.IncreaseUserQuota(task.UserId, quota, false)
+					if err != nil {
+						common.LogError(ctx, "fail to increase user quota: "+err.Error())
+					}
+					logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
+					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+				}
+			}
+		}
+		if responseItem.Status == model.TaskStatusSuccess {
+			task.Progress = "100%"
+		}
+		task.Data = responseItem.Data
+
+		err = task.Update()
+		if err != nil {
+			common.SysError("UpdateMidjourneyTask task error: " + err.Error())
+		}
+	}
+	return nil
+}
+
+func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
+
+	if oldTask.SubmitTime != newTask.SubmitTime {
+		return true
+	}
+	if oldTask.StartTime != newTask.StartTime {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if string(oldTask.Status) != newTask.Status {
+		return true
+	}
+	if oldTask.FailReason != newTask.FailReason {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+
+	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
+		return true
+	}
+
+	oldData, _ := json.Marshal(oldTask.Data)
+	newData, _ := json.Marshal(newTask.Data)
+
+	sort.Slice(oldData, func(i, j int) bool {
+		return oldData[i] < oldData[j]
+	})
+	sort.Slice(newData, func(i, j int) bool {
+		return newData[i] < newData[j]
+	})
+
+	if string(oldData) != string(newData) {
+		return true
+	}
+	return false
+}
+
+func GetAllTask(c *gin.Context) {
+	pageInfo := common.GetPageQuery(c)
+
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+	// 解析其他查询参数
+	queryParams := model.SyncTaskQueryParams{
+		Platform:       constant.TaskPlatform(c.Query("platform")),
+		TaskID:         c.Query("task_id"),
+		Status:         c.Query("status"),
+		Action:         c.Query("action"),
+		StartTimestamp: startTimestamp,
+		EndTimestamp:   endTimestamp,
+		ChannelID:      c.Query("channel_id"),
+	}
+
+	items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+	total := model.TaskCountAllTasks(queryParams)
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(items)
+	common.ApiSuccess(c, pageInfo)
+}
+
+func GetUserTask(c *gin.Context) {
+	pageInfo := common.GetPageQuery(c)
+
+	userId := c.GetInt("id")
+
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+
+	queryParams := model.SyncTaskQueryParams{
+		Platform:       constant.TaskPlatform(c.Query("platform")),
+		TaskID:         c.Query("task_id"),
+		Status:         c.Query("status"),
+		Action:         c.Query("action"),
+		StartTimestamp: startTimestamp,
+		EndTimestamp:   endTimestamp,
+	}
+
+	items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+	total := model.TaskCountAllUserTask(userId, queryParams)
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(items)
+	common.ApiSuccess(c, pageInfo)
+}

+ 138 - 0
controller/task_video.go

@@ -0,0 +1,138 @@
+package controller
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/model"
+	"one-api/relay"
+	"one-api/relay/channel"
+	"time"
+)
+
+func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
+			common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	cacheGetChannel, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
+			"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if errUpdate != nil {
+			common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+		}
+		return fmt.Errorf("CacheGetChannel failed: %w", err)
+	}
+	adaptor := relay.GetTaskAdaptor(platform)
+	if adaptor == nil {
+		return fmt.Errorf("video adaptor not found")
+	}
+	for _, taskId := range taskIds {
+		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+			common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
+	baseURL := constant.ChannelBaseURLs[channel.Type]
+	if channel.GetBaseURL() != "" {
+		baseURL = channel.GetBaseURL()
+	}
+
+	task := taskM[taskId]
+	if task == nil {
+		common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+		return fmt.Errorf("task %s not found", taskId)
+	}
+	resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
+		"task_id": taskId,
+		"action":  task.Action,
+	})
+	if err != nil {
+		return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
+	}
+	//if resp.StatusCode != http.StatusOK {
+	//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
+	//}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
+	}
+
+	taskResult, err := adaptor.ParseTaskResult(responseBody)
+	if err != nil {
+		return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
+	}
+	//if taskResult.Code != 0 {
+	//	return fmt.Errorf("video task fetch failed for task %s", taskId)
+	//}
+
+	now := time.Now().Unix()
+	if taskResult.Status == "" {
+		return fmt.Errorf("task %s status is empty", taskId)
+	}
+	task.Status = model.TaskStatus(taskResult.Status)
+	switch taskResult.Status {
+	case model.TaskStatusSubmitted:
+		task.Progress = "10%"
+	case model.TaskStatusQueued:
+		task.Progress = "20%"
+	case model.TaskStatusInProgress:
+		task.Progress = "30%"
+		if task.StartTime == 0 {
+			task.StartTime = now
+		}
+	case model.TaskStatusSuccess:
+		task.Progress = "100%"
+		if task.FinishTime == 0 {
+			task.FinishTime = now
+		}
+		task.FailReason = taskResult.Url
+	case model.TaskStatusFailure:
+		task.Status = model.TaskStatusFailure
+		task.Progress = "100%"
+		if task.FinishTime == 0 {
+			task.FinishTime = now
+		}
+		task.FailReason = taskResult.Reason
+		common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+		quota := task.Quota
+		if quota != 0 {
+			if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
+				common.LogError(ctx, "Failed to increase user quota: "+err.Error())
+			}
+			logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
+			model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+		}
+	default:
+		return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
+	}
+	if taskResult.Progress != "" {
+		task.Progress = taskResult.Progress
+	}
+
+	task.Data = responseBody
+	if err := task.Update(); err != nil {
+		common.SysError("UpdateVideoTask task error: " + err.Error())
+	}
+
+	return nil
+}

+ 124 - 0
controller/telegram.go

@@ -0,0 +1,124 @@
+package controller
+
+import (
+	"crypto/hmac"
+	"crypto/sha256"
+	"encoding/hex"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+	"sort"
+
+	"github.com/gin-contrib/sessions"
+	"github.com/gin-gonic/gin"
+)
+
+func TelegramBind(c *gin.Context) {
+	if !common.TelegramOAuthEnabled {
+		c.JSON(200, gin.H{
+			"message": "管理员未开启通过 Telegram 登录以及注册",
+			"success": false,
+		})
+		return
+	}
+	params := c.Request.URL.Query()
+	if !checkTelegramAuthorization(params, common.TelegramBotToken) {
+		c.JSON(200, gin.H{
+			"message": "无效的请求",
+			"success": false,
+		})
+		return
+	}
+	telegramId := params["id"][0]
+	if model.IsTelegramIdAlreadyTaken(telegramId) {
+		c.JSON(200, gin.H{
+			"message": "该 Telegram 账户已被绑定",
+			"success": false,
+		})
+		return
+	}
+
+	session := sessions.Default(c)
+	id := session.Get("id")
+	user := model.User{Id: id.(int)}
+	if err := user.FillUserById(); err != nil {
+		c.JSON(200, gin.H{
+			"message": err.Error(),
+			"success": false,
+		})
+		return
+	}
+	if user.Id == 0 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "用户已注销",
+		})
+		return
+	}
+	user.TelegramId = telegramId
+	if err := user.Update(false); err != nil {
+		c.JSON(200, gin.H{
+			"message": err.Error(),
+			"success": false,
+		})
+		return
+	}
+
+	c.Redirect(302, "/setting")
+}
+
+func TelegramLogin(c *gin.Context) {
+	if !common.TelegramOAuthEnabled {
+		c.JSON(200, gin.H{
+			"message": "管理员未开启通过 Telegram 登录以及注册",
+			"success": false,
+		})
+		return
+	}
+	params := c.Request.URL.Query()
+	if !checkTelegramAuthorization(params, common.TelegramBotToken) {
+		c.JSON(200, gin.H{
+			"message": "无效的请求",
+			"success": false,
+		})
+		return
+	}
+
+	telegramId := params["id"][0]
+	user := model.User{TelegramId: telegramId}
+	if err := user.FillUserByTelegramId(); err != nil {
+		c.JSON(200, gin.H{
+			"message": err.Error(),
+			"success": false,
+		})
+		return
+	}
+	setupLogin(&user, c)
+}
+
+func checkTelegramAuthorization(params map[string][]string, token string) bool {
+	strs := []string{}
+	var hash = ""
+	for k, v := range params {
+		if k == "hash" {
+			hash = v[0]
+			continue
+		}
+		strs = append(strs, k+"="+v[0])
+	}
+	sort.Strings(strs)
+	var imploded = ""
+	for _, s := range strs {
+		if imploded != "" {
+			imploded += "\n"
+		}
+		imploded += s
+	}
+	sha256hash := sha256.New()
+	io.WriteString(sha256hash, token)
+	hmachash := hmac.New(sha256.New, sha256hash.Sum(nil))
+	io.WriteString(hmachash, imploded)
+	ss := hex.EncodeToString(hmachash.Sum(nil))
+	return hash == ss
+}

+ 241 - 0
controller/token.go

@@ -0,0 +1,241 @@
+package controller
+
+import (
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+	"strconv"
+
+	"github.com/gin-gonic/gin"
+)
+
+func GetAllTokens(c *gin.Context) {
+	userId := c.GetInt("id")
+	pageInfo := common.GetPageQuery(c)
+	tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	total, _ := model.CountUserTokens(userId)
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(tokens)
+	common.ApiSuccess(c, pageInfo)
+	return
+}
+
+func SearchTokens(c *gin.Context) {
+	userId := c.GetInt("id")
+	keyword := c.Query("keyword")
+	token := c.Query("token")
+	tokens, err := model.SearchUserTokens(userId, keyword, token)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    tokens,
+	})
+	return
+}
+
+func GetToken(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	userId := c.GetInt("id")
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	token, err := model.GetTokenByIds(id, userId)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    token,
+	})
+	return
+}
+
+func GetTokenStatus(c *gin.Context) {
+	tokenId := c.GetInt("token_id")
+	userId := c.GetInt("id")
+	token, err := model.GetTokenByIds(tokenId, userId)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	expiredAt := token.ExpiredTime
+	if expiredAt == -1 {
+		expiredAt = 0
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"object":          "credit_summary",
+		"total_granted":   token.RemainQuota,
+		"total_used":      0, // not supported currently
+		"total_available": token.RemainQuota,
+		"expires_at":      expiredAt * 1000,
+	})
+}
+
+func AddToken(c *gin.Context) {
+	token := model.Token{}
+	err := c.ShouldBindJSON(&token)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	if len(token.Name) > 30 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "令牌名称过长",
+		})
+		return
+	}
+	key, err := common.GenerateKey()
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "生成令牌失败",
+		})
+		common.SysError("failed to generate token key: " + err.Error())
+		return
+	}
+	cleanToken := model.Token{
+		UserId:             c.GetInt("id"),
+		Name:               token.Name,
+		Key:                key,
+		CreatedTime:        common.GetTimestamp(),
+		AccessedTime:       common.GetTimestamp(),
+		ExpiredTime:        token.ExpiredTime,
+		RemainQuota:        token.RemainQuota,
+		UnlimitedQuota:     token.UnlimitedQuota,
+		ModelLimitsEnabled: token.ModelLimitsEnabled,
+		ModelLimits:        token.ModelLimits,
+		AllowIps:           token.AllowIps,
+		Group:              token.Group,
+		RateLimitPerMinute: token.RateLimitPerMinute,
+		RateLimitPerDay:    token.RateLimitPerDay,
+		LastRateLimitReset: 0,
+	}
+	err = cleanToken.Insert()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func DeleteToken(c *gin.Context) {
+	id, _ := strconv.Atoi(c.Param("id"))
+	userId := c.GetInt("id")
+	err := model.DeleteTokenById(id, userId)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+	})
+	return
+}
+
+func UpdateToken(c *gin.Context) {
+	userId := c.GetInt("id")
+	statusOnly := c.Query("status_only")
+	token := model.Token{}
+	err := c.ShouldBindJSON(&token)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	if len(token.Name) > 30 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "令牌名称过长",
+		})
+		return
+	}
+	cleanToken, err := model.GetTokenByIds(token.Id, userId)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	if token.Status == common.TokenStatusEnabled {
+		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
+			})
+			return
+		}
+		if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
+			})
+			return
+		}
+	}
+	if statusOnly != "" {
+		cleanToken.Status = token.Status
+	} else {
+		// If you add more fields, please also update token.Update()
+		cleanToken.Name = token.Name
+		cleanToken.ExpiredTime = token.ExpiredTime
+		cleanToken.RemainQuota = token.RemainQuota
+		cleanToken.UnlimitedQuota = token.UnlimitedQuota
+		cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
+		cleanToken.ModelLimits = token.ModelLimits
+		cleanToken.AllowIps = token.AllowIps
+		cleanToken.Group = token.Group
+		cleanToken.RateLimitPerMinute = token.RateLimitPerMinute
+		cleanToken.RateLimitPerDay = token.RateLimitPerDay
+	}
+	err = cleanToken.Update()
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    cleanToken,
+	})
+	return
+}
+
+type TokenBatch struct {
+	Ids []int `json:"ids"`
+}
+
+func DeleteTokenBatch(c *gin.Context) {
+	tokenBatch := TokenBatch{}
+	if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "参数错误",
+		})
+		return
+	}
+	userId := c.GetInt("id")
+	count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    count,
+	})
+}

Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini