garm/websocket/client.go
Gabriel Adrian Samfira 9f8659abd6 Add events websocket endpoint
This change adds a new websocket endpoint for database events. The events
endpoint allows clients to stream events as they happen in GARM. Events
are defined as a structure containning the event type (create, update, delete),
the database entity involved (instances, pools, repos, etc) and the payload
consisting of the object involved in the event. The payload translates
to the types normally returned by the API and can be deserialized as one
of the types present in the params package.

The events endpoint is a websocket endpoint and it accepts filters as
a simple json send over the websocket connection. The filters allows the
user to specify which entities are of interest, and which operations should
be returned. For example, you may be interested in changes made to pools
or runners, in which case you could create a filter that only returns
update operations for pools. Or update and delete operations.

The filters can be defined as:

{
  "filters": [
    {
      "entity_type": "instance",
      "operations": ["update", "delete"]
    },
    {
      "entity_type": "pool"
    },
  ],
  "send_everything": false
}

This would return only update and delete events for instances and all events
for pools. Alternatively you can ask GARM to send you everything:

{
  "send_everything": true
}

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
2024-07-05 12:55:35 +00:00

316 lines
6.7 KiB
Go

package websocket
import (
"context"
"fmt"
"log/slog"
"net"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/params"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 1024
)
type HandleWebsocketMessage func([]byte) error
func NewClient(ctx context.Context, conn *websocket.Conn) (*Client, error) {
clientID := uuid.New()
consumerID := fmt.Sprintf("ws-client-watcher-%s", clientID.String())
user := auth.UserID(ctx)
if user == "" {
return nil, fmt.Errorf("user not found in context")
}
generation := auth.PasswordGeneration(ctx)
consumer, err := watcher.RegisterConsumer(
ctx, consumerID,
watcher.WithUserIDFilter(user),
)
if err != nil {
return nil, errors.Wrap(err, "registering consumer")
}
return &Client{
id: clientID.String(),
conn: conn,
ctx: ctx,
userID: user,
passwordGeneration: generation,
consumer: consumer,
done: make(chan struct{}),
send: make(chan []byte, 100),
}, nil
}
type Client struct {
id string
conn *websocket.Conn
// Buffered channel of outbound messages.
send chan []byte
mux sync.Mutex
writeMux sync.Mutex
ctx context.Context
userID string
passwordGeneration uint
consumer common.Consumer
messageHandler HandleWebsocketMessage
running bool
done chan struct{}
}
func (c *Client) ID() string {
return c.id
}
func (c *Client) Stop() {
c.mux.Lock()
defer c.mux.Unlock()
if !c.running {
return
}
c.running = false
c.writeMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.conn.Close()
close(c.send)
close(c.done)
}
func (c *Client) Done() <-chan struct{} {
return c.done
}
func (c *Client) SetMessageHandler(handler HandleWebsocketMessage) {
c.mux.Lock()
defer c.mux.Unlock()
c.messageHandler = handler
}
func (c *Client) Start() error {
c.mux.Lock()
defer c.mux.Unlock()
c.running = true
go c.runWatcher()
go c.clientReader()
go c.clientWriter()
return nil
}
func (c *Client) Write(msg []byte) (int, error) {
c.mux.Lock()
defer c.mux.Unlock()
if !c.running {
return 0, fmt.Errorf("client is stopped")
}
tmp := make([]byte, len(msg))
copy(tmp, msg)
select {
case <-time.After(5 * time.Second):
return 0, fmt.Errorf("timed out sending message to client")
case c.send <- tmp:
}
return len(tmp), nil
}
// clientReader waits for options changes from the client. The client can at any time
// change the log level and binary name it watches.
func (c *Client) clientReader() {
defer func() {
c.Stop()
}()
c.conn.SetReadLimit(maxMessageSize)
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
slog.With(slog.Any("error", err)).Error("failed to set read deadline")
}
c.conn.SetPongHandler(func(string) error {
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
return err
}
return nil
})
for {
mt, data, err := c.conn.ReadMessage()
if err != nil {
if IsErrorOfInterest(err) {
slog.ErrorContext(c.ctx, "error reading websocket message", slog.Any("error", err))
}
break
}
if c.messageHandler != nil {
if err := c.messageHandler(data); err != nil {
slog.ErrorContext(c.ctx, "error handling message", slog.Any("error", err))
}
}
if mt == websocket.CloseMessage {
break
}
}
}
func (c *Client) writeMessage(messageType int, message []byte) error {
c.writeMux.Lock()
defer c.writeMux.Unlock()
if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
return fmt.Errorf("failed to set write deadline: %w", err)
}
if err := c.conn.WriteMessage(messageType, message); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
return nil
}
// clientWriter
func (c *Client) clientWriter() {
ticker := time.NewTicker(pingPeriod)
defer func() {
c.Stop()
ticker.Stop()
}()
var authExpires time.Time
expires := auth.Expires(c.ctx)
if expires != nil {
authExpires = *expires
}
for {
select {
case message, ok := <-c.send:
if !ok {
// The hub closed the channel.
if err := c.writeMessage(websocket.CloseMessage, []byte{}); err != nil {
if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("failed to write message")
}
}
return
}
if err := c.writeMessage(websocket.TextMessage, message); err != nil {
if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("error sending message")
}
return
}
case <-ticker.C:
if err := c.writeMessage(websocket.PingMessage, nil); err != nil {
if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("failed to write ping message")
}
return
}
case <-c.ctx.Done():
return
case <-time.After(time.Until(authExpires)):
// Auth has expired
slog.DebugContext(c.ctx, "auth expired, closing connection")
return
}
}
}
func (c *Client) runWatcher() {
defer func() {
c.Stop()
}()
for {
select {
case <-c.Done():
return
case <-c.ctx.Done():
return
case event, ok := <-c.consumer.Watch():
if !ok {
slog.InfoContext(c.ctx, "watcher closed")
return
}
if event.EntityType != common.UserEntityType {
continue
}
user, ok := event.Payload.(params.User)
if !ok {
slog.ErrorContext(c.ctx, "failed to cast payload to user")
continue
}
if user.ID != c.userID {
continue
}
if event.Operation == common.DeleteOperation {
slog.InfoContext(c.ctx, "user deleted; closing connection")
c.Stop()
}
if !user.Enabled {
slog.InfoContext(c.ctx, "user disabled; closing connection")
c.Stop()
}
if user.Generation != c.passwordGeneration {
slog.InfoContext(c.ctx, "password generation mismatch; closing connection")
c.Stop()
}
}
}
}
func IsErrorOfInterest(err error) bool {
if err == nil {
return false
}
if errors.Is(err, websocket.ErrCloseSent) {
return false
}
if errors.Is(err, websocket.ErrBadHandshake) {
return false
}
if errors.Is(err, net.ErrClosed) {
return false
}
asCloseErr, ok := err.(*websocket.CloseError)
if ok {
switch asCloseErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway,
websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure:
return false
}
}
return true
}