diff --git a/cmd/garm-cli/cmd/log.go b/cmd/garm-cli/cmd/log.go index c8e61fa0..19708afc 100644 --- a/cmd/garm-cli/cmd/log.go +++ b/cmd/garm-cli/cmd/log.go @@ -63,7 +63,7 @@ var logCmd = &cobra.Command{ slog.With(slog.Any("error", err)).Error("reading log message") return } - fmt.Print(util.SanitizeLogEntry(string(message))) + fmt.Println(util.SanitizeLogEntry(string(message))) } }() diff --git a/cmd/garm/main.go b/cmd/garm/main.go index 217ba994..4b10fbaa 100644 --- a/cmd/garm/main.go +++ b/cmd/garm/main.go @@ -71,8 +71,7 @@ func maybeInitController(db common.Store) error { return nil } -func setupLogging(ctx context.Context, cfg *config.Config, hub *websocket.Hub) { - logCfg := cfg.GetLoggingConfig() +func setupLogging(ctx context.Context, logCfg config.Logging, hub *websocket.Hub) { logWriter, err := util.GetLoggingWriter(logCfg.LogFile) if err != nil { log.Fatalf("fetching log writer: %+v", err) @@ -157,16 +156,16 @@ func main() { log.Fatalf("Fetching config: %+v", err) } + logCfg := cfg.GetLoggingConfig() var hub *websocket.Hub - if cfg.Default.EnableLogStreamer != nil && *cfg.Default.EnableLogStreamer { + if logCfg.EnableLogStreamer != nil && *logCfg.EnableLogStreamer { hub = websocket.NewHub(ctx) if err := hub.Start(); err != nil { log.Fatal(err) } defer hub.Stop() //nolint } - - setupLogging(ctx, cfg, hub) + setupLogging(ctx, logCfg, hub) db, err := database.NewDatabase(ctx, cfg.Database) if err != nil { diff --git a/websocket/websocket.go b/websocket/websocket.go index a482fdcd..a650088d 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -3,6 +3,7 @@ package websocket import ( "context" "fmt" + "sync" "time" ) @@ -33,41 +34,64 @@ type Hub struct { // Unregister requests from clients. unregister chan *Client + + mux sync.Mutex + once sync.Once } func (h *Hub) run() { + defer func() { + close(h.closed) + }() for { select { case <-h.quit: - close(h.closed) return case <-h.ctx.Done(): - close(h.closed) return case client := <-h.register: if client != nil { + h.mux.Lock() h.clients[client.id] = client + h.mux.Unlock() } case client := <-h.unregister: if client != nil { + h.mux.Lock() if _, ok := h.clients[client.id]; ok { - delete(h.clients, client.id) + client.conn.Close() close(client.send) + delete(h.clients, client.id) } + h.mux.Unlock() } case message := <-h.broadcast: + staleClients := []string{} for id, client := range h.clients { if client == nil { + staleClients = append(staleClients, id) continue } select { case client.send <- message: case <-time.After(5 * time.Second): - close(client.send) - delete(h.clients, id) + staleClients = append(staleClients, id) } } + if len(staleClients) > 0 { + h.mux.Lock() + for _, id := range staleClients { + if client, ok := h.clients[id]; ok { + if client != nil { + client.conn.Close() + close(client.send) + } + delete(h.clients, id) + } + } + h.mux.Unlock() + } } } } @@ -78,13 +102,15 @@ func (h *Hub) Register(client *Client) error { } func (h *Hub) Write(msg []byte) (int, error) { + tmp := make([]byte, len(msg)) + copy(tmp, msg) + select { case <-time.After(5 * time.Second): return 0, fmt.Errorf("timed out sending message to client") - case h.broadcast <- msg: - + case h.broadcast <- tmp: } - return len(msg), nil + return len(tmp), nil } func (h *Hub) Start() error { @@ -92,8 +118,15 @@ func (h *Hub) Start() error { return nil } +func (h *Hub) Close() error { + h.once.Do(func() { + close(h.quit) + }) + return nil +} + func (h *Hub) Stop() error { - close(h.quit) + h.Close() select { case <-h.closed: return nil