Merge pull request #104 from Snawoot/imp_fast_resolver

Improved fast resolver
This commit is contained in:
Snawoot
2025-09-24 00:45:37 +03:00
committed by GitHub
3 changed files with 53 additions and 34 deletions
+4 -4
View File
@@ -99,8 +99,8 @@ func probeDialer(ctx context.Context, dialer ContextDialer, url string, dlLimit
func NewFastestServerSelectionFunc(url string, dlLimit int64, tlsClientConfig *tls.Config) SelectionFunc { func NewFastestServerSelectionFunc(url string, dlLimit int64, tlsClientConfig *tls.Config) SelectionFunc {
return func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error) { return func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error) {
var resErr error var resErr error
masterNotInterested := make(chan struct{}) ctx, cl := context.WithCancel(ctx)
defer close(masterNotInterested) defer cl()
errors := make(chan error) errors := make(chan error)
success := make(chan ContextDialer) success := make(chan ContextDialer)
for _, dialer := range dialers { for _, dialer := range dialers {
@@ -109,12 +109,12 @@ func NewFastestServerSelectionFunc(url string, dlLimit int64, tlsClientConfig *t
if err == nil { if err == nil {
select { select {
case success <- dialer: case success <- dialer:
case <-masterNotInterested: case <-ctx.Done():
} }
} else { } else {
select { select {
case errors <- err: case errors <- err:
case <-masterNotInterested: case <-ctx.Done():
} }
} }
}(dialer) }(dialer)
+21 -4
View File
@@ -10,14 +10,23 @@ import (
) )
func FromURL(u string) (*net.Resolver, error) { func FromURL(u string) (*net.Resolver, error) {
begin:
parsed, err := url.Parse(u) parsed, err := url.Parse(u)
if err != nil { if err != nil {
return nil, err return nil, err
} }
host := parsed.Hostname() host := parsed.Hostname()
port := parsed.Port() port := parsed.Port()
switch strings.ToLower(parsed.Scheme) { switch scheme := strings.ToLower(parsed.Scheme); scheme {
case "", "dns": case "":
switch {
case strings.HasPrefix(u, "//"):
u = "dns:" + u
default:
u = "dns://" + u
}
goto begin
case "udp", "dns":
if port == "" { if port == "" {
port = "53" port = "53"
} }
@@ -27,12 +36,20 @@ func FromURL(u string) (*net.Resolver, error) {
port = "53" port = "53"
} }
return NewTCPResolver(net.JoinHostPort(host, port)), nil return NewTCPResolver(net.JoinHostPort(host, port)), nil
case "http", "https": case "http", "https", "doh":
if port == "" { if port == "" {
if scheme == "http" {
port = "80"
} else {
port = "443" port = "443"
} }
}
if scheme == "doh" {
parsed.Scheme = "https"
u = parsed.String()
}
return dns.NewDoHResolver(u, dns.DoHAddresses(net.JoinHostPort(host, port))) return dns.NewDoHResolver(u, dns.DoHAddresses(net.JoinHostPort(host, port)))
case "tls": case "tls", "dot":
if port == "" { if port == "" {
port = "853" port = "853"
} }
+27 -25
View File
@@ -16,12 +16,7 @@ type FastResolver struct {
upstreams []LookupNetIPer upstreams []LookupNetIPer
} }
type lookupReply struct { func FastFromURLs(urls ...string) (LookupNetIPer, error) {
addrs []netip.Addr
err error
}
func FastFromURLs(urls ...string) (*FastResolver, error) {
resolvers := make([]LookupNetIPer, 0, len(urls)) resolvers := make([]LookupNetIPer, 0, len(urls))
for i, u := range urls { for i, u := range urls {
res, err := FromURL(u) res, err := FromURL(u)
@@ -30,6 +25,9 @@ func FastFromURLs(urls ...string) (*FastResolver, error) {
} }
resolvers = append(resolvers, res) resolvers = append(resolvers, res)
} }
if len(resolvers) == 1 {
return resolvers[0], nil
}
return NewFastResolver(resolvers...), nil return NewFastResolver(resolvers...), nil
} }
@@ -41,32 +39,36 @@ func NewFastResolver(resolvers ...LookupNetIPer) *FastResolver {
func (r FastResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { func (r FastResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
ctx, cl := context.WithCancel(ctx) ctx, cl := context.WithCancel(ctx)
drain := make(chan lookupReply, len(r.upstreams)) defer cl()
errors := make(chan error)
success := make(chan []netip.Addr)
for _, res := range r.upstreams { for _, res := range r.upstreams {
go func(res LookupNetIPer) { go func(res LookupNetIPer) {
addrs, err := res.LookupNetIP(ctx, network, host) addrs, err := res.LookupNetIP(ctx, network, host)
drain <- lookupReply{addrs, err} if err == nil {
select {
case success <- addrs:
case <-ctx.Done():
}
} else {
select {
case errors <- err:
case <-ctx.Done():
}
}
}(res) }(res)
} }
i := 0
var resAddrs []netip.Addr
var resErr error var resErr error
for ; i < len(r.upstreams); i++ { for _ = range r.upstreams {
pair := <-drain select {
if pair.err != nil { case <-ctx.Done():
resErr = multierror.Append(resErr, pair.err) return nil, ctx.Err()
} else { case resAddrs := <-success:
cl() return resAddrs, nil
resAddrs = pair.addrs case err := <-errors:
resErr = nil resErr = multierror.Append(resErr, err)
break
} }
} }
go func() { return nil, resErr
for i = i + 1; i < len(r.upstreams); i++ {
<-drain
}
}()
return resAddrs, resErr
} }