Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/private/v1/configtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package v1
type WaitConfig struct {
Endpoint string `json:"endpoint"`
Duration string `json:"duration"`
DurationAutoExtend string `json:"duration_auto_extend"`
AuthorizedKeys []string `json:"authorized_keys"`
AuthorizedGithubUsers []string `json:"authorized_github_users"`
Shell []string `json:"shell"`
Expand Down
5 changes: 5 additions & 0 deletions cmd/breakpoint/wait.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ func newWaitCmd() *cobra.Command {

_, _ = w.Write(ww.Bytes())
},
WriteNotify: func() {
if cfg.ParsedDurationAutoExtend > 0 {
mgr.ExtendWait(cfg.ParsedDurationAutoExtend, false)
}
},
})
if err != nil {
return err
Expand Down
19 changes: 15 additions & 4 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,19 @@ func LoadConfig(ctx context.Context, file string) (ParsedConfig, error) {
if err != nil {
return cfg, err
}

cfg.ParsedDuration = dur

if cfg.DurationAutoExtend != "" {
dur, err = time.ParseDuration(cfg.DurationAutoExtend)
if err != nil {
return cfg, err
}
if dur < 1*time.Minute {
dur = 1 * time.Minute
}
cfg.ParsedDurationAutoExtend = dur
}

keyMap, err := github.ResolveSSHKeys(ctx, cfg.AuthorizedGithubUsers)
if err != nil {
return cfg, err
Expand All @@ -102,7 +112,8 @@ func LoadConfig(ctx context.Context, file string) (ParsedConfig, error) {
type ParsedConfig struct {
internalv1.WaitConfig

AllKeys map[string]string // Key ID -> Owned name
ParsedDuration time.Duration
RegisterMetadata metadata.MD
AllKeys map[string]string // Key ID -> Owned name
ParsedDuration time.Duration
ParsedDurationAutoExtend time.Duration
RegisterMetadata metadata.MD
}
2 changes: 1 addition & 1 deletion pkg/internalserver/internalserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func ListenAndServe(ctx context.Context, mgr *waiter.Manager) error {
}

func (g waiterService) Extend(ctx context.Context, req *pb.ExtendRequest) (*pb.ExtendResponse, error) {
expiration := g.manager.ExtendWait(req.WaitFor.AsDuration())
expiration := g.manager.ExtendWait(req.WaitFor.AsDuration(), true)
return &pb.ExtendResponse{
Expiration: timestamppb.New(expiration),
}, nil
Expand Down
64 changes: 57 additions & 7 deletions pkg/sshd/sshd.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type SSHServerOpts struct {
Dir string

InteractiveMOTD func(io.Writer)
WriteNotify func()
}

type sshKey struct {
Expand Down Expand Up @@ -94,23 +95,26 @@ func MakeServer(ctx context.Context, opts SSHServerOpts) (*SSHServer, error) {
// Make sure that the connection with the client is kept alive.
go keepAlive(ctx, sessionLog, session)

// Wrapping the session lets us know when writes are happening.
nsess := newNotifyingSession(ctx, session, opts.WriteNotify)

if isPty {
// Print MOTD only if no command was provided
if opts.InteractiveMOTD != nil && session.RawCommand() == "" {
opts.InteractiveMOTD(session)
if opts.InteractiveMOTD != nil && nsess.RawCommand() == "" {
opts.InteractiveMOTD(nsess)
}

if err := handlePty(session, ptyReq, winCh, cmd); err != nil {
if err := handlePty(nsess, ptyReq, winCh, cmd); err != nil {
sessionLog.Err(err).Msg("pty start failed")
session.Exit(1)
nsess.Exit(1)
return
}
} else {
cmd.Stdout = session
cmd.Stderr = session
cmd.Stdout = nsess
cmd.Stderr = nsess
if err := cmd.Start(); err != nil {
sessionLog.Err(err).Msg("start failed")
session.Exit(1)
nsess.Exit(1)
return
}
}
Expand Down Expand Up @@ -182,3 +186,49 @@ func lookupKey(allowed []sshKey, key ssh.PublicKey) (sshKey, bool) {
}
return sshKey{}, false
}

type notifyingSession struct {
ssh.Session
notifyCh chan struct{}
notify func()
}

func newNotifyingSession(ctx context.Context, s ssh.Session, notify func()) ssh.Session {
if notify == nil {
return s
}

sess := notifyingSession{
Session: s,
notifyCh: make(chan struct{}),
notify: notify,
}
go sess.listen(ctx)
return sess
}

func (s notifyingSession) listen(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-s.notifyCh:
}

s.notify()

select {
case <-ctx.Done():
return
case <-time.After(1 * time.Second):
}
}
}

func (s notifyingSession) Write(p []byte) (int, error) {
select {
case s.notifyCh <- struct{}{}:
default: // avoid blocking
}
return s.Session.Write(p)
}
20 changes: 13 additions & 7 deletions pkg/waiter/waiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@ type ManagerStatus struct {
NumConnections uint32 `json:"num_connections"`
}

type update struct {
announce bool
}

type Manager struct {
ctx context.Context
logger zerolog.Logger

opts ManagerOpts

mu sync.Mutex
updated chan struct{}
updated chan update
expiration time.Time
endpoint string
resources []io.Closer
Expand All @@ -56,7 +60,7 @@ func NewManager(ctx context.Context, opts ManagerOpts) (*Manager, context.Contex
ctx: ctx,
logger: l,
opts: opts,
updated: make(chan struct{}, 1),
updated: make(chan update, 1),
expiration: time.Now().Add(opts.InitialDur),
}

Expand Down Expand Up @@ -95,7 +99,7 @@ func (m *Manager) loop(ctx context.Context) {

for {
select {
case _, ok := <-m.updated:
case update, ok := <-m.updated:
if !ok {
return
}
Expand All @@ -105,7 +109,9 @@ func (m *Manager) loop(ctx context.Context) {
m.mu.Unlock()

exitTimer.Reset(time.Until(newExp))
m.announce()
if update.announce {
m.announce()
}

case <-exitTimer.C:
// Timer has expired, terminate the program
Expand All @@ -130,13 +136,13 @@ func logTick() time.Duration {
return math.MaxInt64
}

func (m *Manager) ExtendWait(dur time.Duration) time.Time {
func (m *Manager) ExtendWait(dur time.Duration, announce bool) time.Time {
m.mu.Lock()
defer m.mu.Unlock()

m.expiration = m.expiration.Add(dur)

m.updated <- struct{}{}
m.updated <- update{announce: announce}

m.logger.Info().
Dur("dur", dur).
Expand Down Expand Up @@ -188,7 +194,7 @@ func (m *Manager) SetEndpoint(addr string) {
m.resources = resources
m.mu.Unlock()

m.updated <- struct{}{}
m.updated <- update{announce: true}

expandf := expand(addr, m.Expiration())

Expand Down