device.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. // Package hyper provides functions to handle Hyper device flow authentication.
  2. package hyper
  3. import (
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "net/http"
  11. "os"
  12. "strings"
  13. "time"
  14. "github.com/charmbracelet/crush/internal/agent/hyper"
  15. "github.com/charmbracelet/crush/internal/event"
  16. "github.com/charmbracelet/crush/internal/oauth"
  17. )
  18. // DeviceAuthResponse contains the response from the device authorization endpoint.
  19. type DeviceAuthResponse struct {
  20. DeviceCode string `json:"device_code"`
  21. UserCode string `json:"user_code"`
  22. VerificationURL string `json:"verification_url"`
  23. ExpiresIn int `json:"expires_in"`
  24. }
  25. // TokenResponse contains the response from the polling endpoint.
  26. type TokenResponse struct {
  27. RefreshToken string `json:"refresh_token,omitempty"`
  28. UserID string `json:"user_id"`
  29. OrganizationID string `json:"organization_id"`
  30. OrganizationName string `json:"organization_name"`
  31. Error string `json:"error,omitempty"`
  32. ErrorDescription string `json:"error_description,omitempty"`
  33. }
  34. // InitiateDeviceAuth calls the /device/auth endpoint to start the device flow.
  35. func InitiateDeviceAuth(ctx context.Context) (*DeviceAuthResponse, error) {
  36. url := hyper.BaseURL() + "/device/auth"
  37. req, err := http.NewRequestWithContext(
  38. ctx, http.MethodPost, url,
  39. strings.NewReader(fmt.Sprintf(`{"device_name":%q}`, deviceName())),
  40. )
  41. if err != nil {
  42. return nil, fmt.Errorf("create request: %w", err)
  43. }
  44. req.Header.Set("Content-Type", "application/json")
  45. req.Header.Set("User-Agent", "crush")
  46. client := &http.Client{Timeout: 30 * time.Second}
  47. resp, err := client.Do(req)
  48. if err != nil {
  49. return nil, fmt.Errorf("execute request: %w", err)
  50. }
  51. defer resp.Body.Close()
  52. body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
  53. if err != nil {
  54. return nil, fmt.Errorf("read response: %w", err)
  55. }
  56. if resp.StatusCode != http.StatusOK {
  57. return nil, fmt.Errorf("device auth failed: status %d, body %q", resp.StatusCode, string(body))
  58. }
  59. var authResp DeviceAuthResponse
  60. if err := json.Unmarshal(body, &authResp); err != nil {
  61. return nil, fmt.Errorf("unmarshal response: %w", err)
  62. }
  63. return &authResp, nil
  64. }
  65. func deviceName() string {
  66. if hostname, err := os.Hostname(); err == nil && hostname != "" {
  67. return "Crush (" + hostname + ")"
  68. }
  69. return "Crush"
  70. }
  71. // PollForToken polls the /device/token endpoint until authorization is complete.
  72. // It respects the polling interval and handles various error states.
  73. func PollForToken(ctx context.Context, deviceCode string, expiresIn int) (string, error) {
  74. ctx, cancel := context.WithTimeout(ctx, time.Duration(expiresIn)*time.Second)
  75. defer cancel()
  76. d := 5 * time.Second
  77. ticker := time.NewTicker(d)
  78. defer ticker.Stop()
  79. for {
  80. select {
  81. case <-ctx.Done():
  82. return "", ctx.Err()
  83. case <-ticker.C:
  84. result, err := pollOnce(ctx, deviceCode)
  85. if err != nil {
  86. return "", err
  87. }
  88. if result.RefreshToken != "" {
  89. event.Alias(result.UserID)
  90. return result.RefreshToken, nil
  91. }
  92. switch result.Error {
  93. case "authorization_pending":
  94. continue
  95. default:
  96. return "", errors.New(result.ErrorDescription)
  97. }
  98. }
  99. }
  100. }
  101. func pollOnce(ctx context.Context, deviceCode string) (TokenResponse, error) {
  102. var result TokenResponse
  103. url := fmt.Sprintf("%s/device/auth/%s", hyper.BaseURL(), deviceCode)
  104. req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
  105. if err != nil {
  106. return result, fmt.Errorf("create request: %w", err)
  107. }
  108. req.Header.Set("Content-Type", "application/json")
  109. req.Header.Set("User-Agent", "crush")
  110. client := &http.Client{Timeout: 30 * time.Second}
  111. resp, err := client.Do(req)
  112. if err != nil {
  113. return result, fmt.Errorf("execute request: %w", err)
  114. }
  115. defer resp.Body.Close()
  116. body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
  117. if err != nil {
  118. return result, fmt.Errorf("read response: %w", err)
  119. }
  120. if err := json.Unmarshal(body, &result); err != nil {
  121. return result, fmt.Errorf("unmarshal response: %w: %s", err, string(body))
  122. }
  123. if resp.StatusCode != http.StatusOK {
  124. return result, fmt.Errorf("token request failed: status %d body %q", resp.StatusCode, string(body))
  125. }
  126. return result, nil
  127. }
  128. // ExchangeToken exchanges a refresh token for an access token.
  129. func ExchangeToken(ctx context.Context, refreshToken string) (*oauth.Token, error) {
  130. reqBody := map[string]string{
  131. "refresh_token": refreshToken,
  132. }
  133. data, err := json.Marshal(reqBody)
  134. if err != nil {
  135. return nil, fmt.Errorf("marshal request: %w", err)
  136. }
  137. url := hyper.BaseURL() + "/token/exchange"
  138. req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
  139. if err != nil {
  140. return nil, fmt.Errorf("create request: %w", err)
  141. }
  142. req.Header.Set("Content-Type", "application/json")
  143. req.Header.Set("User-Agent", "crush")
  144. client := &http.Client{Timeout: 30 * time.Second}
  145. resp, err := client.Do(req)
  146. if err != nil {
  147. return nil, fmt.Errorf("execute request: %w", err)
  148. }
  149. defer resp.Body.Close()
  150. body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
  151. if err != nil {
  152. return nil, fmt.Errorf("read response: %w", err)
  153. }
  154. if resp.StatusCode != http.StatusOK {
  155. return nil, fmt.Errorf("token exchange failed: status %d body %q", resp.StatusCode, string(body))
  156. }
  157. var token oauth.Token
  158. if err := json.Unmarshal(body, &token); err != nil {
  159. return nil, fmt.Errorf("unmarshal response: %w", err)
  160. }
  161. token.SetExpiresAt()
  162. return &token, nil
  163. }
  164. // IntrospectTokenResponse contains the response from the token introspection endpoint.
  165. type IntrospectTokenResponse struct {
  166. Active bool `json:"active"`
  167. Sub string `json:"sub,omitempty"`
  168. OrgID string `json:"org_id,omitempty"`
  169. Exp int64 `json:"exp,omitempty"`
  170. Iat int64 `json:"iat,omitempty"`
  171. Iss string `json:"iss,omitempty"`
  172. Jti string `json:"jti,omitempty"`
  173. }
  174. // IntrospectToken validates an access token using the introspection endpoint.
  175. // Implements OAuth2 Token Introspection (RFC 7662).
  176. func IntrospectToken(ctx context.Context, accessToken string) (*IntrospectTokenResponse, error) {
  177. reqBody := map[string]string{
  178. "token": accessToken,
  179. }
  180. data, err := json.Marshal(reqBody)
  181. if err != nil {
  182. return nil, fmt.Errorf("marshal request: %w", err)
  183. }
  184. url := hyper.BaseURL() + "/token/introspect"
  185. req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
  186. if err != nil {
  187. return nil, fmt.Errorf("create request: %w", err)
  188. }
  189. req.Header.Set("Content-Type", "application/json")
  190. req.Header.Set("User-Agent", "crush")
  191. client := &http.Client{Timeout: 30 * time.Second}
  192. resp, err := client.Do(req)
  193. if err != nil {
  194. return nil, fmt.Errorf("execute request: %w", err)
  195. }
  196. defer resp.Body.Close()
  197. body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
  198. if err != nil {
  199. return nil, fmt.Errorf("read response: %w", err)
  200. }
  201. if resp.StatusCode != http.StatusOK {
  202. return nil, fmt.Errorf("token introspection failed: status %d body %q", resp.StatusCode, string(body))
  203. }
  204. var result IntrospectTokenResponse
  205. if err := json.Unmarshal(body, &result); err != nil {
  206. return nil, fmt.Errorf("unmarshal response: %w", err)
  207. }
  208. return &result, nil
  209. }