|
|
@@ -1,15 +1,12 @@
|
|
|
package azure
|
|
|
|
|
|
import (
|
|
|
- "bufio"
|
|
|
"context"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "io/ioutil"
|
|
|
"net/http"
|
|
|
"os"
|
|
|
- "os/signal"
|
|
|
- "runtime"
|
|
|
- "strings"
|
|
|
|
|
|
"github.com/docker/api/context/store"
|
|
|
|
|
|
@@ -81,18 +78,17 @@ func createACIContainers(ctx context.Context, aciContext store.AciContext, group
|
|
|
|
|
|
containers := *containerGroup.Containers
|
|
|
container := containers[0]
|
|
|
- response, err := execACIContainer(ctx, "/bin/sh", *containerGroup.Name, *container.Name, aciContext)
|
|
|
+ response, err := execACIContainer(ctx, aciContext, "/bin/sh", *containerGroup.Name, *container.Name)
|
|
|
if err != nil {
|
|
|
return c, err
|
|
|
}
|
|
|
|
|
|
- err = execWebSocketLoopWithCmd(
|
|
|
+ if err = execCommands(
|
|
|
ctx,
|
|
|
*response.WebSocketURI,
|
|
|
*response.Password,
|
|
|
commands,
|
|
|
- false)
|
|
|
- if err != nil {
|
|
|
+ ); err != nil {
|
|
|
return containerinstance.ContainerGroup{}, err
|
|
|
}
|
|
|
}
|
|
|
@@ -122,7 +118,7 @@ func listACIContainers(aciContext store.AciContext) (c []containerinstance.Conta
|
|
|
return containers, err
|
|
|
}
|
|
|
|
|
|
-func execACIContainer(ctx context.Context, command, containerGroup string, containerName string, aciContext store.AciContext) (c containerinstance.ContainerExecResponse, err error) {
|
|
|
+func execACIContainer(ctx context.Context, aciContext store.AciContext, command, containerGroup string, containerName string) (c containerinstance.ContainerExecResponse, err error) {
|
|
|
containerClient := getContainerClient(aciContext.SubscriptionID)
|
|
|
rows, cols := getTermSize()
|
|
|
containerExecRequest := containerinstance.ContainerExecRequest{
|
|
|
@@ -132,6 +128,7 @@ func execACIContainer(ctx context.Context, command, containerGroup string, conta
|
|
|
Cols: cols,
|
|
|
},
|
|
|
}
|
|
|
+
|
|
|
return containerClient.ExecuteCommand(
|
|
|
ctx,
|
|
|
aciContext.ResourceGroup,
|
|
|
@@ -146,93 +143,85 @@ func getTermSize() (*int32, *int32) {
|
|
|
return to.Int32Ptr(int32(rows)), to.Int32Ptr(int32(cols))
|
|
|
}
|
|
|
|
|
|
-func execWebSocketLoop(ctx context.Context, wsURL, passwd string) error {
|
|
|
- return execWebSocketLoopWithCmd(ctx, wsURL, passwd, []string{}, true)
|
|
|
+type commandSender struct {
|
|
|
+ commands []string
|
|
|
}
|
|
|
|
|
|
-func execWebSocketLoopWithCmd(ctx context.Context, wsURL, passwd string, commands []string, outputEnabled bool) error {
|
|
|
+func (cs commandSender) Read(p []byte) (int, error) {
|
|
|
+ if len(cs.commands) == 0 {
|
|
|
+ return 0, io.EOF
|
|
|
+ }
|
|
|
+ command := cs.commands[0]
|
|
|
+ cs.commands = cs.commands[1:]
|
|
|
+ copy(p, command)
|
|
|
+ return len(command), nil
|
|
|
+}
|
|
|
+
|
|
|
+func execCommands(ctx context.Context, address string, password string, commands []string) error {
|
|
|
+ writer := ioutil.Discard
|
|
|
+ reader := commandSender{
|
|
|
+ commands: commands,
|
|
|
+ }
|
|
|
+ return exec(ctx, address, password, reader, writer)
|
|
|
+}
|
|
|
+
|
|
|
+func exec(ctx context.Context, address string, password string, reader io.Reader, writer io.Writer) error {
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
- conn, _, _, err := ws.DefaultDialer.Dial(ctx, wsURL)
|
|
|
+ conn, _, _, err := ws.DefaultDialer.Dial(ctx, address)
|
|
|
if err != nil {
|
|
|
cancel()
|
|
|
return err
|
|
|
}
|
|
|
- err = wsutil.WriteClientMessage(conn, ws.OpText, []byte(passwd))
|
|
|
+ err = wsutil.WriteClientMessage(conn, ws.OpText, []byte(password))
|
|
|
if err != nil {
|
|
|
cancel()
|
|
|
return err
|
|
|
}
|
|
|
- lastCommandLen := 0
|
|
|
+
|
|
|
done := make(chan struct{})
|
|
|
+
|
|
|
go func() {
|
|
|
defer close(done)
|
|
|
for {
|
|
|
msg, _, err := wsutil.ReadServerData(conn)
|
|
|
if err != nil {
|
|
|
- if err != io.EOF {
|
|
|
- fmt.Printf("read error: %s\n", err)
|
|
|
- }
|
|
|
return
|
|
|
}
|
|
|
- lines := strings.Split(string(msg), "\n")
|
|
|
- lastCommandLen = len(lines[len(lines)-1])
|
|
|
- if outputEnabled {
|
|
|
- fmt.Printf("%s", msg)
|
|
|
- }
|
|
|
+ fmt.Fprint(writer, string(msg))
|
|
|
}
|
|
|
}()
|
|
|
- interrupt := make(chan os.Signal, 1)
|
|
|
- signal.Notify(interrupt, os.Interrupt)
|
|
|
- scanner := bufio.NewScanner(os.Stdin)
|
|
|
- rc := make(chan string, 10)
|
|
|
- if len(commands) > 0 {
|
|
|
- for _, command := range commands {
|
|
|
- rc <- command
|
|
|
- }
|
|
|
- }
|
|
|
+
|
|
|
+ readChannel := make(chan []byte, 10)
|
|
|
+
|
|
|
go func() {
|
|
|
for {
|
|
|
- if !scanner.Scan() {
|
|
|
+ // We send each byte, byte-per-byte over the
|
|
|
+ // websocket because the console is in raw mode
|
|
|
+ buffer := make([]byte, 1)
|
|
|
+ n, err := reader.Read(buffer)
|
|
|
+ if err != nil {
|
|
|
close(done)
|
|
|
cancel()
|
|
|
- fmt.Println("exiting...")
|
|
|
break
|
|
|
}
|
|
|
- t := scanner.Text()
|
|
|
- rc <- t
|
|
|
- cleanLastCommand(lastCommandLen)
|
|
|
+
|
|
|
+ if n > 0 {
|
|
|
+ readChannel <- buffer
|
|
|
+ }
|
|
|
}
|
|
|
}()
|
|
|
+
|
|
|
for {
|
|
|
select {
|
|
|
case <-done:
|
|
|
return nil
|
|
|
- case line := <-rc:
|
|
|
- err = wsutil.WriteClientMessage(conn, ws.OpText, []byte(line+"\n"))
|
|
|
+ case bytes := <-readChannel:
|
|
|
+ err := wsutil.WriteClientMessage(conn, ws.OpText, bytes)
|
|
|
if err != nil {
|
|
|
- fmt.Println("write: ", err)
|
|
|
- return nil
|
|
|
+ return err
|
|
|
}
|
|
|
- case <-interrupt:
|
|
|
- fmt.Println("interrupted...")
|
|
|
- close(done)
|
|
|
- cancel()
|
|
|
- return nil
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-func cleanLastCommand(lastCommandLen int) {
|
|
|
- tm.MoveCursorUp(1)
|
|
|
- tm.MoveCursorForward(lastCommandLen)
|
|
|
- if runtime.GOOS != "windows" {
|
|
|
- for i := 0; i < tm.Width(); i++ {
|
|
|
- _, _ = tm.Print(" ")
|
|
|
}
|
|
|
- tm.MoveCursorUp(1)
|
|
|
}
|
|
|
-
|
|
|
- tm.Flush()
|
|
|
}
|
|
|
|
|
|
func getContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) {
|