Ensuring no race conditions in my concurrent/parallel ForEachAsync method

Posted on

Problem

Hopefully the final chapter of my ForEachAsync code review. The original question was the starting point and the second question contained modifications suggested in the original trying to ensure I was handling cancellation and exceptions correctly. The second question also contained my test scenarios (in pseudo code fashion). This final post addresses a comment about a race condition (after applying other suggestions from the second post). This post contains a ‘complete’ LINQPad script that can be ran in LINQPad to illustrate what I think is a working function and never presents the aforementioned race condition.

Couple of disclaimers about the script:

  1. The code is taken from a larger code base that is a Windows service project. You can see my SO post about that here if you so desire. So the ‘context’ of this script isn’t the exact scenario where I’d use it, but appears to be good enough to illustrate my core functionality at question. Most notably, the Main() method of the script exits and completes before the ForEachAsync task is complete. That is because in my service project, this ‘same’ code is being ran within a System.Threading.Timer callback and is meant to ‘fire and forget’ the job, not wait for its completion.
  2. This test script is meant to cancel after 4 seconds of processing. My Task.Delay calculation may not be exactly right. In reality, ProcessJobAsync might be looking for a cancel state in a database and triggering the task cancellation. And just like this script, in my real code base, any number of Tasks that are running due to the ForEachAsync parallelism could end up calling internalCancel.Cancel(); because it is very likely that more than one of the tasks will get past the !cts.Token.IsCancellationRequested check and hit database at the same time. Is it OK to call .Cancel multiple times?
  3. Please note the question at the end of ForEachAsync regarding ‘cts.Token.ThrowIfCancellationRequested() instead of throwing OperationCanceledException‘ That might be a better way to handle cancellation, I’m not sure.
  4. Task.Run vs Task.Factory.StartNew – I had done some reading (mostly from Stephen Toub) before writing this and my understanding may be wrong, but I chose Task.Factory.StartNew because:

    1. I thought I should be passing in TaskCreationOptions.LongRunning, but maybe that isn’t needed?
    2. I thought calling the override that allowed object state parameter to be passed was correct way to do this, but maybe I should just be accessing the myJob variable directly from within the async delegate? Any problems (either way) regarding the fact that availableJobs could have more than one item in it, and the fire and forget nature of my task would result in myJob being reassigned almost immediately?

Given these caveats, does the previous concern that I have a race condition (resulting in a TaskCancelledException instead of a OperationCancelledException), still exist? If so, any suggestions on how to prevent it? Other comments/concerns welcome as well.

void Main()
{
    // This would really be hitting a DB and returning 1-MaxJobsAllowed (10) jobs
    var availableJobs = Enumerable.Range(0, 1);

    foreach (var j in availableJobs)
    {
        var myJob = new ScheduledJob
        {
            CancellationTokenSource = new CancellationTokenSource(),
            Key = j
        };

        // Main is really a simulation for a System.Threading.Timer callback method, so I want fire and forget by *not* awaiting Task.Factory.StartNew
        Task.Factory.StartNew(
            async jobState =>
            {
                var scheduledJob = jobState as ScheduledJob;

                try
                {
                    var jobProcessor = new FooAsync();
                    $"Task.Factory.StartNew - Before await ProcessJobAsync".Dump();
                    await jobProcessor.ProcessJobAsync( scheduledJob.Key, new XElement("InputPackage") /* fake 'instructions' */, scheduledJob.CancellationTokenSource.Token);
                    $"Task.Factory.StartNew - After await ProcessJobAsync".Dump();
                }
                catch (OperationCanceledException)
                {
                    $"Task.Factory.StartNew - Operation Cancelled".Dump();
                }
                catch (Exception ex)
                {
                    $"Task.Factory.StartNew - Exception - {ex.Message}".Dump();
                // throw;
                }
                finally
                {
                    $"Task.Factory.StartNew - Finished Processing, return 'worker thread' to pool".Dump();
                }
            },
            myJob,
            myJob.CancellationTokenSource.Token,
            TaskCreationOptions.LongRunning,
            TaskScheduler.Default
        );
    }

    "Main: Loop Complete".Dump();
}

