Add some basic auth

This commit is contained in:
Gabriel Adrian Samfira 2022-04-28 16:13:20 +00:00
parent 66b46ae0ab
commit 0883fcd5cd
24 changed files with 1687 additions and 674 deletions

View file

@ -6,12 +6,12 @@ import (
)
type Store interface {
CreateRepository(ctx context.Context, owner, name, webhookSecret string) (params.Repository, error)
CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string) (params.Repository, error)
GetRepository(ctx context.Context, owner, name string) (params.Repository, error)
ListRepositories(ctx context.Context) ([]params.Repository, error)
DeleteRepository(ctx context.Context, owner, name string) error
CreateOrganization(ctx context.Context, name, webhookSecret string) (params.Organization, error)
CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string) (params.Organization, error)
GetOrganization(ctx context.Context, name string) (params.Organization, error)
ListOrganizations(ctx context.Context) ([]params.Organization, error)
DeleteOrganization(ctx context.Context, name string) error
@ -41,4 +41,12 @@ type Store interface {
// GetInstance(ctx context.Context, poolID string, instanceID string) (params.Instance, error)
GetInstanceByName(ctx context.Context, poolID string, instanceName string) (params.Instance, error)
CreateUser(ctx context.Context, user params.NewUserParams) (params.User, error)
GetUser(ctx context.Context, user string) (params.User, error)
UpdateUser(ctx context.Context, user string, param params.UpdateUserParams) (params.User, error)
HasAdminUser(ctx context.Context) bool
ControllerInfo() (params.ControllerInfo, error)
InitController() (params.ControllerInfo, error)
}

View file

