|  | @@ -1,15 +1,12 @@
 | 
	
		
			
				|  |  |  package azure
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import (
 | 
	
		
			
				|  |  | -	"bufio"
 | 
	
		
			
				|  |  |  	"context"
 | 
	
		
			
				|  |  |  	"fmt"
 | 
	
		
			
				|  |  |  	"io"
 | 
	
		
			
				|  |  | +	"io/ioutil"
 | 
	
		
			
				|  |  |  	"net/http"
 | 
	
		
			
				|  |  |  	"os"
 | 
	
		
			
				|  |  | -	"os/signal"
 | 
	
		
			
				|  |  | -	"runtime"
 | 
	
		
			
				|  |  | -	"strings"
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	"github.com/docker/api/context/store"
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -81,18 +78,17 @@ func createACIContainers(ctx context.Context, aciContext store.AciContext, group
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		containers := *containerGroup.Containers
 | 
	
		
			
				|  |  |  		container := containers[0]
 | 
	
		
			
				|  |  | -		response, err := execACIContainer(ctx, "/bin/sh", *containerGroup.Name, *container.Name, aciContext)
 | 
	
		
			
				|  |  | +		response, err := execACIContainer(ctx, aciContext, "/bin/sh", *containerGroup.Name, *container.Name)
 | 
	
		
			
				|  |  |  		if err != nil {
 | 
	
		
			
				|  |  |  			return c, err
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		err = execWebSocketLoopWithCmd(
 | 
	
		
			
				|  |  | +		if err = execCommands(
 | 
	
		
			
				|  |  |  			ctx,
 | 
	
		
			
				|  |  |  			*response.WebSocketURI,
 | 
	
		
			
				|  |  |  			*response.Password,
 | 
	
		
			
				|  |  |  			commands,
 | 
	
		
			
				|  |  | -			false)
 | 
	
		
			
				|  |  | -		if err != nil {
 | 
	
		
			
				|  |  | +		); err != nil {
 | 
	
		
			
				|  |  |  			return containerinstance.ContainerGroup{}, err
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  	}
 | 
	
	
		
			
				|  | @@ -122,7 +118,7 @@ func listACIContainers(aciContext store.AciContext) (c []containerinstance.Conta
 | 
	
		
			
				|  |  |  	return containers, err
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func execACIContainer(ctx context.Context, command, containerGroup string, containerName string, aciContext store.AciContext) (c containerinstance.ContainerExecResponse, err error) {
 | 
	
		
			
				|  |  | +func execACIContainer(ctx context.Context, aciContext store.AciContext, command, containerGroup string, containerName string) (c containerinstance.ContainerExecResponse, err error) {
 | 
	
		
			
				|  |  |  	containerClient := getContainerClient(aciContext.SubscriptionID)
 | 
	
		
			
				|  |  |  	rows, cols := getTermSize()
 | 
	
		
			
				|  |  |  	containerExecRequest := containerinstance.ContainerExecRequest{
 | 
	
	
		
			
				|  | @@ -132,6 +128,7 @@ func execACIContainer(ctx context.Context, command, containerGroup string, conta
 | 
	
		
			
				|  |  |  			Cols: cols,
 | 
	
		
			
				|  |  |  		},
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	return containerClient.ExecuteCommand(
 | 
	
		
			
				|  |  |  		ctx,
 | 
	
		
			
				|  |  |  		aciContext.ResourceGroup,
 | 
	
	
		
			
				|  | @@ -146,93 +143,85 @@ func getTermSize() (*int32, *int32) {
 | 
	
		
			
				|  |  |  	return to.Int32Ptr(int32(rows)), to.Int32Ptr(int32(cols))
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func execWebSocketLoop(ctx context.Context, wsURL, passwd string) error {
 | 
	
		
			
				|  |  | -	return execWebSocketLoopWithCmd(ctx, wsURL, passwd, []string{}, true)
 | 
	
		
			
				|  |  | +type commandSender struct {
 | 
	
		
			
				|  |  | +	commands []string
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func execWebSocketLoopWithCmd(ctx context.Context, wsURL, passwd string, commands []string, outputEnabled bool) error {
 | 
	
		
			
				|  |  | +func (cs commandSender) Read(p []byte) (int, error) {
 | 
	
		
			
				|  |  | +	if len(cs.commands) == 0 {
 | 
	
		
			
				|  |  | +		return 0, io.EOF
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	command := cs.commands[0]
 | 
	
		
			
				|  |  | +	cs.commands = cs.commands[1:]
 | 
	
		
			
				|  |  | +	copy(p, command)
 | 
	
		
			
				|  |  | +	return len(command), nil
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func execCommands(ctx context.Context, address string, password string, commands []string) error {
 | 
	
		
			
				|  |  | +	writer := ioutil.Discard
 | 
	
		
			
				|  |  | +	reader := commandSender{
 | 
	
		
			
				|  |  | +		commands: commands,
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	return exec(ctx, address, password, reader, writer)
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func exec(ctx context.Context, address string, password string, reader io.Reader, writer io.Writer) error {
 | 
	
		
			
				|  |  |  	ctx, cancel := context.WithCancel(ctx)
 | 
	
		
			
				|  |  | -	conn, _, _, err := ws.DefaultDialer.Dial(ctx, wsURL)
 | 
	
		
			
				|  |  | +	conn, _, _, err := ws.DefaultDialer.Dial(ctx, address)
 | 
	
		
			
				|  |  |  	if err != nil {
 | 
	
		
			
				|  |  |  		cancel()
 | 
	
		
			
				|  |  |  		return err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	err = wsutil.WriteClientMessage(conn, ws.OpText, []byte(passwd))
 | 
	
		
			
				|  |  | +	err = wsutil.WriteClientMessage(conn, ws.OpText, []byte(password))
 | 
	
		
			
				|  |  |  	if err != nil {
 | 
	
		
			
				|  |  |  		cancel()
 | 
	
		
			
				|  |  |  		return err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	lastCommandLen := 0
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	done := make(chan struct{})
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	go func() {
 | 
	
		
			
				|  |  |  		defer close(done)
 | 
	
		
			
				|  |  |  		for {
 | 
	
		
			
				|  |  |  			msg, _, err := wsutil.ReadServerData(conn)
 | 
	
		
			
				|  |  |  			if err != nil {
 | 
	
		
			
				|  |  | -				if err != io.EOF {
 | 
	
		
			
				|  |  | -					fmt.Printf("read error: %s\n", err)
 | 
	
		
			
				|  |  | -				}
 | 
	
		
			
				|  |  |  				return
 | 
	
		
			
				|  |  |  			}
 | 
	
		
			
				|  |  | -			lines := strings.Split(string(msg), "\n")
 | 
	
		
			
				|  |  | -			lastCommandLen = len(lines[len(lines)-1])
 | 
	
		
			
				|  |  | -			if outputEnabled {
 | 
	
		
			
				|  |  | -				fmt.Printf("%s", msg)
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  | +			fmt.Fprint(writer, string(msg))
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  	}()
 | 
	
		
			
				|  |  | -	interrupt := make(chan os.Signal, 1)
 | 
	
		
			
				|  |  | -	signal.Notify(interrupt, os.Interrupt)
 | 
	
		
			
				|  |  | -	scanner := bufio.NewScanner(os.Stdin)
 | 
	
		
			
				|  |  | -	rc := make(chan string, 10)
 | 
	
		
			
				|  |  | -	if len(commands) > 0 {
 | 
	
		
			
				|  |  | -		for _, command := range commands {
 | 
	
		
			
				|  |  | -			rc <- command
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	readChannel := make(chan []byte, 10)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	go func() {
 | 
	
		
			
				|  |  |  		for {
 | 
	
		
			
				|  |  | -			if !scanner.Scan() {
 | 
	
		
			
				|  |  | +			// We send each byte, byte-per-byte over the
 | 
	
		
			
				|  |  | +			// websocket because the console is in raw mode
 | 
	
		
			
				|  |  | +			buffer := make([]byte, 1)
 | 
	
		
			
				|  |  | +			n, err := reader.Read(buffer)
 | 
	
		
			
				|  |  | +			if err != nil {
 | 
	
		
			
				|  |  |  				close(done)
 | 
	
		
			
				|  |  |  				cancel()
 | 
	
		
			
				|  |  | -				fmt.Println("exiting...")
 | 
	
		
			
				|  |  |  				break
 | 
	
		
			
				|  |  |  			}
 | 
	
		
			
				|  |  | -			t := scanner.Text()
 | 
	
		
			
				|  |  | -			rc <- t
 | 
	
		
			
				|  |  | -			cleanLastCommand(lastCommandLen)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +			if n > 0 {
 | 
	
		
			
				|  |  | +				readChannel <- buffer
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  	}()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	for {
 | 
	
		
			
				|  |  |  		select {
 | 
	
		
			
				|  |  |  		case <-done:
 | 
	
		
			
				|  |  |  			return nil
 | 
	
		
			
				|  |  | -		case line := <-rc:
 | 
	
		
			
				|  |  | -			err = wsutil.WriteClientMessage(conn, ws.OpText, []byte(line+"\n"))
 | 
	
		
			
				|  |  | +		case bytes := <-readChannel:
 | 
	
		
			
				|  |  | +			err := wsutil.WriteClientMessage(conn, ws.OpText, bytes)
 | 
	
		
			
				|  |  |  			if err != nil {
 | 
	
		
			
				|  |  | -				fmt.Println("write: ", err)
 | 
	
		
			
				|  |  | -				return nil
 | 
	
		
			
				|  |  | +				return err
 | 
	
		
			
				|  |  |  			}
 | 
	
		
			
				|  |  | -		case <-interrupt:
 | 
	
		
			
				|  |  | -			fmt.Println("interrupted...")
 | 
	
		
			
				|  |  | -			close(done)
 | 
	
		
			
				|  |  | -			cancel()
 | 
	
		
			
				|  |  | -			return nil
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  | -}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -func cleanLastCommand(lastCommandLen int) {
 | 
	
		
			
				|  |  | -	tm.MoveCursorUp(1)
 | 
	
		
			
				|  |  | -	tm.MoveCursorForward(lastCommandLen)
 | 
	
		
			
				|  |  | -	if runtime.GOOS != "windows" {
 | 
	
		
			
				|  |  | -		for i := 0; i < tm.Width(); i++ {
 | 
	
		
			
				|  |  | -			_, _ = tm.Print(" ")
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  | -		tm.MoveCursorUp(1)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -	tm.Flush()
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func getContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) {
 |