auth_flow.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import json
  2. import time
  3. import uuid
  4. import os
  5. import asyncio
  6. from typing import Dict, Tuple, Optional
  7. import httpx
  8. def _get_proxies() -> Optional[Dict[str, str]]:
  9. proxy = os.getenv("HTTP_PROXY", "").strip()
  10. if proxy:
  11. return {"http": proxy, "https": proxy}
  12. return None
  13. # OIDC endpoints and constants (aligned with v1/auth_client.py)
  14. OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
  15. REGISTER_URL = f"{OIDC_BASE}/client/register"
  16. DEVICE_AUTH_URL = f"{OIDC_BASE}/device_authorization"
  17. TOKEN_URL = f"{OIDC_BASE}/token"
  18. START_URL = "https://view.awsapps.com/start"
  19. USER_AGENT = "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0"
  20. X_AMZ_USER_AGENT = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
  21. AMZ_SDK_REQUEST = "attempt=1; max=3"
  22. def make_headers() -> Dict[str, str]:
  23. return {
  24. "content-type": "application/json",
  25. "user-agent": USER_AGENT,
  26. "x-amz-user-agent": X_AMZ_USER_AGENT,
  27. "amz-sdk-request": AMZ_SDK_REQUEST,
  28. "amz-sdk-invocation-id": str(uuid.uuid4()),
  29. }
  30. async def post_json(client: httpx.AsyncClient, url: str, payload: Dict) -> httpx.Response:
  31. # Keep JSON order and mimic body closely to v1
  32. payload_str = json.dumps(payload, ensure_ascii=False)
  33. headers = make_headers()
  34. resp = await client.post(url, headers=headers, content=payload_str, timeout=httpx.Timeout(15.0, read=60.0))
  35. return resp
  36. async def register_client_min() -> Tuple[str, str]:
  37. """
  38. Register an OIDC client (minimal) and return (clientId, clientSecret).
  39. """
  40. payload = {
  41. "clientName": "Amazon Q Developer for command line",
  42. "clientType": "public",
  43. "scopes": [
  44. "codewhisperer:completions",
  45. "codewhisperer:analysis",
  46. "codewhisperer:conversations",
  47. ],
  48. }
  49. proxies = _get_proxies()
  50. mounts = None
  51. if proxies:
  52. proxy_url = proxies.get("https") or proxies.get("http")
  53. if proxy_url:
  54. mounts = {
  55. "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  56. "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  57. }
  58. async with httpx.AsyncClient(mounts=mounts) as client:
  59. r = await post_json(client, REGISTER_URL, payload)
  60. r.raise_for_status()
  61. data = r.json()
  62. return data["clientId"], data["clientSecret"]
  63. async def device_authorize(client_id: str, client_secret: str) -> Dict:
  64. """
  65. Start device authorization. Returns dict that includes:
  66. - deviceCode
  67. - interval
  68. - expiresIn
  69. - verificationUriComplete
  70. - userCode
  71. """
  72. payload = {
  73. "clientId": client_id,
  74. "clientSecret": client_secret,
  75. "startUrl": START_URL,
  76. }
  77. proxies = _get_proxies()
  78. mounts = None
  79. if proxies:
  80. proxy_url = proxies.get("https") or proxies.get("http")
  81. if proxy_url:
  82. mounts = {
  83. "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  84. "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  85. }
  86. async with httpx.AsyncClient(mounts=mounts) as client:
  87. r = await post_json(client, DEVICE_AUTH_URL, payload)
  88. r.raise_for_status()
  89. return r.json()
  90. async def poll_token_device_code(
  91. client_id: str,
  92. client_secret: str,
  93. device_code: str,
  94. interval: int,
  95. expires_in: int,
  96. max_timeout_sec: Optional[int] = 300,
  97. ) -> Dict:
  98. """
  99. Poll token with device_code until approved or timeout.
  100. - Respects upstream expires_in, but caps total time by max_timeout_sec (default 5 minutes).
  101. Returns token dict with at least 'accessToken' and optionally 'refreshToken'.
  102. Raises:
  103. - TimeoutError on timeout
  104. - httpx.HTTPError for non-recoverable HTTP errors
  105. """
  106. payload = {
  107. "clientId": client_id,
  108. "clientSecret": client_secret,
  109. "deviceCode": device_code,
  110. "grantType": "urn:ietf:params:oauth:grant-type:device_code",
  111. }
  112. now = time.time()
  113. upstream_deadline = now + max(1, int(expires_in))
  114. cap_deadline = now + max_timeout_sec if (max_timeout_sec and max_timeout_sec > 0) else upstream_deadline
  115. deadline = min(upstream_deadline, cap_deadline)
  116. # Ensure interval sane
  117. poll_interval = max(1, int(interval or 1))
  118. proxies = _get_proxies()
  119. mounts = None
  120. if proxies:
  121. proxy_url = proxies.get("https") or proxies.get("http")
  122. if proxy_url:
  123. mounts = {
  124. "https://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  125. "http://": httpx.AsyncHTTPTransport(proxy=proxy_url),
  126. }
  127. async with httpx.AsyncClient(mounts=mounts) as client:
  128. while time.time() < deadline:
  129. r = await post_json(client, TOKEN_URL, payload)
  130. if r.status_code == 200:
  131. return r.json()
  132. if r.status_code == 400:
  133. # Expect AuthorizationPendingException early on
  134. try:
  135. err = r.json()
  136. except Exception:
  137. err = {"error": r.text}
  138. if str(err.get("error")) == "authorization_pending":
  139. await asyncio.sleep(poll_interval)
  140. continue
  141. # Other 4xx are errors
  142. r.raise_for_status()
  143. # Non-200, non-400
  144. r.raise_for_status()
  145. raise TimeoutError("Device authorization expired before approval (timeout reached)")