Create common utility function for ws

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2024-07-04 13:40:59 +00:00
parent 246f826b76
commit 5c5b2256bb
4 changed files with 181 additions and 165 deletions

View file

@ -1,92 +1,47 @@
package cmd
import (
"fmt"
"log/slog"
"context"
"os"
"os/signal"
"time"
"syscall"
"github.com/gorilla/websocket"
"github.com/spf13/cobra"
"github.com/cloudbase/garm-provider-common/util"
garmWs "github.com/cloudbase/garm/websocket"
)
var signals = []os.Signal{
os.Interrupt,
syscall.SIGTERM,
}
var eventsCmd = &cobra.Command{
Use: "debug-events",
SilenceUsage: true,
Short: "Stream garm events",
Long: `Stream all garm events to the terminal.`,
RunE: func(_ *cobra.Command, _ []string) error {
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
ctx, stop := signal.NotifyContext(context.Background(), signals...)
defer stop()
conn, err := getWebsocketConnection("/api/v1/events")
reader, err := garmWs.NewReader(ctx, mgr.BaseURL, "/api/v1/events", mgr.Token)
if err != nil {
return err
}
defer conn.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
for {
_, message, err := conn.ReadMessage()
if err != nil {
if garmWs.IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("reading event message")
}
return
}
fmt.Println(util.SanitizeLogEntry(string(message)))
}
}()
if err := reader.Start(); err != nil {
return err
}
if eventsFilters != "" {
conn.SetWriteDeadline(time.Now().Add(writeWait))
err = conn.WriteMessage(websocket.TextMessage, []byte(eventsFilters))
if err != nil {
if err := reader.WriteMessage(websocket.TextMessage, []byte(eventsFilters)); err != nil {
return err
}
}
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-done:
slog.Info("done")
return nil
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(writeWait))
err := conn.WriteMessage(websocket.PingMessage, nil)
if err != nil {
return err
}
case <-interrupt:
// Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection.
conn.SetWriteDeadline(time.Now().Add(writeWait))
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
return err
}
slog.Info("waiting for server to close connection")
select {
case <-done:
slog.Info("done")
case <-time.After(time.Second):
slog.Info("timeout")
}
return nil
}
}
<-reader.Done()
return nil
},
}

View file

@ -1,131 +1,36 @@
package cmd
import (
"encoding/json"
"fmt"
"log/slog"
"net/http"
"net/url"
"os"
"context"
"os/signal"
"time"
"github.com/gorilla/websocket"
"github.com/spf13/cobra"
"github.com/cloudbase/garm-provider-common/util"
apiParams "github.com/cloudbase/garm/apiserver/params"
garmWs "github.com/cloudbase/garm/websocket"
)
var eventsFilters string
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 30 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
)
func getWebsocketConnection(pth string) (*websocket.Conn, error) {
parsedURL, err := url.Parse(mgr.BaseURL)
if err != nil {
return nil, err
}
wsScheme := "ws"
if parsedURL.Scheme == "https" {
wsScheme = "wss"
}
u := url.URL{Scheme: wsScheme, Host: parsedURL.Host, Path: pth}
slog.Debug("connecting", "url", u.String())
header := http.Header{}
header.Add("Authorization", fmt.Sprintf("Bearer %s", mgr.Token))
c, response, err := websocket.DefaultDialer.Dial(u.String(), header)
if err != nil {
var resp apiParams.APIErrorResponse
var msg string
var status string
if response != nil {
if response.Body != nil {
if err := json.NewDecoder(response.Body).Decode(&resp); err == nil {
msg = resp.Details
}
}
status = response.Status
}
return nil, fmt.Errorf("failed to stream logs: %q %s (%s)", err, msg, status)
}
return c, nil
}
var logCmd = &cobra.Command{
Use: "debug-log",
SilenceUsage: true,
Short: "Stream garm log",
Long: `Stream all garm logging to the terminal.`,
RunE: func(_ *cobra.Command, _ []string) error {
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
ctx, stop := signal.NotifyContext(context.Background(), signals...)
defer stop()
conn, err := getWebsocketConnection("/api/v1/ws")
reader, err := garmWs.NewReader(ctx, mgr.BaseURL, "/api/v1/ws", mgr.Token)
if err != nil {
return err
}
defer conn.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
for {
_, message, err := conn.ReadMessage()
if err != nil {
if garmWs.IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("reading log message")
}
return
}
fmt.Println(util.SanitizeLogEntry(string(message)))
}
}()
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-done:
return nil
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(writeWait))
err := conn.WriteMessage(websocket.PingMessage, nil)
if err != nil {
return err
}
case <-interrupt:
// Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection.
conn.SetWriteDeadline(time.Now().Add(writeWait))
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
return err
}
select {
case <-done:
case <-time.After(time.Second):
}
return nil
}
if err := reader.Start(); err != nil {
return err
}
<-reader.Done()
return nil
},
}

