| | | 1 | | using System.Collections.Concurrent; |
| | | 2 | | |
| | | 3 | | namespace Elsa.Testing.Shared.Services; |
| | | 4 | | |
| | | 5 | | public class SignalManager |
| | | 6 | | { |
| | 1 | 7 | | private readonly ConcurrentDictionary<object, TaskCompletionSource<object?>> _signals = new(); |
| | | 8 | | |
| | | 9 | | public async Task<T> WaitAsync<T>(object signal, int millisecondsTimeout = 60000) |
| | | 10 | | { |
| | 33 | 11 | | var result = await WaitAsync(signal, millisecondsTimeout); |
| | | 12 | | |
| | 33 | 13 | | if(result is not T typedResult) |
| | 0 | 14 | | throw new InvalidCastException($"Signal '{signal}' was not of type '{typeof(T).Name}'."); |
| | | 15 | | |
| | 33 | 16 | | return typedResult; |
| | 33 | 17 | | } |
| | | 18 | | |
| | | 19 | | public async Task<object?> WaitAsync(object signal, int millisecondsTimeout = 60000) |
| | | 20 | | { |
| | 33 | 21 | | var taskCompletionSource = GetOrCreate(signal); |
| | 33 | 22 | | using var cancellationTokenSource = new CancellationTokenSource(millisecondsTimeout); |
| | 33 | 23 | | var delayTask = Task.Delay(millisecondsTimeout, cancellationTokenSource.Token); |
| | 33 | 24 | | var completedTask = await Task.WhenAny(taskCompletionSource.Task, delayTask); |
| | | 25 | | |
| | 33 | 26 | | if (completedTask == delayTask) |
| | | 27 | | { |
| | 0 | 28 | | _signals.TryRemove(signal, out _); |
| | 0 | 29 | | throw new TimeoutException($"Signal '{signal}' timed out after {millisecondsTimeout} milliseconds."); |
| | | 30 | | } |
| | | 31 | | |
| | 33 | 32 | | cancellationTokenSource.Cancel(); |
| | 33 | 33 | | _signals.TryRemove(signal, out _); |
| | 33 | 34 | | return await taskCompletionSource.Task; |
| | 33 | 35 | | } |
| | | 36 | | |
| | | 37 | | public void Trigger(object signal, object? result = null) |
| | | 38 | | { |
| | 71 | 39 | | var taskCompletionSource = GetOrCreate(signal); |
| | | 40 | | |
| | 71 | 41 | | if (taskCompletionSource.Task.IsCompleted) |
| | 11 | 42 | | return; |
| | | 43 | | |
| | 60 | 44 | | taskCompletionSource.SetResult(result); |
| | 60 | 45 | | } |
| | | 46 | | |
| | | 47 | | private TaskCompletionSource<object?> GetOrCreate(object eventName) |
| | | 48 | | { |
| | 164 | 49 | | return _signals.GetOrAdd(eventName, _ => new(TaskCreationOptions.RunContinuationsAsynchronously)); |
| | | 50 | | } |
| | | 51 | | } |