helpers.go 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package taskcommon
  2. import (
  3. "encoding/base64"
  4. "fmt"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/model"
  7. relaycommon "github.com/QuantumNous/new-api/relay/common"
  8. "github.com/QuantumNous/new-api/setting/system_setting"
  9. "github.com/gin-gonic/gin"
  10. )
  11. // UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
  12. // This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target).
  13. func UnmarshalMetadata(metadata map[string]any, target any) error {
  14. if metadata == nil {
  15. return nil
  16. }
  17. // Prevent metadata from overriding model fields to avoid billing bypass.
  18. delete(metadata, "model")
  19. metaBytes, err := common.Marshal(metadata)
  20. if err != nil {
  21. return fmt.Errorf("marshal metadata failed: %w", err)
  22. }
  23. if err := common.Unmarshal(metaBytes, target); err != nil {
  24. return fmt.Errorf("unmarshal metadata failed: %w", err)
  25. }
  26. return nil
  27. }
  28. // DefaultString returns val if non-empty, otherwise fallback.
  29. func DefaultString(val, fallback string) string {
  30. if val == "" {
  31. return fallback
  32. }
  33. return val
  34. }
  35. // DefaultInt returns val if non-zero, otherwise fallback.
  36. func DefaultInt(val, fallback int) int {
  37. if val == 0 {
  38. return fallback
  39. }
  40. return val
  41. }
  42. // EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string.
  43. // Used by Gemini/Vertex to store upstream names as task IDs.
  44. func EncodeLocalTaskID(name string) string {
  45. return base64.RawURLEncoding.EncodeToString([]byte(name))
  46. }
  47. // DecodeLocalTaskID decodes a base64-encoded upstream operation name.
  48. func DecodeLocalTaskID(id string) (string, error) {
  49. b, err := base64.RawURLEncoding.DecodeString(id)
  50. if err != nil {
  51. return "", err
  52. }
  53. return string(b), nil
  54. }
  55. // BuildProxyURL constructs the video proxy URL using the public task ID.
  56. // e.g., "https://your-server.com/v1/videos/task_xxxx/content"
  57. func BuildProxyURL(taskID string) string {
  58. return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
  59. }
  60. // Status-to-progress mapping constants for polling updates.
  61. const (
  62. ProgressSubmitted = "10%"
  63. ProgressQueued = "20%"
  64. ProgressInProgress = "30%"
  65. ProgressComplete = "100%"
  66. )
  67. // ---------------------------------------------------------------------------
  68. // BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
  69. // Adaptors that do not need custom billing can embed this struct directly.
  70. // ---------------------------------------------------------------------------
  71. type BaseBilling struct{}
  72. // EstimateBilling returns nil (no extra ratios; use base model price).
  73. func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
  74. return nil
  75. }
  76. // AdjustBillingOnSubmit returns nil (no submit-time adjustment).
  77. func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
  78. return nil
  79. }
  80. // AdjustBillingOnComplete returns 0 (keep pre-charged amount).
  81. func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
  82. return 0
  83. }