public class FooAsync
{
    public async Task<int> ProcessJobAsync(int jobKey, XElement inputPackage, CancellationToken cancellationToken)
    {
        var seconds = new Random().Next( 15, 30 ); // Simulation of how long job will take to run
        var start = DateTime.Now;

        $"Task {jobKey}: FooAsync - Start ProcessJobAsync (total of {seconds} seconds)".Dump();

        // cancellationToken - Required if controlling service needs to shut down and stop job processing
        // internalCancel - Required for this job to be able to cancel itself due to UI request
        var internalCancel = CancellationTokenSource.CreateLinkedTokenSource( cancellationToken );

        try
        {
            Action<CancellationTokenSource, int> cancelProcessIfNeeded = ( cts, jobData ) =>
            {
                if (!cts.Token.IsCancellationRequested /* Windows Service shutting down */ && (DateTime.Now - start).TotalSeconds > 4 /* Simulate internal cancel check which might be a cancel on Website from User */ )
                {
                    Console.WriteLine($"Task {jobKey}: FooAsync - Cancelling Job, DataChunk {jobData}, Elapsed Time: {(DateTime.Now - start).TotalSeconds} seconds");
                    cts.Cancel();
                }

                cts.Token.ThrowIfCancellationRequested();
            };

            // Simulate running something for 10 data batches...
            await Enumerable
                .Range(0, 1000)
                .ForEachAsync(
                    async jobData =>
                    {
                        Console.WriteLine( $"Task {jobKey}: FooAsync - Start DataChunk {jobData}" );

                        cancelProcessIfNeeded( internalCancel, jobData );

                        await Task.Delay(seconds * 100);

                        Console.WriteLine( $"Task {jobKey}: FooAsync - Finish DataChunk {jobData}" );
                    },
                    new AsyncParallelOptions { MaxDegreeOfParallelism = 100, CancellationToken = internalCancel.Token }
                );
        }
        catch (Exception ex)
        {
            Console.WriteLine( $"Task {jobKey}: FooAsync - Exception: {ex.GetType().ToString()}, internalCancel.Token.IsCancellationRequested: {internalCancel.Token.IsCancellationRequested}" );
            throw;
        }

        Console.WriteLine( $"Task {jobKey}: FooAsync - Finished ProcessJobAsync in {(DateTime.Now - start).TotalSeconds} seconds" );
        return 10;
    }
}

