middleware.ts 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import { Provider } from "../provider/provider"
  2. import { NamedError } from "@opencode-ai/util/error"
  3. import { NotFoundError } from "../storage/db"
  4. import { Session } from "../session"
  5. import type { ContentfulStatusCode } from "hono/utils/http-status"
  6. import type { ErrorHandler, MiddlewareHandler } from "hono"
  7. import { HTTPException } from "hono/http-exception"
  8. import { Log } from "../util/log"
  9. import { Flag } from "@/flag/flag"
  10. import { basicAuth } from "hono/basic-auth"
  11. import { cors } from "hono/cors"
  12. import { compress } from "hono/compress"
  13. const log = Log.create({ service: "server" })
  14. export const ErrorMiddleware: ErrorHandler = (err, c) => {
  15. log.error("failed", {
  16. error: err,
  17. })
  18. if (err instanceof NamedError) {
  19. let status: ContentfulStatusCode
  20. if (err instanceof NotFoundError) status = 404
  21. else if (err instanceof Provider.ModelNotFoundError) status = 400
  22. else if (err.name === "ProviderAuthValidationFailed") status = 400
  23. else if (err.name.startsWith("Worktree")) status = 400
  24. else status = 500
  25. return c.json(err.toObject(), { status })
  26. }
  27. if (err instanceof Session.BusyError) {
  28. return c.json(new NamedError.Unknown({ message: err.message }).toObject(), { status: 400 })
  29. }
  30. if (err instanceof HTTPException) return err.getResponse()
  31. const message = err instanceof Error && err.stack ? err.stack : err.toString()
  32. return c.json(new NamedError.Unknown({ message }).toObject(), {
  33. status: 500,
  34. })
  35. }
  36. export const AuthMiddleware: MiddlewareHandler = (c, next) => {
  37. // Allow CORS preflight requests to succeed without auth.
  38. // Browser clients sending Authorization headers will preflight with OPTIONS.
  39. if (c.req.method === "OPTIONS") return next()
  40. const password = Flag.OPENCODE_SERVER_PASSWORD
  41. if (!password) return next()
  42. const username = Flag.OPENCODE_SERVER_USERNAME ?? "opencode"
  43. if (c.req.query("auth_token")) c.req.raw.headers.set("authorization", `Basic ${c.req.query("auth_token")}`)
  44. return basicAuth({ username, password })(c, next)
  45. }
  46. export const LoggerMiddleware: MiddlewareHandler = async (c, next) => {
  47. const skip = c.req.path === "/log"
  48. if (!skip) {
  49. log.info("request", {
  50. method: c.req.method,
  51. path: c.req.path,
  52. })
  53. }
  54. const timer = log.time("request", {
  55. method: c.req.method,
  56. path: c.req.path,
  57. })
  58. await next()
  59. if (!skip) timer.stop()
  60. }
  61. export function CorsMiddleware(opts?: { cors?: string[] }): MiddlewareHandler {
  62. return cors({
  63. maxAge: 86_400,
  64. origin(input) {
  65. if (!input) return
  66. if (input.startsWith("http://localhost:")) return input
  67. if (input.startsWith("http://127.0.0.1:")) return input
  68. if (input === "tauri://localhost" || input === "http://tauri.localhost" || input === "https://tauri.localhost")
  69. return input
  70. if (/^https:\/\/([a-z0-9-]+\.)*opencode\.ai$/.test(input)) return input
  71. if (opts?.cors?.includes(input)) return input
  72. },
  73. })
  74. }
  75. const zipped = compress()
  76. export const CompressionMiddleware: MiddlewareHandler = (c, next) => {
  77. const path = c.req.path
  78. const method = c.req.method
  79. if (path === "/event" || path === "/global/event" || path === "/global/sync-event") return next()
  80. if (method === "POST" && /\/session\/[^/]+\/(message|prompt_async)$/.test(path)) return next()
  81. return zipped(c, next)
  82. }