sign.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. package jimeng
  2. import (
  3. "bytes"
  4. "crypto/hmac"
  5. "crypto/sha256"
  6. "encoding/hex"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "github.com/gin-gonic/gin"
  11. "io"
  12. "net/http"
  13. "net/url"
  14. "one-api/common"
  15. "sort"
  16. "strings"
  17. "time"
  18. )
  19. // SignRequestForJimeng 对即梦 API 请求进行签名,支持 http.Request 或 header+url+body 方式
  20. //func SignRequestForJimeng(req *http.Request, accessKey, secretKey string) error {
  21. // var bodyBytes []byte
  22. // var err error
  23. //
  24. // if req.Body != nil {
  25. // bodyBytes, err = io.ReadAll(req.Body)
  26. // if err != nil {
  27. // return fmt.Errorf("read request body failed: %w", err)
  28. // }
  29. // _ = req.Body.Close()
  30. // req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // rewind
  31. // } else {
  32. // bodyBytes = []byte{}
  33. // }
  34. //
  35. // return signJimengHeaders(&req.Header, req.Method, req.URL, bodyBytes, accessKey, secretKey)
  36. //}
  37. const HexPayloadHashKey = "HexPayloadHash"
  38. func SetPayloadHash(c *gin.Context, req any) error {
  39. body, err := json.Marshal(req)
  40. if err != nil {
  41. return err
  42. }
  43. common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
  44. payloadHash := sha256.Sum256(body)
  45. hexPayloadHash := hex.EncodeToString(payloadHash[:])
  46. c.Set(HexPayloadHashKey, hexPayloadHash)
  47. return nil
  48. }
  49. func getPayloadHash(c *gin.Context) string {
  50. return c.GetString(HexPayloadHashKey)
  51. }
  52. func Sign(c *gin.Context, req *http.Request, apiKey string) error {
  53. header := req.Header
  54. var bodyBytes []byte
  55. var err error
  56. if req.Body != nil {
  57. bodyBytes, err = io.ReadAll(req.Body)
  58. if err != nil {
  59. return err
  60. }
  61. _ = req.Body.Close()
  62. req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
  63. }
  64. payloadHash := sha256.Sum256(bodyBytes)
  65. hexPayloadHash := hex.EncodeToString(payloadHash[:])
  66. method := c.Request.Method
  67. u := req.URL
  68. keyParts := strings.Split(apiKey, "|")
  69. if len(keyParts) != 2 {
  70. return errors.New("invalid api key format for jimeng: expected 'ak|sk'")
  71. }
  72. accessKey := strings.TrimSpace(keyParts[0])
  73. secretKey := strings.TrimSpace(keyParts[1])
  74. t := time.Now().UTC()
  75. xDate := t.Format("20060102T150405Z")
  76. shortDate := t.Format("20060102")
  77. host := u.Host
  78. header.Set("Host", host)
  79. header.Set("X-Date", xDate)
  80. header.Set("X-Content-Sha256", hexPayloadHash)
  81. // Sort and encode query parameters to create canonical query string
  82. queryParams := u.Query()
  83. sortedKeys := make([]string, 0, len(queryParams))
  84. for k := range queryParams {
  85. sortedKeys = append(sortedKeys, k)
  86. }
  87. sort.Strings(sortedKeys)
  88. var queryParts []string
  89. for _, k := range sortedKeys {
  90. values := queryParams[k]
  91. sort.Strings(values)
  92. for _, v := range values {
  93. queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
  94. }
  95. }
  96. canonicalQueryString := strings.Join(queryParts, "&")
  97. headersToSign := map[string]string{
  98. "host": host,
  99. "x-date": xDate,
  100. "x-content-sha256": hexPayloadHash,
  101. }
  102. if header.Get("Content-Type") == "" {
  103. header.Set("Content-Type", "application/json")
  104. }
  105. headersToSign["content-type"] = header.Get("Content-Type")
  106. var signedHeaderKeys []string
  107. for k := range headersToSign {
  108. signedHeaderKeys = append(signedHeaderKeys, k)
  109. }
  110. sort.Strings(signedHeaderKeys)
  111. var canonicalHeaders strings.Builder
  112. for _, k := range signedHeaderKeys {
  113. canonicalHeaders.WriteString(k)
  114. canonicalHeaders.WriteString(":")
  115. canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
  116. canonicalHeaders.WriteString("\n")
  117. }
  118. signedHeaders := strings.Join(signedHeaderKeys, ";")
  119. canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
  120. method,
  121. u.Path,
  122. canonicalQueryString,
  123. canonicalHeaders.String(),
  124. signedHeaders,
  125. hexPayloadHash,
  126. )
  127. hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
  128. hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
  129. region := "cn-north-1"
  130. serviceName := "cv"
  131. credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
  132. stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
  133. xDate,
  134. credentialScope,
  135. hexHashedCanonicalRequest,
  136. )
  137. kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
  138. kRegion := hmacSHA256(kDate, []byte(region))
  139. kService := hmacSHA256(kRegion, []byte(serviceName))
  140. kSigning := hmacSHA256(kService, []byte("request"))
  141. signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
  142. authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
  143. accessKey,
  144. credentialScope,
  145. signedHeaders,
  146. signature,
  147. )
  148. header.Set("Authorization", authorization)
  149. return nil
  150. }
  151. // hmacSHA256 计算 HMAC-SHA256
  152. func hmacSHA256(key []byte, data []byte) []byte {
  153. h := hmac.New(sha256.New, key)
  154. h.Write(data)
  155. return h.Sum(nil)
  156. }