garm/database/watcher/watcher.go
Gabriel Adrian Samfira 6994c8ce05 Add copyright header
Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
2025-05-20 09:43:29 +00:00

205 lines
4.5 KiB
Go

// Copyright 2025 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package watcher
import (
"context"
"log/slog"
"sync"
"github.com/pkg/errors"
"github.com/cloudbase/garm/database/common"
garmUtil "github.com/cloudbase/garm/util"
)
var databaseWatcher common.Watcher
func InitWatcher(ctx context.Context) {
if databaseWatcher != nil {
return
}
ctx = garmUtil.WithSlogContext(ctx, slog.Any("watcher", "database"))
w := &watcher{
producers: make(map[string]*producer),
consumers: make(map[string]*consumer),
quit: make(chan struct{}),
ctx: ctx,
}
go w.loop()
databaseWatcher = w
}
func CloseWatcher() error {
if databaseWatcher == nil {
return nil
}
databaseWatcher.Close()
databaseWatcher = nil
return nil
}
func RegisterProducer(ctx context.Context, id string) (common.Producer, error) {
if databaseWatcher == nil {
return nil, common.ErrWatcherNotInitialized
}
ctx = garmUtil.WithSlogContext(ctx, slog.Any("producer_id", id))
return databaseWatcher.RegisterProducer(ctx, id)
}
func RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
if databaseWatcher == nil {
return nil, common.ErrWatcherNotInitialized
}
ctx = garmUtil.WithSlogContext(ctx, slog.Any("consumer_id", id))
return databaseWatcher.RegisterConsumer(ctx, id, filters...)
}
type watcher struct {
producers map[string]*producer
consumers map[string]*consumer
mux sync.Mutex
closed bool
quit chan struct{}
ctx context.Context
}
func (w *watcher) RegisterProducer(ctx context.Context, id string) (common.Producer, error) {
w.mux.Lock()
defer w.mux.Unlock()
if _, ok := w.producers[id]; ok {
return nil, errors.Wrapf(common.ErrProducerAlreadyRegistered, "producer_id: %s", id)
}
p := &producer{
id: id,
messages: make(chan common.ChangePayload, 1),
quit: make(chan struct{}),
ctx: ctx,
}
w.producers[id] = p
go w.serviceProducer(p)
return p, nil
}
func (w *watcher) serviceProducer(prod *producer) {
defer func() {
w.mux.Lock()
defer w.mux.Unlock()
prod.Close()
slog.InfoContext(w.ctx, "removing producer from watcher", "consumer_id", prod.id)
delete(w.producers, prod.id)
}()
for {
select {
case <-w.quit:
slog.InfoContext(w.ctx, "shutting down watcher")
return
case <-w.ctx.Done():
slog.InfoContext(w.ctx, "shutting down watcher")
return
case <-prod.quit:
slog.InfoContext(w.ctx, "closing producer")
return
case <-prod.ctx.Done():
slog.InfoContext(w.ctx, "closing producer")
return
case payload := <-prod.messages:
w.mux.Lock()
for _, c := range w.consumers {
go c.Send(payload)
}
w.mux.Unlock()
}
}
}
func (w *watcher) RegisterConsumer(ctx context.Context, id string, filters ...common.PayloadFilterFunc) (common.Consumer, error) {
w.mux.Lock()
defer w.mux.Unlock()
if _, ok := w.consumers[id]; ok {
return nil, common.ErrConsumerAlreadyRegistered
}
c := &consumer{
messages: make(chan common.ChangePayload, 1),
filters: filters,
quit: make(chan struct{}),
id: id,
ctx: ctx,
}
w.consumers[id] = c
go w.serviceConsumer(c)
return c, nil
}
func (w *watcher) serviceConsumer(consumer *consumer) {
defer func() {
w.mux.Lock()
defer w.mux.Unlock()
consumer.Close()
slog.InfoContext(w.ctx, "removing consumer from watcher", "consumer_id", consumer.id)
delete(w.consumers, consumer.id)
}()
slog.InfoContext(w.ctx, "starting consumer", "consumer_id", consumer.id)
for {
select {
case <-consumer.quit:
return
case <-consumer.ctx.Done():
return
case <-w.quit:
return
case <-w.ctx.Done():
return
}
}
}
func (w *watcher) Close() {
w.mux.Lock()
defer w.mux.Unlock()
if w.closed {
return
}
close(w.quit)
w.closed = true
for _, p := range w.producers {
p.Close()
}
for _, c := range w.consumers {
c.Close()
}
databaseWatcher = nil
}
func (w *watcher) loop() {
defer func() {
w.Close()
}()
for {
select {
case <-w.quit:
return
case <-w.ctx.Done():
return
}
}
}