// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. <#@ template debug="false" hostspecific="false" language="C#" #> <#@ assembly name="System.Core" #> <#@ import namespace="System.Linq" #> <#@ import namespace="System.Text" #> <#@ import namespace="System.Collections.Generic" #> <#@ output extension=".cs" #> using System.Collections.Generic; using System.Reactive.Disposables; using System.Threading; using System.Threading.Tasks; namespace System.Reactive.Linq { partial class AsyncObservable { <# for (var i = 2; i <= 15; i++) { var genPars = string.Join(", ", Enumerable.Range(1, i).Select(j => "T" + j).Concat(new[] { "TResult" })); var genArgs = string.Join(", ", Enumerable.Range(1, i).Select(j => "T" + j)); var args = string.Join(", ", Enumerable.Range(1, i).Select(j => "IAsyncObservable source" + j)); var obs = "(" + string.Join(", ", Enumerable.Range(1, i).Select(j => "observer" + j)) + ")"; #> public static IAsyncObservable Zip<<#=genPars#>>(this <#=args#>, Func<<#=genArgs#>, TResult> selector) { <# for (var j = 1; j <= i; j++) { #> if (source<#=j#> == null) throw new ArgumentNullException(nameof(source<#=j#>)); <# } #> if (selector == null) throw new ArgumentNullException(nameof(selector)); return Create(async observer => { var d = new CompositeAsyncDisposable(); var <#=obs#> = AsyncObserver.Zip(observer, selector); <# for (var j = 1; j <= i; j++) { #> var sub<#=j#> = source<#=j#>.SubscribeSafeAsync(observer<#=j#>).ContinueWith(disposable => d.AddAsync(disposable.Result)).Unwrap(); <# } #> await Task.WhenAll(<#=string.Join(", ", Enumerable.Range(1, i).Select(j => "sub" + j))#>).ConfigureAwait(false); return d; }); } public static IAsyncObservable Zip<<#=genPars#>>(this <#=args#>, Func<<#=genArgs#>, Task> selector) { <# for (var j = 1; j <= i; j++) { #> if (source<#=j#> == null) throw new ArgumentNullException(nameof(source<#=j#>)); <# } #> if (selector == null) throw new ArgumentNullException(nameof(selector)); return Create(async observer => { var d = new CompositeAsyncDisposable(); var <#=obs#> = AsyncObserver.Zip(observer, selector); <# for (var j = 1; j <= i; j++) { #> var sub<#=j#> = source<#=j#>.SubscribeSafeAsync(observer<#=j#>).ContinueWith(disposable => d.AddAsync(disposable.Result)).Unwrap(); <# } #> await Task.WhenAll(<#=string.Join(", ", Enumerable.Range(1, i).Select(j => "sub" + j))#>).ConfigureAwait(false); return d; }); } <# } #> } partial class AsyncObserver { <# for (var i = 2; i <= 15; i++) { var res = "(" + string.Join(", ", Enumerable.Range(1, i).Select(j => "IAsyncObserver")) + ")"; var genPars = string.Join(", ", Enumerable.Range(1, i).Select(j => "T" + j).Concat(new[] { "TResult" })); var genArgs = string.Join(", ", Enumerable.Range(1, i).Select(j => "T" + j)); #> public static <#=res#> Zip<<#=genPars#>>(IAsyncObserver observer, Func<<#=genArgs#>, TResult> selector) { if (observer == null) throw new ArgumentNullException(nameof(observer)); if (selector == null) throw new ArgumentNullException(nameof(selector)); return Zip<<#=genPars#>>(observer, (<#=string.Join(", ", Enumerable.Range(1, i).Select(j => "x" + j))#>) => Task.FromResult(selector(<#=string.Join(", ", Enumerable.Range(1, i).Select(j => "x" + j))#>))); } public static <#=res#> Zip<<#=genPars#>>(IAsyncObserver observer, Func<<#=genArgs#>, Task> selector) { if (observer == null) throw new ArgumentNullException(nameof(observer)); if (selector == null) throw new ArgumentNullException(nameof(selector)); var gate = new AsyncLock(); <# for (var j = 1; j <= i; j++) { #> var values<#=j#> = new Queue>(); <# } #> var isDone = new bool[<#=i#>]; <# var all = string.Join(" && ", Enumerable.Range(1, i).Select(j => "values" + j + ".Count > 0")); var vals = string.Join(", ", Enumerable.Range(1, i).Select(j => "values" + j + ".Dequeue()")); var done = ""; #> IAsyncObserver CreateObserver(int index, Queue queue) => Create( async x => { using (await gate.LockAsync().ConfigureAwait(false)) { queue.Enqueue(x); if (<#=all#>) { TResult res; try { res = await selector(<#=vals#>).ConfigureAwait(false); } catch (Exception ex) { await observer.OnErrorAsync(ex).ConfigureAwait(false); return; } await observer.OnNextAsync(res).ConfigureAwait(false); } else { var allDone = true; for (var i = 0; i < <#=i#>; i++) { if (i != index && !isDone[i]) { allDone = false; break; } } if (allDone) { await observer.OnCompletedAsync().ConfigureAwait(false); } } } }, async ex => { using (await gate.LockAsync().ConfigureAwait(false)) { await observer.OnErrorAsync(ex).ConfigureAwait(false); } }, async () => { using (await gate.LockAsync().ConfigureAwait(false)) { isDone[index] = true; var allDone = true; for (var i = 0; i < <#=i#>; i++) { if (!isDone[i]) { allDone = false; break; } } if (allDone) { await observer.OnCompletedAsync().ConfigureAwait(false); } } } ); return ( <# for (var j = 1; j <= i; j++) { #> CreateObserver>(<#=j#>, values<#=j#>)<#=(j < i ? "," : "")#> <# } #> ); } <# } #> } }