main.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. // Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
  2. package main
  3. import (
  4. "bufio"
  5. "context"
  6. "crypto/tls"
  7. "flag"
  8. "log"
  9. "net"
  10. "net/url"
  11. "os"
  12. "path/filepath"
  13. "time"
  14. _ "github.com/syncthing/syncthing/lib/automaxprocs"
  15. syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
  16. "github.com/syncthing/syncthing/lib/relay/client"
  17. "github.com/syncthing/syncthing/lib/relay/protocol"
  18. )
  19. func main() {
  20. ctx, cancel := context.WithCancel(context.Background())
  21. defer cancel()
  22. log.SetOutput(os.Stdout)
  23. log.SetFlags(log.LstdFlags | log.Lshortfile)
  24. var connect, relay, dir string
  25. var join, test bool
  26. flag.StringVar(&connect, "connect", "", "Device ID to which to connect to")
  27. flag.BoolVar(&join, "join", false, "Join relay")
  28. flag.BoolVar(&test, "test", false, "Generic relay test")
  29. flag.StringVar(&relay, "relay", "relay://127.0.0.1:22067", "Relay address")
  30. flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored")
  31. flag.Parse()
  32. certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem")
  33. cert, err := tls.LoadX509KeyPair(certFile, keyFile)
  34. if err != nil {
  35. log.Fatalln("Failed to load X509 key pair:", err)
  36. }
  37. id := syncthingprotocol.NewDeviceID(cert.Certificate[0])
  38. log.Println("ID:", id)
  39. uri, err := url.Parse(relay)
  40. if err != nil {
  41. log.Fatal(err)
  42. }
  43. stdin := make(chan string)
  44. go stdinReader(stdin)
  45. if join {
  46. log.Println("Creating client")
  47. relay, err := client.NewClient(uri, []tls.Certificate{cert}, 10*time.Second)
  48. if err != nil {
  49. log.Fatal(err)
  50. }
  51. log.Println("Created client")
  52. go relay.Serve(ctx)
  53. recv := make(chan protocol.SessionInvitation)
  54. go func() {
  55. log.Println("Starting invitation receiver")
  56. for invite := range relay.Invitations() {
  57. select {
  58. case recv <- invite:
  59. log.Println("Received invitation", invite)
  60. default:
  61. log.Println("Discarding invitation", invite)
  62. }
  63. }
  64. }()
  65. for {
  66. conn, err := client.JoinSession(ctx, <-recv)
  67. if err != nil {
  68. log.Fatalln("Failed to join", err)
  69. }
  70. log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr())
  71. connectToStdio(stdin, conn)
  72. log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr())
  73. }
  74. } else if connect != "" {
  75. id, err := syncthingprotocol.DeviceIDFromString(connect)
  76. if err != nil {
  77. log.Fatal(err)
  78. }
  79. invite, err := client.GetInvitationFromRelay(ctx, uri, id, []tls.Certificate{cert}, 10*time.Second)
  80. if err != nil {
  81. log.Fatal(err)
  82. }
  83. log.Println("Received invitation", invite)
  84. conn, err := client.JoinSession(ctx, invite)
  85. if err != nil {
  86. log.Fatalln("Failed to join", err)
  87. }
  88. log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr())
  89. connectToStdio(stdin, conn)
  90. log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr())
  91. } else if test {
  92. if err := client.TestRelay(ctx, uri, []tls.Certificate{cert}, time.Second, 2*time.Second, 4); err == nil {
  93. log.Println("OK")
  94. } else {
  95. log.Println("FAIL:", err)
  96. }
  97. } else {
  98. log.Fatal("Requires either join or connect")
  99. }
  100. }
  101. func stdinReader(c chan<- string) {
  102. scanner := bufio.NewScanner(os.Stdin)
  103. for scanner.Scan() {
  104. c <- scanner.Text()
  105. c <- "\n"
  106. }
  107. }
  108. func connectToStdio(stdin <-chan string, conn net.Conn) {
  109. buf := make([]byte, 1024)
  110. for {
  111. conn.SetReadDeadline(time.Now().Add(time.Millisecond))
  112. n, err := conn.Read(buf[0:])
  113. if err != nil {
  114. nerr, ok := err.(net.Error)
  115. if !ok || !nerr.Timeout() {
  116. log.Println(err)
  117. return
  118. }
  119. }
  120. os.Stdout.Write(buf[:n])
  121. select {
  122. case msg := <-stdin:
  123. _, err := conn.Write([]byte(msg))
  124. if err != nil {
  125. return
  126. }
  127. default:
  128. }
  129. }
  130. }