garm/workers/cache/cache.go

339 lines
8.2 KiB
Go
Raw Normal View History

package cache
import (
"context"
"fmt"
"log/slog"
"sync"
"github.com/cloudbase/garm/cache"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/params"
garmUtil "github.com/cloudbase/garm/util"
)
func NewWorker(ctx context.Context, store common.Store) *Worker {
consumerID := "cache"
ctx = garmUtil.WithSlogContext(
ctx,
slog.Any("worker", consumerID))
return &Worker{
ctx: ctx,
store: store,
consumerID: consumerID,
quit: make(chan struct{}),
}
}
type Worker struct {
ctx context.Context
consumerID string
consumer common.Consumer
store common.Store
mux sync.Mutex
running bool
quit chan struct{}
}
func (w *Worker) setCacheForEntity(entityGetter params.EntityGetter, pools []params.Pool, scaleSets []params.ScaleSet) error {
entity, err := entityGetter.GetEntity()
if err != nil {
return fmt.Errorf("getting entity: %w", err)
}
cache.SetEntity(entity)
var repoPools []params.Pool
var repoScaleSets []params.ScaleSet
for _, pool := range pools {
if pool.RepoID == entity.ID {
repoPools = append(repoPools, pool)
}
}
for _, scaleSet := range scaleSets {
if scaleSet.RepoID == entity.ID {
repoScaleSets = append(repoScaleSets, scaleSet)
}
}
cache.ReplaceEntityPools(entity.ID, repoPools)
cache.ReplaceEntityScaleSets(entity.ID, repoScaleSets)
return nil
}
func (w *Worker) loadAllEntities() error {
pools, err := w.store.ListAllPools(w.ctx)
if err != nil {
return fmt.Errorf("listing pools: %w", err)
}
scaleSets, err := w.store.ListAllScaleSets(w.ctx)
if err != nil {
return fmt.Errorf("listing scale sets: %w", err)
}
repos, err := w.store.ListRepositories(w.ctx)
if err != nil {
return fmt.Errorf("listing repositories: %w", err)
}
orgs, err := w.store.ListOrganizations(w.ctx)
if err != nil {
return fmt.Errorf("listing organizations: %w", err)
}
enterprises, err := w.store.ListEnterprises(w.ctx)
if err != nil {
return fmt.Errorf("listing enterprises: %w", err)
}
for _, repo := range repos {
if err := w.setCacheForEntity(repo, pools, scaleSets); err != nil {
return fmt.Errorf("setting cache for repo: %w", err)
}
}
for _, org := range orgs {
if err := w.setCacheForEntity(org, pools, scaleSets); err != nil {
return fmt.Errorf("setting cache for org: %w", err)
}
}
for _, enterprise := range enterprises {
if err := w.setCacheForEntity(enterprise, pools, scaleSets); err != nil {
return fmt.Errorf("setting cache for enterprise: %w", err)
}
}
return nil
}
func (w *Worker) loadAllInstances() error {
instances, err := w.store.ListAllInstances(w.ctx)
if err != nil {
return fmt.Errorf("listing instances: %w", err)
}
for _, instance := range instances {
cache.SetInstanceCache(instance)
}
return nil
}
func (w *Worker) loadAllCredentials() error {
creds, err := w.store.ListGithubCredentials(w.ctx)
if err != nil {
return fmt.Errorf("listing github credentials: %w", err)
}
for _, cred := range creds {
cache.SetGithubCredentials(cred)
}
return nil
}
func (w *Worker) Start() error {
slog.DebugContext(w.ctx, "starting cache worker")
w.mux.Lock()
defer w.mux.Unlock()
if w.running {
return nil
}
if err := w.loadAllEntities(); err != nil {
return fmt.Errorf("loading all entities: %w", err)
}
if err := w.loadAllInstances(); err != nil {
return fmt.Errorf("loading all instances: %w", err)
}
if err := w.loadAllCredentials(); err != nil {
return fmt.Errorf("loading all credentials: %w", err)
}
consumer, err := watcher.RegisterConsumer(
w.ctx, w.consumerID,
watcher.WithAll())
if err != nil {
return fmt.Errorf("registering consumer: %w", err)
}
w.consumer = consumer
w.running = true
w.quit = make(chan struct{})
go w.loop()
return nil
}
func (w *Worker) Stop() error {
slog.DebugContext(w.ctx, "stopping cache worker")
w.mux.Lock()
defer w.mux.Unlock()
if !w.running {
return nil
}
w.consumer.Close()
w.running = false
close(w.quit)
return nil
}
func (w *Worker) handleEntityEvent(entityGetter params.EntityGetter, op common.OperationType) {
entity, err := entityGetter.GetEntity()
if err != nil {
slog.DebugContext(w.ctx, "getting entity from event", "error", err)
return
}
switch op {
case common.CreateOperation, common.UpdateOperation:
cache.SetEntity(entity)
case common.DeleteOperation:
cache.DeleteEntity(entity.ID)
}
}
func (w *Worker) handleRepositoryEvent(event common.ChangePayload) {
repo, ok := event.Payload.(params.Repository)
if !ok {
slog.DebugContext(w.ctx, "invalid payload type for repository event", "payload", event.Payload)
return
}
w.handleEntityEvent(repo, event.Operation)
}
func (w *Worker) handleOrgEvent(event common.ChangePayload) {
org, ok := event.Payload.(params.Organization)
if !ok {
slog.DebugContext(w.ctx, "invalid payload type for org event", "payload", event.Payload)
return
}
w.handleEntityEvent(org, event.Operation)
}
func (w *Worker) handleEnterpriseEvent(event common.ChangePayload) {
enterprise, ok := event.Payload.(params.Enterprise)
if !ok {
slog.DebugContext(w.ctx, "invalid payload type for enterprise event", "payload", event.Payload)
return
}
w.handleEntityEvent(enterprise, event.Operation)
}
func (w *Worker) handlePoolEvent(event common.ChangePayload) {
pool, ok := event.Payload.(params.Pool)
if !ok {
slog.DebugContext(w.ctx, "invalid payload type for pool event", "payload", event.Payload)
return
}
entity, err := pool.GetEntity()
if err != nil {
slog.DebugContext(w.ctx, "getting entity from pool", "error", err)
return
}
switch event.Operation {
case common.CreateOperation, common.UpdateOperation:
cache.SetEntityPool(entity.ID, pool)
case common.DeleteOperation:
cache.DeleteEntityPool(entity.ID, pool.ID)
}
}
func (w *Worker) handleScaleSetEvent(event common.ChangePayload) {
scaleSet, ok := event.Payload.(params.ScaleSet)
if !ok {
slog.DebugContext(w.ctx, "invalid payload type for pool event", "payload", event.Payload)
return
}
entity, err := scaleSet.GetEntity()
if err != nil {
slog.DebugContext(w.ctx, "getting entity from pool", "error", err)
return
}
switch event.Operation {
case common.CreateOperation, common.UpdateOperation:
cache.SetEntityScaleSet(entity.ID, scaleSet)
case common.DeleteOperation:
cache.DeleteEntityScaleSet(entity.ID, scaleSet.ID)
}
}
func (w *Worker) handleInstanceEvent(event common.ChangePayload) {
instance, ok := event.Payload.(params.Instance)
if !ok {
slog.DebugContext(w.ctx, "invalid payload type for instance event", "payload", event.Payload)
return
}
switch event.Operation {
case common.CreateOperation, common.UpdateOperation:
cache.SetInstanceCache(instance)
case common.DeleteOperation:
cache.DeleteInstanceCache(instance.Name)
}
}
func (w *Worker) handleCredentialsEvent(event common.ChangePayload) {
credentials, ok := event.Payload.(params.GithubCredentials)
if !ok {
slog.DebugContext(w.ctx, "invalid payload type for credentials event", "payload", event.Payload)
return
}
switch event.Operation {
case common.CreateOperation, common.UpdateOperation:
cache.SetGithubCredentials(credentials)
case common.DeleteOperation:
cache.DeleteGithubCredentials(credentials.ID)
}
}
func (w *Worker) handleEvent(event common.ChangePayload) {
slog.DebugContext(w.ctx, "handling event", "event", event)
switch event.EntityType {
case common.PoolEntityType:
w.handlePoolEvent(event)
case common.ScaleSetEntityType:
w.handleScaleSetEvent(event)
case common.InstanceEntityType:
w.handleInstanceEvent(event)
case common.RepositoryEntityType:
w.handleRepositoryEvent(event)
case common.OrganizationEntityType:
w.handleOrgEvent(event)
case common.EnterpriseEntityType:
w.handleEnterpriseEvent(event)
case common.GithubCredentialsEntityType:
w.handleCredentialsEvent(event)
default:
slog.DebugContext(w.ctx, "unknown entity type", "entity_type", event.EntityType)
}
}
func (w *Worker) loop() {
defer w.Stop()
for {
select {
case <-w.quit:
return
case event, ok := <-w.consumer.Watch():
if !ok {
slog.InfoContext(w.ctx, "consumer channel closed")
return
}
w.handleEvent(event)
case <-w.ctx.Done():
slog.DebugContext(w.ctx, "context done")
return
}
}
}