diff --git a/api/private/v1/configtype.go b/api/private/v1/configtype.go index 5d6f45b..05a888f 100644 --- a/api/private/v1/configtype.go +++ b/api/private/v1/configtype.go @@ -3,6 +3,7 @@ package v1 type WaitConfig struct { Endpoint string `json:"endpoint"` Duration string `json:"duration"` + DurationAutoExtend string `json:"duration_auto_extend"` AuthorizedKeys []string `json:"authorized_keys"` AuthorizedGithubUsers []string `json:"authorized_github_users"` Shell []string `json:"shell"` diff --git a/cmd/breakpoint/wait.go b/cmd/breakpoint/wait.go index 9193207..e677a79 100644 --- a/cmd/breakpoint/wait.go +++ b/cmd/breakpoint/wait.go @@ -75,6 +75,11 @@ func newWaitCmd() *cobra.Command { _, _ = w.Write(ww.Bytes()) }, + WriteNotify: func() { + if cfg.ParsedDurationAutoExtend > 0 { + mgr.ExtendWait(cfg.ParsedDurationAutoExtend, false) + } + }, }) if err != nil { return err diff --git a/pkg/config/config.go b/pkg/config/config.go index 4ad6d68..d809785 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -75,9 +75,19 @@ func LoadConfig(ctx context.Context, file string) (ParsedConfig, error) { if err != nil { return cfg, err } - cfg.ParsedDuration = dur + if cfg.DurationAutoExtend != "" { + dur, err = time.ParseDuration(cfg.DurationAutoExtend) + if err != nil { + return cfg, err + } + if dur < 1*time.Minute { + dur = 1 * time.Minute + } + cfg.ParsedDurationAutoExtend = dur + } + keyMap, err := github.ResolveSSHKeys(ctx, cfg.AuthorizedGithubUsers) if err != nil { return cfg, err @@ -102,7 +112,8 @@ func LoadConfig(ctx context.Context, file string) (ParsedConfig, error) { type ParsedConfig struct { internalv1.WaitConfig - AllKeys map[string]string // Key ID -> Owned name - ParsedDuration time.Duration - RegisterMetadata metadata.MD + AllKeys map[string]string // Key ID -> Owned name + ParsedDuration time.Duration + ParsedDurationAutoExtend time.Duration + RegisterMetadata metadata.MD } diff --git a/pkg/internalserver/internalserver.go b/pkg/internalserver/internalserver.go index 5371c35..893f605 100644 --- a/pkg/internalserver/internalserver.go +++ b/pkg/internalserver/internalserver.go @@ -64,7 +64,7 @@ func ListenAndServe(ctx context.Context, mgr *waiter.Manager) error { } func (g waiterService) Extend(ctx context.Context, req *pb.ExtendRequest) (*pb.ExtendResponse, error) { - expiration := g.manager.ExtendWait(req.WaitFor.AsDuration()) + expiration := g.manager.ExtendWait(req.WaitFor.AsDuration(), true) return &pb.ExtendResponse{ Expiration: timestamppb.New(expiration), }, nil diff --git a/pkg/sshd/sshd.go b/pkg/sshd/sshd.go index 0ed9f9c..30087c4 100644 --- a/pkg/sshd/sshd.go +++ b/pkg/sshd/sshd.go @@ -27,6 +27,7 @@ type SSHServerOpts struct { Dir string InteractiveMOTD func(io.Writer) + WriteNotify func() } type sshKey struct { @@ -94,23 +95,26 @@ func MakeServer(ctx context.Context, opts SSHServerOpts) (*SSHServer, error) { // Make sure that the connection with the client is kept alive. go keepAlive(ctx, sessionLog, session) + // Wrapping the session lets us know when writes are happening. + nsess := newNotifyingSession(ctx, session, opts.WriteNotify) + if isPty { // Print MOTD only if no command was provided - if opts.InteractiveMOTD != nil && session.RawCommand() == "" { - opts.InteractiveMOTD(session) + if opts.InteractiveMOTD != nil && nsess.RawCommand() == "" { + opts.InteractiveMOTD(nsess) } - if err := handlePty(session, ptyReq, winCh, cmd); err != nil { + if err := handlePty(nsess, ptyReq, winCh, cmd); err != nil { sessionLog.Err(err).Msg("pty start failed") - session.Exit(1) + nsess.Exit(1) return } } else { - cmd.Stdout = session - cmd.Stderr = session + cmd.Stdout = nsess + cmd.Stderr = nsess if err := cmd.Start(); err != nil { sessionLog.Err(err).Msg("start failed") - session.Exit(1) + nsess.Exit(1) return } } @@ -182,3 +186,49 @@ func lookupKey(allowed []sshKey, key ssh.PublicKey) (sshKey, bool) { } return sshKey{}, false } + +type notifyingSession struct { + ssh.Session + notifyCh chan struct{} + notify func() +} + +func newNotifyingSession(ctx context.Context, s ssh.Session, notify func()) ssh.Session { + if notify == nil { + return s + } + + sess := notifyingSession{ + Session: s, + notifyCh: make(chan struct{}), + notify: notify, + } + go sess.listen(ctx) + return sess +} + +func (s notifyingSession) listen(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-s.notifyCh: + } + + s.notify() + + select { + case <-ctx.Done(): + return + case <-time.After(1 * time.Second): + } + } +} + +func (s notifyingSession) Write(p []byte) (int, error) { + select { + case s.notifyCh <- struct{}{}: + default: // avoid blocking + } + return s.Session.Write(p) +} diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 5f689c0..9fa5787 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -35,6 +35,10 @@ type ManagerStatus struct { NumConnections uint32 `json:"num_connections"` } +type update struct { + announce bool +} + type Manager struct { ctx context.Context logger zerolog.Logger @@ -42,7 +46,7 @@ type Manager struct { opts ManagerOpts mu sync.Mutex - updated chan struct{} + updated chan update expiration time.Time endpoint string resources []io.Closer @@ -56,7 +60,7 @@ func NewManager(ctx context.Context, opts ManagerOpts) (*Manager, context.Contex ctx: ctx, logger: l, opts: opts, - updated: make(chan struct{}, 1), + updated: make(chan update, 1), expiration: time.Now().Add(opts.InitialDur), } @@ -95,7 +99,7 @@ func (m *Manager) loop(ctx context.Context) { for { select { - case _, ok := <-m.updated: + case update, ok := <-m.updated: if !ok { return } @@ -105,7 +109,9 @@ func (m *Manager) loop(ctx context.Context) { m.mu.Unlock() exitTimer.Reset(time.Until(newExp)) - m.announce() + if update.announce { + m.announce() + } case <-exitTimer.C: // Timer has expired, terminate the program @@ -130,13 +136,13 @@ func logTick() time.Duration { return math.MaxInt64 } -func (m *Manager) ExtendWait(dur time.Duration) time.Time { +func (m *Manager) ExtendWait(dur time.Duration, announce bool) time.Time { m.mu.Lock() defer m.mu.Unlock() m.expiration = m.expiration.Add(dur) - m.updated <- struct{}{} + m.updated <- update{announce: announce} m.logger.Info(). Dur("dur", dur). @@ -188,7 +194,7 @@ func (m *Manager) SetEndpoint(addr string) { m.resources = resources m.mu.Unlock() - m.updated <- struct{}{} + m.updated <- update{announce: true} expandf := expand(addr, m.Expiration())