mirror of
https://github.com/Alexey71/opera-proxy.git
synced 2026-05-13 22:20:59 +00:00
basic implementation of server selection feature
This commit is contained in:
@@ -0,0 +1,131 @@
|
|||||||
|
package dialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand/v2"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerSelection int
|
||||||
|
|
||||||
|
const (
|
||||||
|
_ = iota
|
||||||
|
ServerSelectionFirst
|
||||||
|
ServerSelectionRandom
|
||||||
|
ServerSelectionFastest
|
||||||
|
)
|
||||||
|
|
||||||
|
func (ss ServerSelection) String() string {
|
||||||
|
switch ss {
|
||||||
|
case ServerSelectionFirst:
|
||||||
|
return "first"
|
||||||
|
case ServerSelectionRandom:
|
||||||
|
return "random"
|
||||||
|
case ServerSelectionFastest:
|
||||||
|
return "fastest"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("ServerSelection(%d)", int(ss))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseServerSelection(s string) (ServerSelection, error) {
|
||||||
|
switch strings.ToLower(s) {
|
||||||
|
case "first":
|
||||||
|
return ServerSelectionFirst, nil
|
||||||
|
case "random":
|
||||||
|
return ServerSelectionRandom, nil
|
||||||
|
case "fastest":
|
||||||
|
return ServerSelectionFastest, nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("unknown server selection strategy")
|
||||||
|
}
|
||||||
|
|
||||||
|
type SelectionFunc = func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error)
|
||||||
|
|
||||||
|
func SelectFirst(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
|
||||||
|
if len(dialers) == 0 {
|
||||||
|
return nil, errors.New("empty dialers list")
|
||||||
|
}
|
||||||
|
return dialers[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SelectRandom(_ context.Context, dialers []ContextDialer) (ContextDialer, error) {
|
||||||
|
if len(dialers) == 0 {
|
||||||
|
return nil, errors.New("empty dialers list")
|
||||||
|
}
|
||||||
|
return dialers[rand.IntN(len(dialers))], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func probeDialer(ctx context.Context, dialer ContextDialer, url string, dlLimit int64) error {
|
||||||
|
httpClient := http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
DialContext: dialer.DialContext,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("bad status code %d for URL %q", resp.StatusCode, url)
|
||||||
|
}
|
||||||
|
var rd io.Reader = resp.Body
|
||||||
|
if dlLimit > 0 {
|
||||||
|
rd = io.LimitReader(rd, dlLimit)
|
||||||
|
}
|
||||||
|
_, err = io.Copy(io.Discard, rd)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFastestServerSelectionFunc(url string, dlLimit int64) SelectionFunc {
|
||||||
|
return func(ctx context.Context, dialers []ContextDialer) (ContextDialer, error) {
|
||||||
|
var resErr error
|
||||||
|
masterNotInterested := make(chan struct{})
|
||||||
|
defer close(masterNotInterested)
|
||||||
|
errors := make(chan error)
|
||||||
|
success := make(chan ContextDialer)
|
||||||
|
for _, dialer := range dialers {
|
||||||
|
go func(dialer ContextDialer) {
|
||||||
|
err := probeDialer(ctx, dialer, url, dlLimit)
|
||||||
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case success <- dialer:
|
||||||
|
case <-masterNotInterested:
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
select {
|
||||||
|
case errors <- err:
|
||||||
|
case <-masterNotInterested:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(dialer)
|
||||||
|
}
|
||||||
|
for _ = range dialers {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case d := <-success:
|
||||||
|
return d, nil
|
||||||
|
case err := <-errors:
|
||||||
|
resErr = multierror.Append(resErr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, resErr
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -81,31 +81,52 @@ func (a *CSVArg) Set(line string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type serverSelectionArg struct {
|
||||||
|
value dialer.ServerSelection
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *serverSelectionArg) Set(s string) error {
|
||||||
|
v, err := dialer.ParseServerSelection(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
a.value = v
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *serverSelectionArg) String() string {
|
||||||
|
return a.value.String()
|
||||||
|
}
|
||||||
|
|
||||||
type CLIArgs struct {
|
type CLIArgs struct {
|
||||||
country string
|
country string
|
||||||
listCountries bool
|
listCountries bool
|
||||||
listProxies bool
|
listProxies bool
|
||||||
bindAddress string
|
bindAddress string
|
||||||
socksMode bool
|
socksMode bool
|
||||||
verbosity int
|
verbosity int
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
showVersion bool
|
showVersion bool
|
||||||
proxy string
|
proxy string
|
||||||
apiLogin string
|
apiLogin string
|
||||||
apiPassword string
|
apiPassword string
|
||||||
apiAddress string
|
apiAddress string
|
||||||
apiClientType string
|
apiClientType string
|
||||||
apiClientVersion string
|
apiClientVersion string
|
||||||
apiUserAgent string
|
apiUserAgent string
|
||||||
bootstrapDNS *CSVArg
|
bootstrapDNS *CSVArg
|
||||||
refresh time.Duration
|
refresh time.Duration
|
||||||
refreshRetry time.Duration
|
refreshRetry time.Duration
|
||||||
initRetries int
|
initRetries int
|
||||||
initRetryInterval time.Duration
|
initRetryInterval time.Duration
|
||||||
certChainWorkaround bool
|
certChainWorkaround bool
|
||||||
caFile string
|
caFile string
|
||||||
fakeSNI string
|
fakeSNI string
|
||||||
overrideProxyAddress string
|
overrideProxyAddress string
|
||||||
|
serverSelection serverSelectionArg
|
||||||
|
serverSelectionTimeout time.Duration
|
||||||
|
serverSelectionTestURL string
|
||||||
|
serverSelectionDLLimit int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func parse_args() *CLIArgs {
|
func parse_args() *CLIArgs {
|
||||||
@@ -123,6 +144,7 @@ func parse_args() *CLIArgs {
|
|||||||
"https://doh.cleanbrowsing.org/doh/adult-filter/",
|
"https://doh.cleanbrowsing.org/doh/adult-filter/",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
serverSelection: serverSelectionArg{dialer.ServerSelectionFastest},
|
||||||
}
|
}
|
||||||
flag.StringVar(&args.country, "country", "EU", "desired proxy location")
|
flag.StringVar(&args.country, "country", "EU", "desired proxy location")
|
||||||
flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit")
|
flag.BoolVar(&args.listCountries, "list-countries", false, "list available countries and exit")
|
||||||
@@ -155,6 +177,11 @@ func parse_args() *CLIArgs {
|
|||||||
flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file")
|
flag.StringVar(&args.caFile, "cafile", "", "use custom CA certificate bundle file")
|
||||||
flag.StringVar(&args.fakeSNI, "fake-SNI", "", "domain name to use as SNI in communications with servers")
|
flag.StringVar(&args.fakeSNI, "fake-SNI", "", "domain name to use as SNI in communications with servers")
|
||||||
flag.StringVar(&args.overrideProxyAddress, "override-proxy-address", "", "use fixed proxy address instead of server address returned by SurfEasy API")
|
flag.StringVar(&args.overrideProxyAddress, "override-proxy-address", "", "use fixed proxy address instead of server address returned by SurfEasy API")
|
||||||
|
flag.Var(&args.serverSelection, "server-selection", "server selection policy (first/random/fastest)")
|
||||||
|
flag.DurationVar(&args.serverSelectionTimeout, "server-selection-timeout", 30*time.Second, "timeout given for server selection function to produce result")
|
||||||
|
flag.StringVar(&args.serverSelectionTestURL, "server-selection-test-url", "https://ajax.googleapis.com/ajax/libs/angularjs/1.8.2/angular.min.js",
|
||||||
|
"URL used for download benchmark by fastest server selection policy")
|
||||||
|
flag.Int64Var(&args.serverSelectionDLLimit, "server-selection-dl-limit", 0, "restrict amount of downloaded data per connection by fastest server selection")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
if args.country == "" {
|
if args.country == "" {
|
||||||
arg_fail("Country can't be empty string.")
|
arg_fail("Country can't be empty string.")
|
||||||
@@ -200,6 +227,20 @@ func run() int {
|
|||||||
KeepAlive: 30 * time.Second,
|
KeepAlive: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var caPool *x509.CertPool
|
||||||
|
if args.caFile != "" {
|
||||||
|
caPool = x509.NewCertPool()
|
||||||
|
certs, err := ioutil.ReadFile(args.caFile)
|
||||||
|
if err != nil {
|
||||||
|
mainLogger.Error("Can't load CA file: %v", err)
|
||||||
|
return 15
|
||||||
|
}
|
||||||
|
if ok := caPool.AppendCertsFromPEM(certs); !ok {
|
||||||
|
mainLogger.Error("Can't load certificates from CA file")
|
||||||
|
return 15
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if args.proxy != "" {
|
if args.proxy != "" {
|
||||||
xproxy.RegisterDialerType("http", proxyFromURLWrapper)
|
xproxy.RegisterDialerType("http", proxyFromURLWrapper)
|
||||||
xproxy.RegisterDialerType("https", proxyFromURLWrapper)
|
xproxy.RegisterDialerType("https", proxyFromURLWrapper)
|
||||||
@@ -218,7 +259,7 @@ func run() int {
|
|||||||
|
|
||||||
seclientDialer := d
|
seclientDialer := d
|
||||||
if args.apiAddress != "" {
|
if args.apiAddress != "" {
|
||||||
mainLogger.Info("Using fixed API host IP address = %s", args.apiAddress)
|
mainLogger.Info("Using fixed API host address = %s", args.apiAddress)
|
||||||
seclientDialer = dialer.NewFixedDialer(args.apiAddress, d)
|
seclientDialer = dialer.NewFixedDialer(args.apiAddress, d)
|
||||||
} else if len(args.bootstrapDNS.values) > 0 {
|
} else if len(args.bootstrapDNS.values) > 0 {
|
||||||
resolver, err := resolver.FastFromURLs(args.bootstrapDNS.values...)
|
resolver, err := resolver.FastFromURLs(args.bootstrapDNS.values...)
|
||||||
@@ -282,28 +323,74 @@ func run() int {
|
|||||||
return printCountries(try, mainLogger, args.timeout, seclient)
|
return printCountries(try, mainLogger, args.timeout, seclient)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handlerDialerFactory := func(endpointAddr string) dialer.ContextDialer {
|
||||||
|
return dialer.NewProxyDialer(
|
||||||
|
dialer.WrapStringToCb(endpointAddr),
|
||||||
|
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
|
||||||
|
dialer.WrapStringToCb(args.fakeSNI),
|
||||||
|
func() (string, error) {
|
||||||
|
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
|
||||||
|
},
|
||||||
|
args.certChainWorkaround,
|
||||||
|
caPool,
|
||||||
|
d)
|
||||||
|
}
|
||||||
|
|
||||||
var ips []se.SEIPEntry
|
var ips []se.SEIPEntry
|
||||||
err = try("discover", func() error {
|
var handlerDialer dialer.ContextDialer
|
||||||
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
|
|
||||||
defer cl()
|
if args.overrideProxyAddress == "" || args.listProxies {
|
||||||
// TODO: learn about requested_geo value format
|
err = try("discover", func() error {
|
||||||
res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country))
|
ctx, cl := context.WithTimeout(context.Background(), args.timeout)
|
||||||
ips = res
|
defer cl()
|
||||||
return err
|
// TODO: learn about requested_geo value format
|
||||||
})
|
res, err := seclient.Discover(ctx, fmt.Sprintf("\"%s\",,", args.country))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 12
|
return err
|
||||||
|
}
|
||||||
|
if len(res) == 0 {
|
||||||
|
return errors.New("empty endpoints list!")
|
||||||
|
}
|
||||||
|
if args.listProxies {
|
||||||
|
ips = res
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var ss dialer.SelectionFunc
|
||||||
|
switch args.serverSelection.value {
|
||||||
|
case dialer.ServerSelectionFirst:
|
||||||
|
ss = dialer.SelectFirst
|
||||||
|
case dialer.ServerSelectionRandom:
|
||||||
|
ss = dialer.SelectRandom
|
||||||
|
case dialer.ServerSelectionFastest:
|
||||||
|
ss = dialer.NewFastestServerSelectionFunc(
|
||||||
|
args.serverSelectionTestURL,
|
||||||
|
args.serverSelectionDLLimit,
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
panic("unhandled server selection value got past parsing")
|
||||||
|
}
|
||||||
|
dialers := make([]dialer.ContextDialer, len(res))
|
||||||
|
for i, ep := range res {
|
||||||
|
dialers[i] = handlerDialerFactory(ep.NetAddr())
|
||||||
|
}
|
||||||
|
ctx, cl = context.WithTimeout(context.Background(), args.serverSelectionTimeout)
|
||||||
|
defer cl()
|
||||||
|
handlerDialer, err = ss(ctx, dialers)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 12
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sanitizedEndpoint := sanitizeFixedProxyAddress(args.overrideProxyAddress)
|
||||||
|
handlerDialer = handlerDialerFactory(sanitizedEndpoint)
|
||||||
|
mainLogger.Info("Endpoint override: %s", sanitizedEndpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.listProxies {
|
if args.listProxies {
|
||||||
return printProxies(ips, seclient)
|
return printProxies(ips, seclient)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ips) == 0 {
|
|
||||||
mainLogger.Critical("Empty endpoint list!")
|
|
||||||
return 13
|
|
||||||
}
|
|
||||||
|
|
||||||
clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
|
clock.RunTicker(context.Background(), args.refresh, args.refreshRetry, func(ctx context.Context) error {
|
||||||
mainLogger.Info("Refreshing login...")
|
mainLogger.Info("Refreshing login...")
|
||||||
reqCtx, cl := context.WithTimeout(ctx, args.timeout)
|
reqCtx, cl := context.WithTimeout(ctx, args.timeout)
|
||||||
@@ -327,40 +414,6 @@ func run() int {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
endpoint := ips[0]
|
|
||||||
|
|
||||||
var caPool *x509.CertPool
|
|
||||||
if args.caFile != "" {
|
|
||||||
caPool = x509.NewCertPool()
|
|
||||||
certs, err := ioutil.ReadFile(args.caFile)
|
|
||||||
if err != nil {
|
|
||||||
mainLogger.Error("Can't load CA file: %v", err)
|
|
||||||
return 15
|
|
||||||
}
|
|
||||||
if ok := caPool.AppendCertsFromPEM(certs); !ok {
|
|
||||||
mainLogger.Error("Can't load certificates from CA file")
|
|
||||||
return 15
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var handlerBaseDialer = d
|
|
||||||
if args.overrideProxyAddress != "" {
|
|
||||||
mainLogger.Info("Original endpoint: %s", endpoint.IP)
|
|
||||||
handlerBaseDialer = dialer.NewFixedDialer(args.overrideProxyAddress, handlerBaseDialer)
|
|
||||||
mainLogger.Info("Endpoint override: %s", args.overrideProxyAddress)
|
|
||||||
} else {
|
|
||||||
mainLogger.Info("Endpoint: %s", endpoint.NetAddr())
|
|
||||||
}
|
|
||||||
handlerDialer := dialer.NewProxyDialer(
|
|
||||||
dialer.WrapStringToCb(endpoint.NetAddr()),
|
|
||||||
dialer.WrapStringToCb(fmt.Sprintf("%s0.%s", args.country, PROXY_SUFFIX)),
|
|
||||||
dialer.WrapStringToCb(args.fakeSNI),
|
|
||||||
func() (string, error) {
|
|
||||||
return dialer.BasicAuthHeader(seclient.GetProxyCredentials()), nil
|
|
||||||
},
|
|
||||||
args.certChainWorkaround,
|
|
||||||
caPool,
|
|
||||||
handlerBaseDialer)
|
|
||||||
mainLogger.Info("Starting proxy server...")
|
mainLogger.Info("Starting proxy server...")
|
||||||
if args.socksMode {
|
if args.socksMode {
|
||||||
socks, initError := handler.NewSocksServer(handlerDialer, socksLogger)
|
socks, initError := handler.NewSocksServer(handlerDialer, socksLogger)
|
||||||
@@ -423,6 +476,13 @@ func printProxies(ips []se.SEIPEntry, seclient *se.SEClient) int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitizeFixedProxyAddress(addr string) string {
|
||||||
|
if _, _, err := net.SplitHostPort(addr); err == nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
return net.JoinHostPort(addr, "443")
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
os.Exit(run())
|
os.Exit(run())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user