do relogin periodically to refresh account on API side

This commit is contained in:
Vladislav Yarmak
2021-03-30 16:29:54 +03:00
parent 397bf28ac6
commit dfea7e62bc
2 changed files with 55 additions and 1 deletions
+14
View File
@@ -54,6 +54,7 @@ type CLIArgs struct {
apiPassword string apiPassword string
apiAddress string apiAddress string
bootstrapDNS string bootstrapDNS string
refresh time.Duration
} }
func parse_args() CLIArgs { func parse_args() CLIArgs {
@@ -76,6 +77,7 @@ func parse_args() CLIArgs {
"DNS/DoH/DoT/DoQ resolver for initial discovering of SurfEasy API address. "+ "DNS/DoH/DoT/DoQ resolver for initial discovering of SurfEasy API address. "+
"See https://github.com/ameshkov/dnslookup/ for upstream DNS URL format. "+ "See https://github.com/ameshkov/dnslookup/ for upstream DNS URL format. "+
"Examples: https://1.1.1.1/dns-query, quic://dns.adguard.com") "Examples: https://1.1.1.1/dns-query, quic://dns.adguard.com")
flag.DurationVar(&args.refresh, "refresh", 4*time.Hour, "login refresh interval")
flag.Parse() flag.Parse()
if args.country == "" { if args.country == "" {
arg_fail("Country can't be empty string.") arg_fail("Country can't be empty string.")
@@ -206,6 +208,18 @@ func run() int {
} }
cl() cl()
runTicker(context.Background(), args.refresh, func (ctx context.Context) {
mainLogger.Info("Refreshing login...")
loginCtx, cl := context.WithTimeout(ctx, args.timeout)
defer cl()
err := seclient.Login(loginCtx)
if err != nil {
mainLogger.Critical("Login refresh failed: %v", err)
return
}
mainLogger.Info("Login refreshed.")
})
if args.listCountries { if args.listCountries {
return printCountries(mainLogger, args.timeout, seclient) return printCountries(mainLogger, args.timeout, seclient)
} }
+41 -1
View File
@@ -12,7 +12,10 @@ import (
"time" "time"
) )
const COPY_BUF = 128 * 1024 const (
COPY_BUF = 128 * 1024
WALLCLOCK_PRECISION = 1 * time.Second
)
func basic_auth_header(login, password string) string { func basic_auth_header(login, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString( return "Basic " + base64.StdEncoding.EncodeToString(
@@ -143,3 +146,40 @@ func copyBody(wr io.Writer, body io.Reader) {
} }
} }
} }
func AfterWallClock(d time.Duration) <-chan time.Time {
ch := make(chan time.Time, 1)
deadline := time.Now().Add(d).Truncate(0)
after_ch := time.After(d)
ticker := time.NewTicker(WALLCLOCK_PRECISION)
go func() {
var t time.Time
defer ticker.Stop()
for {
select {
case t = <-after_ch:
ch <-t
return
case t = <-ticker.C:
if t.After(deadline) {
ch <-t
return
}
}
}
}()
return ch
}
func runTicker(ctx context.Context, interval time.Duration, cb func (context.Context)) {
go func() {
for {
select {
case <-ctx.Done():
return
case <-AfterWallClock(interval):
cb(ctx)
}
}
}()
}