diff --git a/cmd/garm-cli/cmd/events.go b/cmd/garm-cli/cmd/events.go index d0adb685..52a3fe9a 100644 --- a/cmd/garm-cli/cmd/events.go +++ b/cmd/garm-cli/cmd/events.go @@ -9,6 +9,7 @@ import ( "github.com/gorilla/websocket" "github.com/spf13/cobra" + "github.com/cloudbase/garm/cmd/garm-cli/common" garmWs "github.com/cloudbase/garm/websocket" ) @@ -26,7 +27,7 @@ var eventsCmd = &cobra.Command{ ctx, stop := signal.NotifyContext(context.Background(), signals...) defer stop() - reader, err := garmWs.NewReader(ctx, mgr.BaseURL, "/api/v1/events", mgr.Token) + reader, err := garmWs.NewReader(ctx, mgr.BaseURL, "/api/v1/events", mgr.Token, common.PrintWebsocketMessage) if err != nil { return err } diff --git a/cmd/garm-cli/cmd/log.go b/cmd/garm-cli/cmd/log.go index 57d08814..89f687ef 100644 --- a/cmd/garm-cli/cmd/log.go +++ b/cmd/garm-cli/cmd/log.go @@ -6,6 +6,7 @@ import ( "github.com/spf13/cobra" + "github.com/cloudbase/garm/cmd/garm-cli/common" garmWs "github.com/cloudbase/garm/websocket" ) @@ -20,7 +21,7 @@ var logCmd = &cobra.Command{ ctx, stop := signal.NotifyContext(context.Background(), signals...) defer stop() - reader, err := garmWs.NewReader(ctx, mgr.BaseURL, "/api/v1/ws", mgr.Token) + reader, err := garmWs.NewReader(ctx, mgr.BaseURL, "/api/v1/ws", mgr.Token, common.PrintWebsocketMessage) if err != nil { return err } diff --git a/cmd/garm-cli/common/common.go b/cmd/garm-cli/common/common.go index 3fc6c339..08189d21 100644 --- a/cmd/garm-cli/common/common.go +++ b/cmd/garm-cli/common/common.go @@ -20,6 +20,8 @@ import ( "github.com/manifoldco/promptui" "github.com/nbutton23/zxcvbn-go" + + "github.com/cloudbase/garm-provider-common/util" ) func PromptPassword(label string, compareTo string) (string, error) { @@ -67,3 +69,8 @@ func PromptString(label string, a ...interface{}) (string, error) { } return result, nil } + +func PrintWebsocketMessage(_ int, msg []byte) error { + fmt.Println(util.SanitizeLogEntry(string(msg))) + return nil +} diff --git a/websocket/util.go b/websocket/util.go index bb173014..e0dc256d 100644 --- a/websocket/util.go +++ b/websocket/util.go @@ -12,11 +12,12 @@ import ( "github.com/gorilla/websocket" - "github.com/cloudbase/garm-provider-common/util" apiParams "github.com/cloudbase/garm/apiserver/params" ) -func NewReader(ctx context.Context, baseURL, pth, token string) (*Reader, error) { +type MessageHandler func(msgType int, msg []byte) error + +func NewReader(ctx context.Context, baseURL, pth, token string, handler MessageHandler) (*Reader, error) { parsedURL, err := url.Parse(baseURL) if err != nil { return nil, err @@ -31,10 +32,11 @@ func NewReader(ctx context.Context, baseURL, pth, token string) (*Reader, error) header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) return &Reader{ - ctx: ctx, - url: u, - header: header, - done: make(chan struct{}), + ctx: ctx, + url: u, + header: header, + handler: handler, + done: make(chan struct{}), }, nil } @@ -46,6 +48,8 @@ type Reader struct { done chan struct{} running bool + handler MessageHandler + conn *websocket.Conn mux sync.Mutex writeMux sync.Mutex @@ -107,11 +111,11 @@ func (w *Reader) Start() error { w.conn = c w.running = true go w.loop() - go w.printWebsocketToConsole() + go w.handlerReader() return nil } -func (w *Reader) printWebsocketToConsole() { +func (w *Reader) handlerReader() { defer w.Stop() w.writeMux.Lock() w.conn.SetReadLimit(maxMessageSize) @@ -119,14 +123,18 @@ func (w *Reader) printWebsocketToConsole() { w.conn.SetPongHandler(func(string) error { w.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) w.writeMux.Unlock() for { - _, message, err := w.conn.ReadMessage() + msgType, message, err := w.conn.ReadMessage() if err != nil { if IsErrorOfInterest(err) { slog.With(slog.Any("error", err)).Error("reading log message") } return } - fmt.Println(util.SanitizeLogEntry(string(message))) + if w.handler != nil { + if err := w.handler(msgType, message); err != nil { + slog.With(slog.Any("error", err)).Error("handling log message") + } + } } }