flight2handler.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. package dtls
  2. import (
  3. "bytes"
  4. "context"
  5. "fmt"
  6. "github.com/pion/dtls/v2/pkg/protocol"
  7. "github.com/pion/dtls/v2/pkg/protocol/alert"
  8. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  9. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  10. )
  11. func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  12. seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
  13. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  14. )
  15. if !ok {
  16. // Client may retransmit the first ClientHello when HelloVerifyRequest is dropped.
  17. // Parse as flight 0 in this case.
  18. return flight0Parse(ctx, c, state, cache, cfg)
  19. }
  20. state.handshakeRecvSequence = seq
  21. var clientHello *handshake.MessageClientHello
  22. // Validate type
  23. if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok {
  24. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  25. }
  26. if !clientHello.Version.Equal(protocol.Version1_2) {
  27. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  28. }
  29. if len(clientHello.Cookie) == 0 {
  30. return 0, nil, nil
  31. }
  32. if !bytes.Equal(state.cookie, clientHello.Cookie) {
  33. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.AccessDenied}, errCookieMismatch
  34. }
  35. // TODO 添加 CiscoCompat 支持
  36. if cfg.localCiscoCompatCallback != nil {
  37. var err error
  38. state.SessionID = clientHello.SessionID
  39. if len(state.SessionID) == 0 {
  40. err = fmt.Errorf("clientHello SessionID is nil")
  41. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
  42. }
  43. state.masterSecret, err = cfg.localCiscoCompatCallback(state.SessionID)
  44. if err != nil {
  45. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
  46. }
  47. }
  48. return flight4, nil, nil
  49. }
  50. func flight2Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  51. state.handshakeSendSequence = 0
  52. return []*packet{
  53. {
  54. record: &recordlayer.RecordLayer{
  55. Header: recordlayer.Header{
  56. Version: protocol.Version1_2,
  57. },
  58. Content: &handshake.Handshake{
  59. Message: &handshake.MessageHelloVerifyRequest{
  60. Version: protocol.Version1_2,
  61. Cookie: state.cookie,
  62. },
  63. },
  64. },
  65. },
  66. }, nil, nil
  67. }