From 5c5b2256bb7bc51175fee901547131bdb1cea2c0 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Thu, 4 Jul 2024 13:40:59 +0000 Subject: [PATCH] Create common utility function for ws Signed-off-by: Gabriel Adrian Samfira --- cmd/garm-cli/cmd/events.go | 77 ++++--------------- cmd/garm-cli/cmd/log.go | 113 +++------------------------ websocket/client.go | 5 ++ websocket/util.go | 151 +++++++++++++++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 165 deletions(-) create mode 100644 websocket/util.go diff --git a/cmd/garm-cli/cmd/events.go b/cmd/garm-cli/cmd/events.go index 5f8df697..d0adb685 100644 --- a/cmd/garm-cli/cmd/events.go +++ b/cmd/garm-cli/cmd/events.go @@ -1,92 +1,47 @@ package cmd import ( - "fmt" - "log/slog" + "context" "os" "os/signal" - "time" + "syscall" "github.com/gorilla/websocket" "github.com/spf13/cobra" - "github.com/cloudbase/garm-provider-common/util" garmWs "github.com/cloudbase/garm/websocket" ) +var signals = []os.Signal{ + os.Interrupt, + syscall.SIGTERM, +} + var eventsCmd = &cobra.Command{ Use: "debug-events", SilenceUsage: true, Short: "Stream garm events", Long: `Stream all garm events to the terminal.`, RunE: func(_ *cobra.Command, _ []string) error { - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt) + ctx, stop := signal.NotifyContext(context.Background(), signals...) + defer stop() - conn, err := getWebsocketConnection("/api/v1/events") + reader, err := garmWs.NewReader(ctx, mgr.BaseURL, "/api/v1/events", mgr.Token) if err != nil { return err } - defer conn.Close() - done := make(chan struct{}) - - go func() { - defer close(done) - conn.SetReadDeadline(time.Now().Add(pongWait)) - conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - for { - _, message, err := conn.ReadMessage() - if err != nil { - if garmWs.IsErrorOfInterest(err) { - slog.With(slog.Any("error", err)).Error("reading event message") - } - return - } - fmt.Println(util.SanitizeLogEntry(string(message))) - } - }() + if err := reader.Start(); err != nil { + return err + } if eventsFilters != "" { - conn.SetWriteDeadline(time.Now().Add(writeWait)) - err = conn.WriteMessage(websocket.TextMessage, []byte(eventsFilters)) - if err != nil { + if err := reader.WriteMessage(websocket.TextMessage, []byte(eventsFilters)); err != nil { return err } } - - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - - for { - select { - case <-done: - slog.Info("done") - return nil - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(writeWait)) - err := conn.WriteMessage(websocket.PingMessage, nil) - if err != nil { - return err - } - case <-interrupt: - // Cleanly close the connection by sending a close message and then - // waiting (with timeout) for the server to close the connection. - conn.SetWriteDeadline(time.Now().Add(writeWait)) - err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { - return err - } - slog.Info("waiting for server to close connection") - select { - case <-done: - slog.Info("done") - case <-time.After(time.Second): - slog.Info("timeout") - } - return nil - } - } + <-reader.Done() + return nil }, } diff --git a/cmd/garm-cli/cmd/log.go b/cmd/garm-cli/cmd/log.go index 13ed0a8d..57d08814 100644 --- a/cmd/garm-cli/cmd/log.go +++ b/cmd/garm-cli/cmd/log.go @@ -1,131 +1,36 @@ package cmd import ( - "encoding/json" - "fmt" - "log/slog" - "net/http" - "net/url" - "os" + "context" "os/signal" - "time" - "github.com/gorilla/websocket" "github.com/spf13/cobra" - "github.com/cloudbase/garm-provider-common/util" - apiParams "github.com/cloudbase/garm/apiserver/params" garmWs "github.com/cloudbase/garm/websocket" ) var eventsFilters string -const ( - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second - - // Time allowed to read the next pong message from the peer. - pongWait = 30 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 -) - -func getWebsocketConnection(pth string) (*websocket.Conn, error) { - parsedURL, err := url.Parse(mgr.BaseURL) - if err != nil { - return nil, err - } - - wsScheme := "ws" - if parsedURL.Scheme == "https" { - wsScheme = "wss" - } - u := url.URL{Scheme: wsScheme, Host: parsedURL.Host, Path: pth} - slog.Debug("connecting", "url", u.String()) - - header := http.Header{} - header.Add("Authorization", fmt.Sprintf("Bearer %s", mgr.Token)) - - c, response, err := websocket.DefaultDialer.Dial(u.String(), header) - if err != nil { - var resp apiParams.APIErrorResponse - var msg string - var status string - if response != nil { - if response.Body != nil { - if err := json.NewDecoder(response.Body).Decode(&resp); err == nil { - msg = resp.Details - } - } - status = response.Status - } - return nil, fmt.Errorf("failed to stream logs: %q %s (%s)", err, msg, status) - } - return c, nil -} - var logCmd = &cobra.Command{ Use: "debug-log", SilenceUsage: true, Short: "Stream garm log", Long: `Stream all garm logging to the terminal.`, RunE: func(_ *cobra.Command, _ []string) error { - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt) + ctx, stop := signal.NotifyContext(context.Background(), signals...) + defer stop() - conn, err := getWebsocketConnection("/api/v1/ws") + reader, err := garmWs.NewReader(ctx, mgr.BaseURL, "/api/v1/ws", mgr.Token) if err != nil { return err } - defer conn.Close() - done := make(chan struct{}) - - go func() { - defer close(done) - conn.SetReadDeadline(time.Now().Add(pongWait)) - conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - for { - _, message, err := conn.ReadMessage() - if err != nil { - if garmWs.IsErrorOfInterest(err) { - slog.With(slog.Any("error", err)).Error("reading log message") - } - return - } - fmt.Println(util.SanitizeLogEntry(string(message))) - } - }() - - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - - for { - select { - case <-done: - return nil - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(writeWait)) - err := conn.WriteMessage(websocket.PingMessage, nil) - if err != nil { - return err - } - case <-interrupt: - // Cleanly close the connection by sending a close message and then - // waiting (with timeout) for the server to close the connection. - conn.SetWriteDeadline(time.Now().Add(writeWait)) - err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { - return err - } - select { - case <-done: - case <-time.After(time.Second): - } - return nil - } + if err := reader.Start(); err != nil { + return err } + + <-reader.Done() + return nil }, } diff --git a/websocket/client.go b/websocket/client.go index 3091c8a1..fdccd8a5 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -198,6 +198,11 @@ func (c *Client) clientWriter() { c.Stop() ticker.Stop() }() + // Set up expiration timer. + // NOTE: if a token is created without an expiration date + // this will be set to nil, which will close the loop bellow + // and terminate the connection immediately. + // We can't have a token without an expiration date. var authExpires time.Time expires := auth.Expires(c.ctx) if expires != nil { diff --git a/websocket/util.go b/websocket/util.go new file mode 100644 index 00000000..bd0ad6de --- /dev/null +++ b/websocket/util.go @@ -0,0 +1,151 @@ +package websocket + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/url" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/cloudbase/garm-provider-common/util" + apiParams "github.com/cloudbase/garm/apiserver/params" +) + +func NewReader(ctx context.Context, baseURL, pth, token string) (*Reader, error) { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, err + } + + wsScheme := "ws" + if parsedURL.Scheme == "https" { + wsScheme = "wss" + } + u := url.URL{Scheme: wsScheme, Host: parsedURL.Host, Path: pth} + header := http.Header{} + header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) + + return &Reader{ + ctx: ctx, + url: u, + header: header, + done: make(chan struct{}), + }, nil +} + +type Reader struct { + ctx context.Context + url url.URL + header http.Header + + done chan struct{} + running bool + + conn *websocket.Conn + mux sync.Mutex + writeMux sync.Mutex +} + +func (w *Reader) Stop() { + w.mux.Lock() + defer w.mux.Unlock() + if !w.running { + return + } + w.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + w.conn.Close() + close(w.done) + w.running = false +} + +func (w *Reader) Done() <-chan struct{} { + return w.done +} + +func (w *Reader) WriteMessage(messageType int, data []byte) error { + // The websocket package does not support concurrent writes and panics if it + // detects that one has occurred, so we need to lock the writeMux to prevent + // concurrent writes to the same connection. + w.writeMux.Lock() + defer w.writeMux.Unlock() + if !w.running { + return fmt.Errorf("websocket is not running") + } + if err := w.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + return err + } + return w.conn.WriteMessage(messageType, data) +} + +func (w *Reader) Start() error { + w.mux.Lock() + defer w.mux.Unlock() + if w.running { + return nil + } + + c, response, err := websocket.DefaultDialer.Dial(w.url.String(), w.header) + if err != nil { + var resp apiParams.APIErrorResponse + var msg string + var status string + if response != nil { + if response.Body != nil { + if err := json.NewDecoder(response.Body).Decode(&resp); err == nil { + msg = resp.Details + } + } + status = response.Status + } + return fmt.Errorf("failed to stream logs: %q %s (%s)", err, msg, status) + } + w.conn = c + w.running = true + go w.loop() + go w.printWebsocketToConsole() + return nil +} + +func (w *Reader) printWebsocketToConsole() { + defer w.Stop() + w.conn.SetReadDeadline(time.Now().Add(pongWait)) + w.conn.SetPongHandler(func(string) error { w.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, message, err := w.conn.ReadMessage() + if err != nil { + if IsErrorOfInterest(err) { + slog.With(slog.Any("error", err)).Error("reading log message") + } + return + } + fmt.Println(util.SanitizeLogEntry(string(message))) + } +} + +func (w *Reader) loop() { + defer w.Stop() + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + for { + select { + case <-w.ctx.Done(): + return + case <-w.Done(): + return + case <-ticker.C: + w.writeMux.Lock() + w.conn.SetWriteDeadline(time.Now().Add(writeWait)) + err := w.conn.WriteMessage(websocket.PingMessage, nil) + if err != nil { + w.writeMux.Unlock() + return + } + w.writeMux.Unlock() + } + } +}