Browse Source

ipn/localapi: add support for multipart POST to file-put

This allows sending multiple files via Taildrop in one request.
Progress is tracked via ipn.Notify.

Updates tailscale/corp#18202

Signed-off-by: Percy Wegmann <[email protected]>
Percy Wegmann 1 year ago
parent
commit
bed818a978
5 changed files with 300 additions and 17 deletions
  1. 21 3
      ipn/backend.go
  2. 3 0
      ipn/ipnlocal/local.go
  3. 34 0
      ipn/ipnlocal/taildrop.go
  4. 203 14
      ipn/localapi/localapi.go
  5. 39 0
      util/progresstracking/progresstracking.go

+ 21 - 3
ipn/backend.go

@@ -67,8 +67,9 @@ const (
 	NotifyInitialPrefs  // if set, the first Notify message (sent immediately) will contain the current Prefs
 	NotifyInitialNetMap // if set, the first Notify message (sent immediately) will contain the current NetMap
 
-	NotifyNoPrivateKeys       // if set, private keys that would normally be sent in updates are zeroed out
-	NotifyInitialTailFSShares // if set, the first Notify message (sent immediately) will contain the current TailFS Shares
+	NotifyNoPrivateKeys        // if set, private keys that would normally be sent in updates are zeroed out
+	NotifyInitialTailFSShares  // if set, the first Notify message (sent immediately) will contain the current TailFS Shares
+	NotifyInitialOutgoingFiles // if set, the first Notify message (sent immediately) will contain the current Taildrop OutgoingFiles
 )
 
 // Notify is a communication from a backend (e.g. tailscaled) to a frontend
