zhaolei
2020-11-20 4a2e5b9a21940f11757be37d99f0944e240e908b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authentication.OAuth;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Security.Claims;
using System.Text;
using System.Text.Encodings.Web;
using System.Threading.Tasks;
#if !NETSTANDARD2_0
using System.Text.Json;
#else
using Newtonsoft.Json.Linq;
#endif
 
namespace Prow.OAuth
{
    /// <summary>
    /// Gitee 认证处理类
    /// </summary>
    public class LgbOAuthHandler<TOptions> : OAuthHandler<TOptions> where TOptions : LgbOAuthOptions, new()
    {
        /// <summary>
        /// 默认构造函数
        /// </summary>
        /// <param name="options"></param>
        /// <param name="logger"></param>
        /// <param name="encoder"></param>
        /// <param name="clock"></param>
        public LgbOAuthHandler(IOptionsMonitor<TOptions> options, ILoggerFactory logger, UrlEncoder encoder, ISystemClock clock) : base(options, logger, encoder, clock)
        {
 
        }
 
        /// <summary>
        /// 生成票据方法
        /// </summary>
        /// <param name="identity"></param>
        /// <param name="properties"></param>
        /// <param name="tokens"></param>
        /// <returns></returns>
        protected override async Task<AuthenticationTicket> CreateTicketAsync(ClaimsIdentity identity, AuthenticationProperties properties, OAuthTokenResponse tokens)
        {
            properties.RedirectUri = Options.HomePath;
            properties.IsPersistent = true;
            if (!string.IsNullOrEmpty(tokens.ExpiresIn) && int.TryParse(tokens.ExpiresIn, out var expiresIn)) properties.ExpiresUtc = Clock.UtcNow.AddSeconds(expiresIn);
 
            // add roles
            Options.Roles.ToList().ForEach(r => identity.AddClaim(new Claim(ClaimTypes.Role, r)));
 
            // 获取用户信息
            var user = await HandleUserInfoAsync(tokens);
            var context = new OAuthCreatingTicketContext(new ClaimsPrincipal(identity), properties, Context, Scheme, Options, Backchannel, tokens, user);
            context.RunClaimActions();
 
            await Events.CreatingTicket(context);
            return new AuthenticationTicket(context.Principal, context.Properties, Scheme.Name);
        }
 
        /// <summary>
        /// 刷新 Token 方法
        /// </summary>
        /// <param name="oAuthToken"></param>
        /// <returns></returns>
        protected virtual async Task<OAuthTokenResponse> RefreshTokenAsync(OAuthTokenResponse oAuthToken)
        {
            var tokenRequestParameters = new Dictionary<string, string>()
            {
                { "refresh_token", oAuthToken.RefreshToken },
                { "grant_type", "refresh_token" },
            };
            var url = QueryHelpers.AddQueryString(Options.TokenEndpoint, tokenRequestParameters);
            var requestMessage = new HttpRequestMessage(HttpMethod.Post, url);
            requestMessage.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
            var response = await Backchannel.SendAsync(requestMessage, Context.RequestAborted);
            if (response.IsSuccessStatusCode)
            {
#if !NETSTANDARD2_0
                var payload = JsonDocument.Parse(await response.Content.ReadAsStringAsync());
#else
                var payload = JObject.Parse(await response.Content.ReadAsStringAsync());
#endif
                return OAuthTokenResponse.Success(payload);
            }
            else
            {
                var error = "OAuth token endpoint failure: " + await Display(response);
                return OAuthTokenResponse.Failed(new Exception(error));
            }
        }
 
        /// <summary>
        /// 生成错误信息方法
        /// </summary>
        /// <param name="response"></param>
        /// <returns></returns>
        protected static async Task<string> Display(HttpResponseMessage response)
        {
            var output = new StringBuilder();
            output.Append("Status: " + response.StatusCode + ";");
            output.Append("Headers: " + response.Headers.ToString() + ";");
            output.Append("Body: " + await response.Content.ReadAsStringAsync() + ";");
            return output.ToString();
        }
 
        /// <summary>
        /// 处理用户信息方法
        /// </summary>
        /// <param name="tokens"></param>
        /// <returns></returns>
#if !NETSTANDARD2_0
        protected virtual async Task<JsonElement> HandleUserInfoAsync(OAuthTokenResponse tokens)
#else
        protected virtual async Task<JObject> HandleUserInfoAsync(OAuthTokenResponse tokens)
#endif
        {
            var url = BuildUserInfoUrl(tokens);
            var requestMessage = new HttpRequestMessage(HttpMethod.Get, url);
            requestMessage.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
            var response = await Backchannel.SendAsync(requestMessage, Context.RequestAborted);
            if (response.IsSuccessStatusCode)
            {
#if !NETSTANDARD2_0
                return JsonDocument.Parse(await response.Content.ReadAsStringAsync()).RootElement;
#else
                return JObject.Parse(await response.Content.ReadAsStringAsync());
#endif
            }
            else
            {
                var error = "OAuth user information endpoint failure: " + await Display(response);
                throw new Exception(error);
            }
        }
 
        /// <summary>
        /// 生成用户信息请求地址方法
        /// </summary>
        /// <param name="tokens"></param>
        /// <returns></returns>
        protected virtual string BuildUserInfoUrl(OAuthTokenResponse tokens)
        {
            var tokenRequestParameters = new Dictionary<string, string>()
            {
                { "access_token", tokens.AccessToken }
            };
            return QueryHelpers.AddQueryString(Options.UserInformationEndpoint, tokenRequestParameters);
        }
    }
}