body_storage.go 7.0 KB


  1. package common
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "os"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. // BodyStorage 请求体存储接口
  12. type BodyStorage interface {
  13. io.ReadSeeker
  14. io.Closer
  15. // Bytes 获取全部内容
  16. Bytes() ([]byte, error)
  17. // Size 获取数据大小
  18. Size() int64
  19. // IsDisk 是否是磁盘存储
  20. IsDisk() bool
  21. }
  22. // ErrStorageClosed 存储已关闭错误
  23. var ErrStorageClosed = fmt.Errorf("body storage is closed")
  24. // memoryStorage 内存存储实现
  25. type memoryStorage struct {
  26. data []byte
  27. reader *bytes.Reader
  28. size int64
  29. closed int32
  30. mu sync.Mutex
  31. }
  32. func newMemoryStorage(data []byte) *memoryStorage {
  33. size := int64(len(data))
  34. IncrementMemoryBuffers(size)
  35. return &memoryStorage{
  36. data: data,
  37. reader: bytes.NewReader(data),
  38. size: size,
  39. }
  40. }
  41. func (m *memoryStorage) Read(p []byte) (n int, err error) {
  42. m.mu.Lock()
  43. defer m.mu.Unlock()
  44. if atomic.LoadInt32(&m.closed) == 1 {
  45. return 0, ErrStorageClosed
  46. }
  47. return m.reader.Read(p)
  48. }
  49. func (m *memoryStorage) Seek(offset int64, whence int) (int64, error) {
  50. m.mu.Lock()
  51. defer m.mu.Unlock()
  52. if atomic.LoadInt32(&m.closed) == 1 {
  53. return 0, ErrStorageClosed
  54. }
  55. return m.reader.Seek(offset, whence)
  56. }
  57. func (m *memoryStorage) Close() error {
  58. m.mu.Lock()
  59. defer m.mu.Unlock()
  60. if atomic.CompareAndSwapInt32(&m.closed, 0, 1) {
  61. DecrementMemoryBuffers(m.size)
  62. }
  63. return nil
  64. }
  65. func (m *memoryStorage) Bytes() ([]byte, error) {
  66. m.mu.Lock()
  67. defer m.mu.Unlock()
  68. if atomic.LoadInt32(&m.closed) == 1 {
  69. return nil, ErrStorageClosed
  70. }
  71. return m.data, nil
  72. }
  73. func (m *memoryStorage) Size() int64 {
  74. return m.size
  75. }
  76. func (m *memoryStorage) IsDisk() bool {
  77. return false
  78. }
  79. // diskStorage 磁盘存储实现
  80. type diskStorage struct {
  81. file *os.File
  82. filePath string
  83. size int64
  84. closed int32
  85. mu sync.Mutex
  86. }
  87. func newDiskStorage(data []byte, cachePath string) (*diskStorage, error) {
  88. // 使用统一的缓存目录管理
  89. filePath, file, err := CreateDiskCacheFile(DiskCacheTypeBody)
  90. if err != nil {
  91. return nil, err
  92. }
  93. // 写入数据
  94. n, err := file.Write(data)
  95. if err != nil {
  96. file.Close()
  97. os.Remove(filePath)
  98. return nil, fmt.Errorf("failed to write to temp file: %w", err)
  99. }
  100. // 重置文件指针
  101. if _, err := file.Seek(0, io.SeekStart); err != nil {
  102. file.Close()
  103. os.Remove(filePath)
  104. return nil, fmt.Errorf("failed to seek temp file: %w", err)
  105. }
  106. size := int64(n)
  107. IncrementDiskFiles(size)
  108. return &diskStorage{
  109. file: file,
  110. filePath: filePath,
  111. size: size,
  112. }, nil
  113. }
  114. func newDiskStorageFromReader(reader io.Reader, maxBytes int64, cachePath string) (*diskStorage, error) {
  115. // 使用统一的缓存目录管理
  116. filePath, file, err := CreateDiskCacheFile(DiskCacheTypeBody)
  117. if err != nil {
  118. return nil, err
  119. }
  120. // 从 reader 读取并写入文件
  121. written, err := io.Copy(file, io.LimitReader(reader, maxBytes+1))
  122. if err != nil {
  123. file.Close()
  124. os.Remove(filePath)
  125. return nil, fmt.Errorf("failed to write to temp file: %w", err)
  126. }
  127. if written > maxBytes {
  128. file.Close()
  129. os.Remove(filePath)
  130. return nil, ErrRequestBodyTooLarge
  131. }
  132. // 重置文件指针
  133. if _, err := file.Seek(0, io.SeekStart); err != nil {
  134. file.Close()
  135. os.Remove(filePath)
  136. return nil, fmt.Errorf("failed to seek temp file: %w", err)
  137. }
  138. IncrementDiskFiles(written)
  139. return &diskStorage{
  140. file: file,
  141. filePath: filePath,
  142. size: written,
  143. }, nil
  144. }
  145. func (d *diskStorage) Read(p []byte) (n int, err error) {
  146. d.mu.Lock()
  147. defer d.mu.Unlock()
  148. if atomic.LoadInt32(&d.closed) == 1 {
  149. return 0, ErrStorageClosed
  150. }
  151. return d.file.Read(p)
  152. }
  153. func (d *diskStorage) Seek(offset int64, whence int) (int64, error) {
  154. d.mu.Lock()
  155. defer d.mu.Unlock()
  156. if atomic.LoadInt32(&d.closed) == 1 {
  157. return 0, ErrStorageClosed
  158. }
  159. return d.file.Seek(offset, whence)
  160. }
  161. func (d *diskStorage) Close() error {
  162. d.mu.Lock()
  163. defer d.mu.Unlock()
  164. if atomic.CompareAndSwapInt32(&d.closed, 0, 1) {
  165. d.file.Close()
  166. os.Remove(d.filePath)
  167. DecrementDiskFiles(d.size)
  168. }
  169. return nil
  170. }
  171. func (d *diskStorage) Bytes() ([]byte, error) {
  172. d.mu.Lock()
  173. defer d.mu.Unlock()
  174. if atomic.LoadInt32(&d.closed) == 1 {
  175. return nil, ErrStorageClosed
  176. }
  177. // 保存当前位置
  178. currentPos, err := d.file.Seek(0, io.SeekCurrent)
  179. if err != nil {
  180. return nil, err
  181. }
  182. // 移动到开头
  183. if _, err := d.file.Seek(0, io.SeekStart); err != nil {
  184. return nil, err
  185. }
  186. // 读取全部内容
  187. data := make([]byte, d.size)
  188. _, err = io.ReadFull(d.file, data)
  189. if err != nil {
  190. return nil, err
  191. }
  192. // 恢复位置
  193. if _, err := d.file.Seek(currentPos, io.SeekStart); err != nil {
  194. return nil, err
  195. }
  196. return data, nil
  197. }
  198. func (d *diskStorage) Size() int64 {
  199. return d.size
  200. }
  201. func (d *diskStorage) IsDisk() bool {
  202. return true
  203. }
  204. // CreateBodyStorage 根据数据大小创建合适的存储
  205. func CreateBodyStorage(data []byte) (BodyStorage, error) {
  206. size := int64(len(data))
  207. threshold := GetDiskCacheThresholdBytes()
  208. // 检查是否应该使用磁盘缓存
  209. if IsDiskCacheEnabled() &&
  210. size >= threshold &&
  211. IsDiskCacheAvailable(size) {
  212. storage, err := newDiskStorage(data, GetDiskCachePath())
  213. if err != nil {
  214. // 如果磁盘存储失败,回退到内存存储
  215. SysError(fmt.Sprintf("failed to create disk storage, falling back to memory: %v", err))
  216. return newMemoryStorage(data), nil
  217. }
  218. return storage, nil
  219. }
  220. return newMemoryStorage(data), nil
  221. }
  222. // CreateBodyStorageFromReader 从 Reader 创建存储(用于大请求的流式处理)
  223. func CreateBodyStorageFromReader(reader io.Reader, contentLength int64, maxBytes int64) (BodyStorage, error) {
  224. threshold := GetDiskCacheThresholdBytes()
  225. // 如果启用了磁盘缓存且内容长度超过阈值,直接使用磁盘存储
  226. if IsDiskCacheEnabled() &&
  227. contentLength > 0 &&
  228. contentLength >= threshold &&
  229. IsDiskCacheAvailable(contentLength) {
  230. storage, err := newDiskStorageFromReader(reader, maxBytes, GetDiskCachePath())
  231. if err != nil {
  232. if IsRequestBodyTooLargeError(err) {
  233. return nil, err
  234. }
  235. // 磁盘存储失败,reader 已被消费,无法安全回退
  236. // 直接返回错误而非尝试回退(因为 reader 数据已丢失)
  237. return nil, fmt.Errorf("disk storage creation failed: %w", err)
  238. }
  239. IncrementDiskCacheHits()
  240. return storage, nil
  241. }
  242. // 使用内存读取
  243. data, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
  244. if err != nil {
  245. return nil, err
  246. }
  247. if int64(len(data)) > maxBytes {
  248. return nil, ErrRequestBodyTooLarge
  249. }
  250. storage, err := CreateBodyStorage(data)
  251. if err != nil {
  252. return nil, err
  253. }
  254. // 如果最终使用内存存储,记录内存缓存命中
  255. if !storage.IsDisk() {
  256. IncrementMemoryCacheHits()
  257. } else {
  258. IncrementDiskCacheHits()
  259. }
  260. return storage, nil
  261. }
  262. // ReaderOnly wraps an io.Reader to hide io.Closer, preventing http.NewRequest
  263. // from type-asserting io.ReadCloser and closing the underlying BodyStorage.
  264. func ReaderOnly(r io.Reader) io.Reader {
  265. return struct{ io.Reader }{r}
  266. }
  267. // CleanupOldCacheFiles 清理旧的缓存文件(用于启动时清理残留)
  268. func CleanupOldCacheFiles() {
  269. // 使用统一的缓存管理
  270. CleanupOldDiskCacheFiles(5 * time.Minute)
  271. }