Parcourir la source

Merge pull request #38 from rumpl/fix-exec

Simplify exec on ACI
Djordje Lukic il y a 5 ans
Parent
commit
29737c2a23
1 fichiers modifiés avec 22 ajouts et 23 suppressions
  1. 22 23
      azure/aci.go

+ 22 - 23
azure/aci.go

@@ -8,15 +8,15 @@ import (
 	"net/http"
 	"os"
 
-	"github.com/docker/api/context/store"
-
-	"github.com/gobwas/ws"
-	"github.com/gobwas/ws/wsutil"
-
 	"github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2018-10-01/containerinstance"
 	"github.com/Azure/azure-sdk-for-go/services/keyvault/auth"
 	"github.com/Azure/go-autorest/autorest"
 	"github.com/Azure/go-autorest/autorest/to"
+	"github.com/gobwas/ws"
+	"github.com/gobwas/ws/wsutil"
+	"github.com/pkg/errors"
+
+	"github.com/docker/api/context/store"
 
 	tm "github.com/buger/goterm"
 )
@@ -166,33 +166,33 @@ func execCommands(ctx context.Context, address string, password string, commands
 }
 
 func exec(ctx context.Context, address string, password string, reader io.Reader, writer io.Writer) error {
-	ctx, cancel := context.WithCancel(ctx)
 	conn, _, _, err := ws.DefaultDialer.Dial(ctx, address)
 	if err != nil {
-		cancel()
 		return err
 	}
 	err = wsutil.WriteClientMessage(conn, ws.OpText, []byte(password))
 	if err != nil {
-		cancel()
 		return err
 	}
 
-	done := make(chan struct{})
+	downstreamChannel := make(chan error, 10)
+	upstreamChannel := make(chan error, 10)
 
 	go func() {
-		defer close(done)
 		for {
 			msg, _, err := wsutil.ReadServerData(conn)
 			if err != nil {
+				if err == io.EOF {
+					downstreamChannel <- nil
+					return
+				}
+				downstreamChannel <- err
 				return
 			}
 			fmt.Fprint(writer, string(msg))
 		}
 	}()
 
-	readChannel := make(chan []byte, 10)
-
 	go func() {
 		for {
 			// We send each byte, byte-per-byte over the
@@ -200,26 +200,25 @@ func exec(ctx context.Context, address string, password string, reader io.Reader
 			buffer := make([]byte, 1)
 			n, err := reader.Read(buffer)
 			if err != nil {
-				close(done)
-				cancel()
-				break
+				upstreamChannel <- err
+				return
 			}
 
 			if n > 0 {
-				readChannel <- buffer
+				err := wsutil.WriteClientMessage(conn, ws.OpText, buffer)
+				if err != nil {
+					upstreamChannel <- err
+				}
 			}
 		}
 	}()
 
 	for {
 		select {
-		case <-done:
-			return nil
-		case bytes := <-readChannel:
-			err := wsutil.WriteClientMessage(conn, ws.OpText, bytes)
-			if err != nil {
-				return err
-			}
+		case err := <-downstreamChannel:
+			return errors.Wrap(err, "failed to read input from container")
+		case err := <-upstreamChannel:
+			return errors.Wrap(err, "failed to send input to container")
 		}
 	}
 }