Refactor the websocket client and add fixes

The websocket client and hub interaction has been simplified a bit.
The hub now acts only as a tee writer to the various clients that
register. Clients must register and unregister explicitly. The hub
is no longer passed in to the client.

Websocket clients now watch for password changes or jwt token expiration
times. Clients are disconnected if auth token expires or if the password
is changed.

Various aditional safety checks have been added.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
Gabriel Adrian Samfira 2024-07-02 22:26:12 +00:00
parent ca7f20b62d
commit dd1740c189
17 changed files with 426 additions and 143 deletions

View file

@ -183,14 +183,9 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets") slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets")
return return
} }
defer conn.Close()
// nolint:golangci-lint,godox client, err := wsWriter.NewClient(ctx, conn)
// TODO (gsamfira): Handle ExpiresAt. Right now, if a client uses
// a valid token to authenticate, and keeps the websocket connection
// open, it will allow that client to stream logs via websockets
// until the connection is broken. We need to forcefully disconnect
// the client once the token expires.
client, err := wsWriter.NewClient(conn, a.hub)
if err != nil { if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client") slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client")
return return
@ -199,7 +194,14 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to register new client") slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to register new client")
return return
} }
client.Go() defer a.hub.Unregister(client)
if err := client.Start(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to start client")
return
}
<-client.Done()
slog.Info("client disconnected", "client_id", client.ID())
} }
// NotFoundHandler is returned when an invalid URL is acccessed // NotFoundHandler is returned when an invalid URL is acccessed

View file

@ -55,6 +55,7 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) {
expires := &jwt.NumericDate{ expires := &jwt.NumericDate{
Time: expireToken, Time: expireToken,
} }
generation := PasswordGeneration(ctx)
claims := JWTClaims{ claims := JWTClaims{
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: expires, ExpiresAt: expires,
@ -62,10 +63,11 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) {
// TODO: make this configurable // TODO: make this configurable
Issuer: "garm", Issuer: "garm",
}, },
UserID: UserID(ctx), UserID: UserID(ctx),
TokenID: tokenID, TokenID: tokenID,
IsAdmin: IsAdmin(ctx), IsAdmin: IsAdmin(ctx),
FullName: FullName(ctx), FullName: FullName(ctx),
Generation: generation,
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(a.cfg.Secret)) tokenString, err := token.SignedString([]byte(a.cfg.Secret))
@ -182,5 +184,5 @@ func (a *Authenticator) AuthenticateUser(ctx context.Context, info params.Passwo
return ctx, runnerErrors.ErrUnauthorized return ctx, runnerErrors.ErrUnauthorized
} }
return PopulateContext(ctx, user), nil return PopulateContext(ctx, user, nil), nil
} }

View file

