From 6eb2054faf5c1022f9fb69246577b74700ee5013 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Sun, 14 Sep 2025 15:16:02 +0300 Subject: [PATCH 1/3] basic implementation of server selection feature --- dialer/selection.go | 131 ++++++++++++++++++++++++++++ main.go | 208 ++++++++++++++++++++++++++++---------------- 2 files changed, 265 insertions(+), 74 deletions(-) create mode 100644 dialer/selection.go diff --git a/dialer/selection.go b/dialer/selection.go new file mode 100644 index 0000000..75dd7fc --- /dev/null +++ b/dialer/selection.go @@ -0,0 +1,131 @@ +package dialer + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand/v2" + "net/http" + "strings" + "time" + + "github.com/hashicorp/go-multierror" +) + +type ServerSelection int + +const ( + _ = iota + ServerSelectionFirst + ServerSelectionRandom + ServerSelectionFastest +) + +func (ss ServerSelection) String() string { + switch ss { + case ServerSelectionFirst: + return "first" + case ServerSelectionRandom: + return "random" + case ServerSelectionFastest: + return "fastest" + default: + return fmt.Sprintf("ServerSelection(%d)", int(ss)) + } +} + +func ParseServerSelection(s string) (ServerSelection, error) { + switch strings.ToLower(s) { + case "first": + return ServerSelectionFirst, nil + case "random": + return ServerSelectionRandom, nil + case "fastest": + return ServerSelectionFastest, nil + } + return 0, errors.New("unknown server selection strategy") +} + +type SelectionFunc = func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error) + +func SelectFirst(_ context.Context, dialers []ContextDialer) (ContextDialer, error) { + if len(dialers) == 0 { + return nil, errors.New("empty dialers list") + } + return dialers[0], nil +} + +func SelectRandom(_ context.Context, dialers []ContextDialer) (ContextDialer, error) { + if len(dialers) == 0 { + return nil, errors.New("empty dialers list") + } + return dialers[rand.IntN(len(dialers))], nil +} + +func probeDialer(ctx context.Context, dialer ContextDialer, url string, dlLimit int64) error { + httpClient := http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DialContext: dialer.DialContext, + }, + } + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return err + } + resp, err := httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad status code %d for URL %q", resp.StatusCode, url) + } + var rd io.Reader = resp.Body + if dlLimit > 0 { + rd = io.LimitReader(rd, dlLimit) + } + _, err = io.Copy(io.Discard, rd) + return err +} + +func NewFastestServerSelectionFunc(url string, dlLimit int64) SelectionFunc { + return func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error) { + var resErr error + masterNotInterested := make(chan struct{}) + defer close(masterNotInterested) + errors := make(chan error) + success := make(chan ContextDialer) + for _, dialer := range dialers { + go func(dialer ContextDialer) { + err := probeDialer(ctx, dialer, url, dlLimit) + if err == nil { + select { + case success <- dialer: + case <-masterNotInterested: + } + } else { + select { + case errors <- err: + case <-masterNotInterested: + } + } + }(dialer) + } + for _ = range dialers { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case d := <-success: + return d, nil + case err := <-errors: + resErr = multierror.Append(resErr, err) + } + } + return nil, resErr + } +} diff --git a/main.go b/main.go index 988def2..2c4fbb0 100644 --- a/main.go +++ b/main.go @@ -81,31 +81,52 @@ func (a *CSVArg) Set(line string) error { return nil } +type serverSelectionArg struct { + value dialer.ServerSelection +} + +func (a *serverSelectionArg) Set(s string) error { + v, err := dialer.ParseServerSelection(s) + if err != nil { + return err + } + a.value = v + return nil +} + +func (a *serverSelectionArg) String() string { + return a.value.String() +} + type CLIArgs struct { - country string - listCountries bool - listProxies bool - bindAddress string - socksMode bool - verbosity int - timeout time.Duration - showVersion bool - proxy string - apiLogin string - apiPassword string - apiAddress string - apiClientType string - apiClientVersion string - apiUserAgent string - bootstrapDNS *CSVArg - refresh time.Duration - refreshRetry time.Duration - initRetries int - initRetryInterval time.Duration - certChainWorkaround bool - caFile string - fakeSNI string - overrideProxyAddress string + country string + listCountries bool + listProxies bool + bindAddress string + socksMode bool + verbosity int + timeout time.Duration + showVersion bool + proxy string + apiLogin string + apiPassword string + apiAddress string + apiClientType string + apiClientVersion string + apiUserAgent string + bootstrapDNS *CSVArg + refresh time.Duration + refreshRetry time.Duration + initRetries int + initRetryInterval time.Duration + certChainWorkaround bool + caFile string + fakeSNI string + overrideProxyAddress string + serverSelection serverSelectionArg + serverSelectionTimeout time.Duration + serverSelectionTestURL string + serverSelectionDLLimit int64 } func parse_args() *CLIArgs { @@ -123,6 +144,7 @@ func parse_args() *CLIArgs { "https://doh.cleanbrowsing.org/doh/adult-filter/", }, }, + serverSelection: serverSelectionArg{dialer.ServerSelectionFastest}, } flag.StringVar(&args.country, "country", "EU", "desired proxy location") flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit") @@ -155,6 +177,11 @@ func parse_args() *CLIArgs { flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file") flag.StringVar(&args.fakeSNI, "fake-SNI", "", "domain name to use as SNI in communications with servers") flag.StringVar(&args.overrideProxyAddress, "override-proxy-address", "", "use fixed proxy address instead of server address returned by SurfEasy API") + flag.Var(&args.serverSelection, "server-selection", "server selection policy (first/random/fastest)") + flag.DurationVar(&args.serverSelectionTimeout, "server-selection-timeout", 30*time.Second, "timeout given for server selection function to produce result") + flag.StringVar(&args.serverSelectionTestURL, "server-selection-test-url", "https://ajax.googleapis.com/ajax/libs/angularjs/1.8.2/angular.min.js", + "URL used for download benchmark by fastest server selection policy") + flag.Int64Var(&args.serverSelectionDLLimit, "server-selection-dl-limit", 0, "restrict amount of downloaded data per connection by fastest server selection") flag.Parse() if args.country == "" { arg_fail("Country can't be empty string.") @@ -200,6 +227,20 @@ func run() int { KeepAlive: 30 * time.Second, } + var caPool *x509.CertPool + if args.caFile != "" { + caPool = x509.NewCertPool() + certs, err := ioutil.ReadFile(args.caFile) + if err != nil { + mainLogger.Error("Can't load CA file: %v", err) + return 15 + } + if ok := caPool.AppendCertsFromPEM(certs); !ok { + mainLogger.Error("Can't load certificates from CA file") + return 15 + } + } + if args.proxy != "" { xproxy.RegisterDialerType("http", proxyFromURLWrapper) xproxy.RegisterDialerType("https", proxyFromURLWrapper) @@ -218,7 +259,7 @@ func run() int { seclientDialer := d if args.apiAddress != "" { - mainLogger.Info("Using fixed API host IP address = %s", args.apiAddress) + mainLogger.Info("Using fixed API host address = %s", args.apiAddress) seclientDialer = dialer.NewFixedDialer(args.apiAddress, d) } else if len(args.bootstrapDNS.values) > 0 { resolver, err := resolver.FastFromURLs(args.bootstrapDNS.values...) @@ -282,28 +323,74 @@ func run() int { return printCountries(try, mainLogger, args.timeout, seclient) } + handlerDialerFactory := func(endpointAddr string) dialer.ContextDialer { + return dialer.NewProxyDialer( + dialer.WrapStringToCb(endpointAddr), + dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)), + dialer.WrapStringToCb(args.fakeSNI), + func() (string, error) { + return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil + }, + args.certChainWorkaround, + caPool, + d) + } + var ips []se.SEIPEntry - err = try("discover", func() error { - ctx, cl := context.WithTimeout(context.Background(), args.timeout) - defer cl() - // TODO: learn about requested_geo value format - res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country)) - ips = res - return err - }) - if err != nil { - return 12 + var handlerDialer dialer.ContextDialer + + if args.overrideProxyAddress == "" || args.listProxies { + err = try("discover", func() error { + ctx, cl := context.WithTimeout(context.Background(), args.timeout) + defer cl() + // TODO: learn about requested_geo value format + res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country)) + if err != nil { + return err + } + if len(res) == 0 { + return errors.New("empty endpoints list!") + } + if args.listProxies { + ips = res + return nil + } + var ss dialer.SelectionFunc + switch args.serverSelection.value { + case dialer.ServerSelectionFirst: + ss = dialer.SelectFirst + case dialer.ServerSelectionRandom: + ss = dialer.SelectRandom + case dialer.ServerSelectionFastest: + ss = dialer.NewFastestServerSelectionFunc( + args.serverSelectionTestURL, + args.serverSelectionDLLimit, + ) + default: + panic("unhandled server selection value got past parsing") + } + dialers := make([]dialer.ContextDialer, len(res)) + for i, ep := range res { + dialers[i] = handlerDialerFactory(ep.NetAddr()) + } + ctx, cl = context.WithTimeout(context.Background(), args.serverSelectionTimeout) + defer cl() + handlerDialer, err = ss(ctx, dialers) + return err + }) + if err != nil { + return 12 + } + } else { + sanitizedEndpoint := sanitizeFixedProxyAddress(args.overrideProxyAddress) + handlerDialer = handlerDialerFactory(sanitizedEndpoint) + mainLogger.Info("Endpoint override: %s", sanitizedEndpoint) } if args.listProxies { return printProxies(ips, seclient) } - if len(ips) == 0 { - mainLogger.Critical("Empty endpoint list!") - return 13 - } - clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error { mainLogger.Info("Refreshing login...") reqCtx, cl := context.WithTimeout(ctx, args.timeout) @@ -327,40 +414,6 @@ func run() int { return nil }) - endpoint := ips[0] - - var caPool *x509.CertPool - if args.caFile != "" { - caPool = x509.NewCertPool() - certs, err := ioutil.ReadFile(args.caFile) - if err != nil { - mainLogger.Error("Can't load CA file: %v", err) - return 15 - } - if ok := caPool.AppendCertsFromPEM(certs); !ok { - mainLogger.Error("Can't load certificates from CA file") - return 15 - } - } - - var handlerBaseDialer = d - if args.overrideProxyAddress != "" { - mainLogger.Info("Original endpoint: %s", endpoint.IP) - handlerBaseDialer = dialer.NewFixedDialer(args.overrideProxyAddress, handlerBaseDialer) - mainLogger.Info("Endpoint override: %s", args.overrideProxyAddress) - } else { - mainLogger.Info("Endpoint: %s", endpoint.NetAddr()) - } - handlerDialer := dialer.NewProxyDialer( - dialer.WrapStringToCb(endpoint.NetAddr()), - dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)), - dialer.WrapStringToCb(args.fakeSNI), - func() (string, error) { - return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil - }, - args.certChainWorkaround, - caPool, - handlerBaseDialer) mainLogger.Info("Starting proxy server...") if args.socksMode { socks, initError := handler.NewSocksServer(handlerDialer, socksLogger) @@ -423,6 +476,13 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int { return 0 } +func sanitizeFixedProxyAddress(addr string) string { + if _, _, err := net.SplitHostPort(addr); err == nil { + return addr + } + return net.JoinHostPort(addr, "443") +} + func main() { os.Exit(run()) } From ddea656d9409706a28c49493ad1b42210e69651d Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Sun, 14 Sep 2025 22:56:19 +0300 Subject: [PATCH 2/3] proper logging of selected endpoint --- dialer/upstream.go | 4 ++++ main.go | 13 +++++++++++-- seclient/messages.go | 4 ++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/dialer/upstream.go b/dialer/upstream.go index 7ec045c..aba7226 100644 --- a/dialer/upstream.go +++ b/dialer/upstream.go @@ -221,6 +221,10 @@ func (d *ProxyDialer) Dial(network, address string) (net.Conn, error) { return d.DialContext(context.Background(), network, address) } +func (d *ProxyDialer) Address() (string, error) { + return d.address() +} + func readResponse(r io.Reader, req *http.Request) (*http.Response, error) { endOfResponse := []byte("\r\n\r\n") buf := &bytes.Buffer{} diff --git a/main.go b/main.go index 2c4fbb0..a895b9e 100644 --- a/main.go +++ b/main.go @@ -343,7 +343,6 @@ func run() int { err = try("discover", func() error { ctx, cl := context.WithTimeout(context.Background(), args.timeout) defer cl() - // TODO: learn about requested_geo value format res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country)) if err != nil { return err @@ -355,6 +354,8 @@ func run() int { ips = res return nil } + + mainLogger.Info("Discovered endpoints: %v. Starting server selection routine %q.", res, args.serverSelection.value) var ss dialer.SelectionFunc switch args.serverSelection.value { case dialer.ServerSelectionFirst: @@ -376,7 +377,15 @@ func run() int { ctx, cl = context.WithTimeout(context.Background(), args.serverSelectionTimeout) defer cl() handlerDialer, err = ss(ctx, dialers) - return err + if err != nil { + return err + } + if addresser, ok := handlerDialer.(interface{ Address() (string, error) }); ok { + if epAddr, err := addresser.Address(); err == nil { + mainLogger.Info("Selected endpoint address: %s", epAddr) + } + } + return nil }) if err != nil { return 12 diff --git a/seclient/messages.go b/seclient/messages.go index 36fa3a1..b4d9f9f 100644 --- a/seclient/messages.go +++ b/seclient/messages.go @@ -97,6 +97,10 @@ func (e *SEIPEntry) NetAddr() string { } } +func (e SEIPEntry) String() string { + return e.NetAddr() +} + type SEDiscoverResponse struct { Data struct { IPs []SEIPEntry `json:"ips"` From 389231de6f4eb46ce3e7ad52f6c959e0bb756bc3 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Sun, 14 Sep 2025 23:13:21 +0300 Subject: [PATCH 3/3] upd doc --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 4e20ee2..ee338d3 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,10 @@ eu3.sec-tunnel.com,77.111.244.22,443 | proxy | String | sets base proxy to use for all dial-outs. Format: `://[login:password@]host[:port]` Examples: `http://user:password@192.168.1.1:3128`, `socks5://10.0.0.1:1080` | | refresh | Duration | login refresh interval (default 4h0m0s) | | refresh-retry | Duration | login refresh retry interval (default 5s) | +| server-selection | Enum | server selection policy (first/random/fastest) (default fastest) | +| server-selection-dl-limit | Number | restrict amount of downloaded data per connection by fastest server selection | +| server-selection-test-url | String | URL used for download benchmark by fastest server selection policy (default `https://ajax.googleapis.com/ajax/libs/angularjs/1.8.2/angular.min.js`) | +| server-selection-timeout | Duration | timeout given for server selection function to produce result (default 30s) | | timeout | Duration | timeout for network operations (default 10s) | | verbosity | Number | logging verbosity (10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical) (default 20) | | version | - | show program version and exit |