stt.go 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package controller
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "math"
  7. "mime/multipart"
  8. "os"
  9. "github.com/gin-gonic/gin"
  10. "github.com/labring/aiproxy/core/common/audio"
  11. "github.com/labring/aiproxy/core/model"
  12. "github.com/labring/aiproxy/core/relay/adaptor/openai"
  13. )
  14. func GetSTTRequestUsage(c *gin.Context, mc model.ModelConfig) (model.Usage, error) {
  15. audioFile, err := c.FormFile("file")
  16. if err != nil {
  17. return model.Usage{}, fmt.Errorf("failed to get audio file: %w", err)
  18. }
  19. duration, err := getAudioDuration(c.Request.Context(), audioFile)
  20. if err != nil {
  21. return model.Usage{}, err
  22. }
  23. durationInt := int64(math.Ceil(duration))
  24. return model.Usage{
  25. InputTokens: model.ZeroNullInt64(
  26. openai.CountTokenInput(c.PostForm("prompt"), mc.Model) + durationInt,
  27. ),
  28. AudioInputTokens: model.ZeroNullInt64(durationInt),
  29. }, nil
  30. }
  31. func getAudioDuration(ctx context.Context, audioFile *multipart.FileHeader) (float64, error) {
  32. // Try to get duration directly from audio data
  33. audioData, err := audioFile.Open()
  34. if err != nil {
  35. return 0, fmt.Errorf("failed to open audio file: %w", err)
  36. }
  37. defer audioData.Close()
  38. // If it's already an os.File, use file path method
  39. if osFile, ok := audioData.(*os.File); ok {
  40. duration, err := audio.GetAudioDurationFromFilePath(ctx, osFile.Name())
  41. if err != nil {
  42. return 0, fmt.Errorf("failed to get audio duration from temp file: %w", err)
  43. }
  44. return duration, nil
  45. }
  46. // Try to get duration from audio data
  47. duration, err := audio.GetAudioDuration(ctx, audioData)
  48. if err == nil {
  49. return duration, nil
  50. }
  51. // If duration is NaN, create temp file and try again
  52. if errors.Is(err, audio.ErrAudioDurationNAN) {
  53. return getDurationFromTempFile(ctx, audioFile)
  54. }
  55. return 0, fmt.Errorf("failed to get audio duration: %w", err)
  56. }
  57. func getDurationFromTempFile(
  58. ctx context.Context,
  59. audioFile *multipart.FileHeader,
  60. ) (float64, error) {
  61. tempFile, err := os.CreateTemp("", "audio")
  62. if err != nil {
  63. return 0, fmt.Errorf("failed to create temp file: %w", err)
  64. }
  65. defer os.Remove(tempFile.Name())
  66. defer tempFile.Close()
  67. newAudioData, err := audioFile.Open()
  68. if err != nil {
  69. return 0, fmt.Errorf("failed to open audio file: %w", err)
  70. }
  71. defer newAudioData.Close()
  72. if _, err = tempFile.ReadFrom(newAudioData); err != nil {
  73. return 0, fmt.Errorf("failed to read from temp file: %w", err)
  74. }
  75. duration, err := audio.GetAudioDurationFromFilePath(ctx, tempFile.Name())
  76. if err != nil {
  77. return 0, fmt.Errorf("failed to get audio duration from temp file: %w", err)
  78. }
  79. return duration, nil
  80. }