replicate.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import json
  2. import uuid
  3. import os
  4. from pathlib import Path
  5. from typing import Dict, Optional, Tuple, Iterator, List, AsyncGenerator, Any
  6. import struct
  7. import httpx
  8. class StreamTracker:
  9. def __init__(self):
  10. self.has_content = False
  11. async def track(self, gen: AsyncGenerator[str, None]) -> AsyncGenerator[str, None]:
  12. async for item in gen:
  13. if item:
  14. self.has_content = True
  15. yield item
  16. def _get_proxies() -> Optional[Dict[str, str]]:
  17. proxy = os.getenv("HTTP_PROXY", "").strip()
  18. if proxy:
  19. return {"http": proxy, "https": proxy}
  20. return None
  21. BASE_DIR = Path(__file__).resolve().parent
  22. TEMPLATE_PATH = BASE_DIR / "templates" / "streaming_request.json"
  23. def load_template() -> Tuple[str, Dict[str, str], Dict[str, Any]]:
  24. data = json.loads(TEMPLATE_PATH.read_text(encoding="utf-8"))
  25. url, headers, body = data
  26. assert isinstance(url, str) and isinstance(headers, dict) and isinstance(body, dict)
  27. return url, headers, body
  28. def _merge_headers(as_log: Dict[str, str], bearer_token: str) -> Dict[str, str]:
  29. headers = dict(as_log)
  30. for k in list(headers.keys()):
  31. kl = k.lower()
  32. if kl in ("content-length","host","connection","transfer-encoding"):
  33. headers.pop(k, None)
  34. def set_header(name: str, value: str):
  35. for key in list(headers.keys()):
  36. if key.lower() == name.lower():
  37. del headers[key]
  38. headers[name] = value
  39. set_header("Authorization", f"Bearer {bearer_token}")
  40. set_header("amz-sdk-invocation-id", str(uuid.uuid4()))
  41. return headers
  42. def _parse_event_headers(raw: bytes) -> Dict[str, object]:
  43. headers: Dict[str, object] = {}
  44. i = 0
  45. n = len(raw)
  46. while i < n:
  47. if i + 1 > n:
  48. break
  49. name_len = raw[i]
  50. i += 1
  51. if i + name_len + 1 > n:
  52. break
  53. name = raw[i : i + name_len].decode("utf-8", errors="ignore")
  54. i += name_len
  55. htype = raw[i]
  56. i += 1
  57. if htype == 0:
  58. val = True
  59. elif htype == 1:
  60. val = False
  61. elif htype == 2:
  62. if i + 1 > n: break
  63. val = raw[i]; i += 1
  64. elif htype == 3:
  65. if i + 2 > n: break
  66. val = int.from_bytes(raw[i:i+2],"big",signed=True); i += 2
  67. elif htype == 4:
  68. if i + 4 > n: break
  69. val = int.from_bytes(raw[i:i+4],"big",signed=True); i += 4
  70. elif htype == 5:
  71. if i + 8 > n: break
  72. val = int.from_bytes(raw[i:i+8],"big",signed=True); i += 8
  73. elif htype == 6:
  74. if i + 2 > n: break
  75. l = int.from_bytes(raw[i:i+2],"big"); i += 2
  76. if i + l > n: break
  77. val = raw[i:i+l]; i += l
  78. elif htype == 7:
  79. if i + 2 > n: break
  80. l = int.from_bytes(raw[i:i+2],"big"); i += 2
  81. if i + l > n: break
  82. val = raw[i:i+l].decode("utf-8", errors="ignore"); i += l
  83. elif htype == 8:
  84. if i + 8 > n: break
  85. val = int.from_bytes(raw[i:i+8],"big",signed=False); i += 8
  86. elif htype == 9:
  87. if i + 16 > n: break
  88. import uuid as _uuid
  89. val = str(_uuid.UUID(bytes=bytes(raw[i:i+16]))); i += 16
  90. else:
  91. break
  92. headers[name] = val
  93. return headers
  94. class AwsEventStreamParser:
  95. def __init__(self):
  96. self._buf = bytearray()
  97. def feed(self, data: bytes) -> List[Tuple[Dict[str, object], bytes]]:
  98. if not data:
  99. return []
  100. self._buf.extend(data)
  101. out: List[Tuple[Dict[str, object], bytes]] = []
  102. while True:
  103. if len(self._buf) < 12:
  104. break
  105. total_len, headers_len, _prelude_crc = struct.unpack(">I I I", self._buf[:12])
  106. if total_len < 16 or headers_len > total_len:
  107. self._buf.pop(0)
  108. continue
  109. if len(self._buf) < total_len:
  110. break
  111. msg = bytes(self._buf[:total_len])
  112. del self._buf[:total_len]
  113. headers_raw = msg[12:12+headers_len]
  114. payload = msg[12+headers_len: total_len-4]
  115. headers = _parse_event_headers(headers_raw)
  116. out.append((headers, payload))
  117. return out
  118. def _try_decode_event_payload(payload: bytes) -> Optional[dict]:
  119. try:
  120. return json.loads(payload.decode("utf-8"))
  121. except Exception:
  122. return None
  123. def _extract_text_from_event(ev: dict) -> Optional[str]:
  124. for key in ("assistantResponseEvent","assistantMessage","message","delta","data"):
  125. if key in ev and isinstance(ev[key], dict):
  126. inner = ev[key]
  127. if isinstance(inner.get("content"), str) and inner.get("content"):
  128. return inner["content"]
  129. if isinstance(ev.get("content"), str) and ev.get("content"):
  130. return ev["content"]
  131. for list_key in ("chunks","content"):
  132. if isinstance(ev.get(list_key), list):
  133. buf = []
  134. for item in ev[list_key]:
  135. if isinstance(item, dict):
  136. if isinstance(item.get("content"), str):
  137. buf.append(item["content"])
  138. elif isinstance(item.get("text"), str):
  139. buf.append(item["text"])
  140. elif isinstance(item, str):
  141. buf.append(item)
  142. if buf:
  143. return "".join(buf)
  144. for k in ("text","delta","payload"):
  145. v = ev.get(k)
  146. if isinstance(v, str) and v:
  147. return v
  148. return None
  149. def openai_messages_to_text(messages: List[Dict[str, Any]]) -> str:
  150. lines: List[str] = []
  151. for m in messages:
  152. role = m.get("role","user")
  153. content = m.get("content","")
  154. if isinstance(content, list):
  155. parts = []
  156. for seg in content:
  157. if isinstance(seg, dict) and isinstance(seg.get("text"), str):
  158. parts.append(seg["text"])
  159. elif isinstance(seg, str):
  160. parts.append(seg)
  161. content = "\n".join(parts)
  162. elif not isinstance(content, str):
  163. content = str(content)
  164. lines.append(f"{role}:\n{content}")
  165. return "\n\n".join(lines)
  166. def inject_history(body_json: Dict[str, Any], history_text: str) -> None:
  167. try:
  168. cur = body_json["conversationState"]["currentMessage"]["userInputMessage"]
  169. content = cur.get("content","")
  170. if isinstance(content, str):
  171. cur["content"] = content.replace("你好,你必须讲个故事", history_text)
  172. except Exception:
  173. pass
  174. def inject_model(body_json: Dict[str, Any], model: Optional[str]) -> None:
  175. if not model:
  176. return
  177. try:
  178. body_json["conversationState"]["currentMessage"]["userInputMessage"]["modelId"] = model
  179. except Exception:
  180. pass
  181. async def send_chat_request(access_token: str, messages: List[Dict[str, Any]], model: Optional[str] = None, stream: bool = False, timeout: Tuple[int,int] = (15,300)) -> Tuple[Optional[str], Optional[AsyncGenerator[str, None]], StreamTracker]:
  182. url, headers_from_log, body_json = load_template()
  183. headers_from_log["amz-sdk-invocation-id"] = str(uuid.uuid4())
  184. try:
  185. body_json["conversationState"]["conversationId"] = str(uuid.uuid4())
  186. except Exception:
  187. pass
  188. history_text = openai_messages_to_text(messages)
  189. inject_history(body_json, history_text)
  190. inject_model(body_json, model)
  191. payload_str = json.dumps(body_json, ensure_ascii=False)
  192. headers = _merge_headers(headers_from_log, access_token)
  193. # Build mounts with proxy if available
  194. proxies = _get_proxies()
  195. mounts = None
  196. if proxies:
  197. proxy_url = proxies.get("https") or proxies.get("http")
  198. if proxy_url:
  199. mounts = {
  200. "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  201. "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  202. }
  203. async with httpx.AsyncClient(mounts=mounts, timeout=httpx.Timeout(timeout[0], read=timeout[1])) as client:
  204. async with client.stream("POST", url, headers=headers, content=payload_str) as resp:
  205. if resp.status_code >= 400:
  206. try:
  207. err = await resp.aread()
  208. err = err.decode("utf-8", errors="ignore")
  209. except Exception:
  210. err = f"HTTP {resp.status_code}"
  211. raise httpx.HTTPError(f"Upstream error {resp.status_code}: {err}")
  212. parser = AwsEventStreamParser()
  213. tracker = StreamTracker()
  214. async def _iter_text() -> AsyncGenerator[str, None]:
  215. async for chunk in resp.aiter_bytes():
  216. if not chunk:
  217. continue
  218. events = parser.feed(chunk)
  219. for _ev_headers, payload in events:
  220. parsed = _try_decode_event_payload(payload)
  221. if parsed is not None:
  222. text = _extract_text_from_event(parsed)
  223. if isinstance(text, str) and text:
  224. yield text
  225. else:
  226. try:
  227. txt = payload.decode("utf-8", errors="ignore")
  228. if txt:
  229. yield txt
  230. except Exception:
  231. pass
  232. if stream:
  233. return None, tracker.track(_iter_text()), tracker
  234. else:
  235. buf = []
  236. async for t in tracker.track(_iter_text()):
  237. buf.append(t)
  238. return "".join(buf), None, tracker