zhaolei
8 days ago 4a2e5b9a21940f11757be37d99f0944e240e908b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using System;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
 
namespace Microsoft.AspNetCore.Builder
{
    /// <summary>
    /// IP 地址过滤器中间件
    /// </summary>
    public class IPAddressSafeListMiddleware
    {
        private readonly RequestDelegate _next;
        private readonly ILogger<IPAddressSafeListMiddleware> _logger;
        private readonly string _sectionName;
        /// <summary>
        /// 构造函数
        /// </summary>
        /// <param name="next"></param>
        /// <param name="logger"></param>
        /// <param name="sectionName">IP 白名单配置项名称</param>
        public IPAddressSafeListMiddleware(RequestDelegate next, ILogger<IPAddressSafeListMiddleware> logger, string sectionName)
        {
            _next = next;
            _logger = logger;
            _sectionName = sectionName;
        }
 
        /// <summary>
        /// 调用方法
        /// </summary>
        /// <param name="context"></param>
        /// <returns></returns>
        public async Task Invoke(HttpContext context)
        {
            var safeList = context.RequestServices.GetRequiredService<IConfiguration>().GetSafeList(_sectionName);
            var valid = context.Connection.RemoteIpAddress.ValidateSafeList(safeList);
            if (!valid)
            {
                _logger.LogInformation("Forbidden Request from Remote IP address: {RemoteIp}", context.Connection.RemoteIpAddress);
                context.Response.StatusCode = 403;
            }
            await _next.Invoke(context);
        }
    }
 
    /// <summary>
    /// IP 地址过滤器中间件扩展操作类
    /// </summary>
    public static class IPAddressSafeListExtensions
    {
        /// <summary>
        /// 添加 IP 地址过滤器中间件到管道中
        /// </summary>
        /// <param name="builder"></param>
        /// <param name="sectionName">IP 地址过滤器配置项名称</param>
        /// <returns></returns>
        public static IApplicationBuilder UseIPAddressSafeList(this IApplicationBuilder builder, string sectionName)
        {
            builder.UseMiddleware<IPAddressSafeListMiddleware>(sectionName);
            return builder;
        }
 
        /// <summary>
        /// IPAddress 过滤扩展方法 验证是否在白名单内
        /// </summary>
        /// <param name="ip"></param>
        /// <param name="safeList">安全 IP 列表</param>
        /// <returns></returns>
        internal static bool ValidateSafeList(this IPAddress ip, string safeList)
        {
            bool allow = true;
            if (!string.IsNullOrEmpty(safeList))
            {
                var allowIpList = safeList.SpanSplitAny(";,| ", StringSplitOptions.RemoveEmptyEntries);
                var bytes = ip.GetAddressBytes();
                allow = allowIpList.Any(p => IPAddress.TryParse(p, out var allowIp) && allowIp.GetAddressBytes().SequenceEqual(bytes));
            }
            return allow;
        }
 
        internal static string GetSafeList(this IConfiguration config, string sectionName) => sectionName.IsNullOrEmpty() ? "" : config.GetValue(sectionName, "");
    }
}