download.go 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. package service
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "strings"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/setting/system_setting"
  10. )
  11. // WorkerRequest Worker请求的数据结构
  12. type WorkerRequest struct {
  13. URL string `json:"url"`
  14. Key string `json:"key"`
  15. Method string `json:"method,omitempty"`
  16. Headers map[string]string `json:"headers,omitempty"`
  17. Body json.RawMessage `json:"body,omitempty"`
  18. }
  19. // DoWorkerRequest 通过Worker发送请求
  20. func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
  21. if !system_setting.EnableWorker() {
  22. return nil, fmt.Errorf("worker not enabled")
  23. }
  24. if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
  25. return nil, fmt.Errorf("only support https url")
  26. }
  27. // SSRF防护:验证请求URL
  28. fetchSetting := system_setting.GetFetchSetting()
  29. if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
  30. return nil, fmt.Errorf("request reject: %v", err)
  31. }
  32. workerUrl := system_setting.WorkerUrl
  33. if !strings.HasSuffix(workerUrl, "/") {
  34. workerUrl += "/"
  35. }
  36. // 序列化worker请求数据
  37. workerPayload, err := json.Marshal(req)
  38. if err != nil {
  39. return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
  40. }
  41. return GetHttpClient().Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
  42. }
  43. func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
  44. if system_setting.EnableWorker() {
  45. common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
  46. req := &WorkerRequest{
  47. URL: originUrl,
  48. Key: system_setting.WorkerValidKey,
  49. }
  50. return DoWorkerRequest(req)
  51. } else {
  52. // SSRF防护:验证请求URL(非Worker模式)
  53. fetchSetting := system_setting.GetFetchSetting()
  54. if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
  55. return nil, fmt.Errorf("request reject: %v", err)
  56. }
  57. common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
  58. return GetHttpClient().Get(originUrl)
  59. }
  60. }