garm/websocket/client.go
Gabriel Adrian Samfira 0faeee703d Make the debug-log command more useful
The debug-log command now supports log level filtering and attribute
filtering. The log level filtering will only be able to set the minimum
log level as low as the server is configured to stream. If the server has
its log level set as INFO, then setting the log level in the CLI to DEBUG
will have no effect.

But anything above what the server sends, is within the control of the client
to filter. This is all done client side.

Attribute filters are useful if you need to watch the logs for a particular
worker, entity, etc.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
2025-08-27 13:21:56 +00:00

337 lines
7.6 KiB
Go

// Copyright 2025 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package websocket
import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/params"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 16384 // 16 KB
)
type HandleWebsocketMessage func([]byte) error
func NewClient(ctx context.Context, conn *websocket.Conn) (*Client, error) {
clientID := uuid.New()
consumerID := fmt.Sprintf("ws-client-watcher-%s", clientID.String())
user := auth.UserID(ctx)
if user == "" {
return nil, fmt.Errorf("user not found in context")
}
generation := auth.PasswordGeneration(ctx)
consumer, err := watcher.RegisterConsumer(
ctx, consumerID,
watcher.WithUserIDFilter(user),
)
if err != nil {
return nil, fmt.Errorf("error registering consumer: %w", err)
}
return &Client{
id: clientID.String(),
conn: conn,
ctx: ctx,
userID: user,
passwordGeneration: generation,
consumer: consumer,
}, nil
}
type Client struct {
id string
conn *websocket.Conn
// Buffered channel of outbound messages.
send chan []byte
mux sync.Mutex
writeMux sync.Mutex
ctx context.Context
userID string
passwordGeneration uint
consumer common.Consumer
messageHandler HandleWebsocketMessage
running bool
done chan struct{}
}
func (c *Client) ID() string {
return c.id
}
func (c *Client) Stop() {
c.mux.Lock()
defer c.mux.Unlock()
if !c.running {
return
}
c.running = false
c.writeMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.conn.Close()
close(c.send)
close(c.done)
}
func (c *Client) Done() <-chan struct{} {
return c.done
}
func (c *Client) SetMessageHandler(handler HandleWebsocketMessage) {
c.mux.Lock()
defer c.mux.Unlock()
c.messageHandler = handler
}
func (c *Client) Start() error {
c.mux.Lock()
defer c.mux.Unlock()
c.running = true
c.send = make(chan []byte, 100)
c.done = make(chan struct{})
go c.runWatcher()
go c.clientReader()
go c.clientWriter()
return nil
}
func (c *Client) Write(msg []byte) (int, error) {
c.mux.Lock()
defer c.mux.Unlock()
if !c.running {
return 0, fmt.Errorf("websocket client is stopped")
}
tmp := make([]byte, len(msg))
copy(tmp, msg)
select {
case c.send <- tmp:
return len(tmp), nil
default:
return 0, fmt.Errorf("timed out sending message to websocket client")
}
}
// clientReader waits for options changes from the client. The client can at any time
// change the log level and binary name it watches.
func (c *Client) clientReader() {
defer func() {
c.Stop()
}()
c.conn.SetReadLimit(maxMessageSize)
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
slog.With(slog.Any("error", err)).Error("failed to set read deadline")
}
c.conn.SetPongHandler(func(string) error {
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
return err
}
return nil
})
for {
mt, data, err := c.conn.ReadMessage()
if err != nil {
if IsErrorOfInterest(err) {
slog.ErrorContext(c.ctx, "error reading websocket message", slog.Any("error", err))
}
break
}
if c.messageHandler != nil {
if err := c.messageHandler(data); err != nil {
slog.ErrorContext(c.ctx, "error handling message", slog.Any("error", err))
}
}
if mt == websocket.CloseMessage {
break
}
}
}
func (c *Client) writeMessage(messageType int, message []byte) error {
c.writeMux.Lock()
defer c.writeMux.Unlock()
if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
return fmt.Errorf("failed to set write deadline: %w", err)
}
if err := c.conn.WriteMessage(messageType, message); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
return nil
}
// clientWriter
func (c *Client) clientWriter() {
// Set up expiration timer.
// NOTE: if a token is created without an expiration date
// this will be set to nil, which will close the loop bellow
// and terminate the connection immediately.
// We can't have a token without an expiration date.
var authExpires time.Time
expires := auth.Expires(c.ctx)
if expires != nil {
authExpires = *expires
}
authTimer := time.NewTimer(time.Until(authExpires))
ticker := time.NewTicker(pingPeriod)
defer func() {
c.Stop()
ticker.Stop()
authTimer.Stop()
}()
for {
select {
case message, ok := <-c.send:
if !ok {
// The hub closed the channel.
if err := c.writeMessage(websocket.CloseMessage, []byte{}); err != nil {
if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("failed to write message")
}
}
return
}
if err := c.writeMessage(websocket.TextMessage, message); err != nil {
if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("error sending message")
}
return
}
case <-ticker.C:
if err := c.writeMessage(websocket.PingMessage, nil); err != nil {
if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("failed to write ping message")
}
return
}
case <-c.ctx.Done():
return
case <-authTimer.C:
// Auth has expired
slog.DebugContext(c.ctx, "auth expired, closing connection")
return
}
}
}
func (c *Client) runWatcher() {
defer func() {
c.Stop()
}()
for {
select {
case <-c.Done():
return
case <-c.ctx.Done():
return
case event, ok := <-c.consumer.Watch():
if !ok {
slog.InfoContext(c.ctx, "watcher closed")
return
}
if event.EntityType != common.UserEntityType {
continue
}
user, ok := event.Payload.(params.User)
if !ok {
slog.ErrorContext(c.ctx, "failed to cast payload to user")
continue
}
if user.ID != c.userID {
continue
}
if event.Operation == common.DeleteOperation {
slog.InfoContext(c.ctx, "user deleted; closing connection")
c.Stop()
}
if !user.Enabled {
slog.InfoContext(c.ctx, "user disabled; closing connection")
c.Stop()
}
if user.Generation != c.passwordGeneration {
slog.InfoContext(c.ctx, "password generation mismatch; closing connection")
c.Stop()
}
}
}
}
func IsErrorOfInterest(err error) bool {
if err == nil {
return false
}
if errors.Is(err, websocket.ErrCloseSent) {
return false
}
if errors.Is(err, websocket.ErrBadHandshake) {
return false
}
if errors.Is(err, net.ErrClosed) {
return false
}
asCloseErr, ok := err.(*websocket.CloseError)
if ok {
switch asCloseErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway,
websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure:
return false
}
}
return true
}