Merge pull request #102 from Snawoot/srv_select

Automatic server selection feature
This commit is contained in:
Snawoot
2025-09-14 23:19:34 +03:00
committed by GitHub
5 changed files with 286 additions and 74 deletions
+4
View File
@@ -109,6 +109,10 @@ eu3.sec-tunnel.com,77.111.244.22,443
| proxy | String | sets base proxy to use for all dial-outs. Format: `<http\|https\|socks5\|socks5h>://[login:password@]host[:port]` Examples: `http://user:password@192.168.1.1:3128`, `socks5://10.0.0.1:1080` | | proxy | String | sets base proxy to use for all dial-outs. Format: `<http\|https\|socks5\|socks5h>://[login:password@]host[:port]` Examples: `http://user:password@192.168.1.1:3128`, `socks5://10.0.0.1:1080` |
| refresh | Duration | login refresh interval (default 4h0m0s) | | refresh | Duration | login refresh interval (default 4h0m0s) |
| refresh-retry | Duration | login refresh retry interval (default 5s) | | refresh-retry | Duration | login refresh retry interval (default 5s) |
| server-selection | Enum | server selection policy (first/random/fastest) (default fastest) |
| server-selection-dl-limit | Number | restrict amount of downloaded data per connection by fastest server selection |
| server-selection-test-url | String | URL used for download benchmark by fastest server selection policy (default `https://ajax.googleapis.com/ajax/libs/angularjs/1.8.2/angular.min.js`) |
| server-selection-timeout | Duration | timeout given for server selection function to produce result (default 30s) |
| timeout | Duration | timeout for network operations (default 10s) | | timeout | Duration | timeout for network operations (default 10s) |
| verbosity | Number | logging verbosity (10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical) (default 20) | | verbosity | Number | logging verbosity (10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical) (default 20) |
| version | - | show program version and exit | | version | - | show program version and exit |
+131
View File
@@ -0,0 +1,131 @@
package dialer
import (
"context"
"errors"
"fmt"
"io"
"math/rand/v2"
"net/http"
"strings"
"time"
"github.com/hashicorp/go-multierror"
)
type ServerSelection int
const (
_ = iota
ServerSelectionFirst
ServerSelectionRandom
ServerSelectionFastest
)
func (ss ServerSelection) String() string {
switch ss {
case ServerSelectionFirst:
return "first"
case ServerSelectionRandom:
return "random"
case ServerSelectionFastest:
return "fastest"
default:
return fmt.Sprintf("ServerSelection(%d)", int(ss))
}
}
func ParseServerSelection(s string) (ServerSelection, error) {
switch strings.ToLower(s) {
case "first":
return ServerSelectionFirst, nil
case "random":
return ServerSelectionRandom, nil
case "fastest":
return ServerSelectionFastest, nil
}
return 0, errors.New("unknown server selection strategy")
}
type SelectionFunc = func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error)
func SelectFirst(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
if len(dialers) == 0 {
return nil, errors.New("empty dialers list")
}
return dialers[0], nil
}
func SelectRandom(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
if len(dialers) == 0 {
return nil, errors.New("empty dialers list")
}
return dialers[rand.IntN(len(dialers))], nil
}
func probeDialer(ctx context.Context, dialer ContextDialer, url string, dlLimit int64) error {
httpClient := http.Client{
Transport: &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: dialer.DialContext,
},
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
resp, err := httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status code %d for URL %q", resp.StatusCode, url)
}
var rd io.Reader = resp.Body
if dlLimit > 0 {
rd = io.LimitReader(rd, dlLimit)
}
_, err = io.Copy(io.Discard, rd)
return err
}
func NewFastestServerSelectionFunc(url string, dlLimit int64) SelectionFunc {
return func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error) {
var resErr error
masterNotInterested := make(chan struct{})
defer close(masterNotInterested)
errors := make(chan error)
success := make(chan ContextDialer)
for _, dialer := range dialers {
go func(dialer ContextDialer) {
err := probeDialer(ctx, dialer, url, dlLimit)
if err == nil {
select {
case success <- dialer:
case <-masterNotInterested:
}
} else {
select {
case errors <- err:
case <-masterNotInterested:
}
}
}(dialer)
}
for _ = range dialers {
select {
case <-ctx.Done():
return nil, ctx.Err()
case d := <-success:
return d, nil
case err := <-errors:
resErr = multierror.Append(resErr, err)
}
}
return nil, resErr
}
}
+4
View File
@@ -221,6 +221,10 @@ func (d *ProxyDialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address) return d.DialContext(context.Background(), network, address)
} }
func (d *ProxyDialer) Address() (string, error) {
return d.address()
}
func readResponse(r io.Reader, req *http.Request) (*http.Response, error) { func readResponse(r io.Reader, req *http.Request) (*http.Response, error) {
endOfResponse := []byte("\r\n\r\n") endOfResponse := []byte("\r\n\r\n")
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
+111 -42
View File
@@ -81,6 +81,23 @@ func (a *CSVArg) Set(line string) error {
return nil return nil
} }
type serverSelectionArg struct {
value dialer.ServerSelection
}
func (a *serverSelectionArg) Set(s string) error {
v, err := dialer.ParseServerSelection(s)
if err != nil {
return err
}
a.value = v
return nil
}
func (a *serverSelectionArg) String() string {
return a.value.String()
}
type CLIArgs struct { type CLIArgs struct {
country string country string
listCountries bool listCountries bool
@@ -106,6 +123,10 @@ type CLIArgs struct {
caFile string caFile string
fakeSNI string fakeSNI string
overrideProxyAddress string overrideProxyAddress string
serverSelection serverSelectionArg
serverSelectionTimeout time.Duration
serverSelectionTestURL string
serverSelectionDLLimit int64
} }
func parse_args() *CLIArgs { func parse_args() *CLIArgs {
@@ -123,6 +144,7 @@ func parse_args() *CLIArgs {
"https://doh.cleanbrowsing.org/doh/adult-filter/", "https://doh.cleanbrowsing.org/doh/adult-filter/",
}, },
}, },
serverSelection: serverSelectionArg{dialer.ServerSelectionFastest},
} }
flag.StringVar(&args.country, "country", "EU", "desired proxy location") flag.StringVar(&args.country, "country", "EU", "desired proxy location")
flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit") flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit")
@@ -155,6 +177,11 @@ func parse_args() *CLIArgs {
flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file") flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file")
flag.StringVar(&args.fakeSNI, "fake-SNI", "", "domain name to use as SNI in communications with servers") flag.StringVar(&args.fakeSNI, "fake-SNI", "", "domain name to use as SNI in communications with servers")
flag.StringVar(&args.overrideProxyAddress, "override-proxy-address", "", "use fixed proxy address instead of server address returned by SurfEasy API") flag.StringVar(&args.overrideProxyAddress, "override-proxy-address", "", "use fixed proxy address instead of server address returned by SurfEasy API")
flag.Var(&args.serverSelection, "server-selection", "server selection policy (first/random/fastest)")
flag.DurationVar(&args.serverSelectionTimeout, "server-selection-timeout", 30*time.Second, "timeout given for server selection function to produce result")
flag.StringVar(&args.serverSelectionTestURL, "server-selection-test-url", "https://ajax.googleapis.com/ajax/libs/angularjs/1.8.2/angular.min.js",
"URL used for download benchmark by fastest server selection policy")
flag.Int64Var(&args.serverSelectionDLLimit, "server-selection-dl-limit", 0, "restrict amount of downloaded data per connection by fastest server selection")
flag.Parse() flag.Parse()
if args.country == "" { if args.country == "" {
arg_fail("Country can't be empty string.") arg_fail("Country can't be empty string.")
@@ -200,6 +227,20 @@ func run() int {
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
} }
var caPool *x509.CertPool
if args.caFile != "" {
caPool = x509.NewCertPool()
certs, err := ioutil.ReadFile(args.caFile)
if err != nil {
mainLogger.Error("Can't load CA file: %v", err)
return 15
}
if ok := caPool.AppendCertsFromPEM(certs); !ok {
mainLogger.Error("Can't load certificates from CA file")
return 15
}
}
if args.proxy != "" { if args.proxy != "" {
xproxy.RegisterDialerType("http", proxyFromURLWrapper) xproxy.RegisterDialerType("http", proxyFromURLWrapper)
xproxy.RegisterDialerType("https", proxyFromURLWrapper) xproxy.RegisterDialerType("https", proxyFromURLWrapper)
@@ -218,7 +259,7 @@ func run() int {
seclientDialer := d seclientDialer := d
if args.apiAddress != "" { if args.apiAddress != "" {
mainLogger.Info("Using fixed API host IP address = %s", args.apiAddress) mainLogger.Info("Using fixed API host address = %s", args.apiAddress)
seclientDialer = dialer.NewFixedDialer(args.apiAddress, d) seclientDialer = dialer.NewFixedDialer(args.apiAddress, d)
} else if len(args.bootstrapDNS.values) > 0 { } else if len(args.bootstrapDNS.values) > 0 {
resolver, err := resolver.FastFromURLs(args.bootstrapDNS.values...) resolver, err := resolver.FastFromURLs(args.bootstrapDNS.values...)
@@ -282,28 +323,83 @@ func run() int {
return printCountries(try, mainLogger, args.timeout, seclient) return printCountries(try, mainLogger, args.timeout, seclient)
} }
handlerDialerFactory := func(endpointAddr string) dialer.ContextDialer {
return dialer.NewProxyDialer(
dialer.WrapStringToCb(endpointAddr),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
dialer.WrapStringToCb(args.fakeSNI),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
d)
}
var ips []se.SEIPEntry var ips []se.SEIPEntry
var handlerDialer dialer.ContextDialer
if args.overrideProxyAddress == "" || args.listProxies {
err = try("discover", func() error { err = try("discover", func() error {
ctx, cl := context.WithTimeout(context.Background(), args.timeout) ctx, cl := context.WithTimeout(context.Background(), args.timeout)
defer cl() defer cl()
// TODO: learn about requested_geo value format
res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country)) res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country))
ips = res if err != nil {
return err return err
}
if len(res) == 0 {
return errors.New("empty endpoints list!")
}
if args.listProxies {
ips = res
return nil
}
mainLogger.Info("Discovered endpoints: %v. Starting server selection routine %q.", res, args.serverSelection.value)
var ss dialer.SelectionFunc
switch args.serverSelection.value {
case dialer.ServerSelectionFirst:
ss = dialer.SelectFirst
case dialer.ServerSelectionRandom:
ss = dialer.SelectRandom
case dialer.ServerSelectionFastest:
ss = dialer.NewFastestServerSelectionFunc(
args.serverSelectionTestURL,
args.serverSelectionDLLimit,
)
default:
panic("unhandled server selection value got past parsing")
}
dialers := make([]dialer.ContextDialer, len(res))
for i, ep := range res {
dialers[i] = handlerDialerFactory(ep.NetAddr())
}
ctx, cl = context.WithTimeout(context.Background(), args.serverSelectionTimeout)
defer cl()
handlerDialer, err = ss(ctx, dialers)
if err != nil {
return err
}
if addresser, ok := handlerDialer.(interface{ Address() (string, error) }); ok {
if epAddr, err := addresser.Address(); err == nil {
mainLogger.Info("Selected endpoint address: %s", epAddr)
}
}
return nil
}) })
if err != nil { if err != nil {
return 12 return 12
} }
} else {
sanitizedEndpoint := sanitizeFixedProxyAddress(args.overrideProxyAddress)
handlerDialer = handlerDialerFactory(sanitizedEndpoint)
mainLogger.Info("Endpoint override: %s", sanitizedEndpoint)
}
if args.listProxies { if args.listProxies {
return printProxies(ips, seclient) return printProxies(ips, seclient)
} }
if len(ips) == 0 {
mainLogger.Critical("Empty endpoint list!")
return 13
}
clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error { clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
mainLogger.Info("Refreshing login...") mainLogger.Info("Refreshing login...")
reqCtx, cl := context.WithTimeout(ctx, args.timeout) reqCtx, cl := context.WithTimeout(ctx, args.timeout)
@@ -327,40 +423,6 @@ func run() int {
return nil return nil
}) })
endpoint := ips[0]
var caPool *x509.CertPool
if args.caFile != "" {
caPool = x509.NewCertPool()
certs, err := ioutil.ReadFile(args.caFile)
if err != nil {
mainLogger.Error("Can't load CA file: %v", err)
return 15
}
if ok := caPool.AppendCertsFromPEM(certs); !ok {
mainLogger.Error("Can't load certificates from CA file")
return 15
}
}
var handlerBaseDialer = d
if args.overrideProxyAddress != "" {
mainLogger.Info("Original endpoint: %s", endpoint.IP)
handlerBaseDialer = dialer.NewFixedDialer(args.overrideProxyAddress, handlerBaseDialer)
mainLogger.Info("Endpoint override: %s", args.overrideProxyAddress)
} else {
mainLogger.Info("Endpoint: %s", endpoint.NetAddr())
}
handlerDialer := dialer.NewProxyDialer(
dialer.WrapStringToCb(endpoint.NetAddr()),
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
dialer.WrapStringToCb(args.fakeSNI),
func() (string, error) {
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
},
args.certChainWorkaround,
caPool,
handlerBaseDialer)
mainLogger.Info("Starting proxy server...") mainLogger.Info("Starting proxy server...")
if args.socksMode { if args.socksMode {
socks, initError := handler.NewSocksServer(handlerDialer, socksLogger) socks, initError := handler.NewSocksServer(handlerDialer, socksLogger)
@@ -423,6 +485,13 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int {
return 0 return 0
} }
func sanitizeFixedProxyAddress(addr string) string {
if _, _, err := net.SplitHostPort(addr); err == nil {
return addr
}
return net.JoinHostPort(addr, "443")
}
func main() { func main() {
os.Exit(run()) os.Exit(run())
} }
+4
View File
@@ -97,6 +97,10 @@ func (e *SEIPEntry) NetAddr() string {
} }
} }
func (e SEIPEntry) String() string {
return e.NetAddr()
}
type SEDiscoverResponse struct { type SEDiscoverResponse struct {
Data struct { Data struct {
IPs []SEIPEntry `json:"ips"` IPs []SEIPEntry `json:"ips"`