garm/websocket/websocket.go
Gabriel Adrian Samfira dd1740c189 Refactor the websocket client and add fixes
The websocket client and hub interaction has been simplified a bit.
The hub now acts only as a tee writer to the various clients that
register. Clients must register and unregister explicitly. The hub
is no longer passed in to the client.

Websocket clients now watch for password changes or jwt token expiration
times. Clients are disconnected if auth token expires or if the password
is changed.

Various aditional safety checks have been added.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
2024-07-05 12:55:35 +00:00

143 lines
2.6 KiB
Go

package websocket
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
func NewHub(ctx context.Context) *Hub {
return &Hub{
clients: map[string]*Client{},
broadcast: make(chan []byte, 100),
ctx: ctx,
closed: make(chan struct{}),
quit: make(chan struct{}),
}
}
type Hub struct {
ctx context.Context
closed chan struct{}
quit chan struct{}
// Registered clients.
clients map[string]*Client
// Inbound messages from the clients.
broadcast chan []byte
mux sync.Mutex
once sync.Once
}
func (h *Hub) run() {
defer func() {
close(h.closed)
}()
for {
select {
case <-h.quit:
return
case <-h.ctx.Done():
return
case message := <-h.broadcast:
staleClients := []string{}
for id, client := range h.clients {
if client == nil {
staleClients = append(staleClients, id)
continue
}
if _, err := client.Write(message); err != nil {
staleClients = append(staleClients, id)
}
}
if len(staleClients) > 0 {
h.mux.Lock()
for _, id := range staleClients {
if client, ok := h.clients[id]; ok {
if client != nil {
client.conn.Close()
close(client.send)
}
delete(h.clients, id)
}
}
h.mux.Unlock()
}
}
}
}
func (h *Hub) Register(client *Client) error {
if client == nil {
return nil
}
h.mux.Lock()
defer h.mux.Unlock()
cli, ok := h.clients[client.ID()]
if ok {
if cli != nil {
return fmt.Errorf("client already registered")
}
}
slog.DebugContext(h.ctx, "registering client", "client_id", client.ID())
h.clients[client.id] = client
return nil
}
func (h *Hub) Unregister(client *Client) error {
if client == nil {
return nil
}
h.mux.Lock()
defer h.mux.Unlock()
cli, ok := h.clients[client.ID()]
if ok {
cli.Stop()
slog.DebugContext(h.ctx, "unregistering client", "client_id", cli.ID())
delete(h.clients, cli.ID())
slog.DebugContext(h.ctx, "current client count", "count", len(h.clients))
}
return nil
}
func (h *Hub) Write(msg []byte) (int, error) {
tmp := make([]byte, len(msg))
copy(tmp, msg)
select {
case <-time.After(5 * time.Second):
return 0, fmt.Errorf("timed out sending message to client")
case h.broadcast <- tmp:
}
return len(tmp), nil
}
func (h *Hub) Start() error {
go h.run()
return nil
}
func (h *Hub) Close() error {
h.once.Do(func() {
close(h.quit)
})
return nil
}
func (h *Hub) Stop() error {
h.Close()
return h.Wait()
}
func (h *Hub) Wait() error {
select {
case <-h.closed:
case <-time.After(60 * time.Second):
return fmt.Errorf("timed out waiting for hub stop")
}
return nil
}