auth_flow.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import json
  2. import time
  3. import uuid
  4. import os
  5. from typing import Dict, Tuple, Optional
  6. import requests
  7. def _get_proxies() -> Optional[Dict[str, str]]:
  8. proxy = os.getenv("HTTP_PROXY", "").strip()
  9. if proxy:
  10. return {"http": proxy, "https": proxy}
  11. return None
  12. # OIDC endpoints and constants (aligned with v1/auth_client.py)
  13. OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
  14. REGISTER_URL = f"{OIDC_BASE}/client/register"
  15. DEVICE_AUTH_URL = f"{OIDC_BASE}/device_authorization"
  16. TOKEN_URL = f"{OIDC_BASE}/token"
  17. START_URL = "https://view.awsapps.com/start"
  18. USER_AGENT = "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0"
  19. X_AMZ_USER_AGENT = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
  20. AMZ_SDK_REQUEST = "attempt=1; max=3"
  21. def make_headers() -> Dict[str, str]:
  22. return {
  23. "content-type": "application/json",
  24. "user-agent": USER_AGENT,
  25. "x-amz-user-agent": X_AMZ_USER_AGENT,
  26. "amz-sdk-request": AMZ_SDK_REQUEST,
  27. "amz-sdk-invocation-id": str(uuid.uuid4()),
  28. }
  29. def post_json(url: str, payload: Dict) -> requests.Response:
  30. # Keep JSON order and mimic body closely to v1
  31. payload_str = json.dumps(payload, ensure_ascii=False)
  32. headers = make_headers()
  33. resp = requests.post(url, headers=headers, data=payload_str, timeout=(15, 60), proxies=_get_proxies())
  34. return resp
  35. def register_client_min() -> Tuple[str, str]:
  36. """
  37. Register an OIDC client (minimal) and return (clientId, clientSecret).
  38. """
  39. payload = {
  40. "clientName": "Amazon Q Developer for command line",
  41. "clientType": "public",
  42. "scopes": [
  43. "codewhisperer:completions",
  44. "codewhisperer:analysis",
  45. "codewhisperer:conversations",
  46. ],
  47. }
  48. r = post_json(REGISTER_URL, payload)
  49. r.raise_for_status()
  50. data = r.json()
  51. return data["clientId"], data["clientSecret"]
  52. def device_authorize(client_id: str, client_secret: str) -> Dict:
  53. """
  54. Start device authorization. Returns dict that includes:
  55. - deviceCode
  56. - interval
  57. - expiresIn
  58. - verificationUriComplete
  59. - userCode
  60. """
  61. payload = {
  62. "clientId": client_id,
  63. "clientSecret": client_secret,
  64. "startUrl": START_URL,
  65. }
  66. r = post_json(DEVICE_AUTH_URL, payload)
  67. r.raise_for_status()
  68. return r.json()
  69. def poll_token_device_code(
  70. client_id: str,
  71. client_secret: str,
  72. device_code: str,
  73. interval: int,
  74. expires_in: int,
  75. max_timeout_sec: Optional[int] = 300,
  76. ) -> Dict:
  77. """
  78. Poll token with device_code until approved or timeout.
  79. - Respects upstream expires_in, but caps total time by max_timeout_sec (default 5 minutes).
  80. Returns token dict with at least 'accessToken' and optionally 'refreshToken'.
  81. Raises:
  82. - TimeoutError on timeout
  83. - requests.HTTPError for non-recoverable HTTP errors
  84. """
  85. payload = {
  86. "clientId": client_id,
  87. "clientSecret": client_secret,
  88. "deviceCode": device_code,
  89. "grantType": "urn:ietf:params:oauth:grant-type:device_code",
  90. }
  91. now = time.time()
  92. upstream_deadline = now + max(1, int(expires_in))
  93. cap_deadline = now + max_timeout_sec if (max_timeout_sec and max_timeout_sec > 0) else upstream_deadline
  94. deadline = min(upstream_deadline, cap_deadline)
  95. # Ensure interval sane
  96. poll_interval = max(1, int(interval or 1))
  97. while time.time() < deadline:
  98. r = post_json(TOKEN_URL, payload)
  99. if r.status_code == 200:
  100. return r.json()
  101. if r.status_code == 400:
  102. # Expect AuthorizationPendingException early on
  103. try:
  104. err = r.json()
  105. except Exception:
  106. err = {"error": r.text}
  107. if str(err.get("error")) == "authorization_pending":
  108. time.sleep(poll_interval)
  109. continue
  110. # Other 4xx are errors
  111. r.raise_for_status()
  112. # Non-200, non-400
  113. r.raise_for_status()
  114. raise TimeoutError("Device authorization expired before approval (timeout reached)")