ASP.NET Core Attribute to Rate limite an Endpoint

Untitled

An easy codesnippet which might help you to avoid an endpoint being DDOS-ed. It uses MemoryCache, which of course also could be configured to use e.g. a Redis backed, but now it uses a hash table with an int key and a boolean, so it is very lightweight and it works in a fast-fail principle (instead of consulting backend storage first). Also keep in mind, that it works –in memory- so it does not protect an array of endpoints.


Usage:


      [HttpGet("myurl")]
        [ThrottleFilter(Name= "myurl", Milliseconds = 5*1000 )]
        public async Task<ActionResult> MyUrl()





using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Filters;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
using System;
using System.IO;
using System.Net;

namespace AdcCure.ActionFilters
{
     /// <summary>
     /// Decorates any MVC route that needs to have client requests limited by time.
     /// </summary>
     /// <remarks>
     /// Uses the current Microsoft.Extensions.Caching.Memory extensions
     /// to store each client request to the decorated route, using the least amount of memory requirements possible
     /// </remarks>
     [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
     public class ThrottleFilterAttribute : ActionFilterAttribute
     {

        /// <summary>
         /// A unique name for this Throttle.
         /// </summary>
         /// <remarks>
         /// We'll be inserting a Cache record based on this name and client IP, e.g. "Name-192.168.0.1"
         /// Cache is 'hashed' so the memory table is int (key), bool (disallow)
         /// The entry only exists for [Milliseconds] so is a minimal resource requirement
         /// </remarks>
         public string Name { get; set; }

        /// <summary>
         /// The number of Milliseconds clients must wait before executing this decorated route again.
         /// </summary>
         public int Milliseconds { get; set; }

        /// <summary>
         /// A text message that will be sent to the client upon throttling.  You can include the token {n} to
         /// show this.Seconds in the message, e.g. "Wait {n} seconds before trying again".
         /// </summary>
         public string Message { get; set; }

        public override void OnActionExecuting(ActionExecutingContext c)
         {
             var services = c.HttpContext.RequestServices;
             if (!(services.GetService(typeof(IMemoryCache)) is IMemoryCache memCache))
             {
                 throw new InvalidOperationException($"{this.GetType()} configure a memory cache in startup");
             }
             var logger = default(ILogger);
             var loggerInst = services.GetService(typeof(ILoggerFactory));
             if (loggerInst is ILoggerFactory factory)
             {
                 logger = factory.CreateLogger<ThrottleFilterAttribute>();
             }

            var headers = c.HttpContext.Request.Headers;
             var testProxy = headers.ContainsKey("X-Forwarded-For");
             var ipStream = new MemoryStream(512);
             var ip = default(IPAddress);
             if (testProxy)
             {
                 IPAddress.TryParse(headers["X-Forwarded-For"], out ip);
             }
             if (ip == null)
             {
                 ip = c.HttpContext.Connection.RemoteIpAddress;
             }
             if (!string.IsNullOrEmpty(Name))
             {
                 var stringBf = System.Text.Encoding.UTF8.GetBytes(Name);
                 ipStream.Write(stringBf, 0, stringBf.Length);
             }
             var bytes = ip.GetAddressBytes();
             ipStream.Write(bytes, 0, bytes.Length);
            
             var key = ComputeHash(ipStream.ToArray());

            memCache.TryGetValue(key, out bool forbidExecute);

            memCache.Set(key, true, new MemoryCacheEntryOptions() { AbsoluteExpirationRelativeToNow = TimeSpan.FromMilliseconds(Milliseconds) });
             if (logger != null)
             {
                 logger.LogDebug("key for Throttle {0}", key);
             }
             if (forbidExecute)
             {
                 if (string.IsNullOrEmpty(Message))
                 {
                     Message = $"You may only perform this action every {Milliseconds}ms.";
                 }
                 if (logger != null)
                 {
                     logger.LogWarning("ip {0}, route {1} was denied execute at {2}", ip, Name ?? "unknown", DateTimeOffset.Now);
                 }
                 c.Result = new ContentResult { Content = Message, ContentType = "text/plain" };
                 // see 409 - http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html
                 c.HttpContext.Response.StatusCode = StatusCodes.Status409Conflict;
             }
         }
         /// <summary>
         /// fast method to calc a byte array to the same hash no matter on which process it is performed
         /// </summary>
         /// <param name="data"></param>
         private static int ComputeHash(byte[] data)
         {
             unchecked
             {
                 const int p = 16777619;
                 int hash = (int)2166136261;
                 var len = data.Length;
                 for (int i = 0; i < len; i++)
                     hash = (hash ^ data[i]) * p;

                hash += hash << 13;
                 hash ^= hash >> 7;
                 hash += hash << 3;
                 hash ^= hash >> 17;
                 hash += hash << 5;
                 return hash;
             }
         }
     }
}

blog comments powered by Disqus