diff --git a/resolver/fast.go b/resolver/fast.go index 3bef339..eadb01b 100644 --- a/resolver/fast.go +++ b/resolver/fast.go @@ -16,11 +16,6 @@ type FastResolver struct { upstreams []LookupNetIPer } -type lookupReply struct { - addrs []netip.Addr - err error -} - func FastFromURLs(urls ...string) (*FastResolver, error) { resolvers := make([]LookupNetIPer, 0, len(urls)) for i, u := range urls { @@ -40,33 +35,37 @@ func NewFastResolver(resolvers ...LookupNetIPer) *FastResolver { } func (r FastResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { - ctx, cl := context.WithCancel(ctx) - drain := make(chan lookupReply, len(r.upstreams)) + masterNotInterested := make(chan struct{}) + defer close(masterNotInterested) + errors := make(chan error) + success := make(chan []netip.Addr) for _, res := range r.upstreams { go func(res LookupNetIPer) { addrs, err := res.LookupNetIP(ctx, network, host) - drain <- lookupReply{addrs, err} + if err == nil { + select { + case success <- addrs: + case <-masterNotInterested: + } + } else { + select { + case errors <-err: + case <-masterNotInterested: + } + } }(res) } - i := 0 - var resAddrs []netip.Addr var resErr error - for ; i < len(r.upstreams); i++ { - pair := <-drain - if pair.err != nil { - resErr = multierror.Append(resErr, pair.err) - } else { - cl() - resAddrs = pair.addrs - resErr = nil - break + for _ = range r.upstreams { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case resAddrs := <-success: + return resAddrs, nil + case err := <-errors: + resErr = multierror.Append(resErr, err) } } - go func() { - for i = i + 1; i < len(r.upstreams); i++ { - <-drain - } - }() - return resAddrs, resErr + return nil, resErr }