auth_flow.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import json
  2. import time
  3. import uuid
  4. from typing import Dict, Tuple, Optional
  5. import requests
  6. # OIDC endpoints and constants (aligned with v1/auth_client.py)
  7. OIDC_BASE = "https://oidc.us-east-1.amazonaws.com"
  8. REGISTER_URL = f"{OIDC_BASE}/client/register"
  9. DEVICE_AUTH_URL = f"{OIDC_BASE}/device_authorization"
  10. TOKEN_URL = f"{OIDC_BASE}/token"
  11. START_URL = "https://view.awsapps.com/start"
  12. USER_AGENT = "aws-sdk-rust/1.3.9 os/windows lang/rust/1.87.0"
  13. X_AMZ_USER_AGENT = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/windows lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
  14. AMZ_SDK_REQUEST = "attempt=1; max=3"
  15. def make_headers() -> Dict[str, str]:
  16. return {
  17. "content-type": "application/json",
  18. "user-agent": USER_AGENT,
  19. "x-amz-user-agent": X_AMZ_USER_AGENT,
  20. "amz-sdk-request": AMZ_SDK_REQUEST,
  21. "amz-sdk-invocation-id": str(uuid.uuid4()),
  22. }
  23. def post_json(url: str, payload: Dict) -> requests.Response:
  24. # Keep JSON order and mimic body closely to v1
  25. payload_str = json.dumps(payload, ensure_ascii=False)
  26. headers = make_headers()
  27. resp = requests.post(url, headers=headers, data=payload_str, timeout=(15, 60))
  28. return resp
  29. def register_client_min() -> Tuple[str, str]:
  30. """
  31. Register an OIDC client (minimal) and return (clientId, clientSecret).
  32. """
  33. payload = {
  34. "clientName": "Amazon Q Developer for command line",
  35. "clientType": "public",
  36. "scopes": [
  37. "codewhisperer:completions",
  38. "codewhisperer:analysis",
  39. "codewhisperer:conversations",
  40. ],
  41. }
  42. r = post_json(REGISTER_URL, payload)
  43. r.raise_for_status()
  44. data = r.json()
  45. return data["clientId"], data["clientSecret"]
  46. def device_authorize(client_id: str, client_secret: str) -> Dict:
  47. """
  48. Start device authorization. Returns dict that includes:
  49. - deviceCode
  50. - interval
  51. - expiresIn
  52. - verificationUriComplete
  53. - userCode
  54. """
  55. payload = {
  56. "clientId": client_id,
  57. "clientSecret": client_secret,
  58. "startUrl": START_URL,
  59. }
  60. r = post_json(DEVICE_AUTH_URL, payload)
  61. r.raise_for_status()
  62. return r.json()
  63. def poll_token_device_code(
  64. client_id: str,
  65. client_secret: str,
  66. device_code: str,
  67. interval: int,
  68. expires_in: int,
  69. max_timeout_sec: Optional[int] = 300,
  70. ) -> Dict:
  71. """
  72. Poll token with device_code until approved or timeout.
  73. - Respects upstream expires_in, but caps total time by max_timeout_sec (default 5 minutes).
  74. Returns token dict with at least 'accessToken' and optionally 'refreshToken'.
  75. Raises:
  76. - TimeoutError on timeout
  77. - requests.HTTPError for non-recoverable HTTP errors
  78. """
  79. payload = {
  80. "clientId": client_id,
  81. "clientSecret": client_secret,
  82. "deviceCode": device_code,
  83. "grantType": "urn:ietf:params:oauth:grant-type:device_code",
  84. }
  85. now = time.time()
  86. upstream_deadline = now + max(1, int(expires_in))
  87. cap_deadline = now + max_timeout_sec if (max_timeout_sec and max_timeout_sec > 0) else upstream_deadline
  88. deadline = min(upstream_deadline, cap_deadline)
  89. # Ensure interval sane
  90. poll_interval = max(1, int(interval or 1))
  91. while time.time() < deadline:
  92. r = post_json(TOKEN_URL, payload)
  93. if r.status_code == 200:
  94. return r.json()
  95. if r.status_code == 400:
  96. # Expect AuthorizationPendingException early on
  97. try:
  98. err = r.json()
  99. except Exception:
  100. err = {"error": r.text}
  101. if str(err.get("error")) == "authorization_pending":
  102. time.sleep(poll_interval)
  103. continue
  104. # Other 4xx are errors
  105. r.raise_for_status()
  106. # Non-200, non-400
  107. r.raise_for_status()
  108. raise TimeoutError("Device authorization expired before approval (timeout reached)")