Come get the world’s best LINQ Extensions

Ok, so maybe I’m exaggerating a little bit, but really this list of extensions for all things IEnumerable<T> is quite useful.

Feel free to chop and change and use whatever you like. A little credit to me in your code wouldn’t go astray though. smile_wink

The Code

using System;
using System.Collections.Generic;
using System.Data.Linq;
using System.Data.SqlClient;
using System.Diagnostics;
using System.Linq;
using System.Threading;

namespace FlatlinerDOA
{
    /// <summary>
    /// World's best LINQ extensions by Flatliner DOA
    /// </summary>
    public static class LinqExtensions
    {
        public const int DefaultMaxRetries = 1;
        public const int DefaultRetryDelayMilliseconds = 500;

        public static void AddOrReplace<T>(this IList<T> target, IEnumerable<T> source, IEqualityComparer<T> comparer)
        {
            target.AddOrReplaceWhere(source, (eachTarget, eachSource) => comparer.Equals(eachTarget, eachSource));
        }

        public static void AddOrReplaceWhere<TTarget, TSource>(this IList<TTarget> target, IEnumerable<TSource> source, Func<TSource, TTarget, bool> compareClause) where TSource : TTarget
        {
            if (source == null)
                return;
            var replacedIndexes = (from eachTarget in target
                                   from eachSource in source
                                   where compareClause(eachSource, eachTarget)
                                   select new 
                                   { 
                                       Index = target.IndexOf(eachTarget), 
                                       NewItem = eachSource 
                                   });

            foreach (var item in replacedIndexes)
            {
                target[item.Index] = item.NewItem;
            }

            foreach (var newItem in source.Except(replacedIndexes.Select(p => p.NewItem)))
            {
                target.Add(newItem);
            }
        }

        public static void RemoveAll<T>(this ICollection<T> target, IEnumerable<T> source)
        {
            if (source == null)
                return;

            var list = source.ToArray();
            foreach (T item in list)
            {
                target.Remove(item);
            }
        }

        public static void AddRange<T>(this ICollection<T> target, IEnumerable<T> source)
        {
            if (source == null)
                return;

            foreach (T item in source)
            {
                target.Add(item);
            }
        }

        public static int ElementIndex<T>(this IEnumerable<T> source, T element) where T : class
        {
            if (source == null)
                return -1;
            if (element == null)
                return -1;
            int i = 0;
            foreach (T item in source)
            {
                if (item.Equals(element))
                    return i;
                i++;
            }
            return -1;
        }

        public static IEnumerable<int> ElementIndexes<T>(this IList<T> source, IEnumerable<T> elements) where T : class
        {
            if (source == null)
                yield break;
            if (elements == null)
                yield break;
            // TODO: Use a sorted list to improve indexing performance
            foreach (T element in elements)
            {
                yield return source.IndexOf(element);
            }
        }

        public static IEnumerable<int> ElementIndexes<T>(this IEnumerable<T> source, IEnumerable<T> elements) where T : class
        {
            int i = 0;
            // TODO: Use a sorted list to improve indexing performance
            foreach (T element in elements)
            {
                foreach (T item in source)
                {
                    if (item.Equals(element))
                        yield return i;
                    i++;
                }
            }
        }

        public static bool IsNullOrEmpty<T>(this IEnumerable<T> source) where T : class
        {
            if (source == null)
                return true;
            return !source.Any();
        }

        public static IEnumerable<TSource> Update<TSource>(this IEnumerable<TSource> source, Action<TSource> update) where TSource : class
        {
            if (source == null)
                throw new ArgumentNullException("source");
            if (update == null)
                throw new ArgumentNullException("update");

            foreach (TSource element in source)
            {
                update(element);
                yield return element;
            }
        }

        public static IEnumerable<TTarget> UpdateWith<TTarget, TUpdateWith>(this IEnumerable<TTarget> target, IEnumerable<TUpdateWith> updateWith, Action<TTarget, TUpdateWith> update) where TTarget : class
        {
            if (target == null)
                throw new ArgumentNullException("target");
            if (update == null)
                throw new ArgumentNullException("update");

            int counter = 0;
            IEnumerator<TUpdateWith> updates = updateWith.GetEnumerator();
            foreach (TTarget element in target)
            {
                if (!updates.MoveNext())
                    yield break;
                //TUpdateWith updateValue = updateWith.Skip(counter).FirstOrDefault();
                update(element, updates.Current);
                yield return element;
                counter++;
            }
        }

