Merge pull request #102 from Snawoot/srv_select

Automatic server selection feature
This commit is contained in:
Snawoot
2025-09-14 23:19:34 +03:00
committed by GitHub
5 changed files with 286 additions and 74 deletions
+4
View File
@@ -109,6 +109,10 @@ eu3.sec-tunnel.com,77.111.244.22,443
| proxy | String | sets base proxy to use for all dial-outs. Format: `<http\|https\|socks5\|socks5h>://[login:password@]host[:port]` Examples: `http://user:password@192.168.1.1:3128`, `socks5://10.0.0.1:1080` |
| refresh | Duration | login refresh interval (default 4h0m0s) |
| refresh-retry | Duration | login refresh retry interval (default 5s) |
| server-selection | Enum | server selection policy (first/random/fastest) (default fastest) |
| server-selection-dl-limit | Number | restrict amount of downloaded data per connection by fastest server selection |
| server-selection-test-url | String | URL used for download benchmark by fastest server selection policy (default `https://ajax.googleapis.com/ajax/libs/angularjs/1.8.2/angular.min.js`) |
| server-selection-timeout | Duration | timeout given for server selection function to produce result (default 30s) |
| timeout | Duration | timeout for network operations (default 10s) |
| verbosity | Number | logging verbosity (10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical) (default 20) |
| version | - | show program version and exit |
+131
View File
@@ -0,0 +1,131 @@
package dialer
import (
"context"
"errors"
"fmt"
"io"
"math/rand/v2"
"net/http"
"strings"
"time"
"github.com/hashicorp/go-multierror"
)
type ServerSelection int
const (
_ = iota
ServerSelectionFirst
ServerSelectionRandom
ServerSelectionFastest
)
func (ss ServerSelection) String() string {
switch ss {
case ServerSelectionFirst:
return "first"
case ServerSelectionRandom:
return "random"
case ServerSelectionFastest:
return "fastest"
default:
return fmt.Sprintf("ServerSelection(%d)", int(ss))
}
}
func ParseServerSelection(s string) (ServerSelection, error) {
switch strings.ToLower(s) {
case "first":
return ServerSelectionFirst, nil
case "random":
return ServerSelectionRandom, nil
case "fastest":
return ServerSelectionFastest, nil
}
return 0, errors.New("unknown server selection strategy")
}
type SelectionFunc = func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error)
func SelectFirst(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
if len(dialers) == 0 {
return nil, errors.New("empty dialers list")
}
return dialers[0], nil
}
func SelectRandom(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
if len(dialers) == 0 {
return nil, errors.New("empty dialers list")
}
return dialers[rand.IntN(len(dialers))], nil
}
func probeDialer(ctx context.Context, dialer ContextDialer, url string, dlLimit int64) error {
httpClient := http.Client{
Transport: &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: dialer.DialContext,
},
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
resp, err := httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status code %d for URL %q", resp.StatusCode, url)
}
var rd io.Reader = resp.Body
if dlLimit > 0 {
rd = io.LimitReader(rd, dlLimit)
}
_, err = io.Copy(io.Discard, rd)
return err
}
func NewFastestServerSelectionFunc(url string, dlLimit int64) SelectionFunc {
return func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error) {
var resErr error
masterNotInterested := make(chan struct{})
defer close(masterNotInterested)
errors := make(chan error)
success := make(chan ContextDialer)
for _, dialer := range dialers {
go func(dialer ContextDialer) {
err := probeDialer(ctx, dialer, url, dlLimit)
if err == nil {
select {
case success <- dialer:
case <-masterNotInterested:
}
} else {
select {
case errors <- err:
case <-masterNotInterested:
}
}
}(dialer)
}
for _ = range dialers {
select {
case <-ctx.Done():
return nil, ctx.Err()
case d := <-success:
return d, nil
case err := <-errors:
resErr = multierror.Append(resErr, err)
}
}
return nil, resErr
}
}
+4
View File
@@ -221,6 +221,10 @@ func (d *ProxyDialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func (d *ProxyDialer) Address() (string, error) {
return d.address()
}
func readResponse(r io.Reader, req *http.Request) (*http.Response, error) {
endOfResponse := []byte("\r\n\r\n")
buf := &bytes.Buffer{}
+143 -74
View File
@@ -81,31 +81,52 @@ func (a *CSVArg) Set(line string) error {
return nil
}
type serverSelectionArg struct {
value dialer.ServerSelection
}
func (a *serverSelectionArg) Set(s string) error {
v, err := dialer.ParseServerSelection(s)
if err != nil {
return err
}
a.value = v
return nil
}
func (a *serverSelectionArg) String() string {
return a.value.String()
}
type CLIArgs struct {
country string
listCountries bool
listProxies bool
bindAddress string
socksMode bool
verbosity int
timeout time.Duration
showVersion bool
proxy string
apiLogin string
apiPassword string
apiAddress string
apiClientType string
apiClientVersion string
apiUserAgent string
bootstrapDNS *CSVArg
refresh time.Duration
refreshRetry time.Duration
initRetries int
initRetryInterval time.Duration
certChainWorkaround bool
caFile string
fakeSNI string
overrideProxyAddress string
country string
listCountries bool
listProxies bool
bindAddress string
socksMode bool
verbosity int
timeout time.Duration
showVersion bool
proxy string
apiLogin string
apiPassword string
apiAddress string
apiClientType string
apiClientVersion string
apiUserAgent string
bootstrapDNS *CSVArg
refresh time.Duration
refreshRetry time.Duration
initRetries int
initRetryInterval time.Duration
certChainWorkaround bool
caFile string
fakeSNI string
overrideProxyAddress string
serverSelection serverSelectionArg
serverSelectionTimeout time.Duration
serverSelectionTestURL string
serverSelectionDLLimit int64
}
func parse_args() *CLIArgs {
@@ -123,6 +144,7 @@ func parse_args() *CLIArgs {
"https://doh.cleanbrowsing.org/doh/adult-filter/",
},
},
serverSelection: serverSelectionArg{dialer.ServerSelectionFastest},
}
flag.StringVar(&args.country, "country", "EU", "desired proxy location")
flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit")
@@ -155,6 +177,11 @@ func parse_args() *CLIArgs {
flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file")
flag.StringVar(&args.fakeSNI, "fake-SNI", "", "domain name to use as SNI in communications with servers")
flag.StringVar(&args.overrideProxyAddress, "override-proxy-address", "", "use fixed proxy address instead of server address returned by SurfEasy API")
flag.Var(&args.serverSelection, "server-selection", "server selection policy (first/random/fastest)")
flag.DurationVar(&args.serverSelectionTimeout, "server-selection-timeout", 30*time.Second, "timeout given for server selection function to produce result")
flag.StringVar(&args.serverSelectionTestURL, "server-selection-test-url", "https://ajax.googleapis.com/ajax/libs/angularjs/1.8.2/angular.min.js",
"URL used for download benchmark by fastest server selection policy")
flag.Int64Var(&args.serverSelectionDLLimit, "server-selection-dl-limit", 0, "restrict amount of downloaded data per connection by fastest server selection")
flag.Parse()
if args.country == "" {
arg_fail("Country can't be empty string.")
@@ -200,6 +227,20 @@ func run() int {
KeepAlive: 30 * time.Second,
}
var caPool *x509.CertPool
if args.caFile != "" {
caPool = x509.NewCertPool()
certs, err := ioutil.ReadFile(args.caFile)
if err != nil {
mainLogger.Error("Can't load CA file: %v", err)
return 15
}
if ok := caPool.AppendCertsFromPEM(certs); !ok {
mainLogger.Error("Can't load certificates from CA file")
return 15
}
}
if args.proxy != "" {
xproxy.RegisterDialerType("http", proxyFromURLWrapper)
xproxy.RegisterDialerType("https", proxyFromURLWrapper)
@@ -218,7 +259,7 @@ func run() int {
seclientDialer := d
if args.apiAddress != "" {
mainLogger.Info("Using fixed API host IP address = %s", args.apiAddress)
mainLogger.Info("Using fixed API host address = %s", args.apiAddress)
seclientDialer = dialer.NewFixedDialer(args.apiAddress, d)
} else if len(args.bootstrapDNS.values) > 0 {
resolver, err := resolver.FastFromURLs(args.bootstrapDNS.values...)
@@ -282,28 +323,83 @@ func run() int {
return printCountries(try, mainLogger, args.timeout, seclient)
}
handlerDialerFactory := func(endpointAddr string) dialer.ContextDialer {
return dialer.NewProxyDialer(
dialer.WrapStringToCb(endpointAddr),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
dialer.WrapStringToCb(args.fakeSNI),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
d)
}
var ips []se.SEIPEntry
err = try("discover", func() error {
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
defer cl()
// TODO: learn about requested_geo value format
res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country))
ips = res
return err
})
if err != nil {
return 12
var handlerDialer dialer.ContextDialer
if args.overrideProxyAddress == "" || args.listProxies {
err = try("discover", func() error {
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
defer cl()
res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country))
if err != nil {
return err
}
if len(res) == 0 {
return errors.New("empty endpoints list!")
}
if args.listProxies {
ips = res
return nil
}
mainLogger.Info("Discovered endpoints: %v. Starting server selection routine %q.", res, args.serverSelection.value)
var ss dialer.SelectionFunc
switch args.serverSelection.value {
case dialer.ServerSelectionFirst:
ss = dialer.SelectFirst
case dialer.ServerSelectionRandom:
ss = dialer.SelectRandom
case dialer.ServerSelectionFastest:
ss = dialer.NewFastestServerSelectionFunc(
args.serverSelectionTestURL,
args.serverSelectionDLLimit,
)
default:
panic("unhandled server selection value got past parsing")
}
dialers := make([]dialer.ContextDialer, len(res))
for i, ep := range res {
dialers[i] = handlerDialerFactory(ep.NetAddr())
}
ctx, cl = context.WithTimeout(context.Background(), args.serverSelectionTimeout)
defer cl()
handlerDialer, err = ss(ctx, dialers)
if err != nil {
return err
}
if addresser, ok := handlerDialer.(interface{ Address() (string, error) }); ok {
if epAddr, err := addresser.Address(); err == nil {
mainLogger.Info("Selected endpoint address: %s", epAddr)
}
}
return nil
})
if err != nil {
return 12
}
} else {
sanitizedEndpoint := sanitizeFixedProxyAddress(args.overrideProxyAddress)
handlerDialer = handlerDialerFactory(sanitizedEndpoint)
mainLogger.Info("Endpoint override: %s", sanitizedEndpoint)
}
if args.listProxies {
return printProxies(ips, seclient)
}
if len(ips) == 0 {
mainLogger.Critical("Empty endpoint list!")
return 13
}
clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
mainLogger.Info("Refreshing login...")
reqCtx, cl := context.WithTimeout(ctx, args.timeout)
@@ -327,40 +423,6 @@ func run() int {
return nil
})
endpoint := ips[0]
var caPool *x509.CertPool
if args.caFile != "" {
caPool = x509.NewCertPool()
certs, err := ioutil.ReadFile(args.caFile)
if err != nil {
mainLogger.Error("Can't load CA file: %v", err)
return 15
}
if ok := caPool.AppendCertsFromPEM(certs); !ok {
mainLogger.Error("Can't load certificates from CA file")
return 15
}
}
var handlerBaseDialer = d
if args.overrideProxyAddress != "" {
mainLogger.Info("Original endpoint: %s", endpoint.IP)
handlerBaseDialer = dialer.NewFixedDialer(args.overrideProxyAddress, handlerBaseDialer)
mainLogger.Info("Endpoint override: %s", args.overrideProxyAddress)
} else {
mainLogger.Info("Endpoint: %s", endpoint.NetAddr())
}
handlerDialer := dialer.NewProxyDialer(
dialer.WrapStringToCb(endpoint.NetAddr()),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
dialer.WrapStringToCb(args.fakeSNI),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
handlerBaseDialer)
mainLogger.Info("Starting proxy server...")
if args.socksMode {
socks, initError := handler.NewSocksServer(handlerDialer, socksLogger)
@@ -423,6 +485,13 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int {
return 0
}
func sanitizeFixedProxyAddress(addr string) string {
if _, _, err := net.SplitHostPort(addr); err == nil {
return addr
}
return net.JoinHostPort(addr, "443")
}
func main() {
os.Exit(run())
}
+4
View File
@@ -97,6 +97,10 @@ func (e *SEIPEntry) NetAddr() string {
}
}
func (e SEIPEntry) String() string {
return e.NetAddr()
}
type SEDiscoverResponse struct {
Data struct {
IPs []SEIPEntry `json:"ips"`