public static class ExtensionMethods
{
    public static async Task ForEachAsync<T>( this IEnumerable<T> source, Func<T, Task> body, AsyncParallelOptions parallelOptions )
    {
        ConcurrentQueue<Exception> exceptions = new ConcurrentQueue<Exception>();

        var maxDegreeOfConcurrency = parallelOptions.MaxDegreeOfParallelism;

        // If they pass in a CancellationToken from caller of ForEachAsync need to create linked token source in case caller cancels, I want
        // ForEachAsync to cancel as well.  If they want to failImmediately, make a new CancellationTokenSource so I can stop processing partitions
        var cts = CancellationTokenSource.CreateLinkedTokenSource( parallelOptions.CancellationToken );

        var allDone = Task.WhenAll(
            from partition in Partitioner.Create( source ).GetPartitions( maxDegreeOfConcurrency )
            select Task.Run( async delegate {

                using ( partition )
                {
                    while ( true )
                    {
                        cts.Token.ThrowIfCancellationRequested(); /* either from caller or failImmediately */

                        // try to read next partition
                        if ( !partition.MoveNext() ) break;

                        await body( partition.Current ).ContinueWith( t => {

                           Console.WriteLine( $"ForEachAsync Extension #1: ContinueWith, t.Exception is null: {t.Exception == null}, t.IsCanceled: {t.IsCanceled}, t.IsFaulted: {t.IsFaulted}, cts.IsCancellationRequested: {cts.IsCancellationRequested}" );

                           // If body() threw an error, cancel caller wants immediate failure
                           if ( t.Exception != null )
                           {
                               // Always gather the exception to throw at the end
                               foreach ( var ex in t.Exception.Flatten().InnerExceptions )
                               {
                                   exceptions.Enqueue( ex );
                               }

                               if ( parallelOptions.FailImmediately )
                               {
                                   cts.Cancel();
                               }
                            }

                         } );
                    }
                }

           }, cts.Token ) );

        // Wait until all finished (or errored out) and then return exceptions
        await allDone;

        // Question: allDone is never going to have IsCanceled or IsFaulted correct?  because await body() will swallow all exceptions?
        Console.WriteLine( $"ForEachAsync Extension #2: Finished, {exceptions?.Count ?? 0} total, allDone.IsCanceled: {allDone.IsCanceled}, allDone.IsFaulted: {allDone.IsFaulted}, cts.IsCancellationRequested: {cts.IsCancellationRequested}" );

        if ( exceptions.Count > 0 )
        {
            Console.WriteLine( $"ForEachAsync Extension #3: Throw Exceptions" );
            throw new AggregateException( exceptions );
        }

        // Question, should I just change this whole if statement to cts.Token.ThrowIfCancellationRequested() instead of throwing OperationCanceledException;
        if ( cts.IsCancellationRequested )
        {
            Console.WriteLine($"ForEachAsync Extension #4: Throw OperationCanceledException");
            throw new OperationCanceledException();
        }
    }
}

public class ScheduledJob
{
    public CancellationTokenSource CancellationTokenSource { get; set; }
    public int Key { get; set;}
}

public class AsyncParallelOptions : System.Threading.Tasks.ParallelOptions
{
    public bool FailImmediately { get; set; } = true;
}

Solution

  1. System.Threading.Timer already executes on a threadpool thread, and by calling Task.Factory.StartNew you are starting yet another thread which is wasteful. If you called Wait() on the task, the TPL could inline it and keep using the same thread. Another possibility is an approach such as this (basically await Task.Delay instead of an actual timer), but that depends on whether you want the timer to fire exactly every n seconds (even if the previous iteration is still running, which is what you’re doing now), or you want each iteration to start n seconds after the last one finished (which is usually what you want, and which you would get with the await Task.Delay approach).
  2. I don’t exactly follow your cancellation code – you cancel if cancellation wasn’t requested? Anyway, yes, calling Cancel multiple times is OK.
  3. Generally you should always use ThrowIfCancellationRequested when available.
  4. (a) Presumably if you start these tasks every n seconds where n isn’t very large, then they are not long-running at all. I don’t think that flag makes sense here. Read Toub’s answer here. (b) Passing the state parameter is slightly more performant since you avoid the lambda capture, but that is probably premature optimization. Each iteration will get its “own” myjob, so don’t worry about that either. (c) These two methods are not exactly equivalent, specifically when the task itself returns a task (which is what you’re doing). Read the part about Unwrap in Toub’s article.

As for OperationCancelledException, you’re seeing it and not TaskCancelledException because CancelltionToken.ThrowIfCancellationRequested, which is what you’re using to cancel your tasks, throws OperationCanceledException. Apparently that makes Task.WhenAll throw OperationCanceledException and so on. But you shouldn’t be worried about TaskCancelledException vs OperationCancelledException. The former inherits from the latter, so just catch the latter and be done with it.

One more thing, in your Task.Delay, I assume you meant to multiply by 1000 (not 100) to convert from milliseconds to seconds. It is then not clear what you attempt to achieve because the minimum is 15 seconds which is more than 4, so the cancellation code will always get called.

Leave a Reply

Your email address will not be published. Required fields are marked *