| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- package model
- import (
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
- "strings"
- "time"
- "github.com/QuantumNous/new-api/common"
- "github.com/go-webauthn/webauthn/protocol"
- "github.com/go-webauthn/webauthn/webauthn"
- "gorm.io/gorm"
- )
- var (
- ErrPasskeyNotFound = errors.New("passkey credential not found")
- ErrFriendlyPasskeyNotFound = errors.New("Passkey 验证失败,请重试或联系管理员")
- )
- type PasskeyCredential struct {
- ID int `json:"id" gorm:"primaryKey"`
- UserID int `json:"user_id" gorm:"uniqueIndex;not null"`
- CredentialID string `json:"credential_id" gorm:"type:varchar(512);uniqueIndex;not null"` // base64 encoded
- PublicKey string `json:"public_key" gorm:"type:text;not null"` // base64 encoded
- AttestationType string `json:"attestation_type" gorm:"type:varchar(255)"`
- AAGUID string `json:"aaguid" gorm:"type:varchar(512)"` // base64 encoded
- SignCount uint32 `json:"sign_count" gorm:"default:0"`
- CloneWarning bool `json:"clone_warning"`
- UserPresent bool `json:"user_present"`
- UserVerified bool `json:"user_verified"`
- BackupEligible bool `json:"backup_eligible"`
- BackupState bool `json:"backup_state"`
- Transports string `json:"transports" gorm:"type:text"`
- Attachment string `json:"attachment" gorm:"type:varchar(32)"`
- LastUsedAt *time.Time `json:"last_used_at"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
- DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
- }
- func (p *PasskeyCredential) TransportList() []protocol.AuthenticatorTransport {
- if p == nil || strings.TrimSpace(p.Transports) == "" {
- return nil
- }
- var transports []string
- if err := json.Unmarshal([]byte(p.Transports), &transports); err != nil {
- return nil
- }
- result := make([]protocol.AuthenticatorTransport, 0, len(transports))
- for _, transport := range transports {
- result = append(result, protocol.AuthenticatorTransport(transport))
- }
- return result
- }
- func (p *PasskeyCredential) SetTransports(list []protocol.AuthenticatorTransport) {
- if len(list) == 0 {
- p.Transports = ""
- return
- }
- stringList := make([]string, len(list))
- for i, transport := range list {
- stringList[i] = string(transport)
- }
- encoded, err := json.Marshal(stringList)
- if err != nil {
- return
- }
- p.Transports = string(encoded)
- }
- func (p *PasskeyCredential) ToWebAuthnCredential() webauthn.Credential {
- flags := webauthn.CredentialFlags{
- UserPresent: p.UserPresent,
- UserVerified: p.UserVerified,
- BackupEligible: p.BackupEligible,
- BackupState: p.BackupState,
- }
- credID, _ := base64.StdEncoding.DecodeString(p.CredentialID)
- pubKey, _ := base64.StdEncoding.DecodeString(p.PublicKey)
- aaguid, _ := base64.StdEncoding.DecodeString(p.AAGUID)
- return webauthn.Credential{
- ID: credID,
- PublicKey: pubKey,
- AttestationType: p.AttestationType,
- Transport: p.TransportList(),
- Flags: flags,
- Authenticator: webauthn.Authenticator{
- AAGUID: aaguid,
- SignCount: p.SignCount,
- CloneWarning: p.CloneWarning,
- Attachment: protocol.AuthenticatorAttachment(p.Attachment),
- },
- }
- }
- func NewPasskeyCredentialFromWebAuthn(userID int, credential *webauthn.Credential) *PasskeyCredential {
- if credential == nil {
- return nil
- }
- passkey := &PasskeyCredential{
- UserID: userID,
- CredentialID: base64.StdEncoding.EncodeToString(credential.ID),
- PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey),
- AttestationType: credential.AttestationType,
- AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID),
- SignCount: credential.Authenticator.SignCount,
- CloneWarning: credential.Authenticator.CloneWarning,
- UserPresent: credential.Flags.UserPresent,
- UserVerified: credential.Flags.UserVerified,
- BackupEligible: credential.Flags.BackupEligible,
- BackupState: credential.Flags.BackupState,
- Attachment: string(credential.Authenticator.Attachment),
- }
- passkey.SetTransports(credential.Transport)
- return passkey
- }
- func (p *PasskeyCredential) ApplyValidatedCredential(credential *webauthn.Credential) {
- if credential == nil || p == nil {
- return
- }
- p.CredentialID = base64.StdEncoding.EncodeToString(credential.ID)
- p.PublicKey = base64.StdEncoding.EncodeToString(credential.PublicKey)
- p.AttestationType = credential.AttestationType
- p.AAGUID = base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID)
- p.SignCount = credential.Authenticator.SignCount
- p.CloneWarning = credential.Authenticator.CloneWarning
- p.UserPresent = credential.Flags.UserPresent
- p.UserVerified = credential.Flags.UserVerified
- p.BackupEligible = credential.Flags.BackupEligible
- p.BackupState = credential.Flags.BackupState
- p.Attachment = string(credential.Authenticator.Attachment)
- p.SetTransports(credential.Transport)
- }
- func GetPasskeyByUserID(userID int) (*PasskeyCredential, error) {
- if userID == 0 {
- common.SysLog("GetPasskeyByUserID: empty user ID")
- return nil, ErrFriendlyPasskeyNotFound
- }
- var credential PasskeyCredential
- if err := DB.Where("user_id = ?", userID).First(&credential).Error; err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- // 未找到记录是正常情况(用户未绑定),返回 ErrPasskeyNotFound 而不记录日志
- return nil, ErrPasskeyNotFound
- }
- // 只有真正的数据库错误才记录日志
- common.SysLog(fmt.Sprintf("GetPasskeyByUserID: database error for user %d: %v", userID, err))
- return nil, ErrFriendlyPasskeyNotFound
- }
- return &credential, nil
- }
- func GetPasskeyByCredentialID(credentialID []byte) (*PasskeyCredential, error) {
- if len(credentialID) == 0 {
- common.SysLog("GetPasskeyByCredentialID: empty credential ID")
- return nil, ErrFriendlyPasskeyNotFound
- }
- credIDStr := base64.StdEncoding.EncodeToString(credentialID)
- var credential PasskeyCredential
- if err := DB.Where("credential_id = ?", credIDStr).First(&credential).Error; err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: passkey not found for credential ID length %d", len(credentialID)))
- return nil, ErrFriendlyPasskeyNotFound
- }
- common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: database error for credential ID: %v", err))
- return nil, ErrFriendlyPasskeyNotFound
- }
- return &credential, nil
- }
- func UpsertPasskeyCredential(credential *PasskeyCredential) error {
- if credential == nil {
- common.SysLog("UpsertPasskeyCredential: nil credential provided")
- return fmt.Errorf("Passkey 保存失败,请重试")
- }
- return DB.Transaction(func(tx *gorm.DB) error {
- // 使用Unscoped()进行硬删除,避免唯一索引冲突
- if err := tx.Unscoped().Where("user_id = ?", credential.UserID).Delete(&PasskeyCredential{}).Error; err != nil {
- common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to delete existing credential for user %d: %v", credential.UserID, err))
- return fmt.Errorf("Passkey 保存失败,请重试")
- }
- if err := tx.Create(credential).Error; err != nil {
- common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to create credential for user %d: %v", credential.UserID, err))
- return fmt.Errorf("Passkey 保存失败,请重试")
- }
- return nil
- })
- }
- func DeletePasskeyByUserID(userID int) error {
- if userID == 0 {
- common.SysLog("DeletePasskeyByUserID: empty user ID")
- return fmt.Errorf("删除失败,请重试")
- }
- // 使用Unscoped()进行硬删除,避免唯一索引冲突
- if err := DB.Unscoped().Where("user_id = ?", userID).Delete(&PasskeyCredential{}).Error; err != nil {
- common.SysLog(fmt.Sprintf("DeletePasskeyByUserID: failed to delete passkey for user %d: %v", userID, err))
- return fmt.Errorf("删除失败,请重试")
- }
- return nil
- }
|