1
0

oauth2.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. // Package smtp provides supports for sending emails
  15. package smtp
  16. import (
  17. "context"
  18. "errors"
  19. "fmt"
  20. "sync"
  21. "time"
  22. "golang.org/x/oauth2"
  23. "golang.org/x/oauth2/google"
  24. "golang.org/x/oauth2/microsoft"
  25. "github.com/drakkan/sftpgo/v2/internal/logger"
  26. "github.com/drakkan/sftpgo/v2/internal/util"
  27. )
  28. // Supported OAuth2 providers
  29. const (
  30. OAuth2ProviderGoogle = iota
  31. OAuth2ProviderMicrosoft
  32. )
  33. var supportedOAuth2Providers = []int{OAuth2ProviderGoogle, OAuth2ProviderMicrosoft}
  34. // OAuth2Config defines OAuth2 settings
  35. type OAuth2Config struct {
  36. Provider int `json:"provider" mapstructure:"provider"`
  37. // Tenant for Microsoft provider, if empty "common" is used
  38. Tenant string `json:"tenant" mapstructure:"tenant"`
  39. // ClientID is the application's ID
  40. ClientID string `json:"client_id" mapstructure:"client_id"`
  41. // ClientSecret is the application's secret
  42. ClientSecret string `json:"client_secret" mapstructure:"client_secret"`
  43. // Token to use to get/renew access tokens
  44. RefreshToken string `json:"refresh_token" mapstructure:"refresh_token"`
  45. mu *sync.RWMutex
  46. config *oauth2.Config
  47. accessToken *oauth2.Token
  48. }
  49. // Validate validates and initializes the configuration
  50. func (c *OAuth2Config) Validate() error {
  51. if !util.Contains(supportedOAuth2Providers, c.Provider) {
  52. return fmt.Errorf("smtp oauth2: unsupported provider %d", c.Provider)
  53. }
  54. if c.ClientID == "" {
  55. return errors.New("smtp oauth2: client id is required")
  56. }
  57. if c.ClientSecret == "" {
  58. return errors.New("smtp oauth2: client secret is required")
  59. }
  60. if c.RefreshToken == "" {
  61. return errors.New("smtp oauth2: refresh token is required")
  62. }
  63. c.initialize()
  64. return nil
  65. }
  66. func (c *OAuth2Config) isEqual(other *OAuth2Config) bool {
  67. if c.Provider != other.Provider {
  68. return false
  69. }
  70. if c.Tenant != other.Tenant {
  71. return false
  72. }
  73. if c.ClientID != other.ClientID {
  74. return false
  75. }
  76. if c.ClientSecret != other.ClientSecret {
  77. return false
  78. }
  79. if c.RefreshToken != other.RefreshToken {
  80. return false
  81. }
  82. return true
  83. }
  84. func (c *OAuth2Config) getAccessToken() (string, error) {
  85. c.mu.RLock()
  86. if c.accessToken.Expiry.After(time.Now().Add(30 * time.Second)) {
  87. accessToken := c.accessToken.AccessToken
  88. c.mu.RUnlock()
  89. return accessToken, nil
  90. }
  91. logger.Debug(logSender, "", "renew oauth2 token required, current token expires at %s", c.accessToken.Expiry)
  92. token := new(oauth2.Token)
  93. *token = *c.accessToken
  94. c.mu.RUnlock()
  95. ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
  96. defer cancel()
  97. newToken, err := c.config.TokenSource(ctx, token).Token()
  98. if err != nil {
  99. logger.Error(logSender, "", "unable to get new token: %v", err)
  100. return "", err
  101. }
  102. accessToken := newToken.AccessToken
  103. refreshToken := newToken.RefreshToken
  104. if refreshToken != "" && refreshToken != token.RefreshToken {
  105. c.mu.Lock()
  106. c.RefreshToken = refreshToken
  107. c.accessToken = newToken
  108. c.mu.Unlock()
  109. logger.Debug(logSender, "", "oauth2 refresh token changed")
  110. go updateRefreshToken(refreshToken)
  111. }
  112. if accessToken != token.AccessToken {
  113. c.mu.Lock()
  114. c.accessToken = newToken
  115. c.mu.Unlock()
  116. logger.Debug(logSender, "", "new oauth2 token saved, expires at %s", c.accessToken.Expiry)
  117. }
  118. return accessToken, nil
  119. }
  120. func (c *OAuth2Config) initialize() {
  121. c.mu = new(sync.RWMutex)
  122. c.config = c.GetOAuth2()
  123. c.accessToken = &oauth2.Token{
  124. TokenType: "Bearer",
  125. RefreshToken: c.RefreshToken,
  126. }
  127. }
  128. // GetOAuth2 returns the oauth2 configuration for the provided parameters.
  129. func (c *OAuth2Config) GetOAuth2() *oauth2.Config {
  130. var endpoint oauth2.Endpoint
  131. var scopes []string
  132. switch c.Provider {
  133. case OAuth2ProviderMicrosoft:
  134. endpoint = microsoft.AzureADEndpoint(c.Tenant)
  135. scopes = []string{"offline_access", "https://outlook.office.com/SMTP.Send"}
  136. default:
  137. endpoint = google.Endpoint
  138. scopes = []string{"https://mail.google.com/"}
  139. }
  140. return &oauth2.Config{
  141. ClientID: c.ClientID,
  142. ClientSecret: c.ClientSecret,
  143. Scopes: scopes,
  144. Endpoint: endpoint,
  145. }
  146. }