Update garm-provider-common

Use the websocket reader from within garm-provider-common.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2024-08-03 16:07:21 +00:00
parent a5b15789a1
commit bc4285dc80
25 changed files with 298 additions and 70 deletions

View file

@ -77,6 +77,10 @@ func GetEnvironment() (Environment, error) {
if err := json.Unmarshal(data.Bytes(), &bootstrapParams); err != nil {
return Environment{}, fmt.Errorf("failed to decode instance params: %w", err)
}
if bootstrapParams.ExtraSpecs == nil {
// Initialize ExtraSpecs as an empty JSON object
bootstrapParams.ExtraSpecs = json.RawMessage([]byte("{}"))
}
env.BootstrapParams = bootstrapParams
}

View file

@ -1,3 +1,17 @@
// Copyright 2023 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package params
// RunnerApplicationDownload represents a binary for the self-hosted runner application that can be downloaded.

View file

@ -0,0 +1,184 @@
package websocket
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"net/url"
"sync"
"time"
"github.com/gorilla/websocket"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 16384 // 16 KB
)
// MessageHandler is a function that processes a message received from a websocket connection.
type MessageHandler func(msgType int, msg []byte) error
type APIErrorResponse struct {
Error string `json:"error"`
Details string `json:"details"`
}
// NewReader creates a new websocket reader. The reader will pass on any message it receives to the
// handler function. The handler function should return an error if it fails to process the message.
func NewReader(ctx context.Context, baseURL, pth, token string, handler MessageHandler) (*Reader, error) {
parsedURL, err := url.Parse(baseURL)
if err != nil {
return nil, err
}
wsScheme := "ws"
if parsedURL.Scheme == "https" {
wsScheme = "wss"
}
u := url.URL{Scheme: wsScheme, Host: parsedURL.Host, Path: pth}
header := http.Header{}
header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
return &Reader{
ctx: ctx,
url: u,
header: header,
handler: handler,
done: make(chan struct{}),
}, nil
}
type Reader struct {
ctx context.Context
url url.URL
header http.Header
done chan struct{}
running bool
handler MessageHandler
conn *websocket.Conn
mux sync.Mutex
writeMux sync.Mutex
}
func (w *Reader) Stop() {
w.mux.Lock()
defer w.mux.Unlock()
if !w.running {
return
}
w.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
w.conn.Close()
close(w.done)
w.running = false
}
func (w *Reader) Done() <-chan struct{} {
return w.done
}
func (w *Reader) WriteMessage(messageType int, data []byte) error {
// The websocket package does not support concurrent writes and panics if it
// detects that one has occurred, so we need to lock the writeMux to prevent
// concurrent writes to the same connection.
w.writeMux.Lock()
defer w.writeMux.Unlock()
if !w.running {
return fmt.Errorf("websocket is not running")
}
if err := w.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
return err
}
return w.conn.WriteMessage(messageType, data)
}
func (w *Reader) Start() error {
w.mux.Lock()
defer w.mux.Unlock()
if w.running {
return nil
}
c, response, err := websocket.DefaultDialer.Dial(w.url.String(), w.header)
if err != nil {
var resp APIErrorResponse
var msg string
var status string
if response != nil {
if response.Body != nil {
if err := json.NewDecoder(response.Body).Decode(&resp); err == nil {
msg = resp.Details
}
}
status = response.Status
}
return fmt.Errorf("failed to stream logs: %q %s (%s)", err, msg, status)
}
w.conn = c
w.running = true
go w.loop()
go w.handlerReader()
return nil
}
func (w *Reader) handlerReader() {
defer w.Stop()
w.writeMux.Lock()
w.conn.SetReadLimit(maxMessageSize)
w.conn.SetReadDeadline(time.Now().Add(pongWait))
w.conn.SetPongHandler(func(string) error { w.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
w.writeMux.Unlock()
for {
msgType, message, err := w.conn.ReadMessage()
if err != nil {
if IsErrorOfInterest(err) {
// TODO(gabriel-samfira): we should allow for an error channel that can be used to signal
// the caller that the connection has been closed.
slog.With(slog.Any("error", err)).Error("reading log message")
}
return
}
if w.handler != nil {
if err := w.handler(msgType, message); err != nil {
slog.With(slog.Any("error", err)).Error("handling log message")
}
}
}
}
func (w *Reader) loop() {
defer w.Stop()
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-w.Done():
return
case <-ticker.C:
w.writeMux.Lock()
w.conn.SetWriteDeadline(time.Now().Add(writeWait))
err := w.conn.WriteMessage(websocket.PingMessage, nil)
if err != nil {
w.writeMux.Unlock()
return
}
w.writeMux.Unlock()
}
}
}

View file

@ -0,0 +1,37 @@
package websocket
import (
"errors"
"net"
"github.com/gorilla/websocket"
)
func IsErrorOfInterest(err error) bool {
if err == nil {
return false
}
if errors.Is(err, websocket.ErrCloseSent) {
return false
}
if errors.Is(err, websocket.ErrBadHandshake) {
return false
}
if errors.Is(err, net.ErrClosed) {
return false
}
asCloseErr, ok := err.(*websocket.CloseError)
if ok {
switch asCloseErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway,
websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure:
return false
}
}
return true
}

View file

@ -5,19 +5,26 @@ linters-settings:
misspell:
locale: US
staticcheck:
checks: ['all', '-SA6002']
linters:
disable-all: true
enable:
- typecheck
- durationcheck
- gocritic
- gofumpt
- goimports
- misspell
- gomodguard
- govet
- ineffassign
- gosimple
- unused
- prealloc
- unconvert
- misspell
- revive
- staticcheck
- tenv
- typecheck
- unconvert
- unused
issues:
exclude-use-default: false
@ -25,5 +32,3 @@ issues:
- should have a package comment
- error strings should not be capitalized or end with punctuation or a newline
- don't use ALL_CAPS in Go names
service:
golangci-lint-version: 1.33.0 # use the fixed version to not introduce new linters unexpectedly

View file

@ -31,7 +31,7 @@ func (h headerV10) SetVersion() { h[0] = Version10 }
func (h headerV10) SetCipher(suite byte) { h[1] = suite }
func (h headerV10) SetLen(length int) { binary.LittleEndian.PutUint16(h[2:], uint16(length-1)) }
func (h headerV10) SetSequenceNumber(num uint32) { binary.LittleEndian.PutUint32(h[4:], num) }
func (h headerV10) SetRand(randVal []byte) { copy(h[8:headerSize], randVal[:]) }
func (h headerV10) SetRand(randVal []byte) { copy(h[8:headerSize], randVal) }
func (h headerV10) Nonce() []byte { return h[4:headerSize] }
func (h headerV10) AddData() []byte { return h[:4] }
@ -256,7 +256,7 @@ func (ad *authDecV20) Open(dst, src []byte) error {
ad.finalized = true
refNonce[0] |= 0x80 // set final flag
}
if subtle.ConstantTimeCompare(header.Nonce(), refNonce[:]) != 1 {
if subtle.ConstantTimeCompare(header.Nonce(), refNonce) != 1 {
return errNonceMismatch
}