Просмотр исходного кода

tsnet: cleanup resources upon start failure (#5301)

In a partially initialized state, we should cleanup
all prior resources when an error occurs.

Signed-off-by: Joe Tsai <[email protected]>
Joe Tsai 3 лет назад
Родитель
Сommit
b1fff4499f
1 измененных файлов с 23 добавлено и 1 удалено
  1. 23 1
      tsnet/tsnet.go

+ 23 - 1
tsnet/tsnet.go

@@ -10,6 +10,7 @@ package tsnet
 import (
 	"context"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"log"
 	"net"
@@ -182,7 +183,10 @@ func (s *Server) getAuthKey() string {
 	return os.Getenv("TS_AUTHKEY")
 }
 
-func (s *Server) start() error {
+func (s *Server) start() (reterr error) {
+	var closePool closeOnErrorPool
+	defer closePool.closeAllIfError(&reterr)
+
 	exe, err := os.Executable()
 	if err != nil {
 		return err
@@ -244,6 +248,7 @@ func (s *Server) start() error {
 	if err != nil {
 		return fmt.Errorf("error creating filch: %w", err)
 	}
+	closePool.add(s.logbuffer)
 	c := logtail.Config{
 		Collection: lpc.Collection,
 		PrivateID:  lpc.PrivateID,
@@ -259,11 +264,13 @@ func (s *Server) start() error {
 		HTTPC: &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost)},
 	}
 	s.logtail = logtail.NewLogger(c, logf)
+	closePool.addFunc(func() { s.logtail.Shutdown(context.Background()) })
 
 	s.linkMon, err = monitor.New(logf)
 	if err != nil {
 		return err
 	}
+	closePool.add(s.linkMon)
 
 	s.dialer = new(tsdial.Dialer) // mutated below (before used)
 	eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
@@ -274,6 +281,7 @@ func (s *Server) start() error {
 	if err != nil {
 		return err
 	}
+	closePool.add(s.dialer)
 
 	tunDev, magicConn, dns, ok := eng.(wgengine.InternalsGetter).GetInternals()
 	if !ok {
@@ -317,6 +325,7 @@ func (s *Server) start() error {
 	lb.SetVarRoot(s.rootPath)
 	logf("tsnet starting with hostname %q, varRoot %q", s.hostname, s.rootPath)
 	s.lb = lb
+	closePool.addFunc(func() { s.lb.Shutdown() })
 	lb.SetDecompressor(func() (controlclient.Decompressor, error) {
 		return smallzstd.NewDecoder(nil)
 	})
@@ -357,9 +366,22 @@ func (s *Server) start() error {
 			logf("localapi serve error: %v", err)
 		}
 	}()
+	closePool.add(s.localAPIListener)
 	return nil
 }
 
+type closeOnErrorPool []func()
+
+func (p *closeOnErrorPool) add(c io.Closer)   { *p = append(*p, func() { c.Close() }) }
+func (p *closeOnErrorPool) addFunc(fn func()) { *p = append(*p, fn) }
+func (p closeOnErrorPool) closeAllIfError(errp *error) {
+	if *errp != nil {
+		for _, closeFn := range p {
+			closeFn()
+		}
+	}
+}
+
 func (s *Server) logf(format string, a ...interface{}) {
 	if s.logtail != nil {
 		s.logtail.Logf(format, a...)