@ -16,6 +16,7 @@ package auth
import ( import (
"context" "context"
"time"
runnerErrors "github.com/cloudbase/garm-provider-common/errors" runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/params" "github.com/cloudbase/garm/params"
@ -28,9 +29,11 @@ const (
fullNameKey contextFlags = "full_name" fullNameKey contextFlags = "full_name"
readMetricsKey contextFlags = "read_metrics" readMetricsKey contextFlags = "read_metrics"
// UserIDFlag is the User ID flag we set in the context // UserIDFlag is the User ID flag we set in the context
UserIDFlag contextFlags = "user_id" UserIDFlag contextFlags = "user_id"
isEnabledFlag contextFlags = "is_enabled" isEnabledFlag contextFlags = "is_enabled"
jwtTokenFlag contextFlags = "jwt_token" jwtTokenFlag contextFlags = "jwt_token"
authExpiresFlag contextFlags = "auth_expires"
passwordGenerationFlag contextFlags = "password_generation"
instanceIDKey contextFlags = "id" instanceIDKey contextFlags = "id"
instanceNameKey contextFlags = "name" instanceNameKey contextFlags = "name"
@ -169,14 +172,43 @@ func PopulateInstanceContext(ctx context.Context, instance params.Instance) cont
// PopulateContext sets the appropriate fields in the context, based on // PopulateContext sets the appropriate fields in the context, based on
// the user object // the user object
func PopulateContext(ctx context.Context, user params.User) context.Context { func PopulateContext(ctx context.Context, user params.User, authExpires *time.Time) context.Context {
ctx = SetUserID(ctx, user.ID) ctx = SetUserID(ctx, user.ID)
ctx = SetAdmin(ctx, user.IsAdmin) ctx = SetAdmin(ctx, user.IsAdmin)
ctx = SetIsEnabled(ctx, user.Enabled) ctx = SetIsEnabled(ctx, user.Enabled)
ctx = SetFullName(ctx, user.FullName) ctx = SetFullName(ctx, user.FullName)
ctx = SetExpires(ctx, authExpires)
ctx = SetPasswordGeneration(ctx, user.Generation)
return ctx return ctx
} }
func SetExpires(ctx context.Context, expires *time.Time) context.Context {
if expires == nil {
return ctx
}
return context.WithValue(ctx, authExpiresFlag, expires)
}
func Expires(ctx context.Context) *time.Time {
elem := ctx.Value(authExpiresFlag)
if elem == nil {
return nil
}
return elem.(*time.Time)
}
func SetPasswordGeneration(ctx context.Context, val uint) context.Context {
return context.WithValue(ctx, passwordGenerationFlag, val)
}
func PasswordGeneration(ctx context.Context) uint {
elem := ctx.Value(passwordGenerationFlag)
if elem == nil {
return 0
}
return elem.(uint)
}
// SetFullName sets the user full name in the context // SetFullName sets the user full name in the context
func SetFullName(ctx context.Context, fullName string) context.Context { func SetFullName(ctx context.Context, fullName string) context.Context {
return context.WithValue(ctx, fullNameKey, fullName) return context.WithValue(ctx, fullNameKey, fullName)

View file

@ -21,6 +21,7 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"strings" "strings"
"time"
jwt "github.com/golang-jwt/jwt/v5" jwt "github.com/golang-jwt/jwt/v5"
@ -37,6 +38,7 @@ type JWTClaims struct {
FullName string `json:"full_name"` FullName string `json:"full_name"`
IsAdmin bool `json:"is_admin"` IsAdmin bool `json:"is_admin"`
ReadMetrics bool `json:"read_metrics"` ReadMetrics bool `json:"read_metrics"`
Generation uint `json:"generation"`
jwt.RegisteredClaims jwt.RegisteredClaims
} }
@ -69,7 +71,18 @@ func (amw *jwtMiddleware) claimsToContext(ctx context.Context, claims *JWTClaims
return ctx, runnerErrors.ErrUnauthorized return ctx, runnerErrors.ErrUnauthorized
} }
ctx = PopulateContext(ctx, userInfo) var expiresAt *time.Time
if claims.ExpiresAt != nil {
expires := claims.ExpiresAt.Time.UTC()
expiresAt = &expires
}
if userInfo.Generation != claims.Generation {
// Password was reset since token was issued. Invalidate.
return ctx, runnerErrors.ErrUnauthorized
}
ctx = PopulateContext(ctx, userInfo, expiresAt)
return ctx, nil return ctx, nil
} }

View file

@ -16,6 +16,7 @@ import (
"github.com/cloudbase/garm-provider-common/util" "github.com/cloudbase/garm-provider-common/util"
apiParams "github.com/cloudbase/garm/apiserver/params" apiParams "github.com/cloudbase/garm/apiserver/params"
garmWs "github.com/cloudbase/garm/websocket"
) )
var logCmd = &cobra.Command{ var logCmd = &cobra.Command{
@ -66,7 +67,9 @@ var logCmd = &cobra.Command{
for { for {
_, message, err := c.ReadMessage() _, message, err := c.ReadMessage()
if err != nil { if err != nil {
slog.With(slog.Any("error", err)).Error("reading log message") if garmWs.IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("reading log message")
}
return return
} }
fmt.Println(util.SanitizeLogEntry(string(message))) fmt.Println(util.SanitizeLogEntry(string(message)))

View file

@ -320,7 +320,7 @@ func main() {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed") slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed")
} }
slog.With(slog.Any("error", err)).ErrorContext(ctx, "waiting for runner to stop") slog.With(slog.Any("error", err)).InfoContext(ctx, "waiting for runner to stop")
if err := runner.Wait(); err != nil { if err := runner.Wait(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers") slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers")
os.Exit(1) os.Exit(1)

View file

@ -284,7 +284,7 @@ func (s *GithubTestSuite) TestCreateCredentials() {
func (s *GithubTestSuite) TestCreateCredentialsFailsOnDuplicateCredentials() { func (s *GithubTestSuite) TestCreateCredentialsFailsOnDuplicateCredentials() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "testuser", s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "testuser", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser) testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
credParams := params.CreateGithubCredentialsParams{ credParams := params.CreateGithubCredentialsParams{
Name: testCredsName, Name: testCredsName,
@ -313,8 +313,8 @@ func (s *GithubTestSuite) TestNormalUsersCanOnlySeeTheirOwnCredentialsAdminCanSe
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "testuser1", s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "testuser1", s.db, s.T())
testUser2 := garmTesting.CreateGARMTestUser(ctx, "testuser2", s.db, s.T()) testUser2 := garmTesting.CreateGARMTestUser(ctx, "testuser2", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser) testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
testUser2Ctx := auth.PopulateContext(context.Background(), testUser2) testUser2Ctx := auth.PopulateContext(context.Background(), testUser2, nil)
credParams := params.CreateGithubCredentialsParams{ credParams := params.CreateGithubCredentialsParams{
Name: testCredsName, Name: testCredsName,
@ -370,7 +370,7 @@ func (s *GithubTestSuite) TestGetGithubCredentialsFailsWhenCredentialsDontExist(
func (s *GithubTestSuite) TestGetGithubCredentialsByNameReturnsOnlyCurrentUserCredentials() { func (s *GithubTestSuite) TestGetGithubCredentialsByNameReturnsOnlyCurrentUserCredentials() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user1", s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "test-user1", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser) testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
credParams := params.CreateGithubCredentialsParams{ credParams := params.CreateGithubCredentialsParams{
Name: testCredsName, Name: testCredsName,
@ -472,7 +472,7 @@ func (s *GithubTestSuite) TestDeleteGithubCredentials() {
func (s *GithubTestSuite) TestDeleteGithubCredentialsByNonAdminUser() { func (s *GithubTestSuite) TestDeleteGithubCredentialsByNonAdminUser() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user4", s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "test-user4", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser) testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
credParams := params.CreateGithubCredentialsParams{ credParams := params.CreateGithubCredentialsParams{
Name: testCredsName, Name: testCredsName,
@ -682,7 +682,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsForNonExistingCredentials()
func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAdminUser() { func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAdminUser() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser) testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
credParams := params.CreateGithubCredentialsParams{ credParams := params.CreateGithubCredentialsParams{
Name: testCredsName, Name: testCredsName,
@ -711,7 +711,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAd
func (s *GithubTestSuite) TestAdminUserCanUpdateAnyGithubCredentials() { func (s *GithubTestSuite) TestAdminUserCanUpdateAnyGithubCredentials() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser) testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
credParams := params.CreateGithubCredentialsParams{ credParams := params.CreateGithubCredentialsParams{
Name: testCredsName, Name: testCredsName,

View file

@ -195,12 +195,13 @@ type Instance struct {
type User struct { type User struct {
Base Base
Username string `gorm:"uniqueIndex;varchar(64)"` Username string `gorm:"uniqueIndex;varchar(64)"`
FullName string `gorm:"type:varchar(254)"` FullName string `gorm:"type:varchar(254)"`
Email string `gorm:"type:varchar(254);unique;index:idx_email"` Email string `gorm:"type:varchar(254);unique;index:idx_email"`
Password string `gorm:"type:varchar(60)"` Password string `gorm:"type:varchar(60)"`
IsAdmin bool Generation uint
Enabled bool IsAdmin bool
Enabled bool
} }
type ControllerInfo struct { type ControllerInfo struct {

View file

@ -239,7 +239,7 @@ func (s *sqlDatabase) migrateCredentialsToDB() (err error) {
// user. GARM is not yet multi-user, so it's safe to assume we only have this // user. GARM is not yet multi-user, so it's safe to assume we only have this
// one user. // one user.
adminCtx := context.Background() adminCtx := context.Background()
adminCtx = auth.PopulateContext(adminCtx, adminUser) adminCtx = auth.PopulateContext(adminCtx, adminUser, nil)
slog.Info("migrating credentials to DB") slog.Info("migrating credentials to DB")
slog.Info("creating github endpoints table") slog.Info("creating github endpoints table")

View file

@ -26,7 +26,7 @@ import (
"github.com/cloudbase/garm/params" "github.com/cloudbase/garm/params"
) )
func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) { func (s *sqlDatabase) getUserByUsernameOrEmail(tx *gorm.DB, user string) (User, error) {
field := "username" field := "username"
if util.IsValidEmail(user) { if util.IsValidEmail(user) {
field = "email" field = "email"
@ -34,7 +34,7 @@ func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) {
query := fmt.Sprintf("%s = ?", field) query := fmt.Sprintf("%s = ?", field)
var dbUser User var dbUser User
q := s.conn.Model(&User{}).Where(query, user).First(&dbUser) q := tx.Model(&User{}).Where(query, user).First(&dbUser)
if q.Error != nil { if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) { if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return User{}, runnerErrors.ErrNotFound return User{}, runnerErrors.ErrNotFound
@ -44,9 +44,9 @@ func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) {
return dbUser, nil return dbUser, nil
} }
func (s *sqlDatabase) getUserByID(userID string) (User, error) { func (s *sqlDatabase) getUserByID(tx *gorm.DB, userID string) (User, error) {
var dbUser User var dbUser User
q := s.conn.Model(&User{}).Where("id = ?", userID).First(&dbUser) q := tx.Model(&User{}).Where("id = ?", userID).First(&dbUser)
if q.Error != nil { if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) { if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return User{}, runnerErrors.ErrNotFound return User{}, runnerErrors.ErrNotFound
@ -57,20 +57,9 @@ func (s *sqlDatabase) getUserByID(userID string) (User, error) {
} }
func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) (params.User, error) { func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) (params.User, error) {
if user.Username == "" || user.Email == "" { if user.Username == "" || user.Email == "" || user.Password == "" {
return params.User{}, runnerErrors.NewBadRequestError("missing username or email") return params.User{}, runnerErrors.NewBadRequestError("missing username, password or email")
} }
if _, err := s.getUserByUsernameOrEmail(user.Username); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
return params.User{}, runnerErrors.NewConflictError("username already exists")
}
if _, err := s.getUserByUsernameOrEmail(user.Email); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
return params.User{}, runnerErrors.NewConflictError("email already exists")
}
if s.HasAdminUser(context.Background()) && user.IsAdmin {
return params.User{}, runnerErrors.NewBadRequestError("admin user already exists")
}
newUser := User{ newUser := User{
Username: user.Username, Username: user.Username,
Password: user.Password, Password: user.Password,
@ -79,22 +68,42 @@ func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) (
Email: user.Email, Email: user.Email,
IsAdmin: user.IsAdmin, IsAdmin: user.IsAdmin,
} }
err := s.conn.Transaction(func(tx *gorm.DB) error {
if _, err := s.getUserByUsernameOrEmail(tx, user.Username); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
return runnerErrors.NewConflictError("username already exists")
}
if _, err := s.getUserByUsernameOrEmail(tx, user.Email); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
return runnerErrors.NewConflictError("email already exists")
}
q := s.conn.Save(&newUser) if s.hasAdmin(tx) && user.IsAdmin {
if q.Error != nil { return runnerErrors.NewBadRequestError("admin user already exists")
return params.User{}, errors.Wrap(q.Error, "creating user") }
q := tx.Save(&newUser)
if q.Error != nil {
return errors.Wrap(q.Error, "creating user")
}
return nil
})
if err != nil {
return params.User{}, errors.Wrap(err, "creating user")
} }
return s.sqlToParamsUser(newUser), nil return s.sqlToParamsUser(newUser), nil
} }
func (s *sqlDatabase) HasAdminUser(_ context.Context) bool { func (s *sqlDatabase) hasAdmin(tx *gorm.DB) bool {
var user User var user User
q := s.conn.Model(&User{}).Where("is_admin = ?", true).First(&user) q := tx.Model(&User{}).Where("is_admin = ?", true).First(&user)
return q.Error == nil return q.Error == nil
} }
func (s *sqlDatabase) HasAdminUser(_ context.Context) bool {
return s.hasAdmin(s.conn)
}
func (s *sqlDatabase) GetUser(_ context.Context, user string) (params.User, error) { func (s *sqlDatabase) GetUser(_ context.Context, user string) (params.User, error) {
dbUser, err := s.getUserByUsernameOrEmail(user) dbUser, err := s.getUserByUsernameOrEmail(s.conn, user)
if err != nil { if err != nil {
return params.User{}, errors.Wrap(err, "fetching user") return params.User{}, errors.Wrap(err, "fetching user")
} }
@ -102,7 +111,7 @@ func (s *sqlDatabase) GetUser(_ context.Context, user string) (params.User, erro
} }
func (s *sqlDatabase) GetUserByID(_ context.Context, userID string) (params.User, error) { func (s *sqlDatabase) GetUserByID(_ context.Context, userID string) (params.User, error) {
dbUser, err := s.getUserByID(userID) dbUser, err := s.getUserByID(s.conn, userID)
if err != nil { if err != nil {
return params.User{}, errors.Wrap(err, "fetching user") return params.User{}, errors.Wrap(err, "fetching user")
} }
@ -110,27 +119,35 @@ func (s *sqlDatabase) GetUserByID(_ context.Context, userID string) (params.User
} }
func (s *sqlDatabase) UpdateUser(_ context.Context, user string, param params.UpdateUserParams) (params.User, error) { func (s *sqlDatabase) UpdateUser(_ context.Context, user string, param params.UpdateUserParams) (params.User, error) {
dbUser, err := s.getUserByUsernameOrEmail(user) var err error
var dbUser User
err = s.conn.Transaction(func(tx *gorm.DB) error {
dbUser, err = s.getUserByUsernameOrEmail(tx, user)
if err != nil {
return errors.Wrap(err, "fetching user")
}
if param.FullName != "" {
dbUser.FullName = param.FullName
}
if param.Enabled != nil {
dbUser.Enabled = *param.Enabled
}
if param.Password != "" {
dbUser.Password = param.Password
dbUser.Generation++
}
if q := tx.Save(&dbUser); q.Error != nil {
return errors.Wrap(q.Error, "saving user")
}
return nil
})
if err != nil { if err != nil {
return params.User{}, errors.Wrap(err, "fetching user") return params.User{}, errors.Wrap(err, "updating user")
} }
if param.FullName != "" {
dbUser.FullName = param.FullName
}
if param.Enabled != nil {
dbUser.Enabled = *param.Enabled
}
if param.Password != "" {
dbUser.Password = param.Password
}
if q := s.conn.Save(&dbUser); q.Error != nil {
return params.User{}, errors.Wrap(q.Error, "saving user")
}
return s.sqlToParamsUser(dbUser), nil return s.sqlToParamsUser(dbUser), nil
} }

View file

@ -145,7 +145,7 @@ func (s *UserTestSuite) TestCreateUserMissingUsernameEmail() {
_, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams) _, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal(("missing username or email"), err.Error()) s.Require().Equal(("missing username, password or email"), err.Error())
} }
func (s *UserTestSuite) TestCreateUserUsernameAlreadyExist() { func (s *UserTestSuite) TestCreateUserUsernameAlreadyExist() {
@ -154,7 +154,7 @@ func (s *UserTestSuite) TestCreateUserUsernameAlreadyExist() {
_, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams) _, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal(("username already exists"), err.Error()) s.Require().Equal(("creating user: username already exists"), err.Error())
} }
func (s *UserTestSuite) TestCreateUserEmailAlreadyExist() { func (s *UserTestSuite) TestCreateUserEmailAlreadyExist() {
@ -163,10 +163,11 @@ func (s *UserTestSuite) TestCreateUserEmailAlreadyExist() {
_, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams) _, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal(("email already exists"), err.Error()) s.Require().Equal(("creating user: email already exists"), err.Error())
} }
func (s *UserTestSuite) TestCreateUserDBCreateErr() { func (s *UserTestSuite) TestCreateUserDBCreateErr() {
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE username = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE username = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")).
WithArgs(s.Fixtures.NewUserParams.Username, 1). WithArgs(s.Fixtures.NewUserParams.Username, 1).
@ -175,7 +176,6 @@ func (s *UserTestSuite) TestCreateUserDBCreateErr() {
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE email = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE email = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")).
WithArgs(s.Fixtures.NewUserParams.Email, 1). WithArgs(s.Fixtures.NewUserParams.Email, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"})) WillReturnRows(sqlmock.NewRows([]string{"id"}))
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectExec("INSERT INTO `users`"). ExpectExec("INSERT INTO `users`").
WillReturnError(fmt.Errorf("creating user mock error")) WillReturnError(fmt.Errorf("creating user mock error"))
@ -183,9 +183,9 @@ func (s *UserTestSuite) TestCreateUserDBCreateErr() {
_, err := s.StoreSQLMocked.CreateUser(context.Background(), s.Fixtures.NewUserParams) _, err := s.StoreSQLMocked.CreateUser(context.Background(), s.Fixtures.NewUserParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("creating user: creating user mock error", err.Error()) s.Require().Equal("creating user: creating user: creating user mock error", err.Error())
s.assertSQLMockExpectations()
} }
func (s *UserTestSuite) TestHasAdminUserNoAdmin() { func (s *UserTestSuite) TestHasAdminUserNoAdmin() {
@ -253,15 +253,15 @@ func (s *UserTestSuite) TestUpdateUserNotFound() {
_, err := s.Store.UpdateUser(context.Background(), "dummy-user", s.Fixtures.UpdateUserParams) _, err := s.Store.UpdateUser(context.Background(), "dummy-user", s.Fixtures.UpdateUserParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching user: not found", err.Error()) s.Require().Equal("updating user: fetching user: not found", err.Error())
} }
func (s *UserTestSuite) TestUpdateUserDBSaveErr() { func (s *UserTestSuite) TestUpdateUserDBSaveErr() {
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE username = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE username = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")).
WithArgs(s.Fixtures.Users[0].ID, 1). WithArgs(s.Fixtures.Users[0].ID, 1).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Users[0].ID)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Users[0].ID))
s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectExec(("UPDATE `users` SET")). ExpectExec(("UPDATE `users` SET")).
WillReturnError(fmt.Errorf("saving user mock error")) WillReturnError(fmt.Errorf("saving user mock error"))
@ -271,7 +271,7 @@ func (s *UserTestSuite) TestUpdateUserDBSaveErr() {
s.assertSQLMockExpectations() s.assertSQLMockExpectations()
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("saving user: saving user mock error", err.Error()) s.Require().Equal("updating user: saving user: saving user mock error", err.Error())
} }
func TestUserTestSuite(t *testing.T) { func TestUserTestSuite(t *testing.T) {

View file

@ -316,15 +316,16 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository, detailed bool) (par
func (s *sqlDatabase) sqlToParamsUser(user User) params.User { func (s *sqlDatabase) sqlToParamsUser(user User) params.User {
return params.User{ return params.User{
ID: user.ID.String(), ID: user.ID.String(),
CreatedAt: user.CreatedAt, CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt, UpdatedAt: user.UpdatedAt,
Email: user.Email, Email: user.Email,
Username: user.Username, Username: user.Username,
FullName: user.FullName, FullName: user.FullName,
Password: user.Password, Password: user.Password,
Enabled: user.Enabled, Enabled: user.Enabled,
IsAdmin: user.IsAdmin, IsAdmin: user.IsAdmin,
Generation: user.Generation,
} }
} }

View file

@ -104,10 +104,19 @@ func WithEntityFilter(entity params.GithubEntity) dbCommon.PayloadFilterFunc {
var ok bool var ok bool
switch payload.EntityType { switch payload.EntityType {
case dbCommon.RepositoryEntityType: case dbCommon.RepositoryEntityType:
if entity.EntityType != params.GithubEntityTypeRepository {
return false
}
ent, ok = payload.Payload.(params.Repository) ent, ok = payload.Payload.(params.Repository)
case dbCommon.OrganizationEntityType: case dbCommon.OrganizationEntityType:
if entity.EntityType != params.GithubEntityTypeOrganization {
return false
}
ent, ok = payload.Payload.(params.Organization) ent, ok = payload.Payload.(params.Organization)
case dbCommon.EnterpriseEntityType: case dbCommon.EnterpriseEntityType:
if entity.EntityType != params.GithubEntityTypeEnterprise {
return false
}
ent, ok = payload.Payload.(params.Enterprise) ent, ok = payload.Payload.(params.Enterprise)
default: default:
return false return false
@ -165,3 +174,17 @@ func WithGithubCredentialsFilter(creds params.GithubCredentials) dbCommon.Payloa
return credsPayload.ID == creds.ID return credsPayload.ID == creds.ID
} }
} }
// WithUserIDFilter returns a filter function that filters payloads by user ID.
func WithUserIDFilter(userID string) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
if payload.EntityType != dbCommon.UserEntityType {
return false
}
userPayload, ok := payload.Payload.(params.User)
if !ok {
return false
}
return userPayload.ID == userID
}
}

View file

@ -57,7 +57,7 @@ func ImpersonateAdminContext(ctx context.Context, db common.Store, s *testing.T)
s.Fatalf("failed to create admin user: %v", err) s.Fatalf("failed to create admin user: %v", err)
} }
} }
ctx = auth.PopulateContext(ctx, adminUser) ctx = auth.PopulateContext(ctx, adminUser, nil)
return ctx return ctx
} }

View file

@ -543,9 +543,11 @@ type User struct {
Email string `json:"email"` Email string `json:"email"`
Username string `json:"username"` Username string `json:"username"`
FullName string `json:"full_name"` FullName string `json:"full_name"`
Password string `json:"-"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
IsAdmin bool `json:"is_admin"` IsAdmin bool `json:"is_admin"`
// Do not serialize sensitive info.
Password string `json:"-"`
Generation uint `json:"-"`
} }
// JWTResponse holds the JWT token returned as a result of a // JWTResponse holds the JWT token returned as a result of a

View file

@ -1,11 +1,20 @@
package websocket package websocket
import ( import (
"context"
"fmt"
"log/slog" "log/slog"
"sync"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/params"
) )
const ( const (
@ -22,13 +31,34 @@ const (
maxMessageSize = 1024 maxMessageSize = 1024
) )
func NewClient(conn *websocket.Conn, hub *Hub) (*Client, error) { type HandleWebsocketMessage func([]byte) error
func NewClient(ctx context.Context, conn *websocket.Conn) (*Client, error) {
clientID := uuid.New() clientID := uuid.New()
consumerID := fmt.Sprintf("ws-client-watcher-%s", clientID.String())
user := auth.UserID(ctx)
if user == "" {
return nil, fmt.Errorf("user not found in context")
}
generation := auth.PasswordGeneration(ctx)
consumer, err := watcher.RegisterConsumer(
ctx, consumerID,
watcher.WithUserIDFilter(user),
)
if err != nil {
return nil, errors.Wrap(err, "registering consumer")
}
return &Client{ return &Client{
id: clientID.String(), id: clientID.String(),
conn: conn, conn: conn,
hub: hub, ctx: ctx,
send: make(chan []byte, 100), userID: user,
passwordGeneration: generation,
consumer: consumer,
done: make(chan struct{}),
send: make(chan []byte, 100),
}, nil }, nil
} }
@ -37,21 +67,84 @@ type Client struct {
conn *websocket.Conn conn *websocket.Conn
// Buffered channel of outbound messages. // Buffered channel of outbound messages.
send chan []byte send chan []byte
mux sync.Mutex
ctx context.Context
hub *Hub userID string
passwordGeneration uint
consumer common.Consumer
messageHandler HandleWebsocketMessage
running bool
done chan struct{}
} }
func (c *Client) Go() { func (c *Client) ID() string {
return c.id
}
func (c *Client) Stop() {
c.mux.Lock()
defer c.mux.Unlock()
if !c.running {
return
}
c.running = false
c.conn.Close()
close(c.send)
close(c.done)
}
func (c *Client) Done() <-chan struct{} {
return c.done
}
func (c *Client) SetMessageHandler(handler HandleWebsocketMessage) {
c.mux.Lock()
defer c.mux.Unlock()
c.messageHandler = handler
}
func (c *Client) Start() error {
c.mux.Lock()
defer c.mux.Unlock()
c.running = true
go c.runWatcher()
go c.clientReader() go c.clientReader()
go c.clientWriter() go c.clientWriter()
return nil
}
func (c *Client) Write(msg []byte) (int, error) {
c.mux.Lock()
defer c.mux.Unlock()
if !c.running {
return 0, fmt.Errorf("client is stopped")
}
tmp := make([]byte, len(msg))
copy(tmp, msg)
select {
case <-time.After(5 * time.Second):
return 0, fmt.Errorf("timed out sending message to client")
case c.send <- tmp:
}
return len(tmp), nil
} }
// clientReader waits for options changes from the client. The client can at any time // clientReader waits for options changes from the client. The client can at any time
// change the log level and binary name it watches. // change the log level and binary name it watches.
func (c *Client) clientReader() { func (c *Client) clientReader() {
defer func() { defer func() {
c.hub.unregister <- c c.Stop()
c.conn.Close()
}() }()
c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadLimit(maxMessageSize)
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
@ -64,10 +157,19 @@ func (c *Client) clientReader() {
return nil return nil
}) })
for { for {
mt, _, err := c.conn.ReadMessage() mt, data, err := c.conn.ReadMessage()
if err != nil { if err != nil {
if IsErrorOfInterest(err) {
slog.ErrorContext(c.ctx, "error reading websocket message", slog.Any("error", err))
}
break break
} }
if c.messageHandler != nil {
if err := c.messageHandler(data); err != nil {
slog.ErrorContext(c.ctx, "error handling message", slog.Any("error", err))
}
}
if mt == websocket.CloseMessage { if mt == websocket.CloseMessage {
break break
} }
@ -78,9 +180,14 @@ func (c *Client) clientReader() {
func (c *Client) clientWriter() { func (c *Client) clientWriter() {
ticker := time.NewTicker(pingPeriod) ticker := time.NewTicker(pingPeriod)
defer func() { defer func() {
c.Stop()
ticker.Stop() ticker.Stop()
c.conn.Close()
}() }()
var authExpires time.Time
expires := auth.Expires(c.ctx)
if expires != nil {
authExpires = *expires
}
for { for {
select { select {
case message, ok := <-c.send: case message, ok := <-c.send:
@ -90,13 +197,17 @@ func (c *Client) clientWriter() {
if !ok { if !ok {
// The hub closed the channel. // The hub closed the channel.
if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil { if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
slog.With(slog.Any("error", err)).Error("failed to write message") if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("failed to write message")
}
} }
return return
} }
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
slog.With(slog.Any("error", err)).Error("error sending message") if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("error sending message")
}
return return
} }
case <-ticker.C: case <-ticker.C:
@ -104,8 +215,81 @@ func (c *Client) clientWriter() {
slog.With(slog.Any("error", err)).Error("failed to set write deadline") slog.With(slog.Any("error", err)).Error("failed to set write deadline")
} }
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
if IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("failed to write ping message")
}
return return
} }
case <-c.ctx.Done():
return
case <-time.After(time.Until(authExpires)):
// Auth has expired
slog.DebugContext(c.ctx, "auth expired, closing connection")
return
} }
} }
} }
func (c *Client) runWatcher() {
defer func() {
c.Stop()
}()
for {
select {
case <-c.Done():
return
case <-c.ctx.Done():
return
case event, ok := <-c.consumer.Watch():
if !ok {
slog.InfoContext(c.ctx, "watcher closed")
return
}
go func(event common.ChangePayload) {
if event.EntityType != common.UserEntityType {
return
}
user, ok := event.Payload.(params.User)
if !ok {
slog.ErrorContext(c.ctx, "failed to cast payload to user")
return
}
if user.ID != c.userID {
return
}
if user.Generation != c.passwordGeneration {
slog.InfoContext(c.ctx, "password generation mismatch; closing connection")
c.Stop()
}
}(event)
}
}
}
func IsErrorOfInterest(err error) bool {
if err == nil {
return false
}
if errors.Is(err, websocket.ErrCloseSent) {
return false
}
if errors.Is(err, websocket.ErrBadHandshake) {
return false
}
asCloseErr, ok := err.(*websocket.CloseError)
if ok {
switch asCloseErr.Code {
case websocket.CloseNormalClosure, websocket.CloseGoingAway,
websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure:
return false
}
}
return true
}

View file

@ -3,19 +3,18 @@ package websocket
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog"
"sync" "sync"
"time" "time"
) )
func NewHub(ctx context.Context) *Hub { func NewHub(ctx context.Context) *Hub {
return &Hub{ return &Hub{
clients: map[string]*Client{}, clients: map[string]*Client{},
broadcast: make(chan []byte, 100), broadcast: make(chan []byte, 100),
register: make(chan *Client, 100), ctx: ctx,
unregister: make(chan *Client, 100), closed: make(chan struct{}),
ctx: ctx, quit: make(chan struct{}),
closed: make(chan struct{}),
quit: make(chan struct{}),
} }
} }
@ -29,12 +28,6 @@ type Hub struct {
// Inbound messages from the clients. // Inbound messages from the clients.
broadcast chan []byte broadcast chan []byte
// Register requests from the clients.
register chan *Client
// Unregister requests from clients.
unregister chan *Client
mux sync.Mutex mux sync.Mutex
once sync.Once once sync.Once
} }
@ -49,22 +42,6 @@ func (h *Hub) run() {
return return
case <-h.ctx.Done(): case <-h.ctx.Done():
return return
case client := <-h.register:
if client != nil {
h.mux.Lock()
h.clients[client.id] = client
h.mux.Unlock()
}
case client := <-h.unregister:
if client != nil {
h.mux.Lock()
if _, ok := h.clients[client.id]; ok {
client.conn.Close()
close(client.send)
delete(h.clients, client.id)
}
h.mux.Unlock()
}
case message := <-h.broadcast: case message := <-h.broadcast:
staleClients := []string{} staleClients := []string{}
for id, client := range h.clients { for id, client := range h.clients {
@ -73,9 +50,7 @@ func (h *Hub) run() {
continue continue
} }
select { if _, err := client.Write(message); err != nil {
case client.send <- message:
case <-time.After(5 * time.Second):
staleClients = append(staleClients, id) staleClients = append(staleClients, id)
} }
} }
@ -97,7 +72,35 @@ func (h *Hub) run() {
} }
func (h *Hub) Register(client *Client) error { func (h *Hub) Register(client *Client) error {
h.register <- client if client == nil {
return nil
}
h.mux.Lock()
defer h.mux.Unlock()
cli, ok := h.clients[client.ID()]
if ok {
if cli != nil {
return fmt.Errorf("client already registered")
}
}
slog.DebugContext(h.ctx, "registering client", "client_id", client.ID())
h.clients[client.id] = client
return nil
}
func (h *Hub) Unregister(client *Client) error {
if client == nil {
return nil
}
h.mux.Lock()
defer h.mux.Unlock()
cli, ok := h.clients[client.ID()]
if ok {
cli.Stop()
slog.DebugContext(h.ctx, "unregistering client", "client_id", cli.ID())
delete(h.clients, cli.ID())
slog.DebugContext(h.ctx, "current client count", "count", len(h.clients))
}
return nil return nil
} }