        public static IEnumerable<TSource> DescendantsAndSelf<TSource>(this IEnumerable<TSource> source, Func<TSource, IEnumerable<TSource>> childSelector, int maxDepth) where TSource : class
        {
            if (maxDepth > 0)
            {
                foreach (TSource child in source)
                {
                    foreach (TSource subChild in child.DescendantsAndSelf(childSelector, maxDepth))
                    {
                        yield return subChild;
                    }
                }
            }
        }

        public static IEnumerable<TSource> DescendantsAndSelf<TSource>(this TSource source, Func<TSource, IEnumerable<TSource>> childSelector, int maxDepth) where TSource : class
        {
            if (maxDepth > 0)
            {
                yield return source;
                foreach (TSource child in childSelector(source))
                {
                    foreach (TSource subChild in child.DescendantsAndSelf(childSelector, maxDepth - 1))
                    {
                        yield return subChild;
                    }
                }
            }
        }


        public static IEnumerable<TSource> Descendants<TSource>(this IEnumerable<TSource> source, Func<TSource, IEnumerable<TSource>> childSelector, int maxDepth) where TSource : class
        {
            if (maxDepth > 0)
            {
                foreach (TSource child in source)
                {
                    foreach (TSource subChild in child.Descendants(childSelector, maxDepth))
                    {
                        yield return subChild;
                    }
                }
            }
        }

        public static IEnumerable<TSource> Descendants<TSource>(this TSource source, Func<TSource, IEnumerable<TSource>> childSelector, int maxDepth) where TSource : class
        {
            if (maxDepth > 0)
            {
                foreach (TSource child in childSelector(source))
                {
                    foreach (TSource subChild in child.DescendantsAndSelf(childSelector, maxDepth - 1))
                    {
                        yield return subChild;
                    }
                }
            }
        }

        public static bool OnlyContains<T>(this IEnumerable<T> items, IEnumerable<T> containsItems)
        {
            return !items.Except(containsItems).Any();
        }

        public static bool OnlyContains<T>(this IEnumerable<T> items, T containsItem)
        {
            return !items.Except(containsItem).Any();
        }

        public static IEnumerable<T> Union<T>(this IEnumerable<T> items, T item)
        {
            return items.Union(new T[] { item });
        }

        public static IEnumerable<T> Concat<T>(this IEnumerable<T> items, T item) 
        {
            return items.Concat(new T[] { item });
        }

        public static IEnumerable<T> Except<T>(this IEnumerable<T> items, T item)
        {
            return items.Where(p => !item.Equals(p));
        }

        /// <summary>
        /// Attempts to submit the changes and will retry once if there is a Deadlock (with a delay of 500ms)
        /// </summary>
        /// <param name="context"></param>
        public static void SubmitWithRetry(this DataContext context)
        {
            context.SubmitWithRetry(DefaultMaxRetries, DefaultRetryDelayMilliseconds);
        }

        /// <summary>
        /// Attempts to submit the changes one or more times, will also retry
        /// </summary>
        /// <param name="context"></param>
        /// <param name="maxRetries"></param>
        /// <param name="retryDelayMilliseconds"></param>
        public static void SubmitWithRetry(this DataContext context, int maxRetries, int retryDelayMilliseconds)
        {
            for (int attempt = 1; attempt <= maxRetries; attempt++)
            {
                try
                {
                    context.SubmitChanges();
                    return;
                }
                catch (SqlException sqlEx)
                {
                    sqlEx.TraceException();
                    if (sqlEx.Class != 13)
                    {
                        // Not a SQL Deadlock
                        throw; // No hope, the connection's gone
                    }
                }
                if (retryDelayMilliseconds > 0)
                    Thread.Sleep(retryDelayMilliseconds);
            }
            throw new ApplicationException("Maximum retries exceeded");
        }

        public static void TraceException(this SqlException exception)
        {
            string message = @"A SQL Exception occurred on server {0}.
Severity: {1}
Error: {2}
Procedure: {3}
Line No: {4}

Message:
----------
{5}

Stack Trace:
-------------
{6}";
            object[] parameters = new object[] 
            { 
                exception.Server,
                exception.Class,
                exception.Number,
                exception.Procedure,
                exception.LineNumber,
                exception.Message,
                exception.StackTrace 
            };

            if (exception.Class >= 20)
            {
                Trace.TraceError(message, parameters);
            }
            else if (exception.Class == 13)
            {
                // SQL Deadlock
                Trace.TraceWarning(message, parameters);
            }
            else
            {
                // Stuff like Foreign Key Contraint failures etc.
                Trace.TraceWarning(message, parameters);
            }
        }
    }
}

Advertisements
  1. #1 by Rob on February 25, 2009 - 8:52 am

    awesome! I was looking for a method to get an element\’s index and you actually wrote one to do that — ElementIndex<T>. Works great, thanks.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: