Browse Source

fix: simplify parallel map using channels (#582)

Craig Andrews 7 months ago
parent
commit
73c012c76c

+ 22 - 32
packages/tui/internal/util/concurrency.go

@@ -2,49 +2,39 @@ package util
 
 import (
 	"strings"
-	"sync"
 )
 
-// MapReducePar performs a parallel map-reduce operation on a slice of items.
-// It applies a function to each item in the slice concurrently,
-// and combines the results serially using a reducer returned from
-// each one of the functions, allowing the use of closures.
-func MapReducePar[a, b any](items []a, init b, fn func(a) func(b) b) b {
-	itemCount := len(items)
-	locks := make([]*sync.Mutex, itemCount)
-	mapped := make([]func(b) b, itemCount)
-
-	for i, value := range items {
-		lock := &sync.Mutex{}
-		lock.Lock()
-		locks[i] = lock
+func mapParallel[in, out any](items []in, fn func(in) out) chan out {
+	mapChans := make([]chan out, 0, len(items))
+
+	for _, v := range items {
+		ch := make(chan out)
+		mapChans = append(mapChans, ch)
 		go func() {
-			defer lock.Unlock()
-			mapped[i] = fn(value)
+			defer close(ch)
+			ch <- fn(v)
 		}()
 	}
 
-	result := init
-	for i := range itemCount {
-		locks[i].Lock()
-		defer locks[i].Unlock()
-		f := mapped[i]
-		if f != nil {
-			result = f(result)
+	resultChan := make(chan out)
+
+	go func() {
+		defer close(resultChan)
+		for _, ch := range mapChans {
+			v := <-ch
+			resultChan <- v
 		}
-	}
+	}()
 
-	return result
+	return resultChan
 }
 
 // WriteStringsPar allows to iterate over a list and compute strings in parallel,
 // yet write them in order.
 func WriteStringsPar[a any](sb *strings.Builder, items []a, fn func(a) string) {
-	MapReducePar(items, sb, func(item a) func(*strings.Builder) *strings.Builder {
-		str := fn(item)
-		return func(sbdr *strings.Builder) *strings.Builder {
-			sbdr.WriteString(str)
-			return sbdr
-		}
-	})
+	ch := mapParallel(items, fn)
+
+	for v := range ch {
+		sb.WriteString(v)
+	}
 }

+ 23 - 0
packages/tui/internal/util/concurrency_test.go

@@ -0,0 +1,23 @@
+package util_test
+
+import (
+	"strconv"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/sst/opencode/internal/util"
+)
+
+func TestWriteStringsPar(t *testing.T) {
+	items := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
+	sb := strings.Builder{}
+	util.WriteStringsPar(&sb, items, func(i int) string {
+		// sleep for the inverse duration so that later items finish first
+		time.Sleep(time.Duration(10-i) * time.Millisecond)
+		return strconv.Itoa(i)
+	})
+	if sb.String() != "0123456789" {
+		t.Fatalf("expected 0123456789, got %s", sb.String())
+	}
+}