replicate.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import json
  2. import uuid
  3. import os
  4. import asyncio
  5. import weakref
  6. from pathlib import Path
  7. from typing import Dict, Optional, Tuple, Iterator, List, AsyncGenerator, Any
  8. import struct
  9. import httpx
  10. import importlib.util
  11. def _load_claude_parser():
  12. """Dynamically load claude_parser module."""
  13. base_dir = Path(__file__).resolve().parent
  14. spec = importlib.util.spec_from_file_location("v2_claude_parser", str(base_dir / "claude_parser.py"))
  15. module = importlib.util.module_from_spec(spec)
  16. spec.loader.exec_module(module)
  17. return module
  18. try:
  19. _parser = _load_claude_parser()
  20. EventStreamParser = _parser.EventStreamParser
  21. extract_event_info = _parser.extract_event_info
  22. except Exception as e:
  23. print(f"Warning: Failed to load claude_parser: {e}")
  24. EventStreamParser = None
  25. extract_event_info = None
  26. class StreamTracker:
  27. def __init__(self):
  28. self.has_content = False
  29. async def track(self, gen: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
  30. async for item in gen:
  31. if item:
  32. self.has_content = True
  33. yield item
  34. def _get_proxies() -> Optional[Dict[str, str]]:
  35. proxy = os.getenv("HTTP_PROXY", "").strip()
  36. if proxy:
  37. return {"http": proxy, "https": proxy}
  38. return None
  39. BASE_DIR = Path(__file__).resolve().parent
  40. TEMPLATE_PATH = BASE_DIR / "templates" / "streaming_request.json"
  41. def load_template() -> Tuple[str, Dict[str, str], Dict[str, Any]]:
  42. data = json.loads(TEMPLATE_PATH.read_text(encoding="utf-8"))
  43. url, headers, body = data
  44. assert isinstance(url, str) and isinstance(headers, dict) and isinstance(body, dict)
  45. return url, headers, body
  46. def _merge_headers(as_log: Dict[str, str], bearer_token: str) -> Dict[str, str]:
  47. headers = dict(as_log)
  48. for k in list(headers.keys()):
  49. kl = k.lower()
  50. if kl in ("content-length","host","connection","transfer-encoding"):
  51. headers.pop(k, None)
  52. def set_header(name: str, value: str):
  53. for key in list(headers.keys()):
  54. if key.lower() == name.lower():
  55. del headers[key]
  56. headers[name] = value
  57. set_header("Authorization", f"Bearer {bearer_token}")
  58. set_header("amz-sdk-invocation-id", str(uuid.uuid4()))
  59. return headers
  60. def _parse_event_headers(raw: bytes) -> Dict[str, object]:
  61. headers: Dict[str, object] = {}
  62. i = 0
  63. n = len(raw)
  64. while i < n:
  65. if i + 1 > n:
  66. break
  67. name_len = raw[i]
  68. i += 1
  69. if i + name_len + 1 > n:
  70. break
  71. name = raw[i : i + name_len].decode("utf-8", errors="ignore")
  72. i += name_len
  73. htype = raw[i]
  74. i += 1
  75. if htype == 0:
  76. val = True
  77. elif htype == 1:
  78. val = False
  79. elif htype == 2:
  80. if i + 1 > n: break
  81. val = raw[i]; i += 1
  82. elif htype == 3:
  83. if i + 2 > n: break
  84. val = int.from_bytes(raw[i:i+2],"big",signed=True); i += 2
  85. elif htype == 4:
  86. if i + 4 > n: break
  87. val = int.from_bytes(raw[i:i+4],"big",signed=True); i += 4
  88. elif htype == 5:
  89. if i + 8 > n: break
  90. val = int.from_bytes(raw[i:i+8],"big",signed=True); i += 8
  91. elif htype == 6:
  92. if i + 2 > n: break
  93. l = int.from_bytes(raw[i:i+2],"big"); i += 2
  94. if i + l > n: break
  95. val = raw[i:i+l]; i += l
  96. elif htype == 7:
  97. if i + 2 > n: break
  98. l = int.from_bytes(raw[i:i+2],"big"); i += 2
  99. if i + l > n: break
  100. val = raw[i:i+l].decode("utf-8", errors="ignore"); i += l
  101. elif htype == 8:
  102. if i + 8 > n: break
  103. val = int.from_bytes(raw[i:i+8],"big",signed=False); i += 8
  104. elif htype == 9:
  105. if i + 16 > n: break
  106. import uuid as _uuid
  107. val = str(_uuid.UUID(bytes=bytes(raw[i:i+16]))); i += 16
  108. else:
  109. break
  110. headers[name] = val
  111. return headers
  112. class AwsEventStreamParser:
  113. def __init__(self):
  114. self._buf = bytearray()
  115. def feed(self, data: bytes) -> List[Tuple[Dict[str, object], bytes]]:
  116. if not data:
  117. return []
  118. self._buf.extend(data)
  119. out: List[Tuple[Dict[str, object], bytes]] = []
  120. while True:
  121. if len(self._buf) < 12:
  122. break
  123. total_len, headers_len, _prelude_crc = struct.unpack(">I I I", self._buf[:12])
  124. if total_len < 16 or headers_len > total_len:
  125. self._buf.pop(0)
  126. continue
  127. if len(self._buf) < total_len:
  128. break
  129. msg = bytes(self._buf[:total_len])
  130. del self._buf[:total_len]
  131. headers_raw = msg[12:12+headers_len]
  132. payload = msg[12+headers_len: total_len-4]
  133. headers = _parse_event_headers(headers_raw)
  134. out.append((headers, payload))
  135. return out
  136. def _try_decode_event_payload(payload: bytes) -> Optional[dict]:
  137. try:
  138. return json.loads(payload.decode("utf-8"))
  139. except Exception:
  140. return None
  141. def _extract_text_from_event(ev: dict) -> Optional[str]:
  142. for key in ("assistantResponseEvent","assistantMessage","message","delta","data"):
  143. if key in ev and isinstance(ev[key], dict):
  144. inner = ev[key]
  145. if isinstance(inner.get("content"), str) and inner.get("content"):
  146. return inner["content"]
  147. if isinstance(ev.get("content"), str) and ev.get("content"):
  148. return ev["content"]
  149. for list_key in ("chunks","content"):
  150. if isinstance(ev.get(list_key), list):
  151. buf = []
  152. for item in ev[list_key]:
  153. if isinstance(item, dict):
  154. if isinstance(item.get("content"), str):
  155. buf.append(item["content"])
  156. elif isinstance(item.get("text"), str):
  157. buf.append(item["text"])
  158. elif isinstance(item, str):
  159. buf.append(item)
  160. if buf:
  161. return "".join(buf)
  162. for k in ("text","delta","payload"):
  163. v = ev.get(k)
  164. if isinstance(v, str) and v:
  165. return v
  166. return None
  167. def openai_messages_to_text(messages: List[Dict[str, Any]]) -> str:
  168. lines: List[str] = []
  169. for m in messages:
  170. role = m.get("role","user")
  171. content = m.get("content","")
  172. if isinstance(content, list):
  173. parts = []
  174. for seg in content:
  175. if isinstance(seg, dict) and isinstance(seg.get("text"), str):
  176. parts.append(seg["text"])
  177. elif isinstance(seg, str):
  178. parts.append(seg)
  179. content = "\n".join(parts)
  180. elif not isinstance(content, str):
  181. content = str(content)
  182. lines.append(f"{role}:\n{content}")
  183. return "\n\n".join(lines)
  184. def inject_history(body_json: Dict[str, Any], history_text: str) -> None:
  185. try:
  186. cur = body_json["conversationState"]["currentMessage"]["userInputMessage"]
  187. content = cur.get("content","")
  188. if isinstance(content, str):
  189. cur["content"] = content.replace("你好,你必须讲个故事", history_text)
  190. except Exception:
  191. pass
  192. def inject_model(body_json: Dict[str, Any], model: Optional[str]) -> None:
  193. if not model:
  194. return
  195. try:
  196. body_json["conversationState"]["currentMessage"]["userInputMessage"]["modelId"] = model
  197. except Exception:
  198. pass
  199. async def send_chat_request(
  200. access_token: str,
  201. messages: List[Dict[str, Any]],
  202. model: Optional[str] = None,
  203. stream: bool = False,
  204. timeout: Tuple[int,int] = (15,300),
  205. client: Optional[httpx.AsyncClient] = None,
  206. raw_payload: Optional[Dict[str, Any]] = None
  207. ) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], StreamTracker, Optional[AsyncGenerator[Any, None]]]:
  208. url, headers_from_log, body_json = load_template()
  209. headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4())
  210. if raw_payload:
  211. # Use raw payload if provided (for Claude API)
  212. body_json = raw_payload
  213. # Ensure conversationId is set if missing
  214. if "conversationState" in body_json and "conversationId" not in body_json["conversationState"]:
  215. body_json["conversationState"]["conversationId"] = str(uuid.uuid4())
  216. else:
  217. # Standard OpenAI-compatible logic
  218. try:
  219. body_json["conversationState"]["conversationId"] = str(uuid.uuid4())
  220. except Exception:
  221. pass
  222. history_text = openai_messages_to_text(messages)
  223. inject_history(body_json, history_text)
  224. inject_model(body_json, model)
  225. payload_str = json.dumps(body_json, ensure_ascii=False)
  226. headers = _merge_headers(headers_from_log, access_token)
  227. local_client = False
  228. if client is None:
  229. local_client = True
  230. proxies = _get_proxies()
  231. mounts = None
  232. if proxies:
  233. proxy_url = proxies.get("https") or proxies.get("http")
  234. if proxy_url:
  235. mounts = {
  236. "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  237. "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  238. }
  239. client = httpx.AsyncClient(mounts=mounts, timeout=httpx.Timeout(timeout[0], read=timeout[1]))
  240. # Use manual request sending to control stream lifetime
  241. req = client.build_request("POST", url, headers=headers, content=payload_str)
  242. resp = None
  243. try:
  244. resp = await client.send(req, stream=True)
  245. if resp.status_code >= 400:
  246. try:
  247. await resp.aread()
  248. err = resp.text
  249. except Exception:
  250. err = f"HTTP {resp.status_code}"
  251. await resp.aclose()
  252. if local_client:
  253. await client.aclose()
  254. raise httpx.HTTPError(f"Upstream error {resp.status_code}: {err}")
  255. parser = AwsEventStreamParser()
  256. tracker = StreamTracker()
  257. # Track if the response has been consumed to avoid double-close
  258. response_consumed = False
  259. async def _iter_events() -> AsyncGenerator[Any, None]:
  260. nonlocal response_consumed
  261. try:
  262. if EventStreamParser and extract_event_info:
  263. # Use proper EventStreamParser
  264. async def byte_gen():
  265. async for chunk in resp.aiter_bytes():
  266. if chunk:
  267. yield chunk
  268. async for message in EventStreamParser.parse_stream(byte_gen()):
  269. event_info = extract_event_info(message)
  270. if event_info:
  271. event_type = event_info.get('event_type')
  272. payload = event_info.get('payload')
  273. if event_type and payload:
  274. yield (event_type, payload)
  275. else:
  276. # Fallback to old parser
  277. async for chunk in resp.aiter_bytes():
  278. if not chunk:
  279. continue
  280. events = parser.feed(chunk)
  281. for ev_headers, payload in events:
  282. parsed = _try_decode_event_payload(payload)
  283. if parsed is not None:
  284. event_type = None
  285. if ":event-type" in ev_headers:
  286. event_type = ev_headers[":event-type"]
  287. yield (event_type, parsed)
  288. except GeneratorExit:
  289. # Client disconnected - ensure cleanup without re-raising
  290. pass
  291. except Exception:
  292. if not tracker.has_content:
  293. raise
  294. finally:
  295. response_consumed = True
  296. if resp and not resp.is_closed:
  297. await resp.aclose()
  298. if local_client and client:
  299. await client.aclose()
  300. async def _iter_text() -> AsyncGenerator[str, None]:
  301. async for event_type, parsed in _iter_events():
  302. text = _extract_text_from_event(parsed)
  303. if isinstance(text, str) and text:
  304. yield text
  305. def _schedule_cleanup():
  306. """Schedule cleanup when generator is GC'd without being consumed.
  307. - If there's a running loop: spawn tasks for aclose()
  308. - Else: try a synchronous close fallback (best-effort)
  309. """
  310. try:
  311. if not resp:
  312. return
  313. if not getattr(resp, "is_closed", True):
  314. try:
  315. loop = asyncio.get_event_loop()
  316. if loop.is_running():
  317. asyncio.create_task(resp.aclose())
  318. if local_client and client:
  319. asyncio.create_task(client.aclose())
  320. else:
  321. # No running loop: best-effort close
  322. try:
  323. # Prefer async close via a temporary loop
  324. asyncio.run(resp.aclose())
  325. if local_client and client:
  326. asyncio.run(client.aclose())
  327. except Exception:
  328. # Fallback to sync close if available
  329. if hasattr(resp, "close"):
  330. try:
  331. resp.close() # type: ignore[attr-defined]
  332. except Exception:
  333. pass
  334. if local_client and client and hasattr(client, "close"):
  335. try:
  336. client.close() # type: ignore[attr-defined]
  337. except Exception:
  338. pass
  339. except RuntimeError:
  340. # No event loop; best-effort sync close
  341. try:
  342. asyncio.run(resp.aclose())
  343. if local_client and client:
  344. asyncio.run(client.aclose())
  345. except Exception:
  346. if hasattr(resp, "close"):
  347. try:
  348. resp.close() # type: ignore[attr-defined]
  349. except Exception:
  350. pass
  351. if local_client and client and hasattr(client, "close"):
  352. try:
  353. client.close() # type: ignore[attr-defined]
  354. except Exception:
  355. pass
  356. except Exception:
  357. # Final safety: swallow exceptions in finalizer
  358. pass
  359. if stream:
  360. # If raw_payload is used, we might want the raw event stream
  361. if raw_payload:
  362. # Return event stream with finalizer as safety net
  363. event_gen = _iter_events()
  364. weakref.finalize(event_gen, _schedule_cleanup)
  365. return None, None, tracker, event_gen
  366. # Return text stream with finalizer as safety net
  367. text_gen = tracker.track(_iter_text())
  368. weakref.finalize(text_gen, _schedule_cleanup)
  369. return None, text_gen, tracker, None
  370. else:
  371. buf = []
  372. try:
  373. async for t in tracker.track(_iter_text()):
  374. buf.append(t)
  375. finally:
  376. # Ensure response is closed even if iteration is incomplete
  377. if not response_consumed and resp:
  378. await resp.aclose()
  379. if local_client:
  380. await client.aclose()
  381. return "".join(buf), None, tracker, None
  382. except Exception:
  383. # Critical: close response on any exception before generators are created
  384. if resp and not resp.is_closed:
  385. await resp.aclose()
  386. if local_client and client:
  387. await client.aclose()
  388. raise