garm/apiserver/events/events.go
Gabriel Adrian Samfira 6994c8ce05 Add copyright header
Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
2025-05-20 09:43:29 +00:00

194 lines
5 KiB
Go

// Copyright 2025 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package events
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"sync"
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
commonUtil "github.com/cloudbase/garm-provider-common/util"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/websocket"
)
func NewHandler(ctx context.Context, client *websocket.Client) (*EventHandler, error) {
if client == nil {
return nil, runnerErrors.ErrUnauthorized
}
newID := commonUtil.NewID()
userID := auth.UserID(ctx)
if userID == "" {
return nil, runnerErrors.ErrUnauthorized
}
consumerID := fmt.Sprintf("ws-event-watcher-%s-%s", userID, newID)
consumer, err := watcher.RegisterConsumer(
// Filter everything by default. Users should set up filters
// after registration.
ctx, consumerID, watcher.WithNone())
if err != nil {
return nil, err
}
handler := &EventHandler{
client: client,
ctx: ctx,
consumer: consumer,
done: make(chan struct{}),
}
client.SetMessageHandler(handler.HandleClientMessages)
return handler, nil
}
type EventHandler struct {
client *websocket.Client
consumer common.Consumer
ctx context.Context
done chan struct{}
running bool
mux sync.Mutex
}
func (e *EventHandler) loop() {
defer e.Stop()
for {
select {
case <-e.ctx.Done():
slog.DebugContext(e.ctx, "context done, stopping event handler")
return
case <-e.client.Done():
slog.DebugContext(e.ctx, "client done, stopping event handler")
return
case <-e.Done():
slog.DebugContext(e.ctx, "done channel closed, stopping event handler")
case event, ok := <-e.consumer.Watch():
if !ok {
slog.DebugContext(e.ctx, "watcher closed, stopping event handler")
return
}
asJs, err := json.Marshal(event)
if err != nil {
slog.ErrorContext(e.ctx, "failed to marshal event", "error", err)
continue
}
if _, err := e.client.Write(asJs); err != nil {
slog.ErrorContext(e.ctx, "failed to write event", "error", err)
}
}
}
}
func (e *EventHandler) Start() error {
e.mux.Lock()
defer e.mux.Unlock()
if e.running {
return nil
}
if err := e.client.Start(); err != nil {
return err
}
e.running = true
go e.loop()
return nil
}
func (e *EventHandler) Stop() {
e.mux.Lock()
defer e.mux.Unlock()
if !e.running {
return
}
e.running = false
e.consumer.Close()
e.client.Stop()
close(e.done)
}
func (e *EventHandler) Done() <-chan struct{} {
return e.done
}
// optionsToWatcherFilters converts the Options struct to a PayloadFilterFunc.
// The client will send an array of filters that indicates which entities and which
// operations the client is interested in. The behavior is that of "any" filter.
// Which means that if any of the elements in the array match an event, it will be
// sent to the websocket.
// Alternatively, clients can choose to get everything.
func (e *EventHandler) optionsToWatcherFilters(opt Options) common.PayloadFilterFunc {
if opt.SendEverything {
return watcher.WithEverything()
}
var funcs []common.PayloadFilterFunc
for _, filter := range opt.Filters {
var filterFunc []common.PayloadFilterFunc
if filter.EntityType == "" {
return watcher.WithNone()
}
filterFunc = append(filterFunc, watcher.WithEntityTypeFilter(filter.EntityType))
if len(filter.Operations) > 0 {
var opFunc []common.PayloadFilterFunc
for _, op := range filter.Operations {
opFunc = append(opFunc, watcher.WithOperationTypeFilter(op))
}
filterFunc = append(filterFunc, watcher.WithAny(opFunc...))
}
funcs = append(funcs, watcher.WithAll(filterFunc...))
}
return watcher.WithAny(funcs...)
}
func (e *EventHandler) HandleClientMessages(message []byte) error {
if e.consumer == nil {
return fmt.Errorf("consumer not initialized")
}
var opt Options
if err := json.Unmarshal(message, &opt); err != nil {
slog.ErrorContext(e.ctx, "failed to unmarshal message from client", "error", err, "message", string(message))
// Client is in error. Disconnect.
e.client.Write([]byte("failed to unmarshal filter"))
e.Stop()
return nil
}
if err := opt.Validate(); err != nil {
if errors.Is(err, common.ErrNoFiltersProvided) {
slog.DebugContext(e.ctx, "no filters provided; ignoring")
return nil
}
slog.ErrorContext(e.ctx, "invalid filter", "error", err)
e.client.Write([]byte("invalid filter"))
e.Stop()
return nil
}
watcherFilters := e.optionsToWatcherFilters(opt)
e.consumer.SetFilters(watcherFilters)
return nil
}