basic implementation of server selection feature

This commit is contained in:
Vladislav Yarmak
2025-09-14 15:16:02 +03:00
parent 7bfaeb9878
commit 6eb2054faf
2 changed files with 265 additions and 74 deletions
+131
View File
@@ -0,0 +1,131 @@
package dialer
import (
"context"
"errors"
"fmt"
"io"
"math/rand/v2"
"net/http"
"strings"
"time"
"github.com/hashicorp/go-multierror"
)
type ServerSelection int
const (
_ = iota
ServerSelectionFirst
ServerSelectionRandom
ServerSelectionFastest
)
func (ss ServerSelection) String() string {
switch ss {
case ServerSelectionFirst:
return "first"
case ServerSelectionRandom:
return "random"
case ServerSelectionFastest:
return "fastest"
default:
return fmt.Sprintf("ServerSelection(%d)", int(ss))
}
}
func ParseServerSelection(s string) (ServerSelection, error) {
switch strings.ToLower(s) {
case "first":
return ServerSelectionFirst, nil
case "random":
return ServerSelectionRandom, nil
case "fastest":
return ServerSelectionFastest, nil
}
return 0, errors.New("unknown server selection strategy")
}
type SelectionFunc = func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error)
func SelectFirst(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
if len(dialers) == 0 {
return nil, errors.New("empty dialers list")
}
return dialers[0], nil
}
func SelectRandom(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
if len(dialers) == 0 {
return nil, errors.New("empty dialers list")
}
return dialers[rand.IntN(len(dialers))], nil
}
func probeDialer(ctx context.Context, dialer ContextDialer, url string, dlLimit int64) error {
httpClient := http.Client{
Transport: &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: dialer.DialContext,
},
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
resp, err := httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status code %d for URL %q", resp.StatusCode, url)
}
var rd io.Reader = resp.Body
if dlLimit > 0 {
rd = io.LimitReader(rd, dlLimit)
}
_, err = io.Copy(io.Discard, rd)
return err
}
func NewFastestServerSelectionFunc(url string, dlLimit int64) SelectionFunc {
return func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error) {
var resErr error
masterNotInterested := make(chan struct{})
defer close(masterNotInterested)
errors := make(chan error)
success := make(chan ContextDialer)
for _, dialer := range dialers {
go func(dialer ContextDialer) {
err := probeDialer(ctx, dialer, url, dlLimit)
if err == nil {
select {
case success <- dialer:
case <-masterNotInterested:
}
} else {
select {
case errors <- err:
case <-masterNotInterested:
}
}
}(dialer)
}
for _ = range dialers {
select {
case <-ctx.Done():
return nil, ctx.Err()
case d := <-success:
return d, nil
case err := <-errors:
resErr = multierror.Append(resErr, err)
}
}
return nil, resErr
}
}
+134 -74
View File
@@ -81,31 +81,52 @@ func (a *CSVArg) Set(line string) error {
return nil return nil
} }
type serverSelectionArg struct {
value dialer.ServerSelection
}
func (a *serverSelectionArg) Set(s string) error {
v, err := dialer.ParseServerSelection(s)
if err != nil {
return err
}
a.value = v
return nil
}
func (a *serverSelectionArg) String() string {
return a.value.String()
}
type CLIArgs struct { type CLIArgs struct {
country string country string
listCountries bool listCountries bool
listProxies bool listProxies bool
bindAddress string bindAddress string
socksMode bool socksMode bool
verbosity int verbosity int
timeout time.Duration timeout time.Duration
showVersion bool showVersion bool
proxy string proxy string
apiLogin string apiLogin string
apiPassword string apiPassword string
apiAddress string apiAddress string
apiClientType string apiClientType string
apiClientVersion string apiClientVersion string
apiUserAgent string apiUserAgent string
bootstrapDNS *CSVArg bootstrapDNS *CSVArg
refresh time.Duration refresh time.Duration
refreshRetry time.Duration refreshRetry time.Duration
initRetries int initRetries int
initRetryInterval time.Duration initRetryInterval time.Duration
certChainWorkaround bool certChainWorkaround bool
caFile string caFile string
fakeSNI string fakeSNI string
overrideProxyAddress string overrideProxyAddress string
serverSelection serverSelectionArg
serverSelectionTimeout time.Duration
serverSelectionTestURL string
serverSelectionDLLimit int64
} }
func parse_args() *CLIArgs { func parse_args() *CLIArgs {
@@ -123,6 +144,7 @@ func parse_args() *CLIArgs {
"https://doh.cleanbrowsing.org/doh/adult-filter/", "https://doh.cleanbrowsing.org/doh/adult-filter/",
}, },
}, },
serverSelection: serverSelectionArg{dialer.ServerSelectionFastest},
} }
flag.StringVar(&args.country, "country", "EU", "desired proxy location") flag.StringVar(&args.country, "country", "EU", "desired proxy location")
flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit") flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit")
@@ -155,6 +177,11 @@ func parse_args() *CLIArgs {
flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file") flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file")
flag.StringVar(&args.fakeSNI, "fake-SNI", "", "domain name to use as SNI in communications with servers") flag.StringVar(&args.fakeSNI, "fake-SNI", "", "domain name to use as SNI in communications with servers")
flag.StringVar(&args.overrideProxyAddress, "override-proxy-address", "", "use fixed proxy address instead of server address returned by SurfEasy API") flag.StringVar(&args.overrideProxyAddress, "override-proxy-address", "", "use fixed proxy address instead of server address returned by SurfEasy API")
flag.Var(&args.serverSelection, "server-selection", "server selection policy (first/random/fastest)")
flag.DurationVar(&args.serverSelectionTimeout, "server-selection-timeout", 30*time.Second, "timeout given for server selection function to produce result")
flag.StringVar(&args.serverSelectionTestURL, "server-selection-test-url", "https://ajax.googleapis.com/ajax/libs/angularjs/1.8.2/angular.min.js",
"URL used for download benchmark by fastest server selection policy")
flag.Int64Var(&args.serverSelectionDLLimit, "server-selection-dl-limit", 0, "restrict amount of downloaded data per connection by fastest server selection")
flag.Parse() flag.Parse()
if args.country == "" { if args.country == "" {
arg_fail("Country can't be empty string.") arg_fail("Country can't be empty string.")
@@ -200,6 +227,20 @@ func run() int {
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
} }
var caPool *x509.CertPool
if args.caFile != "" {
caPool = x509.NewCertPool()
certs, err := ioutil.ReadFile(args.caFile)
if err != nil {
mainLogger.Error("Can't load CA file: %v", err)
return 15
}
if ok := caPool.AppendCertsFromPEM(certs); !ok {
mainLogger.Error("Can't load certificates from CA file")
return 15
}
}
if args.proxy != "" { if args.proxy != "" {
xproxy.RegisterDialerType("http", proxyFromURLWrapper) xproxy.RegisterDialerType("http", proxyFromURLWrapper)
xproxy.RegisterDialerType("https", proxyFromURLWrapper) xproxy.RegisterDialerType("https", proxyFromURLWrapper)
@@ -218,7 +259,7 @@ func run() int {
seclientDialer := d seclientDialer := d
if args.apiAddress != "" { if args.apiAddress != "" {
mainLogger.Info("Using fixed API host IP address = %s", args.apiAddress) mainLogger.Info("Using fixed API host address = %s", args.apiAddress)
seclientDialer = dialer.NewFixedDialer(args.apiAddress, d) seclientDialer = dialer.NewFixedDialer(args.apiAddress, d)
} else if len(args.bootstrapDNS.values) > 0 { } else if len(args.bootstrapDNS.values) > 0 {
resolver, err := resolver.FastFromURLs(args.bootstrapDNS.values...) resolver, err := resolver.FastFromURLs(args.bootstrapDNS.values...)
@@ -282,28 +323,74 @@ func run() int {
return printCountries(try, mainLogger, args.timeout, seclient) return printCountries(try, mainLogger, args.timeout, seclient)
} }
handlerDialerFactory := func(endpointAddr string) dialer.ContextDialer {
return dialer.NewProxyDialer(
dialer.WrapStringToCb(endpointAddr),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
dialer.WrapStringToCb(args.fakeSNI),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
d)
}
var ips []se.SEIPEntry var ips []se.SEIPEntry
err = try("discover", func() error { var handlerDialer dialer.ContextDialer
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
defer cl() if args.overrideProxyAddress == "" || args.listProxies {
// TODO: learn about requested_geo value format err = try("discover", func() error {
res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country)) ctx, cl := context.WithTimeout(context.Background(), args.timeout)
ips = res defer cl()
return err // TODO: learn about requested_geo value format
}) res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country))
if err != nil { if err != nil {
return 12 return err
}
if len(res) == 0 {
return errors.New("empty endpoints list!")
}
if args.listProxies {
ips = res
return nil
}
var ss dialer.SelectionFunc
switch args.serverSelection.value {
case dialer.ServerSelectionFirst:
ss = dialer.SelectFirst
case dialer.ServerSelectionRandom:
ss = dialer.SelectRandom
case dialer.ServerSelectionFastest:
ss = dialer.NewFastestServerSelectionFunc(
args.serverSelectionTestURL,
args.serverSelectionDLLimit,
)
default:
panic("unhandled server selection value got past parsing")
}
dialers := make([]dialer.ContextDialer, len(res))
for i, ep := range res {
dialers[i] = handlerDialerFactory(ep.NetAddr())
}
ctx, cl = context.WithTimeout(context.Background(), args.serverSelectionTimeout)
defer cl()
handlerDialer, err = ss(ctx, dialers)
return err
})
if err != nil {
return 12
}
} else {
sanitizedEndpoint := sanitizeFixedProxyAddress(args.overrideProxyAddress)
handlerDialer = handlerDialerFactory(sanitizedEndpoint)
mainLogger.Info("Endpoint override: %s", sanitizedEndpoint)
} }
if args.listProxies { if args.listProxies {
return printProxies(ips, seclient) return printProxies(ips, seclient)
} }
if len(ips) == 0 {
mainLogger.Critical("Empty endpoint list!")
return 13
}
clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error { clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
mainLogger.Info("Refreshing login...") mainLogger.Info("Refreshing login...")
reqCtx, cl := context.WithTimeout(ctx, args.timeout) reqCtx, cl := context.WithTimeout(ctx, args.timeout)
@@ -327,40 +414,6 @@ func run() int {
return nil return nil
}) })
endpoint := ips[0]
var caPool *x509.CertPool
if args.caFile != "" {
caPool = x509.NewCertPool()
certs, err := ioutil.ReadFile(args.caFile)
if err != nil {
mainLogger.Error("Can't load CA file: %v", err)
return 15
}
if ok := caPool.AppendCertsFromPEM(certs); !ok {
mainLogger.Error("Can't load certificates from CA file")
return 15
}
}
var handlerBaseDialer = d
if args.overrideProxyAddress != "" {
mainLogger.Info("Original endpoint: %s", endpoint.IP)
handlerBaseDialer = dialer.NewFixedDialer(args.overrideProxyAddress, handlerBaseDialer)
mainLogger.Info("Endpoint override: %s", args.overrideProxyAddress)
} else {
mainLogger.Info("Endpoint: %s", endpoint.NetAddr())
}
handlerDialer := dialer.NewProxyDialer(
dialer.WrapStringToCb(endpoint.NetAddr()),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
dialer.WrapStringToCb(args.fakeSNI),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
handlerBaseDialer)
mainLogger.Info("Starting proxy server...") mainLogger.Info("Starting proxy server...")
if args.socksMode { if args.socksMode {
socks, initError := handler.NewSocksServer(handlerDialer, socksLogger) socks, initError := handler.NewSocksServer(handlerDialer, socksLogger)
@@ -423,6 +476,13 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int {
return 0 return 0
} }
func sanitizeFixedProxyAddress(addr string) string {
if _, _, err := net.SplitHostPort(addr); err == nil {
return addr
}
return net.JoinHostPort(addr, "443")
}
func main() { func main() {
os.Exit(run()) os.Exit(run())
} }