Refactor the websocket client and add fixes
The websocket client and hub interaction has been simplified a bit. The hub now acts only as a tee writer to the various clients that register. Clients must register and unregister explicitly. The hub is no longer passed in to the client. Websocket clients now watch for password changes or jwt token expiration times. Clients are disconnected if auth token expires or if the password is changed. Various aditional safety checks have been added. Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
parent
ca7f20b62d
commit
dd1740c189
17 changed files with 426 additions and 143 deletions
|
|
@ -1,11 +1,20 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudbase/garm/auth"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -22,13 +31,34 @@ const (
|
|||
maxMessageSize = 1024
|
||||
)
|
||||
|
||||
func NewClient(conn *websocket.Conn, hub *Hub) (*Client, error) {
|
||||
type HandleWebsocketMessage func([]byte) error
|
||||
|
||||
func NewClient(ctx context.Context, conn *websocket.Conn) (*Client, error) {
|
||||
clientID := uuid.New()
|
||||
consumerID := fmt.Sprintf("ws-client-watcher-%s", clientID.String())
|
||||
|
||||
user := auth.UserID(ctx)
|
||||
if user == "" {
|
||||
return nil, fmt.Errorf("user not found in context")
|
||||
}
|
||||
generation := auth.PasswordGeneration(ctx)
|
||||
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
ctx, consumerID,
|
||||
watcher.WithUserIDFilter(user),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "registering consumer")
|
||||
}
|
||||
return &Client{
|
||||
id: clientID.String(),
|
||||
conn: conn,
|
||||
hub: hub,
|
||||
send: make(chan []byte, 100),
|
||||
id: clientID.String(),
|
||||
conn: conn,
|
||||
ctx: ctx,
|
||||
userID: user,
|
||||
passwordGeneration: generation,
|
||||
consumer: consumer,
|
||||
done: make(chan struct{}),
|
||||
send: make(chan []byte, 100),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -37,21 +67,84 @@ type Client struct {
|
|||
conn *websocket.Conn
|
||||
// Buffered channel of outbound messages.
|
||||
send chan []byte
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
|
||||
hub *Hub
|
||||
userID string
|
||||
passwordGeneration uint
|
||||
consumer common.Consumer
|
||||
|
||||
messageHandler HandleWebsocketMessage
|
||||
|
||||
running bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (c *Client) Go() {
|
||||
func (c *Client) ID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *Client) Stop() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if !c.running {
|
||||
return
|
||||
}
|
||||
|
||||
c.running = false
|
||||
c.conn.Close()
|
||||
close(c.send)
|
||||
close(c.done)
|
||||
}
|
||||
|
||||
func (c *Client) Done() <-chan struct{} {
|
||||
return c.done
|
||||
}
|
||||
|
||||
func (c *Client) SetMessageHandler(handler HandleWebsocketMessage) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.messageHandler = handler
|
||||
}
|
||||
|
||||
func (c *Client) Start() error {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.running = true
|
||||
|
||||
go c.runWatcher()
|
||||
go c.clientReader()
|
||||
go c.clientWriter()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Write(msg []byte) (int, error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if !c.running {
|
||||
return 0, fmt.Errorf("client is stopped")
|
||||
}
|
||||
|
||||
tmp := make([]byte, len(msg))
|
||||
copy(tmp, msg)
|
||||
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
return 0, fmt.Errorf("timed out sending message to client")
|
||||
case c.send <- tmp:
|
||||
}
|
||||
return len(tmp), nil
|
||||
}
|
||||
|
||||
// clientReader waits for options changes from the client. The client can at any time
|
||||
// change the log level and binary name it watches.
|
||||
func (c *Client) clientReader() {
|
||||
defer func() {
|
||||
c.hub.unregister <- c
|
||||
c.conn.Close()
|
||||
c.Stop()
|
||||
}()
|
||||
c.conn.SetReadLimit(maxMessageSize)
|
||||
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
|
||||
|
|
@ -64,10 +157,19 @@ func (c *Client) clientReader() {
|
|||
return nil
|
||||
})
|
||||
for {
|
||||
mt, _, err := c.conn.ReadMessage()
|
||||
mt, data, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if IsErrorOfInterest(err) {
|
||||
slog.ErrorContext(c.ctx, "error reading websocket message", slog.Any("error", err))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if c.messageHandler != nil {
|
||||
if err := c.messageHandler(data); err != nil {
|
||||
slog.ErrorContext(c.ctx, "error handling message", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
if mt == websocket.CloseMessage {
|
||||
break
|
||||
}
|
||||
|
|
@ -78,9 +180,14 @@ func (c *Client) clientReader() {
|
|||
func (c *Client) clientWriter() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
c.Stop()
|
||||
ticker.Stop()
|
||||
c.conn.Close()
|
||||
}()
|
||||
var authExpires time.Time
|
||||
expires := auth.Expires(c.ctx)
|
||||
if expires != nil {
|
||||
authExpires = *expires
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.send:
|
||||
|
|
@ -90,13 +197,17 @@ func (c *Client) clientWriter() {
|
|||
if !ok {
|
||||
// The hub closed the channel.
|
||||
if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
|
||||
slog.With(slog.Any("error", err)).Error("failed to write message")
|
||||
if IsErrorOfInterest(err) {
|
||||
slog.With(slog.Any("error", err)).Error("failed to write message")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
slog.With(slog.Any("error", err)).Error("error sending message")
|
||||
if IsErrorOfInterest(err) {
|
||||
slog.With(slog.Any("error", err)).Error("error sending message")
|
||||
}
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
|
|
@ -104,8 +215,81 @@ func (c *Client) clientWriter() {
|
|||
slog.With(slog.Any("error", err)).Error("failed to set write deadline")
|
||||
}
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
if IsErrorOfInterest(err) {
|
||||
slog.With(slog.Any("error", err)).Error("failed to write ping message")
|
||||
}
|
||||
return
|
||||
}
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Until(authExpires)):
|
||||
// Auth has expired
|
||||
slog.DebugContext(c.ctx, "auth expired, closing connection")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) runWatcher() {
|
||||
defer func() {
|
||||
c.Stop()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case event, ok := <-c.consumer.Watch():
|
||||
if !ok {
|
||||
slog.InfoContext(c.ctx, "watcher closed")
|
||||
return
|
||||
}
|
||||
go func(event common.ChangePayload) {
|
||||
if event.EntityType != common.UserEntityType {
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := event.Payload.(params.User)
|
||||
if !ok {
|
||||
slog.ErrorContext(c.ctx, "failed to cast payload to user")
|
||||
return
|
||||
}
|
||||
|
||||
if user.ID != c.userID {
|
||||
return
|
||||
}
|
||||
|
||||
if user.Generation != c.passwordGeneration {
|
||||
slog.InfoContext(c.ctx, "password generation mismatch; closing connection")
|
||||
c.Stop()
|
||||
}
|
||||
}(event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func IsErrorOfInterest(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, websocket.ErrCloseSent) {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, websocket.ErrBadHandshake) {
|
||||
return false
|
||||
}
|
||||
|
||||
asCloseErr, ok := err.(*websocket.CloseError)
|
||||
if ok {
|
||||
switch asCloseErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway,
|
||||
websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,19 +3,18 @@ package websocket
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewHub(ctx context.Context) *Hub {
|
||||
return &Hub{
|
||||
clients: map[string]*Client{},
|
||||
broadcast: make(chan []byte, 100),
|
||||
register: make(chan *Client, 100),
|
||||
unregister: make(chan *Client, 100),
|
||||
ctx: ctx,
|
||||
closed: make(chan struct{}),
|
||||
quit: make(chan struct{}),
|
||||
clients: map[string]*Client{},
|
||||
broadcast: make(chan []byte, 100),
|
||||
ctx: ctx,
|
||||
closed: make(chan struct{}),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -29,12 +28,6 @@ type Hub struct {
|
|||
// Inbound messages from the clients.
|
||||
broadcast chan []byte
|
||||
|
||||
// Register requests from the clients.
|
||||
register chan *Client
|
||||
|
||||
// Unregister requests from clients.
|
||||
unregister chan *Client
|
||||
|
||||
mux sync.Mutex
|
||||
once sync.Once
|
||||
}
|
||||
|
|
@ -49,22 +42,6 @@ func (h *Hub) run() {
|
|||
return
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
case client := <-h.register:
|
||||
if client != nil {
|
||||
h.mux.Lock()
|
||||
h.clients[client.id] = client
|
||||
h.mux.Unlock()
|
||||
}
|
||||
case client := <-h.unregister:
|
||||
if client != nil {
|
||||
h.mux.Lock()
|
||||
if _, ok := h.clients[client.id]; ok {
|
||||
client.conn.Close()
|
||||
close(client.send)
|
||||
delete(h.clients, client.id)
|
||||
}
|
||||
h.mux.Unlock()
|
||||
}
|
||||
case message := <-h.broadcast:
|
||||
staleClients := []string{}
|
||||
for id, client := range h.clients {
|
||||
|
|
@ -73,9 +50,7 @@ func (h *Hub) run() {
|
|||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case client.send <- message:
|
||||
case <-time.After(5 * time.Second):
|
||||
if _, err := client.Write(message); err != nil {
|
||||
staleClients = append(staleClients, id)
|
||||
}
|
||||
}
|
||||
|
|
@ -97,7 +72,35 @@ func (h *Hub) run() {
|
|||
}
|
||||
|
||||
func (h *Hub) Register(client *Client) error {
|
||||
h.register <- client
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
cli, ok := h.clients[client.ID()]
|
||||
if ok {
|
||||
if cli != nil {
|
||||
return fmt.Errorf("client already registered")
|
||||
}
|
||||
}
|
||||
slog.DebugContext(h.ctx, "registering client", "client_id", client.ID())
|
||||
h.clients[client.id] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hub) Unregister(client *Client) error {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
cli, ok := h.clients[client.ID()]
|
||||
if ok {
|
||||
cli.Stop()
|
||||
slog.DebugContext(h.ctx, "unregistering client", "client_id", cli.ID())
|
||||
delete(h.clients, cli.ID())
|
||||
slog.DebugContext(h.ctx, "current client count", "count", len(h.clients))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue