Refactor the websocket client and add fixes
The websocket client and hub interaction has been simplified a bit. The hub now acts only as a tee writer to the various clients that register. Clients must register and unregister explicitly. The hub is no longer passed in to the client. Websocket clients now watch for password changes or jwt token expiration times. Clients are disconnected if auth token expires or if the password is changed. Various aditional safety checks have been added. Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
This commit is contained in:
parent
ca7f20b62d
commit
dd1740c189
17 changed files with 426 additions and 143 deletions
|
|
@ -183,14 +183,9 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
|
|||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// nolint:golangci-lint,godox
|
||||
// TODO (gsamfira): Handle ExpiresAt. Right now, if a client uses
|
||||
// a valid token to authenticate, and keeps the websocket connection
|
||||
// open, it will allow that client to stream logs via websockets
|
||||
// until the connection is broken. We need to forcefully disconnect
|
||||
// the client once the token expires.
|
||||
client, err := wsWriter.NewClient(conn, a.hub)
|
||||
client, err := wsWriter.NewClient(ctx, conn)
|
||||
if err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client")
|
||||
return
|
||||
|
|
@ -199,7 +194,14 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
|
|||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to register new client")
|
||||
return
|
||||
}
|
||||
client.Go()
|
||||
defer a.hub.Unregister(client)
|
||||
|
||||
if err := client.Start(); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to start client")
|
||||
return
|
||||
}
|
||||
<-client.Done()
|
||||
slog.Info("client disconnected", "client_id", client.ID())
|
||||
}
|
||||
|
||||
// NotFoundHandler is returned when an invalid URL is acccessed
|
||||
|
|
|
|||
12
auth/auth.go
12
auth/auth.go
|
|
@ -55,6 +55,7 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) {
|
|||
expires := &jwt.NumericDate{
|
||||
Time: expireToken,
|
||||
}
|
||||
generation := PasswordGeneration(ctx)
|
||||
claims := JWTClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: expires,
|
||||
|
|
@ -62,10 +63,11 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) {
|
|||
// TODO: make this configurable
|
||||
Issuer: "garm",
|
||||
},
|
||||
UserID: UserID(ctx),
|
||||
TokenID: tokenID,
|
||||
IsAdmin: IsAdmin(ctx),
|
||||
FullName: FullName(ctx),
|
||||
UserID: UserID(ctx),
|
||||
TokenID: tokenID,
|
||||
IsAdmin: IsAdmin(ctx),
|
||||
FullName: FullName(ctx),
|
||||
Generation: generation,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(a.cfg.Secret))
|
||||
|
|
@ -182,5 +184,5 @@ func (a *Authenticator) AuthenticateUser(ctx context.Context, info params.Passwo
|
|||
return ctx, runnerErrors.ErrUnauthorized
|
||||
}
|
||||
|
||||
return PopulateContext(ctx, user), nil
|
||||
return PopulateContext(ctx, user, nil), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ package auth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
|
||||
"github.com/cloudbase/garm/params"
|
||||
|
|
@ -28,9 +29,11 @@ const (
|
|||
fullNameKey contextFlags = "full_name"
|
||||
readMetricsKey contextFlags = "read_metrics"
|
||||
// UserIDFlag is the User ID flag we set in the context
|
||||
UserIDFlag contextFlags = "user_id"
|
||||
isEnabledFlag contextFlags = "is_enabled"
|
||||
jwtTokenFlag contextFlags = "jwt_token"
|
||||
UserIDFlag contextFlags = "user_id"
|
||||
isEnabledFlag contextFlags = "is_enabled"
|
||||
jwtTokenFlag contextFlags = "jwt_token"
|
||||
authExpiresFlag contextFlags = "auth_expires"
|
||||
passwordGenerationFlag contextFlags = "password_generation"
|
||||
|
||||
instanceIDKey contextFlags = "id"
|
||||
instanceNameKey contextFlags = "name"
|
||||
|
|
@ -169,14 +172,43 @@ func PopulateInstanceContext(ctx context.Context, instance params.Instance) cont
|
|||
|
||||
// PopulateContext sets the appropriate fields in the context, based on
|
||||
// the user object
|
||||
func PopulateContext(ctx context.Context, user params.User) context.Context {
|
||||
func PopulateContext(ctx context.Context, user params.User, authExpires *time.Time) context.Context {
|
||||
ctx = SetUserID(ctx, user.ID)
|
||||
ctx = SetAdmin(ctx, user.IsAdmin)
|
||||
ctx = SetIsEnabled(ctx, user.Enabled)
|
||||
ctx = SetFullName(ctx, user.FullName)
|
||||
ctx = SetExpires(ctx, authExpires)
|
||||
ctx = SetPasswordGeneration(ctx, user.Generation)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func SetExpires(ctx context.Context, expires *time.Time) context.Context {
|
||||
if expires == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, authExpiresFlag, expires)
|
||||
}
|
||||
|
||||
func Expires(ctx context.Context) *time.Time {
|
||||
elem := ctx.Value(authExpiresFlag)
|
||||
if elem == nil {
|
||||
return nil
|
||||
}
|
||||
return elem.(*time.Time)
|
||||
}
|
||||
|
||||
func SetPasswordGeneration(ctx context.Context, val uint) context.Context {
|
||||
return context.WithValue(ctx, passwordGenerationFlag, val)
|
||||
}
|
||||
|
||||
func PasswordGeneration(ctx context.Context) uint {
|
||||
elem := ctx.Value(passwordGenerationFlag)
|
||||
if elem == nil {
|
||||
return 0
|
||||
}
|
||||
return elem.(uint)
|
||||
}
|
||||
|
||||
// SetFullName sets the user full name in the context
|
||||
func SetFullName(ctx context.Context, fullName string) context.Context {
|
||||
return context.WithValue(ctx, fullNameKey, fullName)
|
||||
|
|
|
|||
15
auth/jwt.go
15
auth/jwt.go
|
|
@ -21,6 +21,7 @@ import (
|
|||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ type JWTClaims struct {
|
|||
FullName string `json:"full_name"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
ReadMetrics bool `json:"read_metrics"`
|
||||
Generation uint `json:"generation"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
|
|
@ -69,7 +71,18 @@ func (amw *jwtMiddleware) claimsToContext(ctx context.Context, claims *JWTClaims
|
|||
return ctx, runnerErrors.ErrUnauthorized
|
||||
}
|
||||
|
||||
ctx = PopulateContext(ctx, userInfo)
|
||||
var expiresAt *time.Time
|
||||
if claims.ExpiresAt != nil {
|
||||
expires := claims.ExpiresAt.Time.UTC()
|
||||
expiresAt = &expires
|
||||
}
|
||||
|
||||
if userInfo.Generation != claims.Generation {
|
||||
// Password was reset since token was issued. Invalidate.
|
||||
return ctx, runnerErrors.ErrUnauthorized
|
||||
}
|
||||
|
||||
ctx = PopulateContext(ctx, userInfo, expiresAt)
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import (
|
|||
|
||||
"github.com/cloudbase/garm-provider-common/util"
|
||||
apiParams "github.com/cloudbase/garm/apiserver/params"
|
||||
garmWs "github.com/cloudbase/garm/websocket"
|
||||
)
|
||||
|
||||
var logCmd = &cobra.Command{
|
||||
|
|
@ -66,7 +67,9 @@ var logCmd = &cobra.Command{
|
|||
for {
|
||||
_, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
slog.With(slog.Any("error", err)).Error("reading log message")
|
||||
if garmWs.IsErrorOfInterest(err) {
|
||||
slog.With(slog.Any("error", err)).Error("reading log message")
|
||||
}
|
||||
return
|
||||
}
|
||||
fmt.Println(util.SanitizeLogEntry(string(message)))
|
||||
|
|
|
|||
|
|
@ -320,7 +320,7 @@ func main() {
|
|||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed")
|
||||
}
|
||||
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "waiting for runner to stop")
|
||||
slog.With(slog.Any("error", err)).InfoContext(ctx, "waiting for runner to stop")
|
||||
if err := runner.Wait(); err != nil {
|
||||
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers")
|
||||
os.Exit(1)
|
||||
|
|
|
|||
|
|
@ -284,7 +284,7 @@ func (s *GithubTestSuite) TestCreateCredentials() {
|
|||
func (s *GithubTestSuite) TestCreateCredentialsFailsOnDuplicateCredentials() {
|
||||
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
|
||||
testUser := garmTesting.CreateGARMTestUser(ctx, "testuser", s.db, s.T())
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser)
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
|
||||
|
||||
credParams := params.CreateGithubCredentialsParams{
|
||||
Name: testCredsName,
|
||||
|
|
@ -313,8 +313,8 @@ func (s *GithubTestSuite) TestNormalUsersCanOnlySeeTheirOwnCredentialsAdminCanSe
|
|||
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
|
||||
testUser := garmTesting.CreateGARMTestUser(ctx, "testuser1", s.db, s.T())
|
||||
testUser2 := garmTesting.CreateGARMTestUser(ctx, "testuser2", s.db, s.T())
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser)
|
||||
testUser2Ctx := auth.PopulateContext(context.Background(), testUser2)
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
|
||||
testUser2Ctx := auth.PopulateContext(context.Background(), testUser2, nil)
|
||||
|
||||
credParams := params.CreateGithubCredentialsParams{
|
||||
Name: testCredsName,
|
||||
|
|
@ -370,7 +370,7 @@ func (s *GithubTestSuite) TestGetGithubCredentialsFailsWhenCredentialsDontExist(
|
|||
func (s *GithubTestSuite) TestGetGithubCredentialsByNameReturnsOnlyCurrentUserCredentials() {
|
||||
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
|
||||
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user1", s.db, s.T())
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser)
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
|
||||
|
||||
credParams := params.CreateGithubCredentialsParams{
|
||||
Name: testCredsName,
|
||||
|
|
@ -472,7 +472,7 @@ func (s *GithubTestSuite) TestDeleteGithubCredentials() {
|
|||
func (s *GithubTestSuite) TestDeleteGithubCredentialsByNonAdminUser() {
|
||||
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
|
||||
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user4", s.db, s.T())
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser)
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
|
||||
|
||||
credParams := params.CreateGithubCredentialsParams{
|
||||
Name: testCredsName,
|
||||
|
|
@ -682,7 +682,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsForNonExistingCredentials()
|
|||
func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAdminUser() {
|
||||
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
|
||||
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T())
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser)
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
|
||||
|
||||
credParams := params.CreateGithubCredentialsParams{
|
||||
Name: testCredsName,
|
||||
|
|
@ -711,7 +711,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAd
|
|||
func (s *GithubTestSuite) TestAdminUserCanUpdateAnyGithubCredentials() {
|
||||
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
|
||||
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T())
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser)
|
||||
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
|
||||
|
||||
credParams := params.CreateGithubCredentialsParams{
|
||||
Name: testCredsName,
|
||||
|
|
|
|||
|
|
@ -195,12 +195,13 @@ type Instance struct {
|
|||
type User struct {
|
||||
Base
|
||||
|
||||
Username string `gorm:"uniqueIndex;varchar(64)"`
|
||||
FullName string `gorm:"type:varchar(254)"`
|
||||
Email string `gorm:"type:varchar(254);unique;index:idx_email"`
|
||||
Password string `gorm:"type:varchar(60)"`
|
||||
IsAdmin bool
|
||||
Enabled bool
|
||||
Username string `gorm:"uniqueIndex;varchar(64)"`
|
||||
FullName string `gorm:"type:varchar(254)"`
|
||||
Email string `gorm:"type:varchar(254);unique;index:idx_email"`
|
||||
Password string `gorm:"type:varchar(60)"`
|
||||
Generation uint
|
||||
IsAdmin bool
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
type ControllerInfo struct {
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ func (s *sqlDatabase) migrateCredentialsToDB() (err error) {
|
|||
// user. GARM is not yet multi-user, so it's safe to assume we only have this
|
||||
// one user.
|
||||
adminCtx := context.Background()
|
||||
adminCtx = auth.PopulateContext(adminCtx, adminUser)
|
||||
adminCtx = auth.PopulateContext(adminCtx, adminUser, nil)
|
||||
|
||||
slog.Info("migrating credentials to DB")
|
||||
slog.Info("creating github endpoints table")
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import (
|
|||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) {
|
||||
func (s *sqlDatabase) getUserByUsernameOrEmail(tx *gorm.DB, user string) (User, error) {
|
||||
field := "username"
|
||||
if util.IsValidEmail(user) {
|
||||
field = "email"
|
||||
|
|
@ -34,7 +34,7 @@ func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) {
|
|||
query := fmt.Sprintf("%s = ?", field)
|
||||
|
||||
var dbUser User
|
||||
q := s.conn.Model(&User{}).Where(query, user).First(&dbUser)
|
||||
q := tx.Model(&User{}).Where(query, user).First(&dbUser)
|
||||
if q.Error != nil {
|
||||
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
return User{}, runnerErrors.ErrNotFound
|
||||
|
|
@ -44,9 +44,9 @@ func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) {
|
|||
return dbUser, nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) getUserByID(userID string) (User, error) {
|
||||
func (s *sqlDatabase) getUserByID(tx *gorm.DB, userID string) (User, error) {
|
||||
var dbUser User
|
||||
q := s.conn.Model(&User{}).Where("id = ?", userID).First(&dbUser)
|
||||
q := tx.Model(&User{}).Where("id = ?", userID).First(&dbUser)
|
||||
if q.Error != nil {
|
||||
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
|
||||
return User{}, runnerErrors.ErrNotFound
|
||||
|
|
@ -57,20 +57,9 @@ func (s *sqlDatabase) getUserByID(userID string) (User, error) {
|
|||
}
|
||||
|
||||
func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) (params.User, error) {
|
||||
if user.Username == "" || user.Email == "" {
|
||||
return params.User{}, runnerErrors.NewBadRequestError("missing username or email")
|
||||
if user.Username == "" || user.Email == "" || user.Password == "" {
|
||||
return params.User{}, runnerErrors.NewBadRequestError("missing username, password or email")
|
||||
}
|
||||
if _, err := s.getUserByUsernameOrEmail(user.Username); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
|
||||
return params.User{}, runnerErrors.NewConflictError("username already exists")
|
||||
}
|
||||
if _, err := s.getUserByUsernameOrEmail(user.Email); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
|
||||
return params.User{}, runnerErrors.NewConflictError("email already exists")
|
||||
}
|
||||
|
||||
if s.HasAdminUser(context.Background()) && user.IsAdmin {
|
||||
return params.User{}, runnerErrors.NewBadRequestError("admin user already exists")
|
||||
}
|
||||
|
||||
newUser := User{
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
|
|
@ -79,22 +68,42 @@ func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) (
|
|||
Email: user.Email,
|
||||
IsAdmin: user.IsAdmin,
|
||||
}
|
||||
err := s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
if _, err := s.getUserByUsernameOrEmail(tx, user.Username); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
|
||||
return runnerErrors.NewConflictError("username already exists")
|
||||
}
|
||||
if _, err := s.getUserByUsernameOrEmail(tx, user.Email); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
|
||||
return runnerErrors.NewConflictError("email already exists")
|
||||
}
|
||||
|
||||
q := s.conn.Save(&newUser)
|
||||
if q.Error != nil {
|
||||
return params.User{}, errors.Wrap(q.Error, "creating user")
|
||||
if s.hasAdmin(tx) && user.IsAdmin {
|
||||
return runnerErrors.NewBadRequestError("admin user already exists")
|
||||
}
|
||||
|
||||
q := tx.Save(&newUser)
|
||||
if q.Error != nil {
|
||||
return errors.Wrap(q.Error, "creating user")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return params.User{}, errors.Wrap(err, "creating user")
|
||||
}
|
||||
return s.sqlToParamsUser(newUser), nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) HasAdminUser(_ context.Context) bool {
|
||||
func (s *sqlDatabase) hasAdmin(tx *gorm.DB) bool {
|
||||
var user User
|
||||
q := s.conn.Model(&User{}).Where("is_admin = ?", true).First(&user)
|
||||
q := tx.Model(&User{}).Where("is_admin = ?", true).First(&user)
|
||||
return q.Error == nil
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) HasAdminUser(_ context.Context) bool {
|
||||
return s.hasAdmin(s.conn)
|
||||
}
|
||||
|
||||
func (s *sqlDatabase) GetUser(_ context.Context, user string) (params.User, error) {
|
||||
dbUser, err := s.getUserByUsernameOrEmail(user)
|
||||
dbUser, err := s.getUserByUsernameOrEmail(s.conn, user)
|
||||
if err != nil {
|
||||
return params.User{}, errors.Wrap(err, "fetching user")
|
||||
}
|
||||
|
|
@ -102,7 +111,7 @@ func (s *sqlDatabase) GetUser(_ context.Context, user string) (params.User, erro
|
|||
}
|
||||
|
||||
func (s *sqlDatabase) GetUserByID(_ context.Context, userID string) (params.User, error) {
|
||||
dbUser, err := s.getUserByID(userID)
|
||||
dbUser, err := s.getUserByID(s.conn, userID)
|
||||
if err != nil {
|
||||
return params.User{}, errors.Wrap(err, "fetching user")
|
||||
}
|
||||
|
|
@ -110,27 +119,35 @@ func (s *sqlDatabase) GetUserByID(_ context.Context, userID string) (params.User
|
|||
}
|
||||
|
||||
func (s *sqlDatabase) UpdateUser(_ context.Context, user string, param params.UpdateUserParams) (params.User, error) {
|
||||
dbUser, err := s.getUserByUsernameOrEmail(user)
|
||||
var err error
|
||||
var dbUser User
|
||||
err = s.conn.Transaction(func(tx *gorm.DB) error {
|
||||
dbUser, err = s.getUserByUsernameOrEmail(tx, user)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "fetching user")
|
||||
}
|
||||
|
||||
if param.FullName != "" {
|
||||
dbUser.FullName = param.FullName
|
||||
}
|
||||
|
||||
if param.Enabled != nil {
|
||||
dbUser.Enabled = *param.Enabled
|
||||
}
|
||||
|
||||
if param.Password != "" {
|
||||
dbUser.Password = param.Password
|
||||
dbUser.Generation++
|
||||
}
|
||||
|
||||
if q := tx.Save(&dbUser); q.Error != nil {
|
||||
return errors.Wrap(q.Error, "saving user")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return params.User{}, errors.Wrap(err, "fetching user")
|
||||
return params.User{}, errors.Wrap(err, "updating user")
|
||||
}
|
||||
|
||||
if param.FullName != "" {
|
||||
dbUser.FullName = param.FullName
|
||||
}
|
||||
|
||||
if param.Enabled != nil {
|
||||
dbUser.Enabled = *param.Enabled
|
||||
}
|
||||
|
||||
if param.Password != "" {
|
||||
dbUser.Password = param.Password
|
||||
}
|
||||
|
||||
if q := s.conn.Save(&dbUser); q.Error != nil {
|
||||
return params.User{}, errors.Wrap(q.Error, "saving user")
|
||||
}
|
||||
|
||||
return s.sqlToParamsUser(dbUser), nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ func (s *UserTestSuite) TestCreateUserMissingUsernameEmail() {
|
|||
_, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams)
|
||||
|
||||
s.Require().NotNil(err)
|
||||
s.Require().Equal(("missing username or email"), err.Error())
|
||||
s.Require().Equal(("missing username, password or email"), err.Error())
|
||||
}
|
||||
|
||||
func (s *UserTestSuite) TestCreateUserUsernameAlreadyExist() {
|
||||
|
|
@ -154,7 +154,7 @@ func (s *UserTestSuite) TestCreateUserUsernameAlreadyExist() {
|
|||
_, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams)
|
||||
|
||||
s.Require().NotNil(err)
|
||||
s.Require().Equal(("username already exists"), err.Error())
|
||||
s.Require().Equal(("creating user: username already exists"), err.Error())
|
||||
}
|
||||
|
||||
func (s *UserTestSuite) TestCreateUserEmailAlreadyExist() {
|
||||
|
|
@ -163,10 +163,11 @@ func (s *UserTestSuite) TestCreateUserEmailAlreadyExist() {
|
|||
_, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams)
|
||||
|
||||
s.Require().NotNil(err)
|
||||
s.Require().Equal(("email already exists"), err.Error())
|
||||
s.Require().Equal(("creating user: email already exists"), err.Error())
|
||||
}
|
||||
|
||||
func (s *UserTestSuite) TestCreateUserDBCreateErr() {
|
||||
s.Fixtures.SQLMock.ExpectBegin()
|
||||
s.Fixtures.SQLMock.
|
||||
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE username = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")).
|
||||
WithArgs(s.Fixtures.NewUserParams.Username, 1).
|
||||
|
|
@ -175,7 +176,6 @@ func (s *UserTestSuite) TestCreateUserDBCreateErr() {
|
|||
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE email = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")).
|
||||
WithArgs(s.Fixtures.NewUserParams.Email, 1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
s.Fixtures.SQLMock.ExpectBegin()
|
||||
s.Fixtures.SQLMock.
|
||||
ExpectExec("INSERT INTO `users`").
|
||||
WillReturnError(fmt.Errorf("creating user mock error"))
|
||||
|
|
@ -183,9 +183,9 @@ func (s *UserTestSuite) TestCreateUserDBCreateErr() {
|
|||
|
||||
_, err := s.StoreSQLMocked.CreateUser(context.Background(), s.Fixtures.NewUserParams)
|
||||
|
||||
s.assertSQLMockExpectations()
|
||||
s.Require().NotNil(err)
|
||||
s.Require().Equal("creating user: creating user mock error", err.Error())
|
||||
s.Require().Equal("creating user: creating user: creating user mock error", err.Error())
|
||||
s.assertSQLMockExpectations()
|
||||
}
|
||||
|
||||
func (s *UserTestSuite) TestHasAdminUserNoAdmin() {
|
||||
|
|
@ -253,15 +253,15 @@ func (s *UserTestSuite) TestUpdateUserNotFound() {
|
|||
_, err := s.Store.UpdateUser(context.Background(), "dummy-user", s.Fixtures.UpdateUserParams)
|
||||
|
||||
s.Require().NotNil(err)
|
||||
s.Require().Equal("fetching user: not found", err.Error())
|
||||
s.Require().Equal("updating user: fetching user: not found", err.Error())
|
||||
}
|
||||
|
||||
func (s *UserTestSuite) TestUpdateUserDBSaveErr() {
|
||||
s.Fixtures.SQLMock.ExpectBegin()
|
||||
s.Fixtures.SQLMock.
|
||||
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE username = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")).
|
||||
WithArgs(s.Fixtures.Users[0].ID, 1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Users[0].ID))
|
||||
s.Fixtures.SQLMock.ExpectBegin()
|
||||
s.Fixtures.SQLMock.
|
||||
ExpectExec(("UPDATE `users` SET")).
|
||||
WillReturnError(fmt.Errorf("saving user mock error"))
|
||||
|
|
@ -271,7 +271,7 @@ func (s *UserTestSuite) TestUpdateUserDBSaveErr() {
|
|||
|
||||
s.assertSQLMockExpectations()
|
||||
s.Require().NotNil(err)
|
||||
s.Require().Equal("saving user: saving user mock error", err.Error())
|
||||
s.Require().Equal("updating user: saving user: saving user mock error", err.Error())
|
||||
}
|
||||
|
||||
func TestUserTestSuite(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -316,15 +316,16 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository, detailed bool) (par
|
|||
|
||||
func (s *sqlDatabase) sqlToParamsUser(user User) params.User {
|
||||
return params.User{
|
||||
ID: user.ID.String(),
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
Email: user.Email,
|
||||
Username: user.Username,
|
||||
FullName: user.FullName,
|
||||
Password: user.Password,
|
||||
Enabled: user.Enabled,
|
||||
IsAdmin: user.IsAdmin,
|
||||
ID: user.ID.String(),
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
Email: user.Email,
|
||||
Username: user.Username,
|
||||
FullName: user.FullName,
|
||||
Password: user.Password,
|
||||
Enabled: user.Enabled,
|
||||
IsAdmin: user.IsAdmin,
|
||||
Generation: user.Generation,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -104,10 +104,19 @@ func WithEntityFilter(entity params.GithubEntity) dbCommon.PayloadFilterFunc {
|
|||
var ok bool
|
||||
switch payload.EntityType {
|
||||
case dbCommon.RepositoryEntityType:
|
||||
if entity.EntityType != params.GithubEntityTypeRepository {
|
||||
return false
|
||||
}
|
||||
ent, ok = payload.Payload.(params.Repository)
|
||||
case dbCommon.OrganizationEntityType:
|
||||
if entity.EntityType != params.GithubEntityTypeOrganization {
|
||||
return false
|
||||
}
|
||||
ent, ok = payload.Payload.(params.Organization)
|
||||
case dbCommon.EnterpriseEntityType:
|
||||
if entity.EntityType != params.GithubEntityTypeEnterprise {
|
||||
return false
|
||||
}
|
||||
ent, ok = payload.Payload.(params.Enterprise)
|
||||
default:
|
||||
return false
|
||||
|
|
@ -165,3 +174,17 @@ func WithGithubCredentialsFilter(creds params.GithubCredentials) dbCommon.Payloa
|
|||
return credsPayload.ID == creds.ID
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserIDFilter returns a filter function that filters payloads by user ID.
|
||||
func WithUserIDFilter(userID string) dbCommon.PayloadFilterFunc {
|
||||
return func(payload dbCommon.ChangePayload) bool {
|
||||
if payload.EntityType != dbCommon.UserEntityType {
|
||||
return false
|
||||
}
|
||||
userPayload, ok := payload.Payload.(params.User)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return userPayload.ID == userID
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ func ImpersonateAdminContext(ctx context.Context, db common.Store, s *testing.T)
|
|||
s.Fatalf("failed to create admin user: %v", err)
|
||||
}
|
||||
}
|
||||
ctx = auth.PopulateContext(ctx, adminUser)
|
||||
ctx = auth.PopulateContext(ctx, adminUser, nil)
|
||||
return ctx
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -543,9 +543,11 @@ type User struct {
|
|||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
FullName string `json:"full_name"`
|
||||
Password string `json:"-"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
// Do not serialize sensitive info.
|
||||
Password string `json:"-"`
|
||||
Generation uint `json:"-"`
|
||||
}
|
||||
|
||||
// JWTResponse holds the JWT token returned as a result of a
|
||||
|
|
|
|||
|
|
@ -1,11 +1,20 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudbase/garm/auth"
|
||||
"github.com/cloudbase/garm/database/common"
|
||||
"github.com/cloudbase/garm/database/watcher"
|
||||
"github.com/cloudbase/garm/params"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -22,13 +31,34 @@ const (
|
|||
maxMessageSize = 1024
|
||||
)
|
||||
|
||||
func NewClient(conn *websocket.Conn, hub *Hub) (*Client, error) {
|
||||
type HandleWebsocketMessage func([]byte) error
|
||||
|
||||
func NewClient(ctx context.Context, conn *websocket.Conn) (*Client, error) {
|
||||
clientID := uuid.New()
|
||||
consumerID := fmt.Sprintf("ws-client-watcher-%s", clientID.String())
|
||||
|
||||
user := auth.UserID(ctx)
|
||||
if user == "" {
|
||||
return nil, fmt.Errorf("user not found in context")
|
||||
}
|
||||
generation := auth.PasswordGeneration(ctx)
|
||||
|
||||
consumer, err := watcher.RegisterConsumer(
|
||||
ctx, consumerID,
|
||||
watcher.WithUserIDFilter(user),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "registering consumer")
|
||||
}
|
||||
return &Client{
|
||||
id: clientID.String(),
|
||||
conn: conn,
|
||||
hub: hub,
|
||||
send: make(chan []byte, 100),
|
||||
id: clientID.String(),
|
||||
conn: conn,
|
||||
ctx: ctx,
|
||||
userID: user,
|
||||
passwordGeneration: generation,
|
||||
consumer: consumer,
|
||||
done: make(chan struct{}),
|
||||
send: make(chan []byte, 100),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -37,21 +67,84 @@ type Client struct {
|
|||
conn *websocket.Conn
|
||||
// Buffered channel of outbound messages.
|
||||
send chan []byte
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
|
||||
hub *Hub
|
||||
userID string
|
||||
passwordGeneration uint
|
||||
consumer common.Consumer
|
||||
|
||||
messageHandler HandleWebsocketMessage
|
||||
|
||||
running bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (c *Client) Go() {
|
||||
func (c *Client) ID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *Client) Stop() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if !c.running {
|
||||
return
|
||||
}
|
||||
|
||||
c.running = false
|
||||
c.conn.Close()
|
||||
close(c.send)
|
||||
close(c.done)
|
||||
}
|
||||
|
||||
func (c *Client) Done() <-chan struct{} {
|
||||
return c.done
|
||||
}
|
||||
|
||||
func (c *Client) SetMessageHandler(handler HandleWebsocketMessage) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.messageHandler = handler
|
||||
}
|
||||
|
||||
func (c *Client) Start() error {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.running = true
|
||||
|
||||
go c.runWatcher()
|
||||
go c.clientReader()
|
||||
go c.clientWriter()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Write(msg []byte) (int, error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if !c.running {
|
||||
return 0, fmt.Errorf("client is stopped")
|
||||
}
|
||||
|
||||
tmp := make([]byte, len(msg))
|
||||
copy(tmp, msg)
|
||||
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
return 0, fmt.Errorf("timed out sending message to client")
|
||||
case c.send <- tmp:
|
||||
}
|
||||
return len(tmp), nil
|
||||
}
|
||||
|
||||
// clientReader waits for options changes from the client. The client can at any time
|
||||
// change the log level and binary name it watches.
|
||||
func (c *Client) clientReader() {
|
||||
defer func() {
|
||||
c.hub.unregister <- c
|
||||
c.conn.Close()
|
||||
c.Stop()
|
||||
}()
|
||||
c.conn.SetReadLimit(maxMessageSize)
|
||||
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
|
||||
|
|
@ -64,10 +157,19 @@ func (c *Client) clientReader() {
|
|||
return nil
|
||||
})
|
||||
for {
|
||||
mt, _, err := c.conn.ReadMessage()
|
||||
mt, data, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if IsErrorOfInterest(err) {
|
||||
slog.ErrorContext(c.ctx, "error reading websocket message", slog.Any("error", err))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if c.messageHandler != nil {
|
||||
if err := c.messageHandler(data); err != nil {
|
||||
slog.ErrorContext(c.ctx, "error handling message", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
if mt == websocket.CloseMessage {
|
||||
break
|
||||
}
|
||||
|
|
@ -78,9 +180,14 @@ func (c *Client) clientReader() {
|
|||
func (c *Client) clientWriter() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
c.Stop()
|
||||
ticker.Stop()
|
||||
c.conn.Close()
|
||||
}()
|
||||
var authExpires time.Time
|
||||
expires := auth.Expires(c.ctx)
|
||||
if expires != nil {
|
||||
authExpires = *expires
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.send:
|
||||
|
|
@ -90,13 +197,17 @@ func (c *Client) clientWriter() {
|
|||
if !ok {
|
||||
// The hub closed the channel.
|
||||
if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
|
||||
slog.With(slog.Any("error", err)).Error("failed to write message")
|
||||
if IsErrorOfInterest(err) {
|
||||
slog.With(slog.Any("error", err)).Error("failed to write message")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
slog.With(slog.Any("error", err)).Error("error sending message")
|
||||
if IsErrorOfInterest(err) {
|
||||
slog.With(slog.Any("error", err)).Error("error sending message")
|
||||
}
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
|
|
@ -104,8 +215,81 @@ func (c *Client) clientWriter() {
|
|||
slog.With(slog.Any("error", err)).Error("failed to set write deadline")
|
||||
}
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
if IsErrorOfInterest(err) {
|
||||
slog.With(slog.Any("error", err)).Error("failed to write ping message")
|
||||
}
|
||||
return
|
||||
}
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Until(authExpires)):
|
||||
// Auth has expired
|
||||
slog.DebugContext(c.ctx, "auth expired, closing connection")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) runWatcher() {
|
||||
defer func() {
|
||||
c.Stop()
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case event, ok := <-c.consumer.Watch():
|
||||
if !ok {
|
||||
slog.InfoContext(c.ctx, "watcher closed")
|
||||
return
|
||||
}
|
||||
go func(event common.ChangePayload) {
|
||||
if event.EntityType != common.UserEntityType {
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := event.Payload.(params.User)
|
||||
if !ok {
|
||||
slog.ErrorContext(c.ctx, "failed to cast payload to user")
|
||||
return
|
||||
}
|
||||
|
||||
if user.ID != c.userID {
|
||||
return
|
||||
}
|
||||
|
||||
if user.Generation != c.passwordGeneration {
|
||||
slog.InfoContext(c.ctx, "password generation mismatch; closing connection")
|
||||
c.Stop()
|
||||
}
|
||||
}(event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func IsErrorOfInterest(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, websocket.ErrCloseSent) {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, websocket.ErrBadHandshake) {
|
||||
return false
|
||||
}
|
||||
|
||||
asCloseErr, ok := err.(*websocket.CloseError)
|
||||
if ok {
|
||||
switch asCloseErr.Code {
|
||||
case websocket.CloseNormalClosure, websocket.CloseGoingAway,
|
||||
websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,19 +3,18 @@ package websocket
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewHub(ctx context.Context) *Hub {
|
||||
return &Hub{
|
||||
clients: map[string]*Client{},
|
||||
broadcast: make(chan []byte, 100),
|
||||
register: make(chan *Client, 100),
|
||||
unregister: make(chan *Client, 100),
|
||||
ctx: ctx,
|
||||
closed: make(chan struct{}),
|
||||
quit: make(chan struct{}),
|
||||
clients: map[string]*Client{},
|
||||
broadcast: make(chan []byte, 100),
|
||||
ctx: ctx,
|
||||
closed: make(chan struct{}),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -29,12 +28,6 @@ type Hub struct {
|
|||
// Inbound messages from the clients.
|
||||
broadcast chan []byte
|
||||
|
||||
// Register requests from the clients.
|
||||
register chan *Client
|
||||
|
||||
// Unregister requests from clients.
|
||||
unregister chan *Client
|
||||
|
||||
mux sync.Mutex
|
||||
once sync.Once
|
||||
}
|
||||
|
|
@ -49,22 +42,6 @@ func (h *Hub) run() {
|
|||
return
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
case client := <-h.register:
|
||||
if client != nil {
|
||||
h.mux.Lock()
|
||||
h.clients[client.id] = client
|
||||
h.mux.Unlock()
|
||||
}
|
||||
case client := <-h.unregister:
|
||||
if client != nil {
|
||||
h.mux.Lock()
|
||||
if _, ok := h.clients[client.id]; ok {
|
||||
client.conn.Close()
|
||||
close(client.send)
|
||||
delete(h.clients, client.id)
|
||||
}
|
||||
h.mux.Unlock()
|
||||
}
|
||||
case message := <-h.broadcast:
|
||||
staleClients := []string{}
|
||||
for id, client := range h.clients {
|
||||
|
|
@ -73,9 +50,7 @@ func (h *Hub) run() {
|
|||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case client.send <- message:
|
||||
case <-time.After(5 * time.Second):
|
||||
if _, err := client.Write(message); err != nil {
|
||||
staleClients = append(staleClients, id)
|
||||
}
|
||||
}
|
||||
|
|
@ -97,7 +72,35 @@ func (h *Hub) run() {
|
|||
}
|
||||
|
||||
func (h *Hub) Register(client *Client) error {
|
||||
h.register <- client
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
cli, ok := h.clients[client.ID()]
|
||||
if ok {
|
||||
if cli != nil {
|
||||
return fmt.Errorf("client already registered")
|
||||
}
|
||||
}
|
||||
slog.DebugContext(h.ctx, "registering client", "client_id", client.ID())
|
||||
h.clients[client.id] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hub) Unregister(client *Client) error {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
cli, ok := h.clients[client.ID()]
|
||||
if ok {
|
||||
cli.Stop()
|
||||
slog.DebugContext(h.ctx, "unregistering client", "client_id", cli.ID())
|
||||
delete(h.clients, cli.ID())
|
||||
slog.DebugContext(h.ctx, "current client count", "count", len(h.clients))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue