image.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package controller
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/labring/aiproxy/core/model"
  7. "github.com/labring/aiproxy/core/relay/adaptor/openai"
  8. relaymodel "github.com/labring/aiproxy/core/relay/model"
  9. "github.com/labring/aiproxy/core/relay/utils"
  10. )
  11. func getImagesRequest(c *gin.Context) (*relaymodel.ImageRequest, error) {
  12. imageRequest, err := utils.UnmarshalImageRequest(c.Request)
  13. if err != nil {
  14. return nil, err
  15. }
  16. if imageRequest.Prompt == "" {
  17. return nil, errors.New("prompt is required")
  18. }
  19. if imageRequest.N == 0 {
  20. imageRequest.N = 1
  21. }
  22. return imageRequest, nil
  23. }
  24. func GetImagesOutputPrice(modelConfig model.ModelConfig, size, quality string) (float64, bool) {
  25. switch {
  26. case len(modelConfig.ImagePrices) == 0 && len(modelConfig.ImageQualityPrices) == 0:
  27. return float64(modelConfig.Price.OutputPrice), true
  28. case len(modelConfig.ImageQualityPrices) != 0:
  29. price, ok := modelConfig.ImageQualityPrices[size][quality]
  30. return price, ok
  31. case len(modelConfig.ImagePrices) != 0:
  32. price, ok := modelConfig.ImagePrices[size]
  33. return price, ok
  34. default:
  35. return 0, false
  36. }
  37. }
  38. func GetImagesRequestPrice(c *gin.Context, mc model.ModelConfig) (model.Price, error) {
  39. imageRequest, err := getImagesRequest(c)
  40. if err != nil {
  41. return model.Price{}, err
  42. }
  43. imageCostPrice, ok := GetImagesOutputPrice(mc, imageRequest.Size, imageRequest.Quality)
  44. if !ok {
  45. return model.Price{}, fmt.Errorf(
  46. "invalid image size `%s` or quality `%s`",
  47. imageRequest.Size,
  48. imageRequest.Quality,
  49. )
  50. }
  51. return model.Price{
  52. PerRequestPrice: mc.Price.PerRequestPrice,
  53. InputPrice: mc.Price.InputPrice,
  54. InputPriceUnit: mc.Price.InputPriceUnit,
  55. ImageInputPrice: mc.Price.ImageInputPrice,
  56. ImageInputPriceUnit: mc.Price.ImageInputPriceUnit,
  57. OutputPrice: model.ZeroNullFloat64(imageCostPrice),
  58. OutputPriceUnit: mc.Price.OutputPriceUnit,
  59. }, nil
  60. }
  61. func GetImagesRequestUsage(c *gin.Context, _ model.ModelConfig) (model.Usage, error) {
  62. imageRequest, err := getImagesRequest(c)
  63. if err != nil {
  64. return model.Usage{}, err
  65. }
  66. return model.Usage{
  67. InputTokens: model.ZeroNullInt64(openai.CountTokenInput(
  68. imageRequest.Prompt,
  69. imageRequest.Model,
  70. )),
  71. OutputTokens: model.ZeroNullInt64(imageRequest.N),
  72. }, nil
  73. }