Skip to content

Commit c319ff1

Browse files
committed
1 parent 242caa6 commit c319ff1

File tree

6 files changed

+66
-70
lines changed

6 files changed

+66
-70
lines changed

connector.go broker.go

+12-12
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@ import (
1010
"github.com/antoniomika/syncmap"
1111
)
1212

13-
type Connector interface {
13+
type Broker interface {
1414
GetChannels() iter.Seq2[string, *Channel]
1515
GetClients() iter.Seq2[string, *Client]
1616
Connect(*Client, []*Channel) (error, error)
1717
}
1818

19-
type BaseConnector struct {
19+
type BaseBroker struct {
2020
Channels *syncmap.Map[string, *Channel]
2121
}
2222

23-
func (b *BaseConnector) Cleanup() {
23+
func (b *BaseBroker) Cleanup() {
2424
toRemove := []string{}
2525
for _, channel := range b.GetChannels() {
2626
count := 0
@@ -31,7 +31,7 @@ func (b *BaseConnector) Cleanup() {
3131

3232
if count == 0 {
3333
channel.Cleanup()
34-
toRemove = append(toRemove, channel.ID)
34+
toRemove = append(toRemove, channel.Topic)
3535
}
3636
}
3737

@@ -40,25 +40,25 @@ func (b *BaseConnector) Cleanup() {
4040
}
4141
}
4242

43-
func (b *BaseConnector) GetChannels() iter.Seq2[string, *Channel] {
43+
func (b *BaseBroker) GetChannels() iter.Seq2[string, *Channel] {
4444
return b.Channels.Range
4545
}
4646

47-
func (b *BaseConnector) GetClients() iter.Seq2[string, *Client] {
47+
func (b *BaseBroker) GetClients() iter.Seq2[string, *Client] {
4848
return func(yield func(string, *Client) bool) {
4949
for _, channel := range b.GetChannels() {
5050
channel.Clients.Range(yield)
5151
}
5252
}
5353
}
5454

55-
func (b *BaseConnector) Connect(client *Client, channels []*Channel) (error, error) {
55+
func (b *BaseBroker) Connect(client *Client, channels []*Channel) (error, error) {
5656
for _, channel := range channels {
5757
dataChannel := b.ensureChannel(channel)
5858
dataChannel.Clients.Store(client.ID, client)
59-
client.Channels.Store(dataChannel.ID, dataChannel)
59+
client.Channels.Store(dataChannel.Topic, dataChannel)
6060
defer func() {
61-
client.Channels.Delete(channel.ID)
61+
client.Channels.Delete(channel.Topic)
6262
dataChannel.Clients.Delete(client.ID)
6363

6464
client.Cleanup()
@@ -186,10 +186,10 @@ func (b *BaseConnector) Connect(client *Client, channels []*Channel) (error, err
186186
return inputErr, outputErr
187187
}
188188

189-
func (b *BaseConnector) ensureChannel(channel *Channel) *Channel {
190-
dataChannel, _ := b.Channels.LoadOrStore(channel.ID, channel)
189+
func (b *BaseBroker) ensureChannel(channel *Channel) *Channel {
190+
dataChannel, _ := b.Channels.LoadOrStore(channel.Topic, channel)
191191
dataChannel.Handle()
192192
return dataChannel
193193
}
194194

195-
var _ Connector = (*BaseConnector)(nil)
195+
var _ Broker = (*BaseBroker)(nil)

channel.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ type ChannelMessage struct {
3737
Action ChannelAction
3838
}
3939

40-
func NewChannel(name string) *Channel {
40+
func NewChannel(topic string) *Channel {
4141
return &Channel{
42-
ID: name,
42+
Topic: topic,
4343
Done: make(chan struct{}),
4444
Data: make(chan ChannelMessage),
4545
Clients: syncmap.New[string, *Client](),
4646
}
4747
}
4848

4949
type Channel struct {
50-
ID string
50+
Topic string
5151
Done chan struct{}
5252
Data chan ChannelMessage
5353
Clients *syncmap.Map[string, *Client]

cmd/authorized_keys/main.go

+22-31
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"syscall"
1313
"time"
1414

15-
"github.com/antoniomika/syncmap"
1615
"github.com/charmbracelet/ssh"
1716
"github.com/charmbracelet/wish"
1817
"github.com/google/uuid"
@@ -26,66 +25,66 @@ func GetEnv(key string, defaultVal string) string {
2625
return defaultVal
2726
}
2827

29-
func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {
28+
func PubSubMiddleware(broker pubsub.PubSub, logger *slog.Logger) wish.Middleware {
3029
return func(next ssh.Handler) ssh.Handler {
3130
return func(sesh ssh.Session) {
3231
args := sesh.Command()
3332
if len(args) < 2 {
34-
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
33+
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {topic}")
3534
next(sesh)
3635
return
3736
}
3837

3938
cmd := strings.TrimSpace(args[0])
40-
channel := args[1]
39+
topicsRaw := args[1]
4140

42-
channels := strings.Split(channel, ",")
41+
topics := strings.Split(topicsRaw, ",")
4342

44-
logger := cfg.Logger.With(
43+
logger := logger.With(
4544
"cmd", cmd,
46-
"channel", channels,
45+
"topics", topics,
4746
)
4847

4948
logger.Info("running cli")
5049

5150
if cmd == "help" {
52-
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
51+
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {topic}")
5352
} else if cmd == "sub" {
5453
var chans []*pubsub.Channel
5554

56-
for _, c := range channels {
57-
chans = append(chans, pubsub.NewChannel(c))
55+
for _, topic := range topics {
56+
chans = append(chans, pubsub.NewChannel(topic))
5857
}
5958

6059
clientID := uuid.NewString()
6160

62-
err := errors.Join(cfg.PubSub.Sub(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "keepalive"))
61+
err := errors.Join(broker.Sub(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "keepalive"))
6362
if err != nil {
6463
logger.Error("error during pub", slog.Any("error", err), slog.String("client", clientID))
6564
}
6665
} else if cmd == "pub" {
6766
var chans []*pubsub.Channel
6867

69-
for _, c := range channels {
70-
chans = append(chans, pubsub.NewChannel(c))
68+
for _, topic := range topics {
69+
chans = append(chans, pubsub.NewChannel(topic))
7170
}
7271

7372
clientID := uuid.NewString()
7473

75-
err := errors.Join(cfg.PubSub.Pub(sesh.Context(), clientID, sesh, chans))
74+
err := errors.Join(broker.Pub(sesh.Context(), clientID, sesh, chans))
7675
if err != nil {
7776
logger.Error("error during pub", slog.Any("error", err), slog.String("client", clientID))
7877
}
7978
} else if cmd == "pipe" {
8079
var chans []*pubsub.Channel
8180

82-
for _, c := range channels {
83-
chans = append(chans, pubsub.NewChannel(c))
81+
for _, topics := range topics {
82+
chans = append(chans, pubsub.NewChannel(topics))
8483
}
8584

8685
clientID := uuid.NewString()
8786

88-
err := errors.Join(cfg.PubSub.Pipe(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "replay"))
87+
err := errors.Join(broker.Pipe(sesh.Context(), clientID, sesh, chans, args[len(args)-1] == "replay"))
8988
if err != nil {
9089
logger.Error(
9190
"pipe error",
@@ -94,7 +93,7 @@ func PubSubMiddleware(cfg *pubsub.Cfg) wish.Middleware {
9493
)
9594
}
9695
} else {
97-
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {channel}")
96+
wish.Println(sesh, "USAGE: ssh send.pico.sh (sub|pub|pipe) {topic}")
9897
}
9998

10099
next(sesh)
@@ -107,23 +106,15 @@ func main() {
107106
host := GetEnv("SSH_HOST", "0.0.0.0")
108107
port := GetEnv("SSH_PORT", "2222")
109108
keyPath := GetEnv("SSH_AUTHORIZED_KEYS", "./ssh_data/authorized_keys")
110-
cfg := &pubsub.Cfg{
111-
Logger: logger,
112-
PubSub: &pubsub.PubSubMulticast{
113-
Logger: logger,
114-
Connector: &pubsub.BaseConnector{
115-
Channels: syncmap.New[string, *pubsub.Channel](),
116-
},
117-
},
118-
}
109+
broker := pubsub.NewMulticast(logger)
119110

120111
s, err := wish.NewServer(
121112
ssh.NoPty(),
122113
wish.WithAddress(fmt.Sprintf("%s:%s", host, port)),
123114
wish.WithHostKeyPath("ssh_data/term_info_ed25519"),
124115
wish.WithAuthorizedKeys(keyPath),
125116
wish.WithMiddleware(
126-
PubSubMiddleware(cfg),
117+
PubSubMiddleware(broker, logger),
127118
),
128119
)
129120
if err != nil {
@@ -149,10 +140,10 @@ func main() {
149140
slog.Info("Debug Info", slog.Int("goroutines", runtime.NumGoroutine()))
150141
select {
151142
case <-time.After(5 * time.Second):
152-
for _, channel := range cfg.PubSub.GetChannels() {
153-
slog.Info("channel online", slog.Any("channel", channel.ID))
143+
for _, channel := range broker.GetChannels() {
144+
slog.Info("channel online", slog.Any("channel topic", channel.Topic))
154145
for _, client := range channel.GetClients() {
155-
slog.Info("client online", slog.Any("channel", channel.ID), slog.Any("client", client.ID), slog.String("direction", client.Direction.String()))
146+
slog.Info("client online", slog.Any("channel topic", channel.Topic), slog.Any("client", client.ID), slog.String("direction", client.Direction.String()))
156147
}
157148
}
158149
case <-done:

multicast.go

+22-11
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,25 @@ import (
66
"io"
77
"iter"
88
"log/slog"
9+
10+
"github.com/antoniomika/syncmap"
911
)
1012

11-
type PubSubMulticast struct {
12-
Connector
13+
type Multicast struct {
14+
Broker
1315
Logger *slog.Logger
1416
}
1517

16-
func (p *PubSubMulticast) getClients(direction ChannelDirection) iter.Seq2[string, *Client] {
18+
func NewMulticast(logger *slog.Logger) *Multicast {
19+
return &Multicast{
20+
Logger: logger,
21+
Broker: &BaseBroker{
22+
Channels: syncmap.New[string, *Channel](),
23+
},
24+
}
25+
}
26+
27+
func (p *Multicast) getClients(direction ChannelDirection) iter.Seq2[string, *Client] {
1728
return func(yield func(string, *Client) bool) {
1829
for clientID, client := range p.GetClients() {
1930
if client.Direction == direction {
@@ -23,19 +34,19 @@ func (p *PubSubMulticast) getClients(direction ChannelDirection) iter.Seq2[strin
2334
}
2435
}
2536

26-
func (p *PubSubMulticast) GetPipes() iter.Seq2[string, *Client] {
37+
func (p *Multicast) GetPipes() iter.Seq2[string, *Client] {
2738
return p.getClients(ChannelDirectionInputOutput)
2839
}
2940

30-
func (p *PubSubMulticast) GetPubs() iter.Seq2[string, *Client] {
41+
func (p *Multicast) GetPubs() iter.Seq2[string, *Client] {
3142
return p.getClients(ChannelDirectionInput)
3243
}
3344

34-
func (p *PubSubMulticast) GetSubs() iter.Seq2[string, *Client] {
45+
func (p *Multicast) GetSubs() iter.Seq2[string, *Client] {
3546
return p.getClients(ChannelDirectionOutput)
3647
}
3748

38-
func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) {
49+
func (p *Multicast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) {
3950
client := NewClient(ID, rw, direction, blockWrite, replay, keepAlive)
4051

4152
go func() {
@@ -46,16 +57,16 @@ func (p *PubSubMulticast) connect(ctx context.Context, ID string, rw io.ReadWrit
4657
return p.Connect(client, channels)
4758
}
4859

49-
func (p *PubSubMulticast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) {
60+
func (p *Multicast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) {
5061
return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false)
5162
}
5263

53-
func (p *PubSubMulticast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error {
64+
func (p *Multicast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error {
5465
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, true, false, false))
5566
}
5667

57-
func (p *PubSubMulticast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error {
68+
func (p *Multicast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error {
5869
return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive))
5970
}
6071

61-
var _ PubSub = (*PubSubMulticast)(nil)
72+
var _ = (*Multicast)(nil)

multicast_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ func TestMulticastSubBlock(t *testing.T) {
4040
name := "test-channel"
4141
syncer := make(chan int)
4242

43-
cast := &PubSubMulticast{
43+
cast := &Multicast{
4444
Logger: slog.Default(),
45-
Connector: &BaseConnector{
45+
Broker: &BaseBroker{
4646
Channels: syncmap.New[string, *Channel](),
4747
},
4848
}
@@ -85,9 +85,9 @@ func TestMulticastPubBlock(t *testing.T) {
8585
name := "test-channel"
8686
syncer := make(chan int)
8787

88-
cast := &PubSubMulticast{
88+
cast := &Multicast{
8989
Logger: slog.Default(),
90-
Connector: &BaseConnector{
90+
Broker: &BaseBroker{
9191
Channels: syncmap.New[string, *Channel](),
9292
},
9393
}
@@ -131,9 +131,9 @@ func TestMulticastMultSubs(t *testing.T) {
131131
name := "test-channel"
132132
syncer := make(chan int)
133133

134-
cast := &PubSubMulticast{
134+
cast := &Multicast{
135135
Logger: slog.Default(),
136-
Connector: &BaseConnector{
136+
Broker: &BaseBroker{
137137
Channels: syncmap.New[string, *Channel](),
138138
},
139139
}

pubsub.go

+1-7
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,14 @@ import (
44
"context"
55
"io"
66
"iter"
7-
"log/slog"
87
)
98

109
type PubSub interface {
11-
Connector
10+
Broker
1211
GetPubs() iter.Seq2[string, *Client]
1312
GetSubs() iter.Seq2[string, *Client]
1413
GetPipes() iter.Seq2[string, *Client]
1514
Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error)
1615
Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error
1716
Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel) error
1817
}
19-
20-
type Cfg struct {
21-
Logger *slog.Logger
22-
PubSub PubSub
23-
}

0 commit comments

Comments
 (0)