diff --git a/main.go b/main.go index 3c2f6b0..7be846d 100644 --- a/main.go +++ b/main.go @@ -54,6 +54,7 @@ type CLIArgs struct { apiPassword string apiAddress string bootstrapDNS string + refresh time.Duration } func parse_args() CLIArgs { @@ -76,6 +77,7 @@ func parse_args() CLIArgs { "DNS/DoH/DoT/DoQ resolver for initial discovering of SurfEasy API address. "+ "See https://github.com/ameshkov/dnslookup/ for upstream DNS URL format. "+ "Examples: https://1.1.1.1/dns-query, quic://dns.adguard.com") + flag.DurationVar(&args.refresh, "refresh", 4*time.Hour, "login refresh interval") flag.Parse() if args.country == "" { arg_fail("Country can't be empty string.") @@ -206,6 +208,18 @@ func run() int { } cl() + runTicker(context.Background(), args.refresh, func (ctx context.Context) { + mainLogger.Info("Refreshing login...") + loginCtx, cl := context.WithTimeout(ctx, args.timeout) + defer cl() + err := seclient.Login(loginCtx) + if err != nil { + mainLogger.Critical("Login refresh failed: %v", err) + return + } + mainLogger.Info("Login refreshed.") + }) + if args.listCountries { return printCountries(mainLogger, args.timeout, seclient) } diff --git a/utils.go b/utils.go index 17b1e65..8cbeafc 100644 --- a/utils.go +++ b/utils.go @@ -12,7 +12,10 @@ import ( "time" ) -const COPY_BUF = 128 * 1024 +const ( + COPY_BUF = 128 * 1024 + WALLCLOCK_PRECISION = 1 * time.Second +) func basic_auth_header(login, password string) string { return "Basic " + base64.StdEncoding.EncodeToString( @@ -143,3 +146,40 @@ func copyBody(wr io.Writer, body io.Reader) { } } } + +func AfterWallClock(d time.Duration) <-chan time.Time { + ch := make(chan time.Time, 1) + deadline := time.Now().Add(d).Truncate(0) + after_ch := time.After(d) + ticker := time.NewTicker(WALLCLOCK_PRECISION) + go func() { + var t time.Time + defer ticker.Stop() + for { + select { + case t = <-after_ch: + ch <-t + return + case t = <-ticker.C: + if t.After(deadline) { + ch <-t + return + } + } + } + }() + return ch +} + +func runTicker(ctx context.Context, interval time.Duration, cb func (context.Context)) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-AfterWallClock(interval): + cb(ctx) + } + } + }() +}