Merge pull request #241 from gabriel-samfira/some-cleanup

Slightly simplify code
This commit is contained in:
Gabriel 2024-03-30 20:35:24 +02:00 committed by GitHub
commit 9525e013da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 42 additions and 146 deletions

View file

@ -74,7 +74,7 @@ type UserStore interface {
type InstanceStore interface {
CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error)
DeleteInstance(ctx context.Context, poolID string, instanceName string) error
UpdateInstance(ctx context.Context, instanceID string, param params.UpdateInstanceParams) (params.Instance, error)
UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error)
// Probably a bad idea without some king of filter or at least pagination
//
@ -83,8 +83,7 @@ type InstanceStore interface {
ListAllInstances(ctx context.Context) ([]params.Instance, error)
GetInstanceByName(ctx context.Context, instanceName string) (params.Instance, error)
AddInstanceEvent(ctx context.Context, instanceID string, event params.EventType, eventLevel params.EventLevel, eventMessage string) error
ListInstanceEvents(ctx context.Context, instanceID string, eventType params.EventType, eventLevel params.EventLevel) ([]params.StatusMessage, error)
AddInstanceEvent(ctx context.Context, instanceName string, event params.EventType, eventLevel params.EventLevel, eventMessage string) error
}
type JobsStore interface {

View file

@ -14,9 +14,9 @@ type Store struct {
mock.Mock
}
// AddInstanceEvent provides a mock function with given fields: ctx, instanceID, event, eventLevel, eventMessage
func (_m *Store) AddInstanceEvent(ctx context.Context, instanceID string, event params.EventType, eventLevel params.EventLevel, eventMessage string) error {
ret := _m.Called(ctx, instanceID, event, eventLevel, eventMessage)
// AddInstanceEvent provides a mock function with given fields: ctx, instanceName, event, eventLevel, eventMessage
func (_m *Store) AddInstanceEvent(ctx context.Context, instanceName string, event params.EventType, eventLevel params.EventLevel, eventMessage string) error {
ret := _m.Called(ctx, instanceName, event, eventLevel, eventMessage)
if len(ret) == 0 {
panic("no return value specified for AddInstanceEvent")
@ -24,7 +24,7 @@ func (_m *Store) AddInstanceEvent(ctx context.Context, instanceID string, event
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, params.EventType, params.EventLevel, string) error); ok {
r0 = rf(ctx, instanceID, event, eventLevel, eventMessage)
r0 = rf(ctx, instanceName, event, eventLevel, eventMessage)
} else {
r0 = ret.Error(0)
}
@ -1068,36 +1068,6 @@ func (_m *Store) ListEntityPools(ctx context.Context, entity params.GithubEntity
return r0, r1
}
// ListInstanceEvents provides a mock function with given fields: ctx, instanceID, eventType, eventLevel
func (_m *Store) ListInstanceEvents(ctx context.Context, instanceID string, eventType params.EventType, eventLevel params.EventLevel) ([]params.StatusMessage, error) {
ret := _m.Called(ctx, instanceID, eventType, eventLevel)
if len(ret) == 0 {
panic("no return value specified for ListInstanceEvents")
}
var r0 []params.StatusMessage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, params.EventType, params.EventLevel) ([]params.StatusMessage, error)); ok {
return rf(ctx, instanceID, eventType, eventLevel)
}
if rf, ok := ret.Get(0).(func(context.Context, string, params.EventType, params.EventLevel) []params.StatusMessage); ok {
r0 = rf(ctx, instanceID, eventType, eventLevel)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]params.StatusMessage)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, params.EventType, params.EventLevel) error); ok {
r1 = rf(ctx, instanceID, eventType, eventLevel)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListJobsByStatus provides a mock function with given fields: ctx, status
func (_m *Store) ListJobsByStatus(ctx context.Context, status params.JobStatus) ([]params.Job, error) {
ret := _m.Called(ctx, status)
@ -1338,9 +1308,9 @@ func (_m *Store) UpdateEntityPool(ctx context.Context, entity params.GithubEntit
return r0, r1
}
// UpdateInstance provides a mock function with given fields: ctx, instanceID, param
func (_m *Store) UpdateInstance(ctx context.Context, instanceID string, param params.UpdateInstanceParams) (params.Instance, error) {
ret := _m.Called(ctx, instanceID, param)
// UpdateInstance provides a mock function with given fields: ctx, instanceName, param
func (_m *Store) UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) {
ret := _m.Called(ctx, instanceName, param)
if len(ret) == 0 {
panic("no return value specified for UpdateInstance")
@ -1349,16 +1319,16 @@ func (_m *Store) UpdateInstance(ctx context.Context, instanceID string, param pa
var r0 params.Instance
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, params.UpdateInstanceParams) (params.Instance, error)); ok {
return rf(ctx, instanceID, param)
return rf(ctx, instanceName, param)
}
if rf, ok := ret.Get(0).(func(context.Context, string, params.UpdateInstanceParams) params.Instance); ok {
r0 = rf(ctx, instanceID, param)
r0 = rf(ctx, instanceName, param)
} else {
r0 = ret.Get(0).(params.Instance)
}
if rf, ok := ret.Get(1).(func(context.Context, string, params.UpdateInstanceParams) error); ok {
r1 = rf(ctx, instanceID, param)
r1 = rf(ctx, instanceName, param)
} else {
r1 = ret.Error(1)
}

View file

@ -92,22 +92,6 @@ func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param par
return s.sqlToParamsInstance(newInstance)
}
func (s *sqlDatabase) getInstanceByID(_ context.Context, instanceID string) (Instance, error) {
u, err := uuid.Parse(instanceID)
if err != nil {
return Instance{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
var instance Instance
q := s.conn.Model(&Instance{}).
Preload(clause.Associations).
Where("id = ?", u).
First(&instance)
if q.Error != nil {
return Instance{}, errors.Wrap(q.Error, "fetching instance")
}
return instance, nil
}
func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) (Instance, error) {
pool, err := s.getPoolByID(s.conn, poolID)
if err != nil {
@ -184,34 +168,8 @@ func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceN
return nil
}
func (s *sqlDatabase) ListInstanceEvents(_ context.Context, instanceID string, eventType params.EventType, eventLevel params.EventLevel) ([]params.StatusMessage, error) {
var events []InstanceStatusUpdate
query := s.conn.Model(&InstanceStatusUpdate{}).Where("instance_id = ?", instanceID)
if eventLevel != "" {
query = query.Where("event_level = ?", eventLevel)
}
if eventType != "" {
query = query.Where("event_type = ?", eventType)
}
if result := query.Find(&events); result.Error != nil {
return nil, errors.Wrap(result.Error, "fetching events")
}
eventParams := make([]params.StatusMessage, len(events))
for idx, val := range events {
eventParams[idx] = params.StatusMessage{
Message: val.Message,
EventType: val.EventType,
EventLevel: val.EventLevel,
}
}
return eventParams, nil
}
func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceID string, event params.EventType, eventLevel params.EventLevel, statusMessage string) error {
instance, err := s.getInstanceByID(ctx, instanceID)
func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceName string, event params.EventType, eventLevel params.EventLevel, statusMessage string) error {
instance, err := s.getInstanceByName(ctx, instanceName)
if err != nil {
return errors.Wrap(err, "updating instance")
}
@ -228,8 +186,8 @@ func (s *sqlDatabase) AddInstanceEvent(ctx context.Context, instanceID string, e
return nil
}
func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceID string, param params.UpdateInstanceParams) (params.Instance, error) {
instance, err := s.getInstanceByID(ctx, instanceID)
func (s *sqlDatabase) UpdateInstance(ctx context.Context, instanceName string, param params.UpdateInstanceParams) (params.Instance, error) {
instance, err := s.getInstanceByName(ctx, instanceName)
if err != nil {
return params.Instance{}, errors.Wrap(err, "updating instance")
}

View file

@ -357,7 +357,7 @@ func (s *InstancesTestSuite) TestAddInstanceEvent() {
storeInstance := s.Fixtures.Instances[0]
statusMsg := "test-status-message"
err := s.Store.AddInstanceEvent(context.Background(), storeInstance.ID, params.StatusEvent, params.EventInfo, statusMsg)
err := s.Store.AddInstanceEvent(context.Background(), storeInstance.Name, params.StatusEvent, params.EventInfo, statusMsg)
s.Require().Nil(err)
instance, err := s.Store.GetInstanceByName(context.Background(), storeInstance.Name)
@ -368,19 +368,13 @@ func (s *InstancesTestSuite) TestAddInstanceEvent() {
s.Require().Equal(statusMsg, instance.StatusMessages[0].Message)
}
func (s *InstancesTestSuite) TestAddInstanceEventInvalidPoolID() {
err := s.Store.AddInstanceEvent(context.Background(), "dummy-id", params.StatusEvent, params.EventInfo, "dummy-message")
s.Require().Equal("updating instance: parsing id: invalid request", err.Error())
}
func (s *InstancesTestSuite) TestAddInstanceEventDBUpdateErr() {
instance := s.Fixtures.Instances[0]
statusMsg := "test-status-message"
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE id = ? AND `instances`.`deleted_at` IS NULL ORDER BY `instances`.`id` LIMIT 1")).
WithArgs(instance.ID).
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE name = ? AND `instances`.`deleted_at` IS NULL ORDER BY `instances`.`id` LIMIT 1")).
WithArgs(instance.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(instance.ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `addresses` WHERE `addresses`.`instance_id` = ? AND `addresses`.`deleted_at` IS NULL")).
@ -404,15 +398,15 @@ func (s *InstancesTestSuite) TestAddInstanceEventDBUpdateErr() {
WillReturnError(fmt.Errorf("mocked add status message error"))
s.Fixtures.SQLMock.ExpectRollback()
err := s.StoreSQLMocked.AddInstanceEvent(context.Background(), instance.ID, params.StatusEvent, params.EventInfo, statusMsg)
err := s.StoreSQLMocked.AddInstanceEvent(context.Background(), instance.Name, params.StatusEvent, params.EventInfo, statusMsg)
s.assertSQLMockExpectations()
s.Require().NotNil(err)
s.Require().Equal("adding status message: mocked add status message error", err.Error())
s.assertSQLMockExpectations()
}
func (s *InstancesTestSuite) TestUpdateInstance() {
instance, err := s.Store.UpdateInstance(context.Background(), s.Fixtures.Instances[0].ID, s.Fixtures.UpdateInstanceParams)
instance, err := s.Store.UpdateInstance(context.Background(), s.Fixtures.Instances[0].Name, s.Fixtures.UpdateInstanceParams)
s.Require().Nil(err)
s.Require().Equal(s.Fixtures.UpdateInstanceParams.ProviderID, instance.ProviderID)
@ -424,18 +418,12 @@ func (s *InstancesTestSuite) TestUpdateInstance() {
s.Require().Equal(s.Fixtures.UpdateInstanceParams.CreateAttempt, instance.CreateAttempt)
}
func (s *InstancesTestSuite) TestUpdateInstanceInvalidPoolID() {
_, err := s.Store.UpdateInstance(context.Background(), "dummy-id", params.UpdateInstanceParams{})
s.Require().Equal("updating instance: parsing id: invalid request", err.Error())
}
func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateInstanceErr() {
instance := s.Fixtures.Instances[0]
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE id = ? AND `instances`.`deleted_at` IS NULL ORDER BY `instances`.`id` LIMIT 1")).
WithArgs(instance.ID).
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE name = ? AND `instances`.`deleted_at` IS NULL ORDER BY `instances`.`id` LIMIT 1")).
WithArgs(instance.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(instance.ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `addresses` WHERE `addresses`.`instance_id` = ? AND `addresses`.`deleted_at` IS NULL")).
@ -455,19 +443,19 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateInstanceErr() {
WillReturnError(fmt.Errorf("mocked update instance error"))
s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateInstance(context.Background(), instance.ID, s.Fixtures.UpdateInstanceParams)
_, err := s.StoreSQLMocked.UpdateInstance(context.Background(), instance.Name, s.Fixtures.UpdateInstanceParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err)
s.Require().Equal("updating instance: mocked update instance error", err.Error())
s.assertSQLMockExpectations()
}
func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateAddressErr() {
instance := s.Fixtures.Instances[0]
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE id = ? AND `instances`.`deleted_at` IS NULL ORDER BY `instances`.`id` LIMIT 1")).
WithArgs(instance.ID).
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `instances` WHERE name = ? AND `instances`.`deleted_at` IS NULL ORDER BY `instances`.`id` LIMIT 1")).
WithArgs(instance.Name).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(instance.ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `addresses` WHERE `addresses`.`instance_id` = ? AND `addresses`.`deleted_at` IS NULL")).
@ -501,11 +489,11 @@ func (s *InstancesTestSuite) TestUpdateInstanceDBUpdateAddressErr() {
WillReturnError(fmt.Errorf("update addresses mock error"))
s.Fixtures.SQLMock.ExpectRollback()
_, err := s.StoreSQLMocked.UpdateInstance(context.Background(), instance.ID, s.Fixtures.UpdateInstanceParams)
_, err := s.StoreSQLMocked.UpdateInstance(context.Background(), instance.Name, s.Fixtures.UpdateInstanceParams)
s.assertSQLMockExpectations()
s.Require().NotNil(err)
s.Require().Equal("updating addresses: update addresses mock error", err.Error())
s.assertSQLMockExpectations()
}
func (s *InstancesTestSuite) TestListPoolInstances() {

View file

@ -166,11 +166,11 @@ func (r *Runner) GetInstanceGithubRegistrationToken(ctx context.Context) (string
TokenFetched: &tokenFetched,
}
if _, err := r.store.UpdateInstance(r.ctx, instance.ID, updateParams); err != nil {
if _, err := r.store.UpdateInstance(r.ctx, instance.Name, updateParams); err != nil {
return "", errors.Wrap(err, "setting token_fetched for instance")
}
if err := r.store.AddInstanceEvent(ctx, instance.ID, params.FetchTokenEvent, params.EventInfo, "runner registration token was retrieved"); err != nil {
if err := r.store.AddInstanceEvent(ctx, instance.Name, params.FetchTokenEvent, params.EventInfo, "runner registration token was retrieved"); err != nil {
return "", errors.Wrap(err, "recording event")
}

View file

@ -745,20 +745,6 @@ func (r *basePoolManager) waitForErrorGroupOrContextCancelled(g *errgroup.Group)
}
}
func (r *basePoolManager) fetchInstance(runnerName string) (params.Instance, error) {
runner, err := r.store.GetInstanceByName(r.ctx, runnerName)
if err != nil {
return params.Instance{}, errors.Wrap(err, "fetching instance")
}
_, err = r.GetPoolByID(runner.PoolID)
if err != nil {
return params.Instance{}, errors.Wrap(err, "fetching pool")
}
return runner, nil
}
func (r *basePoolManager) setInstanceRunnerStatus(runnerName string, status params.RunnerStatus) (params.Instance, error) {
updateParams := params.UpdateInstanceParams{
RunnerStatus: status,
@ -772,12 +758,7 @@ func (r *basePoolManager) setInstanceRunnerStatus(runnerName string, status para
}
func (r *basePoolManager) updateInstance(runnerName string, update params.UpdateInstanceParams) (params.Instance, error) {
runner, err := r.fetchInstance(runnerName)
if err != nil {
return params.Instance{}, errors.Wrap(err, "fetching instance")
}
instance, err := r.store.UpdateInstance(r.ctx, runner.ID, update)
instance, err := r.store.UpdateInstance(r.ctx, runnerName, update)
if err != nil {
return params.Instance{}, errors.Wrap(err, "updating runner state")
}
@ -980,7 +961,7 @@ func (r *basePoolManager) addInstanceToProvider(instance params.Instance) error
}
updateInstanceArgs := r.updateArgsFromProviderInstance(providerInstance)
if _, err := r.store.UpdateInstance(r.ctx, instance.ID, updateInstanceArgs); err != nil {
if _, err := r.store.UpdateInstance(r.ctx, instance.Name, updateInstanceArgs); err != nil {
return errors.Wrap(err, "updating instance")
}
return nil

View file

@ -870,12 +870,12 @@ func (r *Runner) ListAllInstances(ctx context.Context) ([]params.Instance, error
}
func (r *Runner) AddInstanceStatusMessage(ctx context.Context, param params.InstanceUpdateMessage) error {
instanceID := auth.InstanceID(ctx)
if instanceID == "" {
instanceName := auth.InstanceName(ctx)
if instanceName == "" {
return runnerErrors.ErrUnauthorized
}
if err := r.store.AddInstanceEvent(ctx, instanceID, params.StatusEvent, params.EventInfo, param.Message); err != nil {
if err := r.store.AddInstanceEvent(ctx, instanceName, params.StatusEvent, params.EventInfo, param.Message); err != nil {
return errors.Wrap(err, "adding status update")
}
@ -887,7 +887,7 @@ func (r *Runner) AddInstanceStatusMessage(ctx context.Context, param params.Inst
updateParams.AgentID = *param.AgentID
}
if _, err := r.store.UpdateInstance(r.ctx, instanceID, updateParams); err != nil {
if _, err := r.store.UpdateInstance(r.ctx, instanceName, updateParams); err != nil {
return errors.Wrap(err, "updating runner agent ID")
}
@ -895,9 +895,9 @@ func (r *Runner) AddInstanceStatusMessage(ctx context.Context, param params.Inst
}
func (r *Runner) UpdateSystemInfo(ctx context.Context, param params.UpdateSystemInfoParams) error {
instanceID := auth.InstanceID(ctx)
if instanceID == "" {
slog.ErrorContext(ctx, "missing instance ID")
instanceName := auth.InstanceName(ctx)
if instanceName == "" {
slog.ErrorContext(ctx, "missing instance name")
return runnerErrors.ErrUnauthorized
}
@ -915,7 +915,7 @@ func (r *Runner) UpdateSystemInfo(ctx context.Context, param params.UpdateSystem
updateParams.AgentID = *param.AgentID
}
if _, err := r.store.UpdateInstance(r.ctx, instanceID, updateParams); err != nil {
if _, err := r.store.UpdateInstance(r.ctx, instanceName, updateParams); err != nil {
return errors.Wrap(err, "updating runner system info")
}