basic implementation of server selection feature

This commit is contained in:
Vladislav Yarmak
2025-09-14 15:16:02 +03:00
parent 7bfaeb9878
commit 6eb2054faf
2 changed files with 265 additions and 74 deletions
+131
View File
@@ -0,0 +1,131 @@
package dialer
import (
"context"
"errors"
"fmt"
"io"
"math/rand/v2"
"net/http"
"strings"
"time"
"github.com/hashicorp/go-multierror"
)
type ServerSelection int
const (
_ = iota
ServerSelectionFirst
ServerSelectionRandom
ServerSelectionFastest
)
func (ss ServerSelection) String() string {
switch ss {
case ServerSelectionFirst:
return "first"
case ServerSelectionRandom:
return "random"
case ServerSelectionFastest:
return "fastest"
default:
return fmt.Sprintf("ServerSelection(%d)", int(ss))
}
}
func ParseServerSelection(s string) (ServerSelection, error) {
switch strings.ToLower(s) {
case "first":
return ServerSelectionFirst, nil
case "random":
return ServerSelectionRandom, nil
case "fastest":
return ServerSelectionFastest, nil
}
return 0, errors.New("unknown server selection strategy")
}
type SelectionFunc = func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error)
func SelectFirst(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
if len(dialers) == 0 {
return nil, errors.New("empty dialers list")
}
return dialers[0], nil
}
func SelectRandom(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
if len(dialers) == 0 {
return nil, errors.New("empty dialers list")
}
return dialers[rand.IntN(len(dialers))], nil
}
func probeDialer(ctx context.Context, dialer ContextDialer, url string, dlLimit int64) error {
httpClient := http.Client{
Transport: &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: dialer.DialContext,
},
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
resp, err := httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status code %d for URL %q", resp.StatusCode, url)
}
var rd io.Reader = resp.Body
if dlLimit > 0 {
rd = io.LimitReader(rd, dlLimit)
}
_, err = io.Copy(io.Discard, rd)
return err
}
func NewFastestServerSelectionFunc(url string, dlLimit int64) SelectionFunc {
return func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error) {
var resErr error
masterNotInterested := make(chan struct{})
defer close(masterNotInterested)
errors := make(chan error)
success := make(chan ContextDialer)
for _, dialer := range dialers {
go func(dialer ContextDialer) {
err := probeDialer(ctx, dialer, url, dlLimit)
if err == nil {
select {
case success <- dialer:
case <-masterNotInterested:
}
} else {
select {
case errors <- err:
case <-masterNotInterested:
}
}
}(dialer)
}
for _ = range dialers {
select {
case <-ctx.Done():
return nil, ctx.Err()
case d := <-success:
return d, nil
case err := <-errors:
resErr = multierror.Append(resErr, err)
}
}
return nil, resErr
}
}
+134 -74
View File
@@ -81,31 +81,52 @@ func (a *CSVArg) Set(line string) error {
return nil
}
type serverSelectionArg struct {
value dialer.ServerSelection
}
func (a *serverSelectionArg) Set(s string) error {
v, err := dialer.ParseServerSelection(s)
if err != nil {
return err
}
a.value = v
return nil
}
func (a *serverSelectionArg) String() string {
return a.value.String()
}
type CLIArgs struct {
country string
listCountries bool
listProxies bool
bindAddress string
socksMode bool
verbosity int
timeout time.Duration
showVersion bool
proxy string
apiLogin string
apiPassword string
apiAddress string
apiClientType string
apiClientVersion string
apiUserAgent string
bootstrapDNS *CSVArg
refresh time.Duration
refreshRetry time.Duration
initRetries int
initRetryInterval time.Duration
certChainWorkaround bool
caFile string
fakeSNI string
overrideProxyAddress string
country string
listCountries bool
listProxies bool
bindAddress string
socksMode bool
verbosity int
timeout time.Duration
showVersion bool
proxy string
apiLogin string
apiPassword string
apiAddress string
apiClientType string
apiClientVersion string
apiUserAgent string
bootstrapDNS *CSVArg
refresh time.Duration
refreshRetry time.Duration
initRetries int
initRetryInterval time.Duration
certChainWorkaround bool
caFile string
fakeSNI string
overrideProxyAddress string
serverSelection serverSelectionArg
serverSelectionTimeout time.Duration
serverSelectionTestURL string
serverSelectionDLLimit int64
}
func parse_args() *CLIArgs {
@@ -123,6 +144,7 @@ func parse_args() *CLIArgs {
"https://doh.cleanbrowsing.org/doh/adult-filter/",
},
},
serverSelection: serverSelectionArg{dialer.ServerSelectionFastest},
}
flag.StringVar(&args.country, "country", "EU", "desired proxy location")
flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit")
@@ -155,6 +177,11 @@ func parse_args() *CLIArgs {
flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file")
flag.StringVar(&args.fakeSNI, "fake-SNI", "", "domain name to use as SNI in communications with servers")
flag.StringVar(&args.overrideProxyAddress, "override-proxy-address", "", "use fixed proxy address instead of server address returned by SurfEasy API")
flag.Var(&args.serverSelection, "server-selection", "server selection policy (first/random/fastest)")
flag.DurationVar(&args.serverSelectionTimeout, "server-selection-timeout", 30*time.Second, "timeout given for server selection function to produce result")
flag.StringVar(&args.serverSelectionTestURL, "server-selection-test-url", "https://ajax.googleapis.com/ajax/libs/angularjs/1.8.2/angular.min.js",
"URL used for download benchmark by fastest server selection policy")
flag.Int64Var(&args.serverSelectionDLLimit, "server-selection-dl-limit", 0, "restrict amount of downloaded data per connection by fastest server selection")
flag.Parse()
if args.country == "" {
arg_fail("Country can't be empty string.")
@@ -200,6 +227,20 @@ func run() int {
KeepAlive: 30 * time.Second,
}
var caPool *x509.CertPool
if args.caFile != "" {
caPool = x509.NewCertPool()
certs, err := ioutil.ReadFile(args.caFile)
if err != nil {
mainLogger.Error("Can't load CA file: %v", err)
return 15
}
if ok := caPool.AppendCertsFromPEM(certs); !ok {
mainLogger.Error("Can't load certificates from CA file")
return 15
}
}
if args.proxy != "" {
xproxy.RegisterDialerType("http", proxyFromURLWrapper)
xproxy.RegisterDialerType("https", proxyFromURLWrapper)
@@ -218,7 +259,7 @@ func run() int {
seclientDialer := d
if args.apiAddress != "" {
mainLogger.Info("Using fixed API host IP address = %s", args.apiAddress)
mainLogger.Info("Using fixed API host address = %s", args.apiAddress)
seclientDialer = dialer.NewFixedDialer(args.apiAddress, d)
} else if len(args.bootstrapDNS.values) > 0 {
resolver, err := resolver.FastFromURLs(args.bootstrapDNS.values...)
@@ -282,28 +323,74 @@ func run() int {
return printCountries(try, mainLogger, args.timeout, seclient)
}
handlerDialerFactory := func(endpointAddr string) dialer.ContextDialer {
return dialer.NewProxyDialer(
dialer.WrapStringToCb(endpointAddr),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
dialer.WrapStringToCb(args.fakeSNI),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
d)
}
var ips []se.SEIPEntry
err = try("discover", func() error {
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
defer cl()
// TODO: learn about requested_geo value format
res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country))
ips = res
return err
})
if err != nil {
return 12
var handlerDialer dialer.ContextDialer
if args.overrideProxyAddress == "" || args.listProxies {
err = try("discover", func() error {
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
defer cl()
// TODO: learn about requested_geo value format
res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country))
if err != nil {
return err
}
if len(res) == 0 {
return errors.New("empty endpoints list!")
}
if args.listProxies {
ips = res
return nil
}
var ss dialer.SelectionFunc
switch args.serverSelection.value {
case dialer.ServerSelectionFirst:
ss = dialer.SelectFirst
case dialer.ServerSelectionRandom:
ss = dialer.SelectRandom
case dialer.ServerSelectionFastest:
ss = dialer.NewFastestServerSelectionFunc(
args.serverSelectionTestURL,
args.serverSelectionDLLimit,
)
default:
panic("unhandled server selection value got past parsing")
}
dialers := make([]dialer.ContextDialer, len(res))
for i, ep := range res {
dialers[i] = handlerDialerFactory(ep.NetAddr())
}
ctx, cl = context.WithTimeout(context.Background(), args.serverSelectionTimeout)
defer cl()
handlerDialer, err = ss(ctx, dialers)
return err
})
if err != nil {
return 12
}
} else {
sanitizedEndpoint := sanitizeFixedProxyAddress(args.overrideProxyAddress)
handlerDialer = handlerDialerFactory(sanitizedEndpoint)
mainLogger.Info("Endpoint override: %s", sanitizedEndpoint)
}
if args.listProxies {
return printProxies(ips, seclient)
}
if len(ips) == 0 {
mainLogger.Critical("Empty endpoint list!")
return 13
}
clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
mainLogger.Info("Refreshing login...")
reqCtx, cl := context.WithTimeout(ctx, args.timeout)
@@ -327,40 +414,6 @@ func run() int {
return nil
})
endpoint := ips[0]
var caPool *x509.CertPool
if args.caFile != "" {
caPool = x509.NewCertPool()
certs, err := ioutil.ReadFile(args.caFile)
if err != nil {
mainLogger.Error("Can't load CA file: %v", err)
return 15
}
if ok := caPool.AppendCertsFromPEM(certs); !ok {
mainLogger.Error("Can't load certificates from CA file")
return 15
}
}
var handlerBaseDialer = d
if args.overrideProxyAddress != "" {
mainLogger.Info("Original endpoint: %s", endpoint.IP)
handlerBaseDialer = dialer.NewFixedDialer(args.overrideProxyAddress, handlerBaseDialer)
mainLogger.Info("Endpoint override: %s", args.overrideProxyAddress)
} else {
mainLogger.Info("Endpoint: %s", endpoint.NetAddr())
}
handlerDialer := dialer.NewProxyDialer(
dialer.WrapStringToCb(endpoint.NetAddr()),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
dialer.WrapStringToCb(args.fakeSNI),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
handlerBaseDialer)
mainLogger.Info("Starting proxy server...")
if args.socksMode {
socks, initError := handler.NewSocksServer(handlerDialer, socksLogger)
@@ -423,6 +476,13 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int {
return 0
}
func sanitizeFixedProxyAddress(addr string) string {
if _, _, err := net.SplitHostPort(addr); err == nil {
return addr
}
return net.JoinHostPort(addr, "443")
}
func main() {
os.Exit(run())
}