1+ using System ;
2+ using System . Collections . Generic ;
3+ using System . IO ;
4+ using System . Linq ;
5+ using System . Net . Http ;
6+ using System . Threading ;
7+ using System . Threading . Tasks ;
8+ using System . Net . Sockets ;
9+ using Ae . Dns . Client ;
10+ using Ae . Dns . Protocol ;
11+ using Ae . Dns . Protocol . Enums ;
12+ using Ae . Dns . Protocol . Records ;
13+ using PCL . Core . Logging ;
14+ using System . Runtime . Caching ;
15+
16+ namespace PCL . Core . Net ;
17+
18+ public class HostConnectionHandler
19+ {
20+ public static HostConnectionHandler Instance { get ; } = new ( ) ;
21+ private const string ModuleName = "DoH" ;
22+
23+ private static IDnsClient ? _resolver ;
24+
25+ private HostConnectionHandler ( )
26+ {
27+ // 使用Ae.Dns创建DoH客户端,支持多个DoH服务器
28+ IDnsClient [ ] clients =
29+ [
30+ new DnsHttpClient ( new HttpClient ( )
31+ {
32+ BaseAddress = new Uri ( "https://doh.pub/" )
33+ } ) ,
34+ new DnsHttpClient ( new HttpClient ( )
35+ {
36+ BaseAddress = new Uri ( "https://doh.pysio.online/" )
37+ } ) ,
38+ new DnsHttpClient ( new HttpClient ( )
39+ {
40+ BaseAddress = new Uri ( "https://cloudflare-dns.com/" )
41+ } )
42+ ] ;
43+
44+ // 使用DnsRacerClient实现快速获胜策略
45+ _resolver = new DnsCachingClient ( new DnsRacerClient ( clients ) , new MemoryCache ( "DNS Query Cache" ) ) ;
46+ }
47+
48+ public async ValueTask < Stream > GetConnectionAsync ( SocketsHttpConnectionContext context , CancellationToken cts )
49+ {
50+ ArgumentNullException . ThrowIfNull ( _resolver , "_resolver != null" ) ;
51+ // 获取主机名和端口
52+ var host = context . DnsEndPoint . Host ;
53+ var port = context . DnsEndPoint . Port ;
54+
55+ // 使用Ae.Dns解析IPv4和IPv6地址
56+
57+ var queryA = _resolver . Query ( DnsQueryFactory . CreateQuery ( host ) , cts ) ;
58+ var queryAAAA = _resolver . Query ( DnsQueryFactory . CreateQuery ( host , DnsQueryType . AAAA ) , cts ) ;
59+
60+ var resolveTasks = new List < Task < DnsMessage > > ( )
61+ {
62+ queryA ,
63+ queryAAAA
64+ } ;
65+
66+ var results = await Task . WhenAll ( resolveTasks ) . ConfigureAwait ( false ) ;
67+ var addresses = ( from result in results
68+ from answer in result . Answers
69+ where answer . Resource is DnsIpAddressResource
70+ select ( ( answer . Resource as DnsIpAddressResource ) ! ) . IPAddress ) . ToArray ( ) ;
71+
72+ if ( addresses . Length == 0 )
73+ throw new HttpRequestException ( $ "No IP address for { host } ") ;
74+
75+ // 并行连接所有地址,返回第一个成功的连接
76+ var connectionTasks = addresses . Select ( ip => _ConnectToAddressAsync ( ip . ToString ( ) , port , cts ) ) . ToList ( ) ;
77+
78+ try
79+ {
80+ var completedTask = await Task . WhenAny ( connectionTasks ) . ConfigureAwait ( false ) ;
81+ var stream = await completedTask . ConfigureAwait ( false ) ;
82+
83+ // 取消其他连接任务
84+ foreach ( var task in connectionTasks . Where ( task => task != completedTask ) )
85+ {
86+ _ = task . ContinueWith ( t => {
87+ if ( t . IsCompletedSuccessfully )
88+ {
89+ t . Result ? . Dispose ( ) ;
90+ }
91+ } , cts ) ;
92+ }
93+
94+ LogWrapper . Debug ( ModuleName , $ "Success resolve DoH endpoint: { host } -> { stream . Socket . RemoteEndPoint } ") ;
95+ return stream ;
96+ }
97+ catch
98+ {
99+ throw new HttpRequestException ( $ "No address reachable: { host } -> { string . Join ( ", " , addresses . Select ( x => x . ToString ( ) ) ) } ") ;
100+ }
101+ }
102+
103+ private static async Task < NetworkStream > _ConnectToAddressAsync ( string ip , int port , CancellationToken cts )
104+ {
105+ var socket = new Socket ( SocketType . Stream , ProtocolType . Tcp ) ;
106+ try
107+ {
108+ using var ctsSocket = CancellationTokenSource . CreateLinkedTokenSource ( cts ) ;
109+ ctsSocket . CancelAfter ( 5000 ) ; // 5秒超时
110+
111+ await socket . ConnectAsync ( ip , port , ctsSocket . Token ) ;
112+ return new NetworkStream ( socket , ownsSocket : true ) ;
113+ }
114+ catch
115+ {
116+ socket . Dispose ( ) ;
117+ throw ;
118+ }
119+ }
120+ }
0 commit comments