View file

@ -198,6 +198,11 @@ func (c *Client) clientWriter() {
c.Stop()
ticker.Stop()
}()
// Set up expiration timer.
// NOTE: if a token is created without an expiration date
// this will be set to nil, which will close the loop bellow
// and terminate the connection immediately.
// We can't have a token without an expiration date.
var authExpires time.Time
expires := auth.Expires(c.ctx)
if expires != nil {

151
websocket/util.go Normal file
View file

@ -0,0 +1,151 @@
package websocket
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"net/url"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/cloudbase/garm-provider-common/util"
apiParams "github.com/cloudbase/garm/apiserver/params"
)
func NewReader(ctx context.Context, baseURL, pth, token string) (*Reader, error) {
parsedURL, err := url.Parse(baseURL)
if err != nil {
return nil, err
}
wsScheme := "ws"
if parsedURL.Scheme == "https" {
wsScheme = "wss"
}
u := url.URL{Scheme: wsScheme, Host: parsedURL.Host, Path: pth}
header := http.Header{}
header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
return &Reader{
ctx: ctx,
url: u,
header: header,
done: make(chan struct{}),
}, nil
}
type Reader struct {
ctx context.Context
url url.URL
header http.Header
done chan struct{}
running bool
conn *websocket.Conn
mux sync.Mutex
writeMux sync.Mutex
}
func (w *Reader) Stop() {
w.mux.Lock()
defer w.mux.Unlock()
if !w.running {
return
}
w.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
w.conn.Close()
close(w.done)
w.running = false
}
func (w *Reader) Done() <-chan struct{} {
return w.done
}
func (w *Reader) WriteMessage(messageType int, data []byte) error {
// The websocket package does not support concurrent writes and panics if it
// detects that one has occurred, so we need to lock the writeMux to prevent
// concurrent writes to the same connection.
w.writeMux.Lock()
defer w.writeMux.Unlock()
if !w.running {
return fmt.Errorf("websocket is not running")
}
if err := w.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
return err
}
return w.conn.WriteMessage(messageType, data)
}
func (w *Reader) Start() error {
w.mux.Lock()
defer w.mux.Unlock()
if w.running {
return nil
}
c, response, err := websocket.DefaultDialer.Dial(w.url.String(), w.header)
if err != nil {
var resp apiParams.APIErrorResponse
var msg string
var status string
if response != nil {
if response.Body != nil {
if err := json.NewDecoder(response.Body).Decode(&resp); err == nil {
msg = resp.Details
}
}
status = response.Status
}
return fmt.Errorf("failed to stream logs: %q %s (%s)", err, msg, status)
}
w.conn = c
w.running = true
go w.loop()
go w.printWebsocketToConsole()
return nil
}
func (w *Reader) printWebsocketToConsole() {
defer w.Stop()
w.conn.SetReadDeadline(time.Now().Add(pongWait))
w.conn.SetPongHandler(func(string) error { w.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
for {
_, message, err := w.conn.ReadMessage()
if err != nil {
if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("reading log message")
}
return
}
fmt.Println(util.SanitizeLogEntry(string(message)))
}
}
func (w *Reader) loop() {
defer w.Stop()
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-w.Done():
return
case <-ticker.C:
w.writeMux.Lock()
w.conn.SetWriteDeadline(time.Now().Add(writeWait))
err := w.conn.WriteMessage(websocket.PingMessage, nil)
if err != nil {
w.writeMux.Unlock()
return
}
w.writeMux.Unlock()
}
}
}