implement concurrent bootstrap DNS resolver

This commit is contained in:
Vladislav Yarmak
2024-08-01 12:44:23 +03:00
parent 34dde845dd
commit 9125854fa9
2 changed files with 33 additions and 69 deletions
+21 -66
View File
@@ -1,83 +1,38 @@
package main
import (
"context"
"fmt"
"net/netip"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
type Resolver struct {
upstream upstream.Upstream
resolvers upstream.ParallelResolver
timeout time.Duration
}
const DOT = 0x2e
func NewResolver(address string, timeout time.Duration) (*Resolver, error) {
opts := &upstream.Options{Timeout: timeout}
u, err := upstream.AddressToUpstream(address, opts)
if err != nil {
return nil, err
func NewResolver(addresses []string, timeout time.Duration) (*Resolver, error) {
resolvers := make([]upstream.Resolver, 0, len(addresses))
opts := &upstream.Options{
Timeout: timeout,
}
return &Resolver{upstream: u}, nil
}
func (r *Resolver) ResolveA(domain string) []string {
res := make([]string, 0)
if len(domain) == 0 {
return res
}
if domain[len(domain)-1] != DOT {
domain = domain + "."
}
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: domain, Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := r.upstream.Exchange(&req)
if err != nil {
return res
}
for _, rr := range reply.Answer {
if a, ok := rr.(*dns.A); ok {
res = append(res, a.A.String())
for _, addr := range addresses {
u, err := upstream.AddressToUpstream(addr, opts)
if err != nil {
return nil, fmt.Errorf("unable to construct upstream resolver from string %q: %w",
addr, err)
}
resolvers = append(resolvers, &upstream.UpstreamResolver{Upstream: u})
}
return res
return &Resolver{
resolvers: resolvers,
timeout: timeout,
}, nil
}
func (r *Resolver) ResolveAAAA(domain string) []string {
res := make([]string, 0)
if len(domain) == 0 {
return res
}
if domain[len(domain)-1] != DOT {
domain = domain + "."
}
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: domain, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET},
}
reply, err := r.upstream.Exchange(&req)
if err != nil {
return res
}
for _, rr := range reply.Answer {
if a, ok := rr.(*dns.AAAA); ok {
res = append(res, a.AAAA.String())
}
}
return res
}
func (r *Resolver) Resolve(domain string) []string {
res := r.ResolveA(domain)
if len(res) == 0 {
res = r.ResolveAAAA(domain)
}
return res
func (r *Resolver) LookupNetIP(ctx context.Context, network string, host string) (addrs []netip.Addr, err error) {
return r.resolvers.LookupNetIP(ctx, network, host)
}