@@ -114,6 +115,11 @@ type Notify struct {
 	// Deprecated: use LocalClient.AwaitWaitingFiles instead.
 	IncomingFiles []PartialFile `json:",omitempty"`
 
+	// OutgoingFiles, if non-nil, tracks which files are in the process of
+	// being sent via TailDrop, including files that finished, whether
+	// successful or failed. This slice is sorted by Started time, then Name.
+	OutgoingFiles []*OutgoingFile `json:",omitempty"`
+
 	// LocalTCPPort, if non-nil, informs the UI frontend which
 	// (non-zero) localhost TCP port it's listening on.
 	// This is currently only used by Tailscale when run in the
@@ -175,7 +181,7 @@ func (n Notify) String() string {
 	return s[0:len(s)-1] + "}"
 }
 
-// PartialFile represents an in-progress file transfer.
+// PartialFile represents an in-progress incoming file transfer.
 type PartialFile struct {
 	Name         string    // e.g. "foo.jpg"
 	Started      time.Time // time transfer started
@@ -194,6 +200,18 @@ type PartialFile struct {
 	Done bool `json:",omitempty"`
 }
 
+// OutgoingFile represents an in-progress outgoing file transfer.
+type OutgoingFile struct {
+	ID           string               `json:"-"` // unique identifier for this transfer (a type 4 UUID)
+	PeerID       tailcfg.StableNodeID // identifier for the peer to which this is being transferred
+	Name         string               `json:",omitempty"` // e.g. "foo.jpg"
+	Started      time.Time            // time transfer started
+	DeclaredSize int64                // or -1 if unknown
+	Sent         int64                // bytes copied thus far
+	Finished     bool                 // indicates whether or not the transfer finished
+	Succeeded    bool                 // for a finished transfer, indicates whether or not it was successful
+}
+
 // StateKey is an opaque identifier for a set of LocalBackend state
 // (preferences, private keys, etc.). It is also used as a key for
 // the various LoginProfiles that the instance may be signed into.

+ 3 - 0
ipn/ipnlocal/local.go

@@ -319,6 +319,9 @@ type LocalBackend struct {
 	// lastNotifiedTailFSShares keeps track of the last set of shares that we
 	// notified about.
 	lastNotifiedTailFSShares atomic.Pointer[views.SliceView[*tailfs.Share, tailfs.ShareView]]
+
+	// outgoingFiles keeps track of Taildrop outgoing files
+	outgoingFiles map[string]*ipn.OutgoingFile
 }
 
 type updateStatus struct {

+ 34 - 0
ipn/ipnlocal/taildrop.go

@@ -0,0 +1,34 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ipnlocal
+
+import (
+	"slices"
+	"strings"
+
+	"tailscale.com/ipn"
+)
+
+func (b *LocalBackend) UpdateOutgoingFiles(updates map[string]ipn.OutgoingFile) {
+	b.mu.Lock()
+	if b.outgoingFiles == nil {
+		b.outgoingFiles = make(map[string]*ipn.OutgoingFile, len(updates))
+	}
+	for id, file := range updates {
+		b.outgoingFiles[id] = &file
+	}
+	outgoingFiles := make([]*ipn.OutgoingFile, 0, len(b.outgoingFiles))
+	for _, file := range b.outgoingFiles {
+		outgoingFiles = append(outgoingFiles, file)
+	}
+	b.mu.Unlock()
+	slices.SortFunc(outgoingFiles, func(a, b *ipn.OutgoingFile) int {
+		t := a.Started.Compare(b.Started)
+		if t != 0 {
+			return t
+		}
+		return strings.Compare(a.Name, b.Name)
+	})
+	b.send(ipn.Notify{OutgoingFiles: outgoingFiles})
+}

+ 203 - 14
ipn/localapi/localapi.go

@@ -13,6 +13,9 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"maps"
+	"mime"
+	"mime/multipart"
 	"net"
 	"net/http"
 	"net/http/httputil"
@@ -28,6 +31,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/google/uuid"
 	"tailscale.com/client/tailscale/apitype"
 	"tailscale.com/clientupdate"
 	"tailscale.com/envknob"
@@ -57,6 +61,7 @@ import (
 	"tailscale.com/util/mak"
 	"tailscale.com/util/osdiag"
 	"tailscale.com/util/osuser"
+	"tailscale.com/util/progresstracking"
 	"tailscale.com/util/rands"
 	"tailscale.com/version"
 	"tailscale.com/wgengine/magicsock"
@@ -1529,9 +1534,17 @@ func (h *Handler) serveFileTargets(w http.ResponseWriter, r *http.Request) {
 // The Windows client currently (2021-11-30) uses the peerapi (/v0/put/)
 // directly, as the Windows GUI always runs in tun mode anyway.
 //
+// In addition to single file PUTs, this endpoint accepts multipart file
+// POSTS encoded as multipart/form-data. Each part must include a
+// "Content-Length" in the MIME header indicating the size of the file.
+// The first part should be an application/json file that contains a JSON map
+// of filename -> length, which we can use for tracking progress even before
+// reading the file parts.
+//
 // URL format:
 //
 //   - PUT /localapi/v0/file-put/:stableID/:escaped-filename
+//   - POST /localapi/v0/file-put/:stableID
 func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) {
 	metricFilePutCalls.Add(1)
 
@@ -1539,10 +1552,12 @@ func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "file access denied", http.StatusForbidden)
 		return
 	}
-	if r.Method != "PUT" {
+
+	if r.Method != "PUT" && r.Method != "POST" {
 		http.Error(w, "want PUT to put file", http.StatusBadRequest)
 		return
 	}
+
 	fts, err := h.b.FileTargets()
 	if err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -1554,16 +1569,22 @@ func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "misconfigured", http.StatusInternalServerError)
 		return
 	}
-	stableIDStr, filenameEscaped, ok := strings.Cut(upath, "/")
-	if !ok {
-		http.Error(w, "bogus URL", http.StatusBadRequest)
-		return
+	var peerIDStr, filenameEscaped string
+	if r.Method == "PUT" {
+		ok := false
+		peerIDStr, filenameEscaped, ok = strings.Cut(upath, "/")
+		if !ok {
+			http.Error(w, "bogus URL", http.StatusBadRequest)
+			return
+		}
+	} else {
+		peerIDStr = upath
 	}
-	stableID := tailcfg.StableNodeID(stableIDStr)
+	peerID := tailcfg.StableNodeID(peerIDStr)
 
 	var ft *apitype.FileTarget
 	for _, x := range fts {
-		if x.Node.StableID == stableID {
+		if x.Node.StableID == peerID {
 			ft = x
 			break
 		}
@@ -1578,20 +1599,181 @@ func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	// Report progress on outgoing files every 5 seconds
+	outgoingFiles := make(map[string]ipn.OutgoingFile)
+	t := time.NewTicker(5 * time.Second)
+	progressUpdates := make(chan ipn.OutgoingFile)
+	defer close(progressUpdates)
+
+	go func() {
+		defer t.Stop()
+		defer h.b.UpdateOutgoingFiles(outgoingFiles)
+		for {
+			select {
+			case u, ok := <-progressUpdates:
+				if !ok {
+					return
+				}
+				outgoingFiles[u.ID] = u
+			case <-t.C:
+				h.b.UpdateOutgoingFiles(outgoingFiles)
+			}
+		}
+	}()
+
+	switch r.Method {
+	case "PUT":
+		file := ipn.OutgoingFile{
+			ID:           uuid.Must(uuid.NewRandom()).String(),
+			PeerID:       peerID,
+			Name:         filenameEscaped,
+			DeclaredSize: r.ContentLength,
+		}
+		h.singleFilePut(r.Context(), progressUpdates, w, r.Body, dstURL, file)
+	case "POST":
+		h.multiFilePost(progressUpdates, w, r, peerID, dstURL)
+	default:
+		http.Error(w, "want PUT to put file", http.StatusBadRequest)
+		return
+	}
+}
+
+func (h *Handler) multiFilePost(progressUpdates chan (ipn.OutgoingFile), w http.ResponseWriter, r *http.Request, peerID tailcfg.StableNodeID, dstURL *url.URL) {
+	_, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
+	if err != nil {
+		http.Error(w, fmt.Sprintf("invalid Content-Type for multipart POST: %s", err), http.StatusBadRequest)
+		return
+	}
+
+	ww := &multiFilePostResponseWriter{}
+	defer func() {
+		if err := ww.Flush(w); err != nil {
+			h.logf("error: multiFilePostResponseWriter.Flush(): %s", err)
+		}
+	}()
+
+	outgoingFilesByName := make(map[string]ipn.OutgoingFile)
+	first := true
+	mr := multipart.NewReader(r.Body, params["boundary"])
+	for {
+		part, err := mr.NextPart()
+		if err == io.EOF {
+			// No more parts.
+			return
+		} else if err != nil {
+			http.Error(ww, fmt.Sprintf("failed to decode multipart/form-data: %s", err), http.StatusBadRequest)
+			return
+		}
+
+		if first {
+			first = false
+			if part.Header.Get("Content-Type") != "application/json" {
+				http.Error(ww, "first MIME part must be a JSON map of filename -> size", http.StatusBadRequest)
+				return
+			}
+
+			var manifest map[string]int64
+			err := json.NewDecoder(part).Decode(&manifest)
+			if err != nil {
+				http.Error(ww, fmt.Sprintf("invalid manifest: %s", err), http.StatusBadRequest)
+				return
+			}
+
+			for filename, size := range manifest {
+				file := ipn.OutgoingFile{
+					ID:           uuid.Must(uuid.NewRandom()).String(),
+					Name:         filename,
+					PeerID:       peerID,
+					DeclaredSize: size,
+				}
+				outgoingFilesByName[filename] = file
+				progressUpdates <- file
+			}
+
+			continue
+		}
+
+		if !h.singleFilePut(r.Context(), progressUpdates, ww, part, dstURL, outgoingFilesByName[part.FileName()]) {
+			return
+		}
+
+		if ww.statusCode >= 400 {
+			// put failed, stop immediately
+			h.logf("error: singleFilePut: failed with status %d", ww.statusCode)
+			return
+		}
+	}
+}
+
+// multiFilePostResponseWriter is a buffering http.ResponseWriter that can be
+// reused across multiple singleFilePut calls and then flushed to the client
+// when all files have been PUT.
+type multiFilePostResponseWriter struct {
+	header     http.Header
+	statusCode int
+	body       *bytes.Buffer
+}
+
+func (ww *multiFilePostResponseWriter) Header() http.Header {
+	if ww.header == nil {
+		ww.header = make(http.Header)
+	}
+	return ww.header
+}
+
+func (ww *multiFilePostResponseWriter) WriteHeader(statusCode int) {
+	ww.statusCode = statusCode
+}
+
+func (ww *multiFilePostResponseWriter) Write(p []byte) (int, error) {
+	if ww.body == nil {
+		ww.body = bytes.NewBuffer(nil)
+	}
+	return ww.body.Write(p)
+}
+
+func (ww *multiFilePostResponseWriter) Flush(w http.ResponseWriter) error {
+	maps.Copy(w.Header(), ww.Header())
+	w.WriteHeader(ww.statusCode)
+	_, err := io.Copy(w, ww.body)
+	return err
+}
+
+func (h *Handler) singleFilePut(
+	ctx context.Context,
+	progressUpdates chan (ipn.OutgoingFile),
+	w http.ResponseWriter,
+	body io.Reader,
+	dstURL *url.URL,
+	outgoingFile ipn.OutgoingFile,
+) bool {
+	outgoingFile.Started = time.Now()
+	body = progresstracking.NewReader(body, 1*time.Second, func(n int, err error) {
+		outgoingFile.Sent = int64(n)
+		progressUpdates <- outgoingFile
+	})
+
+	fail := func() {
+		outgoingFile.Finished = true
+		outgoingFile.Succeeded = false
+		progressUpdates <- outgoingFile
+	}
+
 	// Before we PUT a file we check to see if there are any existing partial file and if so,
 	// we resume the upload from where we left off by sending the remaining file instead of
 	// the full file.
 	var offset int64
 	var resumeDuration time.Duration
-	remainingBody := io.Reader(r.Body)
+	remainingBody := io.Reader(body)
 	client := &http.Client{
 		Transport: h.b.Dialer().PeerAPITransport(),
 		Timeout:   10 * time.Second,
 	}
-	req, err := http.NewRequestWithContext(r.Context(), "GET", dstURL.String()+"/v0/put/"+filenameEscaped, nil)
+	req, err := http.NewRequestWithContext(ctx, "GET", dstURL.String()+"/v0/put/"+outgoingFile.Name, nil)
 	if err != nil {
 		http.Error(w, "bogus peer URL", http.StatusInternalServerError)
-		return
+		fail()
+		return false
 	}
 	switch resp, err := client.Do(req); {
 	case err != nil:
@@ -1603,7 +1785,7 @@ func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) {
 	default:
 		resumeStart := time.Now()
 		dec := json.NewDecoder(resp.Body)
-		offset, remainingBody, err = taildrop.ResumeReader(r.Body, func() (out taildrop.BlockChecksum, err error) {
+		offset, remainingBody, err = taildrop.ResumeReader(body, func() (out taildrop.BlockChecksum, err error) {
 			err = dec.Decode(&out)
 			return out, err
 		})
@@ -1613,12 +1795,13 @@ func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) {
 		resumeDuration = time.Since(resumeStart).Round(time.Millisecond)
 	}
 
-	outReq, err := http.NewRequestWithContext(r.Context(), "PUT", "http://peer/v0/put/"+filenameEscaped, remainingBody)
+	outReq, err := http.NewRequestWithContext(ctx, "PUT", "http://peer/v0/put/"+outgoingFile.Name, remainingBody)
 	if err != nil {
 		http.Error(w, "bogus outreq", http.StatusInternalServerError)
-		return
+		fail()
+		return false
 	}
-	outReq.ContentLength = r.ContentLength
+	outReq.ContentLength = outgoingFile.DeclaredSize
 	if offset > 0 {
 		h.logf("resuming put at offset %d after %v", offset, resumeDuration)
 		rangeHdr, _ := httphdr.FormatRange([]httphdr.Range{{Start: offset, Length: 0}})
@@ -1631,6 +1814,12 @@ func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) {
 	rp := httputil.NewSingleHostReverseProxy(dstURL)
 	rp.Transport = h.b.Dialer().PeerAPITransport()
 	rp.ServeHTTP(w, outReq)
+
+	outgoingFile.Finished = true
+	outgoingFile.Succeeded = true
+	progressUpdates <- outgoingFile
+
+	return true
 }
 
 func (h *Handler) serveSetDNS(w http.ResponseWriter, r *http.Request) {

+ 39 - 0
util/progresstracking/progresstracking.go

@@ -0,0 +1,39 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package progresstracking provides wrappers around io.Reader and io.Writer
+// that track progress.
+package progresstracking
+
+import (
+	"io"
+	"time"
+)
+
+// NewReader wraps the given Reader with a progress tracking Reader that
+// reports progress at the following points:
+//
+// - First read
+// - Every read spaced at least interval since the prior read
+// - Last read
+func NewReader(r io.Reader, interval time.Duration, onProgress func(totalRead int, err error)) io.Reader {
+	return &reader{Reader: r, interval: interval, onProgress: onProgress}
+}
+
+type reader struct {
+	io.Reader
+	interval    time.Duration
+	onProgress  func(int, error)
+	lastTracked time.Time
+	totalRead   int
+}
+
+func (r *reader) Read(p []byte) (int, error) {
+	n, err := r.Reader.Read(p)
+	r.totalRead += n
+	if time.Since(r.lastTracked) > r.interval || err != nil {
+		r.onProgress(r.totalRead, err)
+		r.lastTracked = time.Now()
+	}
+	return n, err
+}