@ -5,6 +5,7 @@ import (
"runner-manager/runner/providers/common"
"time"
"github.com/pkg/errors"
uuid "github.com/satori/go.uuid"
"gorm.io/gorm"
)
@ -21,7 +22,11 @@ func (b *Base) BeforeCreate(tx *gorm.DB) error {
if b.ID != emptyId {
return nil
}
b.ID = uuid.NewV4()
newID, err := uuid.NewV4()
if err != nil {
return errors.Wrap(err, "generating id")
}
b.ID = newID
return nil
}
@ -57,18 +62,20 @@ type Pool struct {
type Repository struct {
Base
Owner string `gorm:"index:idx_owner,unique"`
Name string `gorm:"index:idx_owner,unique"`
WebhookSecret []byte
Pools []Pool `gorm:"foreignKey:RepoID"`
CredentialsName string
Owner string `gorm:"index:idx_owner,unique"`
Name string `gorm:"index:idx_owner,unique"`
WebhookSecret []byte
Pools []Pool `gorm:"foreignKey:RepoID"`
}
type Organization struct {
Base
Name string `gorm:"uniqueIndex"`
WebhookSecret []byte
Pools []Pool `gorm:"foreignKey:OrgID"`
CredentialsName string
Name string `gorm:"uniqueIndex"`
WebhookSecret []byte
Pools []Pool `gorm:"foreignKey:OrgID"`
}
type Address struct {
@ -95,3 +102,20 @@ type Instance struct {
PoolID uuid.UUID
Pool Pool `gorm:"foreignKey:PoolID"`
}
type User struct {
Base
Username string `gorm:"uniqueIndex;varchar(64)"`
FullName string `gorm:"type:varchar(254)"`
Email string `gorm:"type:varchar(254);unique;index:idx_email"`
Password string `gorm:"type:varchar(60)"`
IsAdmin bool
Enabled bool
}
type ControllerInfo struct {
Base
ControllerID uuid.UUID
}

View file

@ -9,8 +9,8 @@ import (
"runner-manager/params"
"runner-manager/util"
"github.com/pborman/uuid"
"github.com/pkg/errors"
uuid "github.com/satori/go.uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
@ -41,12 +41,13 @@ type sqlDatabase struct {
func (s *sqlDatabase) migrateDB() error {
if err := s.conn.AutoMigrate(
&Tag{},
// &Runner{},
&Pool{},
&Repository{},
&Organization{},
&Address{},
&Instance{},
&ControllerInfo{},
&User{},
); err != nil {
return err
}
@ -85,10 +86,11 @@ func (s *sqlDatabase) sqlToCommonPool(pool Pool) params.Pool {
func (s *sqlDatabase) sqlToCommonRepository(repo Repository) params.Repository {
ret := params.Repository{
ID: repo.ID.String(),
Name: repo.Name,
Owner: repo.Owner,
Pools: make([]params.Pool, len(repo.Pools)),
ID: repo.ID.String(),
Name: repo.Name,
Owner: repo.Owner,
CredentialsName: repo.CredentialsName,
Pools: make([]params.Pool, len(repo.Pools)),
}
for idx, pool := range repo.Pools {
@ -100,15 +102,16 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository) params.Repository {
func (s *sqlDatabase) sqlToCommonOrganization(org Organization) params.Organization {
ret := params.Organization{
ID: org.ID.String(),
Name: org.Name,
Pools: make([]params.Pool, len(org.Pools)),
ID: org.ID.String(),
Name: org.Name,
CredentialsName: org.CredentialsName,
Pools: make([]params.Pool, len(org.Pools)),
}
return ret
}
func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, webhookSecret string) (params.Repository, error) {
func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, credentialsName, webhookSecret string) (params.Repository, error) {
secret := []byte{}
var err error
if webhookSecret != "" {
@ -118,9 +121,10 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, webhook
}
}
newRepo := Repository{
Name: name,
Owner: owner,
WebhookSecret: secret,
Name: name,
Owner: owner,
WebhookSecret: secret,
CredentialsName: credentialsName,
}
q := s.conn.Create(&newRepo)
@ -134,12 +138,18 @@ func (s *sqlDatabase) CreateRepository(ctx context.Context, owner, name, webhook
return param, nil
}
func (s *sqlDatabase) getRepo(ctx context.Context, owner, name string) (Repository, error) {
func (s *sqlDatabase) getRepo(ctx context.Context, owner, name string, preloadAll bool) (Repository, error) {
var repo Repository
q := s.conn.Preload(clause.Associations).
Where("name = ? and owner = ?", name, owner).
q := s.conn.Where("name = ? and owner = ?", name, owner).
First(&repo)
if preloadAll {
q = q.Preload(clause.Associations)
}
q = q.First(&repo)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return Repository{}, runnerErrors.ErrNotFound
@ -150,12 +160,12 @@ func (s *sqlDatabase) getRepo(ctx context.Context, owner, name string) (Reposito
}
func (s *sqlDatabase) getRepoByID(ctx context.Context, id string) (Repository, error) {
u := uuid.Parse(id)
if u == nil {
return Repository{}, errors.Wrap(runnerErrors.NewBadRequestError(""), "parsing id")
u, err := uuid.FromString(id)
if err != nil {
return Repository{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var repo Repository
q := s.conn.Preload(clause.Associations).
q := s.conn.
Where("id = ?", u).
First(&repo)
@ -169,7 +179,7 @@ func (s *sqlDatabase) getRepoByID(ctx context.Context, id string) (Repository, e
}
func (s *sqlDatabase) GetRepository(ctx context.Context, owner, name string) (params.Repository, error) {
repo, err := s.getRepo(ctx, owner, name)
repo, err := s.getRepo(ctx, owner, name, false)
if err != nil {
return params.Repository{}, errors.Wrap(err, "fetching repo")
}
@ -200,7 +210,7 @@ func (s *sqlDatabase) ListRepositories(ctx context.Context) ([]params.Repository
}
func (s *sqlDatabase) DeleteRepository(ctx context.Context, owner, name string) error {
repo, err := s.getRepo(ctx, owner, name)
repo, err := s.getRepo(ctx, owner, name, false)
if err != nil {
if err == runnerErrors.ErrNotFound {
return nil
@ -216,7 +226,7 @@ func (s *sqlDatabase) DeleteRepository(ctx context.Context, owner, name string)
return nil
}
func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, webhookSecret string) (params.Organization, error) {
func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, credentialsName, webhookSecret string) (params.Organization, error) {
secret := []byte{}
var err error
if webhookSecret != "" {
@ -226,8 +236,9 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, webhookSecre
}
}
newOrg := Organization{
Name: name,
WebhookSecret: secret,
Name: name,
WebhookSecret: secret,
CredentialsName: credentialsName,
}
q := s.conn.Create(&newOrg)
@ -241,9 +252,15 @@ func (s *sqlDatabase) CreateOrganization(ctx context.Context, name, webhookSecre
return param, nil
}
func (s *sqlDatabase) getOrg(ctx context.Context, name string) (Organization, error) {
func (s *sqlDatabase) getOrg(ctx context.Context, name string, preloadAll bool) (Organization, error) {
var org Organization
q := s.conn.Preload(clause.Associations).Where("name = ?", name).First(&org)
q := s.conn.Where("name = ?", name)
if preloadAll {
q = q.Preload(clause.Associations)
}
q = q.First(&org)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return Organization{}, runnerErrors.ErrNotFound
@ -253,13 +270,19 @@ func (s *sqlDatabase) getOrg(ctx context.Context, name string) (Organization, er
return org, nil
}
func (s *sqlDatabase) getOrgByID(ctx context.Context, id string) (Organization, error) {
u := uuid.Parse(id)
if u == nil {
return Organization{}, errors.Wrap(runnerErrors.NewBadRequestError(""), "parsing id")
func (s *sqlDatabase) getOrgByID(ctx context.Context, id string, preloadAll bool) (Organization, error) {
u, err := uuid.FromString(id)
if err != nil {
return Organization{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
q := s.conn.Where("id = ?", u)
if preloadAll {
q = q.Preload(clause.Associations)
}
var org Organization
q := s.conn.Preload(clause.Associations).Where("id = ?", u).First(&org)
q = q.First(&org)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return Organization{}, runnerErrors.ErrNotFound
@ -270,7 +293,7 @@ func (s *sqlDatabase) getOrgByID(ctx context.Context, id string) (Organization,
}
func (s *sqlDatabase) GetOrganization(ctx context.Context, name string) (params.Organization, error) {
org, err := s.getOrg(ctx, name)
org, err := s.getOrg(ctx, name, false)
if err != nil {
return params.Organization{}, errors.Wrap(err, "fetching repo")
}
@ -301,7 +324,7 @@ func (s *sqlDatabase) ListOrganizations(ctx context.Context) ([]params.Organizat
}
func (s *sqlDatabase) DeleteOrganization(ctx context.Context, name string) error {
org, err := s.getOrg(ctx, name)
org, err := s.getOrg(ctx, name, false)
if err != nil {
if err == runnerErrors.ErrNotFound {
return nil
@ -377,11 +400,7 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoId string, p
s.conn.Model(&newPool).Association("Tags").Append(&tt)
}
repo, err = s.getRepoByID(ctx, repoId)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching repo")
}
return s.sqlToCommonPool(repo.Pools[0]), nil
return s.sqlToCommonPool(newPool), nil
}
func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgId string, param params.CreatePoolParams) (params.Pool, error) {
@ -389,7 +408,7 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgId string,
return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified")
}
org, err := s.getOrgByID(ctx, orgId)
org, err := s.getOrgByID(ctx, orgId, false)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching org")
}
@ -422,14 +441,53 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgId string,
return s.sqlToCommonPool(newPool), nil
}
func (s *sqlDatabase) getRepoPools(ctx context.Context, repoID string, preloadAll bool) ([]Pool, error) {
repo, err := s.getRepoByID(ctx, repoID)
if err != nil {
return nil, errors.Wrap(err, "fetching repo")
}
var pools []Pool
q := s.conn.Model(&repo)
if preloadAll {
q = q.Preload(clause.Associations)
}
err = q.Association("Pools").Find(&pools)
if err != nil {
return nil, errors.Wrap(err, "fetching pool")
}
return pools, nil
}
func (s *sqlDatabase) getOrgPools(ctx context.Context, orgID string, preloadAll bool) ([]Pool, error) {
org, err := s.getOrgByID(ctx, orgID, preloadAll)
if err != nil {
return nil, errors.Wrap(err, "fetching repo")
}
var pools []Pool
q := s.conn.Model(&org)
if preloadAll {
q = q.Preload(clause.Associations)
}
err = q.Association("Pools").Find(&pools)
if err != nil {
return nil, errors.Wrap(err, "fetching pool")
}
return pools, nil
}
func (s *sqlDatabase) getRepoPool(ctx context.Context, repoID, poolID string) (Pool, error) {
repo, err := s.getRepoByID(ctx, repoID)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching repo")
}
u := uuid.Parse(poolID)
if u == nil {
return Pool{}, fmt.Errorf("invalid pool id")
u, err := uuid.FromString(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var pool []Pool
err = s.conn.Model(&repo).Association("Pools").Find(&pool, "id = ?", u)
@ -451,22 +509,24 @@ func (s *sqlDatabase) GetRepositoryPool(ctx context.Context, repoID, poolID stri
return s.sqlToCommonPool(pool), nil
}
func (s *sqlDatabase) getOrgPool(ctx context.Context, orgID, poolID string) (Pool, error) {
org, err := s.getOrgByID(ctx, orgID)
func (s *sqlDatabase) getOrgPool(ctx context.Context, orgID, poolID string, preloadAll bool) (Pool, error) {
org, err := s.getOrgByID(ctx, orgID, preloadAll)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching repo")
}
u := uuid.Parse(poolID)
if u == nil {
return Pool{}, fmt.Errorf("invalid pool id")
u, err := uuid.FromString(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var pool []Pool
err = s.conn.Model(&org).
Association(clause.Associations).
Find(&pool, "id = ?", u)
q := s.conn.Model(&org)
if preloadAll {
q = q.Preload(clause.Associations)
}
q = q.Find(&pool, "id = ?", u)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching pool")
if q.Error != nil {
return Pool{}, errors.Wrap(q.Error, "fetching pool")
}
if len(pool) == 0 {
return Pool{}, runnerErrors.ErrNotFound
@ -475,15 +535,18 @@ func (s *sqlDatabase) getOrgPool(ctx context.Context, orgID, poolID string) (Poo
return pool[0], nil
}
func (s *sqlDatabase) getPoolByID(ctx context.Context, poolID string) (Pool, error) {
u := uuid.Parse(poolID)
if u == nil {
func (s *sqlDatabase) getPoolByID(ctx context.Context, poolID string, preloadAll bool) (Pool, error) {
u, err := uuid.FromString(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var pool Pool
q := s.conn.Model(&Pool{}).
Preload(clause.Associations).
Where("id = ?", u).First(&pool)
q := s.conn.Model(&Pool{})
if preloadAll {
q = q.Preload(clause.Associations)
}
q = q.Where("id = ?", u).First(&pool)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
@ -495,7 +558,7 @@ func (s *sqlDatabase) getPoolByID(ctx context.Context, poolID string) (Pool, err
}
func (s *sqlDatabase) GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) {
pool, err := s.getOrgPool(ctx, orgID, poolID)
pool, err := s.getOrgPool(ctx, orgID, poolID, false)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
@ -518,7 +581,7 @@ func (s *sqlDatabase) DeleteRepositoryPool(ctx context.Context, repoID, poolID s
}
func (s *sqlDatabase) DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error {
pool, err := s.getOrgPool(ctx, orgID, poolID)
pool, err := s.getOrgPool(ctx, orgID, poolID, false)
if err != nil {
if errors.Is(err, runnerErrors.ErrNotFound) {
return nil
@ -536,9 +599,9 @@ func (s *sqlDatabase) findPoolByTags(id, poolType string, tags []string) (params
if len(tags) == 0 {
return params.Pool{}, runnerErrors.NewBadRequestError("missing tags")
}
u := uuid.Parse(id)
if u == nil {
return params.Pool{}, errors.Wrap(runnerErrors.NewBadRequestError(""), "parsing id")
u, err := uuid.FromString(id)
if err != nil {
return params.Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var pool Pool
@ -548,7 +611,7 @@ func (s *sqlDatabase) findPoolByTags(id, poolType string, tags []string) (params
Group("pools.id").
Preload("Tags").
Having("count(1) = ?", len(tags)).
Where(where, tags, id).First(&pool)
Where(where, tags, u).First(&pool)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
@ -605,7 +668,7 @@ func (s *sqlDatabase) sqlToParamsInstance(instance Instance) params.Instance {
}
func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) {
pool, err := s.getPoolByID(ctx, param.Pool)
pool, err := s.getPoolByID(ctx, param.Pool, false)
if err != nil {
return params.Instance{}, errors.Wrap(err, "fetching pool")
}
@ -631,8 +694,8 @@ func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param p
// }
func (s *sqlDatabase) getInstanceByID(ctx context.Context, instanceID string) (Instance, error) {
u := uuid.Parse(instanceID)
if u == nil {
u, err := uuid.FromString(instanceID)
if err != nil {
return Instance{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var instance Instance
@ -647,7 +710,7 @@ func (s *sqlDatabase) getInstanceByID(ctx context.Context, instanceID string) (I
}
func (s *sqlDatabase) getInstanceByName(ctx context.Context, poolID string, instanceName string) (Instance, error) {
pool, err := s.getPoolByID(ctx, poolID)
pool, err := s.getPoolByID(ctx, poolID, false)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return Instance{}, errors.Wrap(runnerErrors.ErrNotFound, "fetching instance")
@ -738,7 +801,7 @@ func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceID string, par
}
func (s *sqlDatabase) ListInstances(ctx context.Context, poolID string) ([]params.Instance, error) {
pool, err := s.getPoolByID(ctx, poolID)
pool, err := s.getPoolByID(ctx, poolID, true)
if err != nil {
return nil, errors.Wrap(err, "fetching pool")
}
@ -751,13 +814,13 @@ func (s *sqlDatabase) ListInstances(ctx context.Context, poolID string) ([]param
}
func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) {
repo, err := s.getRepoByID(ctx, repoID)
pools, err := s.getRepoPools(ctx, repoID, true)
if err != nil {
return nil, errors.Wrap(err, "fetching repo")
}
ret := []params.Instance{}
for _, pool := range repo.Pools {
for _, pool := range pools {
for _, instance := range pool.Instances {
ret = append(ret, s.sqlToParamsInstance(instance))
}
@ -766,7 +829,7 @@ func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]p
}
func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) {
org, err := s.getOrgByID(ctx, orgID)
org, err := s.getOrgByID(ctx, orgID, true)
if err != nil {
return nil, errors.Wrap(err, "fetching org")
}
@ -779,10 +842,208 @@ func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]par
return ret, nil
}
func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (params.Pool, error) {
if param.Enabled != nil && pool.Enabled != *param.Enabled {
pool.Enabled = *param.Enabled
}
if param.Flavor != "" {
pool.Flavor = param.Flavor
}
if param.Image != "" {
pool.Image = param.Image
}
if param.MaxRunners != nil {
pool.MaxRunners = *param.MaxRunners
}
if param.MinIdleRunners != nil {
pool.MinIdleRunners = *param.MinIdleRunners
}
if param.OSArch != "" {
pool.OSArch = param.OSArch
}
if param.OSType != "" {
pool.OSType = param.OSType
}
if q := s.conn.Save(&pool); q.Error != nil {
return params.Pool{}, errors.Wrap(q.Error, "saving database entry")
}
if len(param.Tags) > 0 {
tags := make([]Tag, len(param.Tags))
for idx, t := range param.Tags {
tags[idx] = Tag{
Name: t.Name,
}
}
if err := s.conn.Model(&pool).Association("Tags").Replace(&tags); err != nil {
return params.Pool{}, errors.Wrap(err, "replacing tags")
}
}
return s.sqlToCommonPool(pool), nil
}
func (s *sqlDatabase) UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
return params.Pool{}, nil
pool, err := s.getRepoPool(ctx, repoID, poolID)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return s.updatePool(pool, param)
}
func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
return params.Pool{}, nil
pool, err := s.getOrgPool(ctx, orgID, poolID, true)
if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool")
}
return s.updatePool(pool, param)
}
func (s *sqlDatabase) sqlToParamsUser(user User) params.User {
return params.User{
ID: user.ID.String(),
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
Email: user.Email,
Username: user.Username,
FullName: user.FullName,
Password: user.Password,
Enabled: user.Enabled,
IsAdmin: user.IsAdmin,
}
}
func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) {
field := "username"
if util.IsValidEmail(user) {
field = "email"
}
query := fmt.Sprintf("%s = ?", field)
var dbUser User
q := s.conn.Model(&User{}).Where(query, user).First(&dbUser)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return User{}, runnerErrors.ErrNotFound
}
return User{}, errors.Wrap(q.Error, "fetching user")
}
return dbUser, nil
}
func (s *sqlDatabase) CreateUser(ctx context.Context, user params.NewUserParams) (params.User, error) {
if user.Username == "" || user.Email == "" {
return params.User{}, runnerErrors.NewBadRequestError("missing username or email")
}
if _, err := s.getUserByUsernameOrEmail(user.Username); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
return params.User{}, runnerErrors.NewConflictError("username already exists")
}
if _, err := s.getUserByUsernameOrEmail(user.Email); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) {
return params.User{}, runnerErrors.NewConflictError("email already exists")
}
newUser := User{
Username: user.Username,
Password: user.Password,
FullName: user.FullName,
Enabled: user.Enabled,
Email: user.Email,
IsAdmin: user.IsAdmin,
}
q := s.conn.Save(&newUser)
if q.Error != nil {
return params.User{}, errors.Wrap(q.Error, "creating user")
}
return params.User{}, nil
}
func (s *sqlDatabase) HasAdminUser(ctx context.Context) bool {
var user User
q := s.conn.Model(&User{}).Where("is_admin = ?", true).First(&user)
if q.Error != nil {
return false
}
return true
}
func (s *sqlDatabase) GetUser(ctx context.Context, user string) (params.User, error) {
dbUser, err := s.getUserByUsernameOrEmail(user)
if err != nil {
return params.User{}, errors.Wrap(err, "fetching user")
}
return s.sqlToParamsUser(dbUser), nil
}
func (s *sqlDatabase) UpdateUser(ctx context.Context, user string, param params.UpdateUserParams) (params.User, error) {
dbUser, err := s.getUserByUsernameOrEmail(user)
if err != nil {
return params.User{}, errors.Wrap(err, "fetching user")
}
if param.FullName != "" {
dbUser.FullName = param.FullName
}
if param.Enabled != nil {
dbUser.Enabled = *param.Enabled
}
if param.Password != "" {
dbUser.Password = param.Password
}
if q := s.conn.Save(&dbUser); q.Error != nil {
return params.User{}, errors.Wrap(q.Error, "saving user")
}
return s.sqlToParamsUser(dbUser), nil
}
func (s *sqlDatabase) ControllerInfo() (params.ControllerInfo, error) {
var info ControllerInfo
q := s.conn.Model(&ControllerInfo{}).First(&info)
if q.Error != nil {
if errors.Is(q.Error, gorm.ErrRecordNotFound) {
return params.ControllerInfo{}, errors.Wrap(runnerErrors.ErrNotFound, "fetching controller info")
}
return params.ControllerInfo{}, errors.Wrap(q.Error, "fetching controller info")
}
return params.ControllerInfo{
ControllerID: info.ControllerID,
}, nil
}
func (s *sqlDatabase) InitController() (params.ControllerInfo, error) {
if _, err := s.ControllerInfo(); err == nil {
return params.ControllerInfo{}, runnerErrors.NewConflictError("controller already initialized")
}
newID, err := uuid.NewV4()
if err != nil {
return params.ControllerInfo{}, errors.Wrap(err, "generating UUID")
}
newInfo := ControllerInfo{
ControllerID: newID,
}
q := s.conn.Save(&newInfo)
if q.Error != nil {
return params.ControllerInfo{}, errors.Wrap(q.Error, "saving controller info")
}
return params.ControllerInfo{
ControllerID: newInfo.ControllerID,
}, nil
}