Merge pull request #71 from gabriel-samfira/add-extra-specs

Add extra specs on pools
This commit is contained in:
Gabriel 2023-01-30 18:09:50 +02:00 committed by GitHub
commit 1ad52e87ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
126 changed files with 8835 additions and 5331 deletions

View file

@ -1,66 +0,0 @@
// Copyright 2022 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
// orgPoolCmd represents the pool command
var orgInstancesCmd = &cobra.Command{
Use: "runner",
SilenceUsage: true,
Short: "List runners",
Long: `List runners from all pools defined in this organization.`,
Run: nil,
}
var orgRunnerListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List organization runners",
Long: `List all runners for a given organization.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) == 0 {
return fmt.Errorf("requires a organization ID")
}
if len(args) > 1 {
return fmt.Errorf("too many arguments")
}
instances, err := cli.ListOrgInstances(args[0])
if err != nil {
return err
}
formatInstances(instances)
return nil
},
}
func init() {
orgInstancesCmd.AddCommand(
orgRunnerListCmd,
)
organizationCmd.AddCommand(orgInstancesCmd)
}

View file

@ -1,265 +0,0 @@
// Copyright 2022 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package cmd
import (
"fmt"
"garm/config"
"garm/params"
"strings"
"github.com/spf13/cobra"
)
// orgPoolCmd represents the pool command
var orgPoolCmd = &cobra.Command{
Use: "pool",
SilenceUsage: true,
Aliases: []string{"pools"},
Short: "Manage pools",
Long: `Manage pools for a organization.
Repositories and organizations can define multiple pools with different
characteristics, which in turn will spawn github self hosted runners on
compute instances that reflect those characteristics.
For example, one pool can define a runner with tags "GPU,ML" which will
spin up instances with access to a GPU, on the desired provider.`,
Run: nil,
}
var orgPoolAddCmd = &cobra.Command{
Use: "add",
Aliases: []string{"create"},
Short: "Add pool",
Long: `Add a new pool to an organization.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) == 0 {
return fmt.Errorf("requires a organization ID")
}
if len(args) > 1 {
return fmt.Errorf("too many arguments")
}
tags := strings.Split(poolTags, ",")
newPoolParams := params.CreatePoolParams{
ProviderName: poolProvider,
MaxRunners: poolMaxRunners,
RunnerPrefix: params.RunnerPrefix{
Prefix: poolRunnerPrefix,
},
MinIdleRunners: poolMinIdleRunners,
Image: poolImage,
Flavor: poolFlavor,
OSType: config.OSType(poolOSType),
OSArch: config.OSArch(poolOSArch),
Tags: tags,
Enabled: poolEnabled,
}
if err := newPoolParams.Validate(); err != nil {
return err
}
pool, err := cli.CreateOrgPool(args[0], newPoolParams)
if err != nil {
return err
}
formatOnePool(pool)
return nil
},
}
var orgPoolListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List organization pools",
Long: `List all configured pools for a given organization.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) == 0 {
return fmt.Errorf("requires a organization ID")
}
if len(args) > 1 {
return fmt.Errorf("too many arguments")
}
pools, err := cli.ListOrgPools(args[0])
if err != nil {
return err
}
formatPools(pools)
return nil
},
}
var orgPoolShowCmd = &cobra.Command{
Use: "show",
Short: "Show details for one pool",
Long: `Displays detailed information about a single pool.`,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) < 2 || len(args) > 2 {
return fmt.Errorf("command requires orgID and poolID")
}
pool, err := cli.GetOrgPool(args[0], args[1])
if err != nil {
return err
}
formatOnePool(pool)
return nil
},
}
var orgPoolDeleteCmd = &cobra.Command{
Use: "delete",
Aliases: []string{"remove", "rm", "del"},
Short: "Removes one pool",
Long: `Delete one organization pool from the manager.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) < 2 || len(args) > 2 {
return fmt.Errorf("command requires orgID and poolID")
}
if err := cli.DeleteOrgPool(args[0], args[1]); err != nil {
return err
}
return nil
},
}
var orgPoolUpdateCmd = &cobra.Command{
Use: "update",
Short: "Update one pool",
Long: `Updates pool characteristics.
This command updates the pool characteristics. Runners already created prior to updating
the pool, will not be recreated. IF they no longer suit your needs, you will need to
explicitly remove them using the runner delete command.
`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) < 2 || len(args) > 2 {
return fmt.Errorf("command requires orgID and poolID")
}
poolUpdateParams := params.UpdatePoolParams{}
if cmd.Flags().Changed("image") {
poolUpdateParams.Image = poolImage
}
if cmd.Flags().Changed("flavor") {
poolUpdateParams.Flavor = poolFlavor
}
if cmd.Flags().Changed("tags") {
poolUpdateParams.Tags = strings.Split(poolTags, ",")
}
if cmd.Flags().Changed("os-type") {
poolUpdateParams.OSType = config.OSType(poolOSType)
}
if cmd.Flags().Changed("os-arch") {
poolUpdateParams.OSArch = config.OSArch(poolOSArch)
}
if cmd.Flags().Changed("runner-prefix") {
poolUpdateParams.RunnerPrefix = params.RunnerPrefix{
Prefix: poolRunnerPrefix,
}
}
if cmd.Flags().Changed("max-runners") {
poolUpdateParams.MaxRunners = &poolMaxRunners
}
if cmd.Flags().Changed("min-idle-runners") {
poolUpdateParams.MinIdleRunners = &poolMinIdleRunners
}
if cmd.Flags().Changed("enabled") {
poolUpdateParams.Enabled = &poolEnabled
}
pool, err := cli.UpdateOrgPool(args[0], args[1], poolUpdateParams)
if err != nil {
return err
}
formatOnePool(pool)
return nil
},
}
func init() {
orgPoolAddCmd.Flags().StringVar(&poolProvider, "provider-name", "", "The name of the provider where runners will be created.")
orgPoolAddCmd.Flags().StringVar(&poolImage, "image", "", "The provider-specific image name to use for runners in this pool.")
orgPoolAddCmd.Flags().StringVar(&poolFlavor, "flavor", "", "The flavor to use for this runner.")
orgPoolAddCmd.Flags().StringVar(&poolTags, "tags", "", "A comma separated list of tags to assign to this runner.")
orgPoolAddCmd.Flags().StringVar(&poolOSType, "os-type", "linux", "Operating system type (windows, linux, etc).")
orgPoolAddCmd.Flags().StringVar(&poolOSArch, "os-arch", "amd64", "Operating system architecture (amd64, arm, etc).")
orgPoolAddCmd.Flags().StringVar(&poolRunnerPrefix, "runner-prefix", "", "The name prefix to use for runners in this pool.")
orgPoolAddCmd.Flags().UintVar(&poolMaxRunners, "max-runners", 5, "The maximum number of runner this pool will create.")
orgPoolAddCmd.Flags().UintVar(&poolMinIdleRunners, "min-idle-runners", 1, "Attempt to maintain a minimum of idle self-hosted runners of this type.")
orgPoolAddCmd.Flags().BoolVar(&poolEnabled, "enabled", false, "Enable this pool.")
orgPoolAddCmd.MarkFlagRequired("provider-name") //nolint
orgPoolAddCmd.MarkFlagRequired("image") //nolint
orgPoolAddCmd.MarkFlagRequired("flavor") //nolint
orgPoolAddCmd.MarkFlagRequired("tags") //nolint
orgPoolUpdateCmd.Flags().StringVar(&poolImage, "image", "", "The provider-specific image name to use for runners in this pool.")
orgPoolUpdateCmd.Flags().StringVar(&poolFlavor, "flavor", "", "The flavor to use for this runner.")
orgPoolUpdateCmd.Flags().StringVar(&poolTags, "tags", "", "A comma separated list of tags to assign to this runner.")
orgPoolUpdateCmd.Flags().StringVar(&poolOSType, "os-type", "linux", "Operating system type (windows, linux, etc).")
orgPoolUpdateCmd.Flags().StringVar(&poolOSArch, "os-arch", "amd64", "Operating system architecture (amd64, arm, etc).")
orgPoolUpdateCmd.Flags().StringVar(&poolRunnerPrefix, "runner-prefix", "", "The name prefix to use for runners in this pool.")
orgPoolUpdateCmd.Flags().UintVar(&poolMaxRunners, "max-runners", 5, "The maximum number of runner this pool will create.")
orgPoolUpdateCmd.Flags().UintVar(&poolMinIdleRunners, "min-idle-runners", 1, "Attempt to maintain a minimum of idle self-hosted runners of this type.")
orgPoolUpdateCmd.Flags().BoolVar(&poolEnabled, "enabled", false, "Enable this pool.")
orgPoolCmd.AddCommand(
orgPoolListCmd,
orgPoolAddCmd,
orgPoolShowCmd,
orgPoolDeleteCmd,
orgPoolUpdateCmd,
)
organizationCmd.AddCommand(orgPoolCmd)
}

View file

@ -15,20 +15,36 @@
package cmd package cmd
import ( import (
"encoding/json"
"fmt" "fmt"
"garm/config" "garm/config"
"garm/params" "garm/params"
"os" "os"
"strings" "strings"
"github.com/jedib0t/go-pretty/v6/table"
"github.com/pkg/errors"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var ( var (
poolRepository string poolProvider string
poolOrganization string poolMaxRunners uint
poolEnterprise string poolMinIdleRunners uint
poolAll bool poolRunnerPrefix string
poolImage string
poolFlavor string
poolOSType string
poolOSArch string
poolTags string
poolEnabled bool
poolRunnerBootstrapTimeout uint
poolRepository string
poolOrganization string
poolEnterprise string
poolExtraSpecsFile string
poolExtraSpecs string
poolAll bool
) )
// runnerCmd represents the runner command // runnerCmd represents the runner command
@ -181,6 +197,23 @@ var poolAddCmd = &cobra.Command{
Enabled: poolEnabled, Enabled: poolEnabled,
RunnerBootstrapTimeout: poolRunnerBootstrapTimeout, RunnerBootstrapTimeout: poolRunnerBootstrapTimeout,
} }
if cmd.Flags().Changed("extra-specs") {
data, err := asRawMessage([]byte(poolExtraSpecs))
if err != nil {
return err
}
newPoolParams.ExtraSpecs = data
}
if poolExtraSpecsFile != "" {
data, err := extraSpecsFromFile(poolExtraSpecsFile)
if err != nil {
return err
}
newPoolParams.ExtraSpecs = data
}
if err := newPoolParams.Validate(); err != nil { if err := newPoolParams.Validate(); err != nil {
return err return err
} }
@ -274,6 +307,22 @@ explicitly remove them using the runner delete command.
poolUpdateParams.RunnerBootstrapTimeout = &poolRunnerBootstrapTimeout poolUpdateParams.RunnerBootstrapTimeout = &poolRunnerBootstrapTimeout
} }
if cmd.Flags().Changed("extra-specs") {
data, err := asRawMessage([]byte(poolExtraSpecs))
if err != nil {
return err
}
poolUpdateParams.ExtraSpecs = data
}
if poolExtraSpecsFile != "" {
data, err := extraSpecsFromFile(poolExtraSpecsFile)
if err != nil {
return err
}
poolUpdateParams.ExtraSpecs = data
}
pool, err := cli.UpdatePoolByID(args[0], poolUpdateParams) pool, err := cli.UpdatePoolByID(args[0], poolUpdateParams)
if err != nil { if err != nil {
return err return err
@ -301,6 +350,9 @@ func init() {
poolUpdateCmd.Flags().UintVar(&poolMinIdleRunners, "min-idle-runners", 1, "Attempt to maintain a minimum of idle self-hosted runners of this type.") poolUpdateCmd.Flags().UintVar(&poolMinIdleRunners, "min-idle-runners", 1, "Attempt to maintain a minimum of idle self-hosted runners of this type.")
poolUpdateCmd.Flags().BoolVar(&poolEnabled, "enabled", false, "Enable this pool.") poolUpdateCmd.Flags().BoolVar(&poolEnabled, "enabled", false, "Enable this pool.")
poolUpdateCmd.Flags().UintVar(&poolRunnerBootstrapTimeout, "runner-bootstrap-timeout", 20, "Duration in minutes after which a runner is considered failed if it does not join Github.") poolUpdateCmd.Flags().UintVar(&poolRunnerBootstrapTimeout, "runner-bootstrap-timeout", 20, "Duration in minutes after which a runner is considered failed if it does not join Github.")
poolUpdateCmd.Flags().StringVar(&poolExtraSpecsFile, "extra-specs-file", "", "A file containing a valid json which will be passed to the IaaS provider managing the pool.")
poolUpdateCmd.Flags().StringVar(&poolExtraSpecs, "extra-specs", "", "A valid json which will be passed to the IaaS provider managing the pool.")
poolUpdateCmd.MarkFlagsMutuallyExclusive("extra-specs-file", "extra-specs")
poolAddCmd.Flags().StringVar(&poolProvider, "provider-name", "", "The name of the provider where runners will be created.") poolAddCmd.Flags().StringVar(&poolProvider, "provider-name", "", "The name of the provider where runners will be created.")
poolAddCmd.Flags().StringVar(&poolImage, "image", "", "The provider-specific image name to use for runners in this pool.") poolAddCmd.Flags().StringVar(&poolImage, "image", "", "The provider-specific image name to use for runners in this pool.")
@ -309,6 +361,8 @@ func init() {
poolAddCmd.Flags().StringVar(&poolTags, "tags", "", "A comma separated list of tags to assign to this runner.") poolAddCmd.Flags().StringVar(&poolTags, "tags", "", "A comma separated list of tags to assign to this runner.")
poolAddCmd.Flags().StringVar(&poolOSType, "os-type", "linux", "Operating system type (windows, linux, etc).") poolAddCmd.Flags().StringVar(&poolOSType, "os-type", "linux", "Operating system type (windows, linux, etc).")
poolAddCmd.Flags().StringVar(&poolOSArch, "os-arch", "amd64", "Operating system architecture (amd64, arm, etc).") poolAddCmd.Flags().StringVar(&poolOSArch, "os-arch", "amd64", "Operating system architecture (amd64, arm, etc).")
poolAddCmd.Flags().StringVar(&poolExtraSpecsFile, "extra-specs-file", "", "A file containing a valid json which will be passed to the IaaS provider managing the pool.")
poolAddCmd.Flags().StringVar(&poolExtraSpecs, "extra-specs", "", "A valid json which will be passed to the IaaS provider managing the pool.")
poolAddCmd.Flags().UintVar(&poolMaxRunners, "max-runners", 5, "The maximum number of runner this pool will create.") poolAddCmd.Flags().UintVar(&poolMaxRunners, "max-runners", 5, "The maximum number of runner this pool will create.")
poolAddCmd.Flags().UintVar(&poolRunnerBootstrapTimeout, "runner-bootstrap-timeout", 20, "Duration in minutes after which a runner is considered failed if it does not join Github.") poolAddCmd.Flags().UintVar(&poolRunnerBootstrapTimeout, "runner-bootstrap-timeout", 20, "Duration in minutes after which a runner is considered failed if it does not join Github.")
poolAddCmd.Flags().UintVar(&poolMinIdleRunners, "min-idle-runners", 1, "Attempt to maintain a minimum of idle self-hosted runners of this type.") poolAddCmd.Flags().UintVar(&poolMinIdleRunners, "min-idle-runners", 1, "Attempt to maintain a minimum of idle self-hosted runners of this type.")
@ -322,6 +376,7 @@ func init() {
poolAddCmd.Flags().StringVarP(&poolOrganization, "org", "o", "", "Add the new pool withing this organization.") poolAddCmd.Flags().StringVarP(&poolOrganization, "org", "o", "", "Add the new pool withing this organization.")
poolAddCmd.Flags().StringVarP(&poolEnterprise, "enterprise", "e", "", "Add the new pool withing this enterprise.") poolAddCmd.Flags().StringVarP(&poolEnterprise, "enterprise", "e", "", "Add the new pool withing this enterprise.")
poolAddCmd.MarkFlagsMutuallyExclusive("repo", "org", "enterprise") poolAddCmd.MarkFlagsMutuallyExclusive("repo", "org", "enterprise")
poolAddCmd.MarkFlagsMutuallyExclusive("extra-specs-file", "extra-specs")
poolCmd.AddCommand( poolCmd.AddCommand(
poolListCmd, poolListCmd,
@ -333,3 +388,112 @@ func init() {
rootCmd.AddCommand(poolCmd) rootCmd.AddCommand(poolCmd)
} }
func extraSpecsFromFile(specsFile string) (json.RawMessage, error) {
data, err := os.ReadFile(specsFile)
if err != nil {
return nil, errors.Wrap(err, "opening specs file")
}
return asRawMessage(data)
}
func asRawMessage(data []byte) (json.RawMessage, error) {
// unmarshaling and marshaling again will remove new lines and verify we
// have a valid json.
var unmarshaled interface{}
if err := json.Unmarshal(data, &unmarshaled); err != nil {
return nil, errors.Wrap(err, "decoding extra specs")
}
var asRawJson json.RawMessage
var err error
asRawJson, err = json.Marshal(unmarshaled)
if err != nil {
return nil, errors.Wrap(err, "marshaling json")
}
return asRawJson, nil
}
func formatPools(pools []params.Pool) {
t := table.NewWriter()
header := table.Row{"ID", "Image", "Flavor", "Tags", "Belongs to", "Level", "Enabled", "Runner Prefix"}
t.AppendHeader(header)
for _, pool := range pools {
tags := []string{}
for _, tag := range pool.Tags {
tags = append(tags, tag.Name)
}
var belongsTo string
var level string
if pool.RepoID != "" && pool.RepoName != "" {
belongsTo = pool.RepoName
level = "repo"
} else if pool.OrgID != "" && pool.OrgName != "" {
belongsTo = pool.OrgName
level = "org"
} else if pool.EnterpriseID != "" && pool.EnterpriseName != "" {
belongsTo = pool.EnterpriseName
level = "enterprise"
}
t.AppendRow(table.Row{pool.ID, pool.Image, pool.Flavor, strings.Join(tags, " "), belongsTo, level, pool.Enabled, pool.GetRunnerPrefix()})
t.AppendSeparator()
}
fmt.Println(t.Render())
}
func formatOnePool(pool params.Pool) {
t := table.NewWriter()
rowConfigAutoMerge := table.RowConfig{AutoMerge: true}
header := table.Row{"Field", "Value"}
tags := []string{}
for _, tag := range pool.Tags {
tags = append(tags, tag.Name)
}
var belongsTo string
var level string
if pool.RepoID != "" && pool.RepoName != "" {
belongsTo = pool.RepoName
level = "repo"
} else if pool.OrgID != "" && pool.OrgName != "" {
belongsTo = pool.OrgName
level = "org"
} else if pool.EnterpriseID != "" && pool.EnterpriseName != "" {
belongsTo = pool.EnterpriseName
level = "enterprise"
}
t.AppendHeader(header)
t.AppendRow(table.Row{"ID", pool.ID})
t.AppendRow(table.Row{"Provider Name", pool.ProviderName})
t.AppendRow(table.Row{"Image", pool.Image})
t.AppendRow(table.Row{"Flavor", pool.Flavor})
t.AppendRow(table.Row{"OS Type", pool.OSType})
t.AppendRow(table.Row{"OS Architecture", pool.OSArch})
t.AppendRow(table.Row{"Max Runners", pool.MaxRunners})
t.AppendRow(table.Row{"Min Idle Runners", pool.MinIdleRunners})
t.AppendRow(table.Row{"Runner Bootstrap Timeout", pool.RunnerBootstrapTimeout})
t.AppendRow(table.Row{"Tags", strings.Join(tags, ", ")})
t.AppendRow(table.Row{"Belongs to", belongsTo})
t.AppendRow(table.Row{"Level", level})
t.AppendRow(table.Row{"Enabled", pool.Enabled})
t.AppendRow(table.Row{"Runner Prefix", pool.GetRunnerPrefix()})
t.AppendRow(table.Row{"Extra specs", string(pool.ExtraSpecs)})
if len(pool.Instances) > 0 {
for _, instance := range pool.Instances {
t.AppendRow(table.Row{"Instances", fmt.Sprintf("%s (%s)", instance.Name, instance.ID)}, rowConfigAutoMerge)
}
}
t.SetColumnConfigs([]table.ColumnConfig{
{Number: 1, AutoMerge: true},
{Number: 2, AutoMerge: true},
})
fmt.Println(t.Render())
}

View file

@ -1,66 +0,0 @@
// Copyright 2022 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
// repoPoolCmd represents the pool command
var repoInstancesCmd = &cobra.Command{
Use: "runner",
SilenceUsage: true,
Short: "List runners",
Long: `List runners from all pools defined in this repository.`,
Run: nil,
}
var repoRunnerListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List repository runners",
Long: `List all runners for a given repository.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) == 0 {
return fmt.Errorf("requires a repository ID")
}
if len(args) > 1 {
return fmt.Errorf("too many arguments")
}
instances, err := cli.ListRepoInstances(args[0])
if err != nil {
return err
}
formatInstances(instances)
return nil
},
}
func init() {
repoInstancesCmd.AddCommand(
repoRunnerListCmd,
)
repositoryCmd.AddCommand(repoInstancesCmd)
}

View file

@ -1,335 +0,0 @@
// Copyright 2022 Cloudbase Solutions SRL
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package cmd
import (
"fmt"
"garm/config"
"garm/params"
"strings"
"github.com/jedib0t/go-pretty/v6/table"
"github.com/spf13/cobra"
)
var (
poolProvider string
poolMaxRunners uint
poolMinIdleRunners uint
poolRunnerPrefix string
poolImage string
poolFlavor string
poolOSType string
poolOSArch string
poolTags string
poolEnabled bool
poolRunnerBootstrapTimeout uint
)
// repoPoolCmd represents the pool command
var repoPoolCmd = &cobra.Command{
Use: "pool",
SilenceUsage: true,
Aliases: []string{"pools"},
Short: "Manage pools",
Long: `Manage pools for a repository.
Repositories and organizations can define multiple pools with different
characteristics, which in turn will spawn github self hosted runners on
compute instances that reflect those characteristics.
For example, one pool can define a runner with tags "GPU,ML" which will
spin up instances with access to a GPU, on the desired provider.`,
Run: nil,
}
var repoPoolAddCmd = &cobra.Command{
Use: "add",
Aliases: []string{"create"},
Short: "Add pool",
Long: `Add a new pool to a repository.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) == 0 {
return fmt.Errorf("requires a repository ID")
}
if len(args) > 1 {
return fmt.Errorf("too many arguments")
}
tags := strings.Split(poolTags, ",")
newPoolParams := params.CreatePoolParams{
ProviderName: poolProvider,
MaxRunners: poolMaxRunners,
RunnerPrefix: params.RunnerPrefix{
Prefix: poolRunnerPrefix,
},
MinIdleRunners: poolMinIdleRunners,
Image: poolImage,
Flavor: poolFlavor,
OSType: config.OSType(poolOSType),
OSArch: config.OSArch(poolOSArch),
Tags: tags,
Enabled: poolEnabled,
}
if err := newPoolParams.Validate(); err != nil {
return err
}
pool, err := cli.CreateRepoPool(args[0], newPoolParams)
if err != nil {
return err
}
formatOnePool(pool)
return nil
},
}
var repoPoolShowCmd = &cobra.Command{
Use: "show",
Short: "Show details for one pool",
Long: `Displays detailed information about a single pool.`,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) < 2 || len(args) > 2 {
return fmt.Errorf("command requires repoID and poolID")
}
pool, err := cli.GetRepoPool(args[0], args[1])
if err != nil {
return err
}
formatOnePool(pool)
return nil
},
}
var repoPoolDeleteCmd = &cobra.Command{
Use: "delete",
Aliases: []string{"remove", "rm", "del"},
Short: "Removes one pool",
Long: `Delete one repository pool from the manager.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) < 2 || len(args) > 2 {
return fmt.Errorf("command requires repoID and poolID")
}
if err := cli.DeleteRepoPool(args[0], args[1]); err != nil {
return err
}
return nil
},
}
var repoPoolUpdateCmd = &cobra.Command{
Use: "update",
Short: "Update one pool",
Long: `Updates pool characteristics.
This command updates the pool characteristics. Runners already created prior to updating
the pool, will not be recreated. IF they no longer suit your needs, you will need to
explicitly remove them using the runner delete command.
`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if needsInit {
return errNeedsInitError
}
if len(args) < 2 || len(args) > 2 {
return fmt.Errorf("command requires repoID and poolID")
}
poolUpdateParams := params.UpdatePoolParams{}
if cmd.Flags().Changed("image") {
poolUpdateParams.Image = poolImage
}
if cmd.Flags().Changed("flavor") {
poolUpdateParams.Flavor = poolFlavor
}
if cmd.Flags().Changed("tags") {
poolUpdateParams.Tags = strings.Split(poolTags, ",")
}
if cmd.Flags().Changed("os-type") {
poolUpdateParams.OSType = config.OSType(poolOSType)
}
if cmd.Flags().Changed("os-arch") {
poolUpdateParams.OSArch = config.OSArch(poolOSArch)
}
if cmd.Flags().Changed("runner-prefix") {
poolUpdateParams.RunnerPrefix = params.RunnerPrefix{
Prefix: poolRunnerPrefix,
}
}
if cmd.Flags().Changed("max-runners") {
poolUpdateParams.MaxRunners = &poolMaxRunners
}
if cmd.Flags().Changed("min-idle-runners") {
poolUpdateParams.MinIdleRunners = &poolMinIdleRunners
}
if cmd.Flags().Changed("enabled") {
poolUpdateParams.Enabled = &poolEnabled
}
pool, err := cli.UpdateRepoPool(args[0], args[1], poolUpdateParams)
if err != nil {
return err
}
formatOnePool(pool)
return nil
},
}
func init() {
repoPoolAddCmd.Flags().StringVar(&poolProvider, "provider-name", "", "The name of the provider where runners will be created.")
repoPoolAddCmd.Flags().StringVar(&poolImage, "image", "", "The provider-specific image name to use for runners in this pool.")
repoPoolAddCmd.Flags().StringVar(&poolFlavor, "flavor", "", "The flavor to use for this runner.")
repoPoolAddCmd.Flags().StringVar(&poolTags, "tags", "", "A comma separated list of tags to assign to this runner.")
repoPoolAddCmd.Flags().StringVar(&poolOSType, "os-type", "linux", "Operating system type (windows, linux, etc).")
repoPoolAddCmd.Flags().StringVar(&poolOSArch, "os-arch", "amd64", "Operating system architecture (amd64, arm, etc).")
repoPoolAddCmd.Flags().StringVar(&poolRunnerPrefix, "runner-prefix", "", "The name prefix to use for runners in this pool.")
repoPoolAddCmd.Flags().UintVar(&poolMaxRunners, "max-runners", 5, "The maximum number of runner this pool will create.")
repoPoolAddCmd.Flags().UintVar(&poolMinIdleRunners, "min-idle-runners", 1, "Attempt to maintain a minimum of idle self-hosted runners of this type.")
repoPoolAddCmd.Flags().BoolVar(&poolEnabled, "enabled", false, "Enable this pool.")
repoPoolAddCmd.MarkFlagRequired("provider-name") //nolint
repoPoolAddCmd.MarkFlagRequired("image") //nolint
repoPoolAddCmd.MarkFlagRequired("flavor") //nolint
repoPoolAddCmd.MarkFlagRequired("tags") //nolint
repoPoolUpdateCmd.Flags().StringVar(&poolImage, "image", "", "The provider-specific image name to use for runners in this pool.")
repoPoolUpdateCmd.Flags().StringVar(&poolFlavor, "flavor", "", "The flavor to use for this runner.")
repoPoolUpdateCmd.Flags().StringVar(&poolTags, "tags", "", "A comma separated list of tags to assign to this runner.")
repoPoolUpdateCmd.Flags().StringVar(&poolOSType, "os-type", "linux", "Operating system type (windows, linux, etc).")
repoPoolUpdateCmd.Flags().StringVar(&poolOSArch, "os-arch", "amd64", "Operating system architecture (amd64, arm, etc).")
repoPoolUpdateCmd.Flags().StringVar(&poolRunnerPrefix, "runner-prefix", "", "The name prefix to use for runners in this pool.")
repoPoolUpdateCmd.Flags().UintVar(&poolMaxRunners, "max-runners", 5, "The maximum number of runner this pool will create.")
repoPoolUpdateCmd.Flags().UintVar(&poolMinIdleRunners, "min-idle-runners", 1, "Attempt to maintain a minimum of idle self-hosted runners of this type.")
repoPoolUpdateCmd.Flags().BoolVar(&poolEnabled, "enabled", false, "Enable this pool.")
repoPoolCmd.AddCommand(
poolListCmd,
repoPoolAddCmd,
repoPoolShowCmd,
repoPoolDeleteCmd,
repoPoolUpdateCmd,
)
repositoryCmd.AddCommand(repoPoolCmd)
}
func formatPools(pools []params.Pool) {
t := table.NewWriter()
header := table.Row{"ID", "Image", "Flavor", "Tags", "Belongs to", "Level", "Enabled", "Runner Prefix"}
t.AppendHeader(header)
for _, pool := range pools {
tags := []string{}
for _, tag := range pool.Tags {
tags = append(tags, tag.Name)
}
var belongsTo string
var level string
if pool.RepoID != "" && pool.RepoName != "" {
belongsTo = pool.RepoName
level = "repo"
} else if pool.OrgID != "" && pool.OrgName != "" {
belongsTo = pool.OrgName
level = "org"
} else if pool.EnterpriseID != "" && pool.EnterpriseName != "" {
belongsTo = pool.EnterpriseName
level = "enterprise"
}
t.AppendRow(table.Row{pool.ID, pool.Image, pool.Flavor, strings.Join(tags, " "), belongsTo, level, pool.Enabled, pool.GetRunnerPrefix()})
t.AppendSeparator()
}
fmt.Println(t.Render())
}
func formatOnePool(pool params.Pool) {
t := table.NewWriter()
rowConfigAutoMerge := table.RowConfig{AutoMerge: true}
header := table.Row{"Field", "Value"}
tags := []string{}
for _, tag := range pool.Tags {
tags = append(tags, tag.Name)
}
var belongsTo string
var level string
if pool.RepoID != "" && pool.RepoName != "" {
belongsTo = pool.RepoName
level = "repo"
} else if pool.OrgID != "" && pool.OrgName != "" {
belongsTo = pool.OrgName
level = "org"
} else if pool.EnterpriseID != "" && pool.EnterpriseName != "" {
belongsTo = pool.EnterpriseName
level = "enterprise"
}
t.AppendHeader(header)
t.AppendRow(table.Row{"ID", pool.ID})
t.AppendRow(table.Row{"Provider Name", pool.ProviderName})
t.AppendRow(table.Row{"Image", pool.Image})
t.AppendRow(table.Row{"Flavor", pool.Flavor})
t.AppendRow(table.Row{"OS Type", pool.OSType})
t.AppendRow(table.Row{"OS Architecture", pool.OSArch})
t.AppendRow(table.Row{"Max Runners", pool.MaxRunners})
t.AppendRow(table.Row{"Min Idle Runners", pool.MinIdleRunners})
t.AppendRow(table.Row{"Runner Bootstrap Timeout", pool.RunnerBootstrapTimeout})
t.AppendRow(table.Row{"Tags", strings.Join(tags, ", ")})
t.AppendRow(table.Row{"Belongs to", belongsTo})
t.AppendRow(table.Row{"Level", level})
t.AppendRow(table.Row{"Enabled", pool.Enabled})
t.AppendRow(table.Row{"Runner Prefix", pool.GetRunnerPrefix()})
if len(pool.Instances) > 0 {
for _, instance := range pool.Instances {
t.AppendRow(table.Row{"Instances", fmt.Sprintf("%s (%s)", instance.Name, instance.ID)}, rowConfigAutoMerge)
}
}
t.SetColumnConfigs([]table.ColumnConfig{
{Number: 1, AutoMerge: true},
{Number: 2, AutoMerge: true},
})
fmt.Println(t.Render())
}

View file

@ -9,6 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -152,6 +153,10 @@ func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID str
RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, RunnerBootstrapTimeout: param.RunnerBootstrapTimeout,
} }
if len(param.ExtraSpecs) > 0 {
newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs)
}
_, err = s.getEnterprisePoolByUniqueFields(ctx, enterpriseID, newPool.ProviderName, newPool.Image, newPool.Flavor) _, err = s.getEnterprisePoolByUniqueFields(ctx, enterpriseID, newPool.ProviderName, newPool.Image, newPool.Flavor)
if err != nil { if err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) { if !errors.Is(err, runnerErrors.ErrNotFound) {
@ -190,7 +195,7 @@ func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID str
} }
func (s *sqlDatabase) GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) { func (s *sqlDatabase) GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) {
pool, err := s.getEnterprisePool(ctx, enterpriseID, poolID, "Tags", "Instances") pool, err := s.getEntityPool(ctx, params.EnterprisePool, enterpriseID, poolID, "Tags", "Instances")
if err != nil { if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool") return params.Pool{}, errors.Wrap(err, "fetching pool")
} }
@ -198,7 +203,7 @@ func (s *sqlDatabase) GetEnterprisePool(ctx context.Context, enterpriseID, poolI
} }
func (s *sqlDatabase) DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error { func (s *sqlDatabase) DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error {
pool, err := s.getEnterprisePool(ctx, enterpriseID, poolID) pool, err := s.getEntityPool(ctx, params.EnterprisePool, enterpriseID, poolID)
if err != nil { if err != nil {
return errors.Wrap(err, "looking up enterprise pool") return errors.Wrap(err, "looking up enterprise pool")
} }
@ -210,7 +215,7 @@ func (s *sqlDatabase) DeleteEnterprisePool(ctx context.Context, enterpriseID, po
} }
func (s *sqlDatabase) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { func (s *sqlDatabase) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
pool, err := s.getEnterprisePool(ctx, enterpriseID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") pool, err := s.getEntityPool(ctx, params.EnterprisePool, enterpriseID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil { if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool") return params.Pool{}, errors.Wrap(err, "fetching pool")
} }
@ -311,15 +316,10 @@ func (s *sqlDatabase) getEnterprisePoolByUniqueFields(ctx context.Context, enter
return pool[0], nil return pool[0], nil
} }
func (s *sqlDatabase) getEnterprisePool(ctx context.Context, enterpriseID, poolID string, preload ...string) (Pool, error) { func (s *sqlDatabase) getEnterprisePools(ctx context.Context, enterpriseID string, preload ...string) ([]Pool, error) {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID) _, err := s.getEnterpriseByID(ctx, enterpriseID)
if err != nil { if err != nil {
return Pool{}, errors.Wrap(err, "fetching enterprise") return nil, errors.Wrap(err, "fetching enterprise")
}
u, err := uuid.FromString(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
} }
q := s.conn q := s.conn
@ -329,33 +329,10 @@ func (s *sqlDatabase) getEnterprisePool(ctx context.Context, enterpriseID, poolI
} }
} }
var pool []Pool
err = q.Model(&enterprise).Association("Pools").Find(&pool, "id = ?", u)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching pool")
}
if len(pool) == 0 {
return Pool{}, runnerErrors.ErrNotFound
}
return pool[0], nil
}
func (s *sqlDatabase) getEnterprisePools(ctx context.Context, enterpriseID string, preload ...string) ([]Pool, error) {
enterprise, err := s.getEnterpriseByID(ctx, enterpriseID)
if err != nil {
return nil, errors.Wrap(err, "fetching enterprise")
}
var pools []Pool var pools []Pool
err = q.Model(&Pool{}).Where("enterprise_id = ?", enterpriseID).
q := s.conn.Model(&enterprise) Omit("extra_specs").
if len(preload) > 0 { Find(&pools).Error
for _, item := range preload {
q = q.Preload(item)
}
}
err = q.Association("Pools").Find(&pools)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetching pool") return nil, errors.Wrap(err, "fetching pool")

View file

@ -688,7 +688,7 @@ func (s *EnterpriseTestSuite) TestGetEnterprisePoolInvalidEnterpriseID() {
_, err := s.Store.GetEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id") _, err := s.Store.GetEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: fetching enterprise: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
} }
func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() {
@ -701,14 +701,14 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() {
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID) _, err = s.Store.GetEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID)
s.Require().Equal("fetching pool: not found", err.Error()) s.Require().Equal("fetching pool: finding pool: not found", err.Error())
} }
func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolInvalidEnterpriseID() { func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolInvalidEnterpriseID() {
err := s.Store.DeleteEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id") err := s.Store.DeleteEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("looking up enterprise pool: fetching enterprise: parsing id: invalid request", err.Error()) s.Require().Equal("looking up enterprise pool: parsing id: invalid request", err.Error())
} }
func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() {
@ -718,12 +718,8 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() {
} }
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (id = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Enterprises[0].ID). WithArgs(pool.ID, s.Fixtures.Enterprises[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND id = ? AND `pools`.`deleted_at` IS NULL")).
WithArgs(s.Fixtures.Enterprises[0].ID, pool.ID).
WillReturnRows(sqlmock.NewRows([]string{"enterprise_id", "id"}).AddRow(s.Fixtures.Enterprises[0].ID, pool.ID)) WillReturnRows(sqlmock.NewRows([]string{"enterprise_id", "id"}).AddRow(s.Fixtures.Enterprises[0].ID, pool.ID))
s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
@ -809,7 +805,7 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolInvalidEnterpriseID() {
_, err := s.Store.UpdateEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id", s.Fixtures.UpdatePoolParams) _, err := s.Store.UpdateEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id", s.Fixtures.UpdatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: fetching enterprise: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
} }
func TestEnterpriseTestSuite(t *testing.T) { func TestEnterpriseTestSuite(t *testing.T) {

View file

@ -22,6 +22,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -66,6 +67,10 @@ type Pool struct {
OSArch config.OSArch OSArch config.OSArch
Tags []*Tag `gorm:"many2many:pool_tags;"` Tags []*Tag `gorm:"many2many:pool_tags;"`
Enabled bool Enabled bool
// ExtraSpecs is an opaque json that gets sent to the provider
// as part of the bootstrap params for instances. It can contain
// any kind of data needed by providers.
ExtraSpecs datatypes.JSON
RepoID uuid.UUID `gorm:"index"` RepoID uuid.UUID `gorm:"index"`
Repository Repository `gorm:"foreignKey:RepoID"` Repository Repository `gorm:"foreignKey:RepoID"`

View file

@ -24,6 +24,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -169,6 +170,10 @@ func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgId string,
RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, RunnerBootstrapTimeout: param.RunnerBootstrapTimeout,
} }
if len(param.ExtraSpecs) > 0 {
newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs)
}
_, err = s.getOrgPoolByUniqueFields(ctx, orgId, newPool.ProviderName, newPool.Image, newPool.Flavor) _, err = s.getOrgPoolByUniqueFields(ctx, orgId, newPool.ProviderName, newPool.Image, newPool.Flavor)
if err != nil { if err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) { if !errors.Is(err, runnerErrors.ErrNotFound) {
@ -221,7 +226,7 @@ func (s *sqlDatabase) ListOrgPools(ctx context.Context, orgID string) ([]params.
} }
func (s *sqlDatabase) GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) { func (s *sqlDatabase) GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) {
pool, err := s.getOrgPool(ctx, orgID, poolID, "Tags", "Instances") pool, err := s.getEntityPool(ctx, params.OrganizationPool, orgID, poolID, "Tags", "Instances")
if err != nil { if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool") return params.Pool{}, errors.Wrap(err, "fetching pool")
} }
@ -229,7 +234,7 @@ func (s *sqlDatabase) GetOrganizationPool(ctx context.Context, orgID, poolID str
} }
func (s *sqlDatabase) DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error { func (s *sqlDatabase) DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error {
pool, err := s.getOrgPool(ctx, orgID, poolID) pool, err := s.getEntityPool(ctx, params.OrganizationPool, orgID, poolID)
if err != nil { if err != nil {
return errors.Wrap(err, "looking up org pool") return errors.Wrap(err, "looking up org pool")
} }
@ -263,7 +268,7 @@ func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]par
} }
func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
pool, err := s.getOrgPool(ctx, orgID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") pool, err := s.getEntityPool(ctx, params.OrganizationPool, orgID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil { if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool") return params.Pool{}, errors.Wrap(err, "fetching pool")
} }
@ -295,15 +300,10 @@ func (s *sqlDatabase) getPoolByID(ctx context.Context, poolID string, preload ..
return pool, nil return pool, nil
} }
func (s *sqlDatabase) getOrgPool(ctx context.Context, orgID, poolID string, preload ...string) (Pool, error) { func (s *sqlDatabase) getOrgPools(ctx context.Context, orgID string, preload ...string) ([]Pool, error) {
org, err := s.getOrgByID(ctx, orgID) _, err := s.getOrgByID(ctx, orgID)
if err != nil { if err != nil {
return Pool{}, errors.Wrap(err, "fetching org") return nil, errors.Wrap(err, "fetching org")
}
u, err := uuid.FromString(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
} }
q := s.conn q := s.conn
@ -313,33 +313,11 @@ func (s *sqlDatabase) getOrgPool(ctx context.Context, orgID, poolID string, prel
} }
} }
var pool []Pool
err = q.Model(&org).Association("Pools").Find(&pool, "id = ?", u)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching pool")
}
if len(pool) == 0 {
return Pool{}, runnerErrors.ErrNotFound
}
return pool[0], nil
}
func (s *sqlDatabase) getOrgPools(ctx context.Context, orgID string, preload ...string) ([]Pool, error) {
org, err := s.getOrgByID(ctx, orgID)
if err != nil {
return nil, errors.Wrap(err, "fetching org")
}
var pools []Pool var pools []Pool
err = q.Model(&Pool{}).
q := s.conn.Model(&org) Where("org_id = ?", orgID).
if len(preload) > 0 { Omit("extra_specs").
for _, item := range preload { Find(&pools).Error
q = q.Preload(item)
}
}
err = q.Association("Pools").Find(&pools)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetching pool") return nil, errors.Wrap(err, "fetching pool")

View file

@ -688,7 +688,7 @@ func (s *OrgTestSuite) TestGetOrganizationPoolInvalidOrgID() {
_, err := s.Store.GetOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id") _, err := s.Store.GetOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: fetching org: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
} }
func (s *OrgTestSuite) TestDeleteOrganizationPool() { func (s *OrgTestSuite) TestDeleteOrganizationPool() {
@ -701,14 +701,14 @@ func (s *OrgTestSuite) TestDeleteOrganizationPool() {
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID) _, err = s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID)
s.Require().Equal("fetching pool: not found", err.Error()) s.Require().Equal("fetching pool: finding pool: not found", err.Error())
} }
func (s *OrgTestSuite) TestDeleteOrganizationPoolInvalidOrgID() { func (s *OrgTestSuite) TestDeleteOrganizationPoolInvalidOrgID() {
err := s.Store.DeleteOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id") err := s.Store.DeleteOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("looking up org pool: fetching org: parsing id: invalid request", err.Error()) s.Require().Equal("looking up org pool: parsing id: invalid request", err.Error())
} }
func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() {
@ -718,12 +718,8 @@ func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() {
} }
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (id = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Orgs[0].ID). WithArgs(pool.ID, s.Fixtures.Orgs[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND id = ? AND `pools`.`deleted_at` IS NULL")).
WithArgs(s.Fixtures.Orgs[0].ID, pool.ID).
WillReturnRows(sqlmock.NewRows([]string{"org_id", "id"}).AddRow(s.Fixtures.Orgs[0].ID, pool.ID)) WillReturnRows(sqlmock.NewRows([]string{"org_id", "id"}).AddRow(s.Fixtures.Orgs[0].ID, pool.ID))
s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
@ -809,7 +805,7 @@ func (s *OrgTestSuite) TestUpdateOrganizationPoolInvalidOrgID() {
_, err := s.Store.UpdateOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id", s.Fixtures.UpdatePoolParams) _, err := s.Store.UpdateOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id", s.Fixtures.UpdatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: fetching org: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
} }
func TestOrgTestSuite(t *testing.T) { func TestOrgTestSuite(t *testing.T) {

View file

@ -16,10 +16,14 @@ package sql
import ( import (
"context" "context"
"fmt"
runnerErrors "garm/errors"
"garm/params" "garm/params"
"github.com/pkg/errors" "github.com/pkg/errors"
uuid "github.com/satori/go.uuid"
"gorm.io/gorm"
) )
func (s *sqlDatabase) ListAllPools(ctx context.Context) ([]params.Pool, error) { func (s *sqlDatabase) ListAllPools(ctx context.Context) ([]params.Pool, error) {
@ -30,6 +34,7 @@ func (s *sqlDatabase) ListAllPools(ctx context.Context) ([]params.Pool, error) {
Preload("Organization"). Preload("Organization").
Preload("Repository"). Preload("Repository").
Preload("Enterprise"). Preload("Enterprise").
Omit("extra_specs").
Find(&pools) Find(&pools)
if q.Error != nil { if q.Error != nil {
return nil, errors.Wrap(q.Error, "fetching all pools") return nil, errors.Wrap(q.Error, "fetching all pools")
@ -62,3 +67,48 @@ func (s *sqlDatabase) DeletePoolByID(ctx context.Context, poolID string) error {
return nil return nil
} }
func (s *sqlDatabase) getEntityPool(ctx context.Context, entityType params.PoolType, entityID, poolID string, preload ...string) (Pool, error) {
if entityID == "" {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "missing entity id")
}
u, err := uuid.FromString(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
q := s.conn
if len(preload) > 0 {
for _, item := range preload {
q = q.Preload(item)
}
}
var fieldName string
switch entityType {
case params.RepositoryPool:
fieldName = "repo_id"
case params.OrganizationPool:
fieldName = "org_id"
case params.EnterprisePool:
fieldName = "enterprise_id"
default:
return Pool{}, fmt.Errorf("invalid entityType: %v", entityType)
}
var pool Pool
condition := fmt.Sprintf("id = ? and %s = ?", fieldName)
err = q.Model(&Pool{}).
Where(condition, u, entityID).
First(&pool).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return Pool{}, errors.Wrap(runnerErrors.ErrNotFound, "finding pool")
}
return Pool{}, errors.Wrap(err, "fetching pool")
}
return pool, nil
}

View file

@ -127,7 +127,7 @@ func (s *PoolsTestSuite) TestListAllPools() {
func (s *PoolsTestSuite) TestListAllPoolsDBFetchErr() { func (s *PoolsTestSuite) TestListAllPoolsDBFetchErr() {
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`deleted_at` IS NULL")). ExpectQuery(regexp.QuoteMeta("SELECT `pools`.`id`,`pools`.`created_at`,`pools`.`updated_at`,`pools`.`deleted_at`,`pools`.`provider_name`,`pools`.`runner_prefix`,`pools`.`max_runners`,`pools`.`min_idle_runners`,`pools`.`runner_bootstrap_timeout`,`pools`.`image`,`pools`.`flavor`,`pools`.`os_type`,`pools`.`os_arch`,`pools`.`enabled`,`pools`.`repo_id`,`pools`.`org_id`,`pools`.`enterprise_id` FROM `pools` WHERE `pools`.`deleted_at` IS NULL")).
WillReturnError(fmt.Errorf("mocked fetching all pools error")) WillReturnError(fmt.Errorf("mocked fetching all pools error"))
_, err := s.StoreSQLMocked.ListAllPools(context.Background()) _, err := s.StoreSQLMocked.ListAllPools(context.Background())

View file

@ -24,6 +24,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -169,6 +170,10 @@ func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoId string, p
RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, RunnerBootstrapTimeout: param.RunnerBootstrapTimeout,
} }
if len(param.ExtraSpecs) > 0 {
newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs)
}
_, err = s.getRepoPoolByUniqueFields(ctx, repoId, newPool.ProviderName, newPool.Image, newPool.Flavor) _, err = s.getRepoPoolByUniqueFields(ctx, repoId, newPool.ProviderName, newPool.Image, newPool.Flavor)
if err != nil { if err != nil {
if !errors.Is(err, runnerErrors.ErrNotFound) { if !errors.Is(err, runnerErrors.ErrNotFound) {
@ -221,7 +226,7 @@ func (s *sqlDatabase) ListRepoPools(ctx context.Context, repoID string) ([]param
} }
func (s *sqlDatabase) GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) { func (s *sqlDatabase) GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) {
pool, err := s.getRepoPool(ctx, repoID, poolID, "Tags", "Instances") pool, err := s.getEntityPool(ctx, params.RepositoryPool, repoID, poolID, "Tags", "Instances")
if err != nil { if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool") return params.Pool{}, errors.Wrap(err, "fetching pool")
} }
@ -229,7 +234,7 @@ func (s *sqlDatabase) GetRepositoryPool(ctx context.Context, repoID, poolID stri
} }
func (s *sqlDatabase) DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error { func (s *sqlDatabase) DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error {
pool, err := s.getRepoPool(ctx, repoID, poolID) pool, err := s.getEntityPool(ctx, params.RepositoryPool, repoID, poolID)
if err != nil { if err != nil {
return errors.Wrap(err, "looking up repo pool") return errors.Wrap(err, "looking up repo pool")
} }
@ -264,7 +269,7 @@ func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]p
} }
func (s *sqlDatabase) UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { func (s *sqlDatabase) UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) {
pool, err := s.getRepoPool(ctx, repoID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") pool, err := s.getEntityPool(ctx, params.RepositoryPool, repoID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository")
if err != nil { if err != nil {
return params.Pool{}, errors.Wrap(err, "fetching pool") return params.Pool{}, errors.Wrap(err, "fetching pool")
} }
@ -317,36 +322,6 @@ func (s *sqlDatabase) findPoolByTags(id, poolType string, tags []string) (params
return s.sqlToCommonPool(pool), nil return s.sqlToCommonPool(pool), nil
} }
func (s *sqlDatabase) getRepoPool(ctx context.Context, repoID, poolID string, preload ...string) (Pool, error) {
repo, err := s.getRepoByID(ctx, repoID)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching repo")
}
u, err := uuid.FromString(poolID)
if err != nil {
return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id")
}
q := s.conn
if len(preload) > 0 {
for _, item := range preload {
q = q.Preload(item)
}
}
var pool []Pool
err = q.Model(&repo).Association("Pools").Find(&pool, "id = ?", u)
if err != nil {
return Pool{}, errors.Wrap(err, "fetching pool")
}
if len(pool) == 0 {
return Pool{}, runnerErrors.ErrNotFound
}
return pool[0], nil
}
func (s *sqlDatabase) getRepoPoolByUniqueFields(ctx context.Context, repoID string, provider, image, flavor string) (Pool, error) { func (s *sqlDatabase) getRepoPoolByUniqueFields(ctx context.Context, repoID string, provider, image, flavor string) (Pool, error) {
repo, err := s.getRepoByID(ctx, repoID) repo, err := s.getRepoByID(ctx, repoID)
if err != nil { if err != nil {
@ -367,19 +342,22 @@ func (s *sqlDatabase) getRepoPoolByUniqueFields(ctx context.Context, repoID stri
} }
func (s *sqlDatabase) getRepoPools(ctx context.Context, repoID string, preload ...string) ([]Pool, error) { func (s *sqlDatabase) getRepoPools(ctx context.Context, repoID string, preload ...string) ([]Pool, error) {
repo, err := s.getRepoByID(ctx, repoID) _, err := s.getRepoByID(ctx, repoID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetching repo") return nil, errors.Wrap(err, "fetching repo")
} }
var pools []Pool q := s.conn
q := s.conn.Model(&repo)
if len(preload) > 0 { if len(preload) > 0 {
for _, item := range preload { for _, item := range preload {
q = q.Preload(item) q = q.Preload(item)
} }
} }
err = q.Association("Pools").Find(&pools)
var pools []Pool
err = q.Model(&Pool{}).Where("repo_id = ?", repoID).
Omit("extra_specs").
Find(&pools).Error
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetching pool") return nil, errors.Wrap(err, "fetching pool")
} }

View file

@ -725,7 +725,7 @@ func (s *RepoTestSuite) TestGetRepositoryPoolInvalidRepoID() {
_, err := s.Store.GetRepositoryPool(context.Background(), "dummy-repo-id", "dummy-pool-id") _, err := s.Store.GetRepositoryPool(context.Background(), "dummy-repo-id", "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: fetching repo: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
} }
func (s *RepoTestSuite) TestDeleteRepositoryPool() { func (s *RepoTestSuite) TestDeleteRepositoryPool() {
@ -738,14 +738,14 @@ func (s *RepoTestSuite) TestDeleteRepositoryPool() {
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID) _, err = s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID)
s.Require().Equal("fetching pool: fetching org: not found", err.Error()) s.Require().Equal("fetching pool: finding pool: not found", err.Error())
} }
func (s *RepoTestSuite) TestDeleteRepositoryPoolInvalidRepoID() { func (s *RepoTestSuite) TestDeleteRepositoryPoolInvalidRepoID() {
err := s.Store.DeleteRepositoryPool(context.Background(), "dummy-repo-id", "dummy-pool-id") err := s.Store.DeleteRepositoryPool(context.Background(), "dummy-repo-id", "dummy-pool-id")
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("looking up repo pool: fetching repo: parsing id: invalid request", err.Error()) s.Require().Equal("looking up repo pool: parsing id: invalid request", err.Error())
} }
func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() {
@ -755,12 +755,8 @@ func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() {
} }
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (id = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")).
WithArgs(s.Fixtures.Repos[0].ID). WithArgs(pool.ID, s.Fixtures.Repos[0].ID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID))
s.Fixtures.SQLMock.
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND id = ? AND `pools`.`deleted_at` IS NULL")).
WithArgs(s.Fixtures.Repos[0].ID, pool.ID).
WillReturnRows(sqlmock.NewRows([]string{"repo_id", "id"}).AddRow(s.Fixtures.Repos[0].ID, pool.ID)) WillReturnRows(sqlmock.NewRows([]string{"repo_id", "id"}).AddRow(s.Fixtures.Repos[0].ID, pool.ID))
s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock.ExpectBegin()
s.Fixtures.SQLMock. s.Fixtures.SQLMock.
@ -845,7 +841,7 @@ func (s *RepoTestSuite) TestUpdateRepositoryPoolInvalidRepoID() {
_, err := s.Store.UpdateRepositoryPool(context.Background(), "dummy-org-id", "dummy-repo-id", s.Fixtures.UpdatePoolParams) _, err := s.Store.UpdateRepositoryPool(context.Background(), "dummy-org-id", "dummy-repo-id", s.Fixtures.UpdatePoolParams)
s.Require().NotNil(err) s.Require().NotNil(err)
s.Require().Equal("fetching pool: fetching repo: parsing id: invalid request", err.Error()) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error())
} }
func TestRepoTestSuite(t *testing.T) { func TestRepoTestSuite(t *testing.T) {

View file

@ -15,6 +15,7 @@
package sql package sql
import ( import (
"encoding/json"
"fmt" "fmt"
"garm/params" "garm/params"
@ -22,6 +23,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -141,6 +143,7 @@ func (s *sqlDatabase) sqlToCommonPool(pool Pool) params.Pool {
Tags: make([]params.Tag, len(pool.Tags)), Tags: make([]params.Tag, len(pool.Tags)),
Instances: make([]params.Instance, len(pool.Instances)), Instances: make([]params.Instance, len(pool.Instances)),
RunnerBootstrapTimeout: pool.RunnerBootstrapTimeout, RunnerBootstrapTimeout: pool.RunnerBootstrapTimeout,
ExtraSpecs: json.RawMessage(pool.ExtraSpecs),
} }
if pool.RepoID != uuid.Nil { if pool.RepoID != uuid.Nil {
@ -270,6 +273,10 @@ func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (para
pool.OSType = param.OSType pool.OSType = param.OSType
} }
if param.ExtraSpecs != nil {
pool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs)
}
if param.RunnerBootstrapTimeout != nil && *param.RunnerBootstrapTimeout > 0 { if param.RunnerBootstrapTimeout != nil && *param.RunnerBootstrapTimeout > 0 {
pool.RunnerBootstrapTimeout = *param.RunnerBootstrapTimeout pool.RunnerBootstrapTimeout = *param.RunnerBootstrapTimeout
} }

13
go.mod
View file

@ -23,15 +23,16 @@ require (
github.com/spf13/cobra v1.4.1-0.20220504202302-9e88759b19cd github.com/spf13/cobra v1.4.1-0.20220504202302-9e88759b19cd
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0
github.com/teris-io/shortid v0.0.0-20220617161101-71ec9f2aa569 github.com/teris-io/shortid v0.0.0-20220617161101-71ec9f2aa569
golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a
gopkg.in/DATA-DOG/go-sqlmock.v1 v1.3.0 gopkg.in/DATA-DOG/go-sqlmock.v1 v1.3.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.3.3 gorm.io/datatypes v1.1.0
gorm.io/driver/sqlite v1.3.2 gorm.io/driver/mysql v1.4.4
gorm.io/gorm v1.23.4 gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.24.2
) )
require ( require (
@ -42,7 +43,7 @@ require (
github.com/felixge/httpsnoop v1.0.1 // indirect github.com/felixge/httpsnoop v1.0.1 // indirect
github.com/flosch/pongo2 v0.0.0-20200913210552-0d938eb266f3 // indirect github.com/flosch/pongo2 v0.0.0-20200913210552-0d938eb266f3 // indirect
github.com/go-macaroon-bakery/macaroonpb v1.0.0 // indirect github.com/go-macaroon-bakery/macaroonpb v1.0.0 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect
@ -55,7 +56,7 @@ require (
github.com/kr/fs v0.1.0 // indirect github.com/kr/fs v0.1.0 // indirect
github.com/kr/pretty v0.3.0 // indirect github.com/kr/pretty v0.3.0 // indirect
github.com/mattn/go-runewidth v0.0.13 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/mattn/go-sqlite3 v1.14.12 // indirect github.com/mattn/go-sqlite3 v1.14.15 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/pborman/uuid v1.2.1 // indirect github.com/pborman/uuid v1.2.1 // indirect
github.com/pkg/sftp v1.13.4 // indirect github.com/pkg/sftp v1.13.4 // indirect

41
go.sum
View file

@ -91,12 +91,15 @@ github.com/go-macaroon-bakery/macaroonpb v1.0.0 h1:It9exBaRMZ9iix1iJ6gwzfwsDE6Ex
github.com/go-macaroon-bakery/macaroonpb v1.0.0/go.mod h1:UzrGOcbiwTXISFP2XDLDPjfhMINZa+fX/7A2lMd31zc= github.com/go-macaroon-bakery/macaroonpb v1.0.0/go.mod h1:UzrGOcbiwTXISFP2XDLDPjfhMINZa+fX/7A2lMd31zc=
github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY=
github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
@ -170,6 +173,14 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
github.com/jackc/pgconn v1.13.0 h1:3L1XMNV2Zvca/8BYhzcRFS70Lr0WlDg16Di6SFGAbys=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y4Y=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg=
github.com/jackc/pgtype v1.12.0 h1:Dlq8Qvcch7kiehm8wPGIW0W3KsCCHJnRacKW0UM8n5w=
github.com/jackc/pgx/v4 v4.17.2 h1:0Ut0rpeKwvIVbMQ1KbMBU4h6wxehBI535LK6Flheh8E=
github.com/jedib0t/go-pretty/v6 v6.3.1 h1:aOXiD9oqiuLH8btPQW6SfgtQN5zwhyfzZls8a6sPJ/I= github.com/jedib0t/go-pretty/v6 v6.3.1 h1:aOXiD9oqiuLH8btPQW6SfgtQN5zwhyfzZls8a6sPJ/I=
github.com/jedib0t/go-pretty/v6 v6.3.1/go.mod h1:FMkOpgGD3EZ91cW8g/96RfxoV7bdeJyzXPYgz1L1ln0= github.com/jedib0t/go-pretty/v6 v6.3.1/go.mod h1:FMkOpgGD3EZ91cW8g/96RfxoV7bdeJyzXPYgz1L1ln0=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@ -235,10 +246,11 @@ github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYt
github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg=
github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU=
github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.12 h1:TJ1bhYJPV44phC+IMu1u2K/i5RriLTPe+yc68XDJ1Z0= github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@ -341,8 +353,8 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 h1:S25/rfnfsMVgORT4/J61MJ7rdyseOZOyvLIrZEZ7s6s= golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b h1:huxqepDufQpLLIRXiVkTvnxrzJlpwmIWAObmcCcUFr0=
golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -647,13 +659,18 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.3.3 h1:jXG9ANrwBc4+bMvBcSl8zCfPBaVoPyBEBshA8dA93X8= gorm.io/datatypes v1.1.0 h1:EVp1Z28N4ACpYFK1nHboEIJGIFfjY7vLeieDk8jSHJA=
gorm.io/driver/mysql v1.3.3/go.mod h1:ChK6AHbHgDCFZyJp0F+BmVGb06PSIoh9uVYKAlRbb2U= gorm.io/datatypes v1.1.0/go.mod h1:SH2K9R+2RMjuX1CkCONrPwoe9JzVv2hkQvEu4bXGojE=
gorm.io/driver/sqlite v1.3.2 h1:nWTy4cE52K6nnMhv23wLmur9Y3qWbZvOBz+V4PrGAxg= gorm.io/driver/mysql v1.4.4 h1:MX0K9Qvy0Na4o7qSC/YI7XxqUw5KDw01umqgID+svdQ=
gorm.io/driver/sqlite v1.3.2/go.mod h1:B+8GyC9K7VgzJAcrcXMRPdnMcck+8FgJynEehEPM16U= gorm.io/driver/mysql v1.4.4/go.mod h1:BCg8cKI+R0j/rZRQxeKis/forqRwRSYOR8OM3Wo6hOM=
gorm.io/gorm v1.23.1/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/driver/postgres v1.4.5 h1:mTeXTTtHAgnS9PgmhN2YeUbazYpLhUI1doLnw42XUZc=
gorm.io/gorm v1.23.4 h1:1BKWM67O6CflSLcwGQR7ccfmC4ebOxQrTfOQGRE9wjg= gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/gorm v1.23.4/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.24.2 h1:9wR6CFD+G8nOusLdvkZelOEhpJVwwHzpQOUM+REd6U0=
gorm.io/gorm v1.24.2/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View file

@ -15,6 +15,7 @@
package params package params
import ( import (
"encoding/json"
"garm/config" "garm/config"
"garm/runner/providers/common" "garm/runner/providers/common"
"time" "time"
@ -128,6 +129,12 @@ type BootstrapInstance struct {
// SSHKeys are the ssh public keys we may want to inject inside the runners, if the // SSHKeys are the ssh public keys we may want to inject inside the runners, if the
// provider supports it. // provider supports it.
SSHKeys []string `json:"ssh-keys"` SSHKeys []string `json:"ssh-keys"`
// ExtraSpecs is an opaque raw json that gets sent to the provider
// as part of the bootstrap params for instances. It can contain
// any kind of data needed by providers. The contents of this field means
// nothing to garm itself. We don't act on the information in this field at
// all. We only validate that it's a proper json.
ExtraSpecs json.RawMessage `json:"extra_specs,omitempty"`
CACertBundle []byte `json:"ca-cert-bundle"` CACertBundle []byte `json:"ca-cert-bundle"`
@ -164,6 +171,12 @@ type Pool struct {
EnterpriseID string `json:"enterprise_id,omitempty"` EnterpriseID string `json:"enterprise_id,omitempty"`
EnterpriseName string `json:"enterprise_name,omitempty"` EnterpriseName string `json:"enterprise_name,omitempty"`
RunnerBootstrapTimeout uint `json:"runner_bootstrap_timeout"` RunnerBootstrapTimeout uint `json:"runner_bootstrap_timeout"`
// ExtraSpecs is an opaque raw json that gets sent to the provider
// as part of the bootstrap params for instances. It can contain
// any kind of data needed by providers. The contents of this field means
// nothing to garm itself. We don't act on the information in this field at
// all. We only validate that it's a proper json.
ExtraSpecs json.RawMessage `json:"extra_specs,omitempty"`
} }
func (p Pool) GetID() string { func (p Pool) GetID() string {

View file

@ -15,6 +15,7 @@
package params package params
import ( import (
"encoding/json"
"fmt" "fmt"
"garm/config" "garm/config"
@ -108,15 +109,16 @@ type NewUserParams struct {
type UpdatePoolParams struct { type UpdatePoolParams struct {
RunnerPrefix RunnerPrefix
Tags []string `json:"tags,omitempty"` Tags []string `json:"tags,omitempty"`
Enabled *bool `json:"enabled,omitempty"` Enabled *bool `json:"enabled,omitempty"`
MaxRunners *uint `json:"max_runners,omitempty"` MaxRunners *uint `json:"max_runners,omitempty"`
MinIdleRunners *uint `json:"min_idle_runners,omitempty"` MinIdleRunners *uint `json:"min_idle_runners,omitempty"`
RunnerBootstrapTimeout *uint `json:"runner_bootstrap_timeout,omitempty"` RunnerBootstrapTimeout *uint `json:"runner_bootstrap_timeout,omitempty"`
Image string `json:"image"` Image string `json:"image"`
Flavor string `json:"flavor"` Flavor string `json:"flavor"`
OSType config.OSType `json:"os_type"` OSType config.OSType `json:"os_type"`
OSArch config.OSArch `json:"os_arch"` OSArch config.OSArch `json:"os_arch"`
ExtraSpecs json.RawMessage `json:"extra_specs,omitempty"`
} }
type CreateInstanceParams struct { type CreateInstanceParams struct {
@ -133,16 +135,17 @@ type CreateInstanceParams struct {
type CreatePoolParams struct { type CreatePoolParams struct {
RunnerPrefix RunnerPrefix
ProviderName string `json:"provider_name"` ProviderName string `json:"provider_name"`
MaxRunners uint `json:"max_runners"` MaxRunners uint `json:"max_runners"`
MinIdleRunners uint `json:"min_idle_runners"` MinIdleRunners uint `json:"min_idle_runners"`
Image string `json:"image"` Image string `json:"image"`
Flavor string `json:"flavor"` Flavor string `json:"flavor"`
OSType config.OSType `json:"os_type"` OSType config.OSType `json:"os_type"`
OSArch config.OSArch `json:"os_arch"` OSArch config.OSArch `json:"os_arch"`
Tags []string `json:"tags"` Tags []string `json:"tags"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
RunnerBootstrapTimeout uint `json:"runner_bootstrap_timeout"` RunnerBootstrapTimeout uint `json:"runner_bootstrap_timeout"`
ExtraSpecs json.RawMessage `json:"extra_specs,omitempty"`
} }
func (p *CreatePoolParams) Validate() error { func (p *CreatePoolParams) Validate() error {

View file

@ -415,7 +415,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() {
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Fixtures.Store.GetEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, pool.ID) _, err = s.Fixtures.Store.GetEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, pool.ID)
s.Require().Equal("fetching pool: not found", err.Error()) s.Require().Equal("fetching pool: finding pool: not found", err.Error())
} }
func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolErrUnauthorized() { func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolErrUnauthorized() {

View file

@ -415,7 +415,7 @@ func (s *OrgTestSuite) TestDeleteOrgPool() {
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Fixtures.Store.GetOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, pool.ID) _, err = s.Fixtures.Store.GetOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, pool.ID)
s.Require().Equal("fetching pool: not found", err.Error()) s.Require().Equal("fetching pool: finding pool: not found", err.Error())
} }
func (s *OrgTestSuite) TestDeleteOrgPoolErrUnauthorized() { func (s *OrgTestSuite) TestDeleteOrgPoolErrUnauthorized() {

View file

@ -592,6 +592,7 @@ func (r *basePoolManager) addInstanceToProvider(instance params.Instance) error
OSArch: pool.OSArch, OSArch: pool.OSArch,
Flavor: pool.Flavor, Flavor: pool.Flavor,
Image: pool.Image, Image: pool.Image,
ExtraSpecs: pool.ExtraSpecs,
Labels: labels, Labels: labels,
PoolID: instance.PoolID, PoolID: instance.PoolID,
CACertBundle: r.credsDetails.CABundle, CACertBundle: r.credsDetails.CABundle,
@ -835,7 +836,7 @@ func (r *basePoolManager) ensureIdleRunnersForOnePool(pool params.Pool) {
} }
for i := 0; i < required; i++ { for i := 0; i < required; i++ {
log.Printf("addind new idle worker to pool %s", pool.ID) log.Printf("adding new idle worker to pool %s", pool.ID)
if err := r.AddRunner(r.ctx, pool.ID); err != nil { if err := r.AddRunner(r.ctx, pool.ID); err != nil {
log.Printf("failed to add new instance for pool %s: %s", pool.ID, err) log.Printf("failed to add new instance for pool %s: %s", pool.ID, err)
} }

View file

@ -418,7 +418,7 @@ func (s *RepoTestSuite) TestDeleteRepoPool() {
s.Require().Nil(err) s.Require().Nil(err)
_, err = s.Fixtures.Store.GetRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, pool.ID) _, err = s.Fixtures.Store.GetRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, pool.ID)
s.Require().Equal("fetching pool: not found", err.Error()) s.Require().Equal("fetching pool: finding pool: not found", err.Error())
} }
func (s *RepoTestSuite) TestDeleteRepoPoolErrUnauthorized() { func (s *RepoTestSuite) TestDeleteRepoPoolErrUnauthorized() {

View file

@ -23,6 +23,7 @@ Asta Xie <xiemengjun at gmail.com>
Bulat Gaifullin <gaifullinbf at gmail.com> Bulat Gaifullin <gaifullinbf at gmail.com>
Caine Jette <jette at alum.mit.edu> Caine Jette <jette at alum.mit.edu>
Carlos Nieto <jose.carlos at menteslibres.net> Carlos Nieto <jose.carlos at menteslibres.net>
Chris Kirkland <chriskirkland at github.com>
Chris Moos <chris at tech9computers.com> Chris Moos <chris at tech9computers.com>
Craig Wilson <craiggwilson at gmail.com> Craig Wilson <craiggwilson at gmail.com>
Daniel Montoya <dsmontoyam at gmail.com> Daniel Montoya <dsmontoyam at gmail.com>
@ -45,6 +46,7 @@ Ilia Cimpoes <ichimpoesh at gmail.com>
INADA Naoki <songofacandy at gmail.com> INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com> Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com> James Harr <james.harr at gmail.com>
Janek Vedock <janekvedock at comcast.net>
Jeff Hodges <jeff at somethingsimilar.com> Jeff Hodges <jeff at somethingsimilar.com>
Jeffrey Charles <jeffreycharles at gmail.com> Jeffrey Charles <jeffreycharles at gmail.com>
Jerome Meyer <jxmeyer at gmail.com> Jerome Meyer <jxmeyer at gmail.com>
@ -59,12 +61,14 @@ Kamil Dziedzic <kamil at klecza.pl>
Kei Kamikawa <x00.x7f.x86 at gmail.com> Kei Kamikawa <x00.x7f.x86 at gmail.com>
Kevin Malachowski <kevin at chowski.com> Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com> Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lance Tian <lance6716 at gmail.com>
Lennart Rudolph <lrudolph at hmc.edu> Lennart Rudolph <lrudolph at hmc.edu>
Leonardo YongUk Kim <dalinaum at gmail.com> Leonardo YongUk Kim <dalinaum at gmail.com>
Linh Tran Tuan <linhduonggnu at gmail.com> Linh Tran Tuan <linhduonggnu at gmail.com>
Lion Yang <lion at aosc.xyz> Lion Yang <lion at aosc.xyz>
Luca Looz <luca.looz92 at gmail.com> Luca Looz <luca.looz92 at gmail.com>
Lucas Liu <extrafliu at gmail.com> Lucas Liu <extrafliu at gmail.com>
Lunny Xiao <xiaolunwen at gmail.com>
Luke Scott <luke at webconnex.com> Luke Scott <luke at webconnex.com>
Maciej Zimnoch <maciej.zimnoch at codilime.com> Maciej Zimnoch <maciej.zimnoch at codilime.com>
Michael Woolnough <michael.woolnough at gmail.com> Michael Woolnough <michael.woolnough at gmail.com>
@ -79,6 +83,7 @@ Reed Allman <rdallman10 at gmail.com>
Richard Wilkes <wilkes at me.com> Richard Wilkes <wilkes at me.com>
Robert Russell <robert at rrbrussell.com> Robert Russell <robert at rrbrussell.com>
Runrioter Wung <runrioter at gmail.com> Runrioter Wung <runrioter at gmail.com>
Santhosh Kumar Tekuri <santhosh.tekuri at gmail.com>
Sho Iizuka <sho.i518 at gmail.com> Sho Iizuka <sho.i518 at gmail.com>
Sho Ikeda <suicaicoca at gmail.com> Sho Ikeda <suicaicoca at gmail.com>
Shuode Li <elemount at qq.com> Shuode Li <elemount at qq.com>
@ -99,12 +104,14 @@ Xiuming Chen <cc at cxm.cc>
Xuehong Chan <chanxuehong at gmail.com> Xuehong Chan <chanxuehong at gmail.com>
Zhenye Xie <xiezhenye at gmail.com> Zhenye Xie <xiezhenye at gmail.com>
Zhixin Wen <john.wenzhixin at gmail.com> Zhixin Wen <john.wenzhixin at gmail.com>
Ziheng Lyu <zihenglv at gmail.com>
# Organizations # Organizations
Barracuda Networks, Inc. Barracuda Networks, Inc.
Counting Ltd. Counting Ltd.
DigitalOcean Inc. DigitalOcean Inc.
dyves labs AG
Facebook Inc. Facebook Inc.
GitHub Inc. GitHub Inc.
Google Inc. Google Inc.

View file

@ -1,3 +1,24 @@
## Version 1.7 (2022-11-29)
Changes:
- Drop support of Go 1.12 (#1211)
- Refactoring `(*textRows).readRow` in a more clear way (#1230)
- util: Reduce boundary check in escape functions. (#1316)
- enhancement for mysqlConn handleAuthResult (#1250)
New Features:
- support Is comparison on MySQLError (#1210)
- return unsigned in database type name when necessary (#1238)
- Add API to express like a --ssl-mode=PREFERRED MySQL client (#1370)
- Add SQLState to MySQLError (#1321)
Bugfixes:
- Fix parsing 0 year. (#1257)
## Version 1.6 (2021-04-01) ## Version 1.6 (2021-04-01)
Changes: Changes:

View file

@ -40,7 +40,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
* Optional placeholder interpolation * Optional placeholder interpolation
## Requirements ## Requirements
* Go 1.10 or higher. We aim to support the 3 latest versions of Go. * Go 1.13 or higher. We aim to support the 3 latest versions of Go.
* MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
--------------------------------------- ---------------------------------------
@ -85,7 +85,7 @@ db.SetMaxIdleConns(10)
`db.SetMaxOpenConns()` is highly recommended to limit the number of connection used by the application. There is no recommended limit number because it depends on application and MySQL server. `db.SetMaxOpenConns()` is highly recommended to limit the number of connection used by the application. There is no recommended limit number because it depends on application and MySQL server.
`db.SetMaxIdleConns()` is recommended to be set same to (or greater than) `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed very frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15. `db.SetMaxIdleConns()` is recommended to be set same to `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed much more frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15.
### DSN (Data Source Name) ### DSN (Data Source Name)
@ -157,6 +157,17 @@ Default: false
`allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. `allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network.
##### `allowFallbackToPlaintext`
```
Type: bool
Valid Values: true, false
Default: false
```
`allowFallbackToPlaintext=true` acts like a `--ssl-mode=PREFERRED` MySQL client as described in [Command Options for Connecting to the Server](https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode)
##### `allowNativePasswords` ##### `allowNativePasswords`
``` ```
@ -454,7 +465,7 @@ user:password@/
The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively.
## `ColumnType` Support ## `ColumnType` Support
This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `BIGINT`.
## `context.Context` Support ## `context.Context` Support
Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts.

19
vendor/github.com/go-sql-driver/mysql/atomic_bool.go generated vendored Normal file
View file

@ -0,0 +1,19 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build go1.19
// +build go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
type atomicBool = atomic.Bool

View file

@ -0,0 +1,47 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !go1.19
// +build !go1.19
package mysql
import "sync/atomic"
/******************************************************************************
* Sync utils *
******************************************************************************/
// atomicBool is an implementation of atomic.Bool for older version of Go.
// it is a wrapper around uint32 for usage as a boolean value with
// atomic access.
type atomicBool struct {
_ noCopy
value uint32
}
// Load returns whether the current boolean value is true
func (ab *atomicBool) Load() bool {
return atomic.LoadUint32(&ab.value) > 0
}
// Store sets the value of the bool regardless of the previous value
func (ab *atomicBool) Store(value bool) {
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}
// Swap sets the value of the bool and returns the old value.
func (ab *atomicBool) Swap(value bool) bool {
if value {
return atomic.SwapUint32(&ab.value, 1) > 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}

View file

@ -33,27 +33,26 @@ var (
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver // Note: The provided rsa.PublicKey instance is exclusively owned by the driver
// after registering it and may not be modified. // after registering it and may not be modified.
// //
// data, err := ioutil.ReadFile("mykey.pem") // data, err := ioutil.ReadFile("mykey.pem")
// if err != nil { // if err != nil {
// log.Fatal(err) // log.Fatal(err)
// } // }
// //
// block, _ := pem.Decode(data) // block, _ := pem.Decode(data)
// if block == nil || block.Type != "PUBLIC KEY" { // if block == nil || block.Type != "PUBLIC KEY" {
// log.Fatal("failed to decode PEM block containing public key") // log.Fatal("failed to decode PEM block containing public key")
// } // }
// //
// pub, err := x509.ParsePKIXPublicKey(block.Bytes) // pub, err := x509.ParsePKIXPublicKey(block.Bytes)
// if err != nil { // if err != nil {
// log.Fatal(err) // log.Fatal(err)
// } // }
//
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
// mysql.RegisterServerPubKey("mykey", rsaPubKey)
// } else {
// log.Fatal("not a RSA public key")
// }
// //
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
// mysql.RegisterServerPubKey("mykey", rsaPubKey)
// } else {
// log.Fatal("not a RSA public key")
// }
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) {
serverPubKeyLock.Lock() serverPubKeyLock.Lock()
if serverPubKeyRegistry == nil { if serverPubKeyRegistry == nil {
@ -274,7 +273,9 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
if len(mc.cfg.Passwd) == 0 { if len(mc.cfg.Passwd) == 0 {
return []byte{0}, nil return []byte{0}, nil
} }
if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // unlike caching_sha2_password, sha256_password does not accept
// cleartext password on unix transport.
if mc.cfg.TLS != nil {
// write cleartext auth packet // write cleartext auth packet
return append([]byte(mc.cfg.Passwd), 0), nil return append([]byte(mc.cfg.Passwd), 0), nil
} }
@ -350,7 +351,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
} }
case cachingSha2PasswordPerformFullAuthentication: case cachingSha2PasswordPerformFullAuthentication:
if mc.cfg.tls != nil || mc.cfg.Net == "unix" { if mc.cfg.TLS != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet // write cleartext auth packet
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
if err != nil { if err != nil {
@ -365,13 +366,20 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
return err return err
} }
data[4] = cachingSha2PasswordRequestPublicKey data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data) err = mc.writePacket(data)
if err != nil {
return err
}
// parse public key
if data, err = mc.readPacket(); err != nil { if data, err = mc.readPacket(); err != nil {
return err return err
} }
if data[0] != iAuthMoreData {
return fmt.Errorf("unexpect resp from server for caching_sha2_password perform full authentication")
}
// parse public key
block, rest := pem.Decode(data[1:]) block, rest := pem.Decode(data[1:])
if block == nil { if block == nil {
return fmt.Errorf("No Pem data found, data: %s", rest) return fmt.Errorf("No Pem data found, data: %s", rest)
@ -404,6 +412,10 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
return nil // auth successful return nil // auth successful
default: default:
block, _ := pem.Decode(authData) block, _ := pem.Decode(authData)
if block == nil {
return fmt.Errorf("no Pem data found, data: %s", authData)
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes) pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil { if err != nil {
return err return err

View file

@ -13,7 +13,8 @@ const binaryCollation = "binary"
// A list of available collations mapped to the internal ID. // A list of available collations mapped to the internal ID.
// To update this map use the following MySQL query: // To update this map use the following MySQL query:
// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID //
// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID
// //
// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255. // Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255.
// //

View file

@ -6,6 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file, // License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/. // You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos
// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos // +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
package mysql package mysql

View file

@ -6,6 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file, // License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/. // You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos
// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos // +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos
package mysql package mysql

View file

@ -104,7 +104,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
} }
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.IsSet() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn) errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
@ -123,7 +123,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
func (mc *mysqlConn) Close() (err error) { func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent // Makes Close idempotent
if !mc.closed.IsSet() { if !mc.closed.Load() {
err = mc.writeCommandPacket(comQuit) err = mc.writeCommandPacket(comQuit)
} }
@ -137,7 +137,7 @@ func (mc *mysqlConn) Close() (err error) {
// is called before auth or on auth failure because MySQL will have already // is called before auth or on auth failure because MySQL will have already
// closed the network connection. // closed the network connection.
func (mc *mysqlConn) cleanup() { func (mc *mysqlConn) cleanup() {
if !mc.closed.TrySet(true) { if mc.closed.Swap(true) {
return return
} }
@ -152,7 +152,7 @@ func (mc *mysqlConn) cleanup() {
} }
func (mc *mysqlConn) error() error { func (mc *mysqlConn) error() error {
if mc.closed.IsSet() { if mc.closed.Load() {
if err := mc.canceled.Value(); err != nil { if err := mc.canceled.Value(); err != nil {
return err return err
} }
@ -162,7 +162,7 @@ func (mc *mysqlConn) error() error {
} }
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.IsSet() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn) errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
@ -295,7 +295,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
} }
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.IsSet() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn) errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
@ -356,7 +356,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
} }
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
if mc.closed.IsSet() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn) errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
@ -450,7 +450,7 @@ func (mc *mysqlConn) finish() {
// Ping implements driver.Pinger interface // Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) { func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.IsSet() { if mc.closed.Load() {
errLog.Print(ErrInvalidConn) errLog.Print(ErrInvalidConn)
return driver.ErrBadConn return driver.ErrBadConn
} }
@ -469,7 +469,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
// BeginTx implements driver.ConnBeginTx interface // BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if mc.closed.IsSet() { if mc.closed.Load() {
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
@ -636,7 +636,7 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
// ResetSession implements driver.SessionResetter. // ResetSession implements driver.SessionResetter.
// (From Go 1.10) // (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error { func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.IsSet() { if mc.closed.Load() {
return driver.ErrBadConn return driver.ErrBadConn
} }
mc.reset = true mc.reset = true
@ -646,5 +646,5 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
// IsValid implements driver.Validator interface // IsValid implements driver.Validator interface
// (From Go 1.15) // (From Go 1.15)
func (mc *mysqlConn) IsValid() bool { func (mc *mysqlConn) IsValid() bool {
return !mc.closed.IsSet() return !mc.closed.Load()
} }

View file

@ -8,10 +8,10 @@
// //
// The driver should be used via the database/sql package: // The driver should be used via the database/sql package:
// //
// import "database/sql" // import "database/sql"
// import _ "github.com/go-sql-driver/mysql" // import _ "github.com/go-sql-driver/mysql"
// //
// db, err := sql.Open("mysql", "user:password@/dbname") // db, err := sql.Open("mysql", "user:password@/dbname")
// //
// See https://github.com/go-sql-driver/mysql#usage for details // See https://github.com/go-sql-driver/mysql#usage for details
package mysql package mysql

View file

@ -46,22 +46,23 @@ type Config struct {
ServerPubKey string // Server public key name ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout WriteTimeout time.Duration // I/O write timeout
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowNativePasswords bool // Allows the native password authentication method AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
AllowOldPasswords bool // Allows the old insecure password method AllowNativePasswords bool // Allows the native password authentication method
CheckConnLiveness bool // Check connections for liveness before using them AllowOldPasswords bool // Allows the old insecure password method
ClientFoundRows bool // Return number of matching rows instead of rows changed CheckConnLiveness bool // Check connections for liveness before using them
ColumnsWithAlias bool // Prepend table alias to column names ClientFoundRows bool // Return number of matching rows instead of rows changed
InterpolateParams bool // Interpolate placeholders into query string ColumnsWithAlias bool // Prepend table alias to column names
MultiStatements bool // Allow multiple statements in one query InterpolateParams bool // Interpolate placeholders into query string
ParseTime bool // Parse time values to time.Time MultiStatements bool // Allow multiple statements in one query
RejectReadOnly bool // Reject read-only connections ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
} }
// NewConfig creates a new Config and sets default values. // NewConfig creates a new Config and sets default values.
@ -77,8 +78,8 @@ func NewConfig() *Config {
func (cfg *Config) Clone() *Config { func (cfg *Config) Clone() *Config {
cp := *cfg cp := *cfg
if cp.tls != nil { if cp.TLS != nil {
cp.tls = cfg.tls.Clone() cp.TLS = cfg.TLS.Clone()
} }
if len(cp.Params) > 0 { if len(cp.Params) > 0 {
cp.Params = make(map[string]string, len(cfg.Params)) cp.Params = make(map[string]string, len(cfg.Params))
@ -119,24 +120,29 @@ func (cfg *Config) normalize() error {
cfg.Addr = ensureHavePort(cfg.Addr) cfg.Addr = ensureHavePort(cfg.Addr)
} }
switch cfg.TLSConfig { if cfg.TLS == nil {
case "false", "": switch cfg.TLSConfig {
// don't set anything case "false", "":
case "true": // don't set anything
cfg.tls = &tls.Config{} case "true":
case "skip-verify", "preferred": cfg.TLS = &tls.Config{}
cfg.tls = &tls.Config{InsecureSkipVerify: true} case "skip-verify":
default: cfg.TLS = &tls.Config{InsecureSkipVerify: true}
cfg.tls = getTLSConfigClone(cfg.TLSConfig) case "preferred":
if cfg.tls == nil { cfg.TLS = &tls.Config{InsecureSkipVerify: true}
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) cfg.AllowFallbackToPlaintext = true
default:
cfg.TLS = getTLSConfigClone(cfg.TLSConfig)
if cfg.TLS == nil {
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
}
} }
} }
if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { if cfg.TLS != nil && cfg.TLS.ServerName == "" && !cfg.TLS.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr) host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil { if err == nil {
cfg.tls.ServerName = host cfg.TLS.ServerName = host
} }
} }
@ -204,6 +210,10 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true") writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true")
} }
if cfg.AllowFallbackToPlaintext {
writeDSNParam(&buf, &hasParam, "allowFallbackToPlaintext", "true")
}
if !cfg.AllowNativePasswords { if !cfg.AllowNativePasswords {
writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false") writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false")
} }
@ -391,6 +401,14 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return errors.New("invalid bool value: " + value) return errors.New("invalid bool value: " + value)
} }
// Allow fallback to unencrypted connection if server does not support TLS
case "allowFallbackToPlaintext":
var isBool bool
cfg.AllowFallbackToPlaintext, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use native password authentication // Use native password authentication
case "allowNativePasswords": case "allowNativePasswords":
var isBool bool var isBool bool
@ -426,7 +444,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
// Collation // Collation
case "collation": case "collation":
cfg.Collation = value cfg.Collation = value
break
case "columnsWithAlias": case "columnsWithAlias":
var isBool bool var isBool bool

View file

@ -56,10 +56,22 @@ func SetLogger(logger Logger) error {
// MySQLError is an error type which represents a single MySQL error // MySQLError is an error type which represents a single MySQL error
type MySQLError struct { type MySQLError struct {
Number uint16 Number uint16
Message string SQLState [5]byte
Message string
} }
func (me *MySQLError) Error() string { func (me *MySQLError) Error() string {
if me.SQLState != [5]byte{} {
return fmt.Sprintf("Error %d (%s): %s", me.Number, me.SQLState, me.Message)
}
return fmt.Sprintf("Error %d: %s", me.Number, me.Message) return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
} }
func (me *MySQLError) Is(err error) bool {
if merr, ok := err.(*MySQLError); ok {
return merr.Number == me.Number
}
return false
}

View file

@ -41,6 +41,9 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeJSON: case fieldTypeJSON:
return "JSON" return "JSON"
case fieldTypeLong: case fieldTypeLong:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED INT"
}
return "INT" return "INT"
case fieldTypeLongBLOB: case fieldTypeLongBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != collations[binaryCollation] {
@ -48,6 +51,9 @@ func (mf *mysqlField) typeDatabaseName() string {
} }
return "LONGBLOB" return "LONGBLOB"
case fieldTypeLongLong: case fieldTypeLongLong:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED BIGINT"
}
return "BIGINT" return "BIGINT"
case fieldTypeMediumBLOB: case fieldTypeMediumBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != collations[binaryCollation] {
@ -63,6 +69,9 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeSet: case fieldTypeSet:
return "SET" return "SET"
case fieldTypeShort: case fieldTypeShort:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED SMALLINT"
}
return "SMALLINT" return "SMALLINT"
case fieldTypeString: case fieldTypeString:
if mf.charSet == collations[binaryCollation] { if mf.charSet == collations[binaryCollation] {
@ -74,6 +83,9 @@ func (mf *mysqlField) typeDatabaseName() string {
case fieldTypeTimestamp: case fieldTypeTimestamp:
return "TIMESTAMP" return "TIMESTAMP"
case fieldTypeTiny: case fieldTypeTiny:
if mf.flags&flagUnsigned != 0 {
return "UNSIGNED TINYINT"
}
return "TINYINT" return "TINYINT"
case fieldTypeTinyBLOB: case fieldTypeTinyBLOB:
if mf.charSet != collations[binaryCollation] { if mf.charSet != collations[binaryCollation] {
@ -106,7 +118,7 @@ var (
scanTypeInt64 = reflect.TypeOf(int64(0)) scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(nullTime{}) scanTypeNullTime = reflect.TypeOf(sql.NullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0)) scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0)) scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0)) scanTypeUint32 = reflect.TypeOf(uint32(0))

View file

@ -6,6 +6,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file, // License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/. // You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build gofuzz
// +build gofuzz // +build gofuzz
package mysql package mysql

View file

@ -28,12 +28,11 @@ var (
// Alternatively you can allow the use of all local files with // Alternatively you can allow the use of all local files with
// the DSN parameter 'allowAllFiles=true' // the DSN parameter 'allowAllFiles=true'
// //
// filePath := "/home/gopher/data.csv" // filePath := "/home/gopher/data.csv"
// mysql.RegisterLocalFile(filePath) // mysql.RegisterLocalFile(filePath)
// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") // err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
// if err != nil { // if err != nil {
// ... // ...
//
func RegisterLocalFile(filePath string) { func RegisterLocalFile(filePath string) {
fileRegisterLock.Lock() fileRegisterLock.Lock()
// lazy map init // lazy map init
@ -58,15 +57,14 @@ func DeregisterLocalFile(filePath string) {
// If the handler returns a io.ReadCloser Close() is called when the // If the handler returns a io.ReadCloser Close() is called when the
// request is finished. // request is finished.
// //
// mysql.RegisterReaderHandler("data", func() io.Reader { // mysql.RegisterReaderHandler("data", func() io.Reader {
// var csvReader io.Reader // Some Reader that returns CSV data // var csvReader io.Reader // Some Reader that returns CSV data
// ... // Open Reader here // ... // Open Reader here
// return csvReader // return csvReader
// }) // })
// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") // err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
// if err != nil { // if err != nil {
// ... // ...
//
func RegisterReaderHandler(name string, handler func() io.Reader) { func RegisterReaderHandler(name string, handler func() io.Reader) {
readerRegisterLock.Lock() readerRegisterLock.Lock()
// lazy map init // lazy map init
@ -93,10 +91,12 @@ func deferredClose(err *error, closer io.Closer) {
} }
} }
const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
func (mc *mysqlConn) handleInFileRequest(name string) (err error) { func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
var rdr io.Reader var rdr io.Reader
var data []byte var data []byte
packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP packetSize := defaultPacketSize
if mc.maxWriteSize < packetSize { if mc.maxWriteSize < packetSize {
packetSize = mc.maxWriteSize packetSize = mc.maxWriteSize
} }

View file

@ -9,11 +9,32 @@
package mysql package mysql
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"time" "time"
) )
// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// # This NullTime implementation is not driver-specific
//
// Deprecated: NullTime doesn't honor the loc DSN parameter.
// NullTime.Scan interprets a time as UTC, not the loc DSN parameter.
// Use sql.NullTime instead.
type NullTime sql.NullTime
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string), // The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails. // otherwise Scan fails.

View file

@ -1,40 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build go1.13
package mysql
import (
"database/sql"
)
// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// This NullTime implementation is not driver-specific
//
// Deprecated: NullTime doesn't honor the loc DSN parameter.
// NullTime.Scan interprets a time as UTC, not the loc DSN parameter.
// Use sql.NullTime instead.
type NullTime sql.NullTime
// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = sql.NullTime // sql.NullTime is available

View file

@ -1,39 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build !go1.13
package mysql
import (
"time"
)
// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// This NullTime implementation is not driver-specific
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = NullTime // sql.NullTime is not available

View file

@ -110,14 +110,13 @@ func (mc *mysqlConn) writePacket(data []byte) error {
conn = mc.rawConn conn = mc.rawConn
} }
var err error var err error
// If this connection has a ReadTimeout which we've been setting on if mc.cfg.CheckConnLiveness {
// reads, reset it to its default value before we attempt a non-blocking if mc.cfg.ReadTimeout != 0 {
// read, otherwise the scheduler will just time us out before we can read err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
if mc.cfg.ReadTimeout != 0 { }
err = conn.SetReadDeadline(time.Time{}) if err == nil {
} err = connCheck(conn)
if err == nil && mc.cfg.CheckConnLiveness { }
err = connCheck(conn)
} }
if err != nil { if err != nil {
errLog.Print("closing bad idle connection: ", err) errLog.Print("closing bad idle connection: ", err)
@ -223,9 +222,9 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if mc.flags&clientProtocol41 == 0 { if mc.flags&clientProtocol41 == 0 {
return nil, "", ErrOldProtocol return nil, "", ErrOldProtocol
} }
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
if mc.cfg.TLSConfig == "preferred" { if mc.cfg.AllowFallbackToPlaintext {
mc.cfg.tls = nil mc.cfg.TLS = nil
} else { } else {
return nil, "", ErrNoTLS return nil, "", ErrNoTLS
} }
@ -293,7 +292,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
} }
// To enable TLS / SSL // To enable TLS / SSL
if mc.cfg.tls != nil { if mc.cfg.TLS != nil {
clientFlags |= clientSSL clientFlags |= clientSSL
} }
@ -357,14 +356,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// SSL Connection Request Packet // SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if mc.cfg.tls != nil { if mc.cfg.TLS != nil {
// Send TLS / SSL request packet // Send TLS / SSL request packet
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
return err return err
} }
// Switch to TLS // Switch to TLS
tlsConn := tls.Client(mc.netConn, mc.cfg.tls) tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
return err return err
} }
@ -588,19 +587,20 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
return driver.ErrBadConn return driver.ErrBadConn
} }
me := &MySQLError{Number: errno}
pos := 3 pos := 3
// SQL State [optional: # + 5bytes string] // SQL State [optional: # + 5bytes string]
if data[3] == 0x23 { if data[3] == 0x23 {
//sqlstate := string(data[4 : 4+5]) copy(me.SQLState[:], data[4:4+5])
pos = 9 pos = 9
} }
// Error Message [string] // Error Message [string]
return &MySQLError{ me.Message = string(data[pos:])
Number: errno,
Message: string(data[pos:]), return me
}
} }
func readStatus(b []byte) statusFlag { func readStatus(b []byte) statusFlag {
@ -761,40 +761,40 @@ func (rows *textRows) readRow(dest []driver.Value) error {
} }
// RowSet Packet // RowSet Packet
var n int var (
var isNull bool n int
pos := 0 isNull bool
pos int = 0
)
for i := range dest { for i := range dest {
// Read bytes and convert to string // Read bytes and convert to string
dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
pos += n pos += n
if err == nil {
if !isNull {
if !mc.parseTime {
continue
} else {
switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
dest[i].([]byte),
mc.cfg.Loc,
)
if err == nil {
continue
}
default:
continue
}
}
} else { if err != nil {
dest[i] = nil return err
continue }
if isNull {
dest[i] = nil
continue
}
if !mc.parseTime {
continue
}
// Parse time field
switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp,
fieldTypeDateTime,
fieldTypeDate,
fieldTypeNewDate:
if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil {
return err
} }
} }
return err // err != nil
} }
return nil return nil

View file

@ -23,7 +23,7 @@ type mysqlStmt struct {
} }
func (stmt *mysqlStmt) Close() error { func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.closed.IsSet() { if stmt.mc == nil || stmt.mc.closed.Load() {
// driver.Stmt.Close can be called more than once, thus this function // driver.Stmt.Close can be called more than once, thus this function
// has to be idempotent. // has to be idempotent.
// See also Issue #450 and golang/go#16019. // See also Issue #450 and golang/go#16019.
@ -50,7 +50,7 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
} }
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.IsSet() { if stmt.mc.closed.Load() {
errLog.Print(ErrInvalidConn) errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
@ -98,7 +98,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
} }
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.IsSet() { if stmt.mc.closed.Load() {
errLog.Print(ErrInvalidConn) errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
@ -157,7 +157,7 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
if driver.IsValue(sv) { if driver.IsValue(sv) {
return sv, nil return sv, nil
} }
// A value returend from the Valuer interface can be "a type handled by // A value returned from the Valuer interface can be "a type handled by
// a database driver's NamedValueChecker interface" so we should accept // a database driver's NamedValueChecker interface" so we should accept
// uint64 here as well. // uint64 here as well.
if u, ok := sv.(uint64); ok { if u, ok := sv.(uint64); ok {

View file

@ -13,7 +13,7 @@ type mysqlTx struct {
} }
func (tx *mysqlTx) Commit() (err error) { func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.closed.IsSet() { if tx.mc == nil || tx.mc.closed.Load() {
return ErrInvalidConn return ErrInvalidConn
} }
err = tx.mc.exec("COMMIT") err = tx.mc.exec("COMMIT")
@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
} }
func (tx *mysqlTx) Rollback() (err error) { func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.closed.IsSet() { if tx.mc == nil || tx.mc.closed.Load() {
return ErrInvalidConn return ErrInvalidConn
} }
err = tx.mc.exec("ROLLBACK") err = tx.mc.exec("ROLLBACK")

View file

@ -35,26 +35,25 @@ var (
// Note: The provided tls.Config is exclusively owned by the driver after // Note: The provided tls.Config is exclusively owned by the driver after
// registering it. // registering it.
// //
// rootCertPool := x509.NewCertPool() // rootCertPool := x509.NewCertPool()
// pem, err := ioutil.ReadFile("/path/ca-cert.pem") // pem, err := ioutil.ReadFile("/path/ca-cert.pem")
// if err != nil { // if err != nil {
// log.Fatal(err) // log.Fatal(err)
// } // }
// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { // if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
// log.Fatal("Failed to append PEM.") // log.Fatal("Failed to append PEM.")
// } // }
// clientCert := make([]tls.Certificate, 0, 1) // clientCert := make([]tls.Certificate, 0, 1)
// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") // certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
// if err != nil { // if err != nil {
// log.Fatal(err) // log.Fatal(err)
// } // }
// clientCert = append(clientCert, certs) // clientCert = append(clientCert, certs)
// mysql.RegisterTLSConfig("custom", &tls.Config{ // mysql.RegisterTLSConfig("custom", &tls.Config{
// RootCAs: rootCertPool, // RootCAs: rootCertPool,
// Certificates: clientCert, // Certificates: clientCert,
// }) // })
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
//
func RegisterTLSConfig(key string, config *tls.Config) error { func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
return fmt.Errorf("key '%s' is reserved", key) return fmt.Errorf("key '%s' is reserved", key)
@ -118,10 +117,6 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
} }
if year <= 0 {
year = 1
}
if b[4] != '-' { if b[4] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4]) return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
} }
@ -130,9 +125,6 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
} }
if m <= 0 {
m = 1
}
month := time.Month(m) month := time.Month(m)
if b[7] != '-' { if b[7] != '-' {
@ -143,9 +135,6 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
} }
if day <= 0 {
day = 1
}
if len(b) == 10 { if len(b) == 10 {
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
} }
@ -199,7 +188,7 @@ func parseByteYear(b []byte) (int, error) {
return 0, err return 0, err
} }
year += v * n year += v * n
n = n / 10 n /= 10
} }
return year, nil return year, nil
} }
@ -542,7 +531,7 @@ func stringToInt(b []byte) int {
return val return val
} }
// returns the string read as a bytes slice, wheter the value is NULL, // returns the string read as a bytes slice, whether the value is NULL,
// the number of bytes read and an error, in case the string is longer than // the number of bytes read and an error, in case the string is longer than
// the input slice // the input slice
func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
@ -652,32 +641,32 @@ func escapeBytesBackslash(buf, v []byte) []byte {
for _, c := range v { for _, c := range v {
switch c { switch c {
case '\x00': case '\x00':
buf[pos] = '\\'
buf[pos+1] = '0' buf[pos+1] = '0'
buf[pos] = '\\'
pos += 2 pos += 2
case '\n': case '\n':
buf[pos] = '\\'
buf[pos+1] = 'n' buf[pos+1] = 'n'
buf[pos] = '\\'
pos += 2 pos += 2
case '\r': case '\r':
buf[pos] = '\\'
buf[pos+1] = 'r' buf[pos+1] = 'r'
buf[pos] = '\\'
pos += 2 pos += 2
case '\x1a': case '\x1a':
buf[pos] = '\\'
buf[pos+1] = 'Z' buf[pos+1] = 'Z'
buf[pos] = '\\'
pos += 2 pos += 2
case '\'': case '\'':
buf[pos] = '\\'
buf[pos+1] = '\'' buf[pos+1] = '\''
buf[pos] = '\\'
pos += 2 pos += 2
case '"': case '"':
buf[pos] = '\\'
buf[pos+1] = '"' buf[pos+1] = '"'
buf[pos] = '\\'
pos += 2 pos += 2
case '\\': case '\\':
buf[pos] = '\\'
buf[pos+1] = '\\' buf[pos+1] = '\\'
buf[pos] = '\\'
pos += 2 pos += 2
default: default:
buf[pos] = c buf[pos] = c
@ -697,32 +686,32 @@ func escapeStringBackslash(buf []byte, v string) []byte {
c := v[i] c := v[i]
switch c { switch c {
case '\x00': case '\x00':
buf[pos] = '\\'
buf[pos+1] = '0' buf[pos+1] = '0'
buf[pos] = '\\'
pos += 2 pos += 2
case '\n': case '\n':
buf[pos] = '\\'
buf[pos+1] = 'n' buf[pos+1] = 'n'
buf[pos] = '\\'
pos += 2 pos += 2
case '\r': case '\r':
buf[pos] = '\\'
buf[pos+1] = 'r' buf[pos+1] = 'r'
buf[pos] = '\\'
pos += 2 pos += 2
case '\x1a': case '\x1a':
buf[pos] = '\\'
buf[pos+1] = 'Z' buf[pos+1] = 'Z'
buf[pos] = '\\'
pos += 2 pos += 2
case '\'': case '\'':
buf[pos] = '\\'
buf[pos+1] = '\'' buf[pos+1] = '\''
buf[pos] = '\\'
pos += 2 pos += 2
case '"': case '"':
buf[pos] = '\\'
buf[pos+1] = '"' buf[pos+1] = '"'
buf[pos] = '\\'
pos += 2 pos += 2
case '\\': case '\\':
buf[pos] = '\\'
buf[pos+1] = '\\' buf[pos+1] = '\\'
buf[pos] = '\\'
pos += 2 pos += 2
default: default:
buf[pos] = c buf[pos] = c
@ -744,8 +733,8 @@ func escapeBytesQuotes(buf, v []byte) []byte {
for _, c := range v { for _, c := range v {
if c == '\'' { if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\'' buf[pos+1] = '\''
buf[pos] = '\''
pos += 2 pos += 2
} else { } else {
buf[pos] = c buf[pos] = c
@ -764,8 +753,8 @@ func escapeStringQuotes(buf []byte, v string) []byte {
for i := 0; i < len(v); i++ { for i := 0; i < len(v); i++ {
c := v[i] c := v[i]
if c == '\'' { if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\'' buf[pos+1] = '\''
buf[pos] = '\''
pos += 2 pos += 2
} else { } else {
buf[pos] = c buf[pos] = c
@ -790,39 +779,16 @@ type noCopy struct{}
// Lock is a no-op used by -copylocks checker from `go vet`. // Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {} func (*noCopy) Lock() {}
// atomicBool is a wrapper around uint32 for usage as a boolean value with // Unlock is a no-op used by -copylocks checker from `go vet`.
// atomic access. // noCopy should implement sync.Locker from Go 1.11
type atomicBool struct { // https://github.com/golang/go/commit/c2eba53e7f80df21d51285879d51ab81bcfbf6bc
_noCopy noCopy // https://github.com/golang/go/issues/26165
value uint32 func (*noCopy) Unlock() {}
}
// IsSet returns whether the current boolean value is true
func (ab *atomicBool) IsSet() bool {
return atomic.LoadUint32(&ab.value) > 0
}
// Set sets the value of the bool regardless of the previous value
func (ab *atomicBool) Set(value bool) {
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}
// TrySet sets the value of the bool and returns whether the value changed
func (ab *atomicBool) TrySet(value bool) bool {
if value {
return atomic.SwapUint32(&ab.value, 1) == 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}
// atomicError is a wrapper for atomically accessed error values // atomicError is a wrapper for atomically accessed error values
type atomicError struct { type atomicError struct {
_noCopy noCopy _ noCopy
value atomic.Value value atomic.Value
} }
// Set sets the error value regardless of the previous value. // Set sets the error value regardless of the previous value.

View file

@ -353,6 +353,20 @@ func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
return nil return nil
} }
func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error {
if v.IsNil() {
C.sqlite3_result_null(ctx)
return nil
}
cb, err := callbackRet(v.Elem().Type())
if err != nil {
return err
}
return cb(ctx, v.Elem())
}
func callbackRet(typ reflect.Type) (callbackRetConverter, error) { func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
switch typ.Kind() { switch typ.Kind() {
case reflect.Interface: case reflect.Interface:
@ -360,6 +374,11 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
if typ.Implements(errorInterface) { if typ.Implements(errorInterface) {
return callbackRetNil, nil return callbackRetNil, nil
} }
if typ.NumMethod() == 0 {
return callbackRetGeneric, nil
}
fallthrough fallthrough
case reflect.Slice: case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 { if typ.Elem().Kind() != reflect.Uint8 {

File diff suppressed because it is too large Load diff

View file

@ -147,9 +147,9 @@ extern "C" {
** [sqlite3_libversion_number()], [sqlite3_sourceid()], ** [sqlite3_libversion_number()], [sqlite3_sourceid()],
** [sqlite_version()] and [sqlite_source_id()]. ** [sqlite_version()] and [sqlite_source_id()].
*/ */
#define SQLITE_VERSION "3.38.0" #define SQLITE_VERSION "3.39.2"
#define SQLITE_VERSION_NUMBER 3038000 #define SQLITE_VERSION_NUMBER 3039002
#define SQLITE_SOURCE_ID "2022-02-22 18:58:40 40fa792d359f84c3b9e9d6623743e1a59826274e221df1bde8f47086968a1bab" #define SQLITE_SOURCE_ID "2022-07-21 15:24:47 698edb77537b67c41adc68f9b892db56bcf9a55e00371a61420f3ddd668e6603"
/* /*
** CAPI3REF: Run-Time Library Version Numbers ** CAPI3REF: Run-Time Library Version Numbers
@ -4980,6 +4980,10 @@ SQLITE_API int sqlite3_data_count(sqlite3_stmt *pStmt);
** even empty strings, are always zero-terminated. ^The return ** even empty strings, are always zero-terminated. ^The return
** value from sqlite3_column_blob() for a zero-length BLOB is a NULL pointer. ** value from sqlite3_column_blob() for a zero-length BLOB is a NULL pointer.
** **
** ^Strings returned by sqlite3_column_text16() always have the endianness
** which is native to the platform, regardless of the text encoding set
** for the database.
**
** <b>Warning:</b> ^The object returned by [sqlite3_column_value()] is an ** <b>Warning:</b> ^The object returned by [sqlite3_column_value()] is an
** [unprotected sqlite3_value] object. In a multithreaded environment, ** [unprotected sqlite3_value] object. In a multithreaded environment,
** an unprotected sqlite3_value object may only be used safely with ** an unprotected sqlite3_value object may only be used safely with
@ -4993,7 +4997,7 @@ SQLITE_API int sqlite3_data_count(sqlite3_stmt *pStmt);
** [application-defined SQL functions] or [virtual tables], not within ** [application-defined SQL functions] or [virtual tables], not within
** top-level application code. ** top-level application code.
** **
** The these routines may attempt to convert the datatype of the result. ** These routines may attempt to convert the datatype of the result.
** ^For example, if the internal representation is FLOAT and a text result ** ^For example, if the internal representation is FLOAT and a text result
** is requested, [sqlite3_snprintf()] is used internally to perform the ** is requested, [sqlite3_snprintf()] is used internally to perform the
** conversion automatically. ^(The following table details the conversions ** conversion automatically. ^(The following table details the conversions
@ -5018,7 +5022,7 @@ SQLITE_API int sqlite3_data_count(sqlite3_stmt *pStmt);
** <tr><td> TEXT <td> BLOB <td> No change ** <tr><td> TEXT <td> BLOB <td> No change
** <tr><td> BLOB <td> INTEGER <td> [CAST] to INTEGER ** <tr><td> BLOB <td> INTEGER <td> [CAST] to INTEGER
** <tr><td> BLOB <td> FLOAT <td> [CAST] to REAL ** <tr><td> BLOB <td> FLOAT <td> [CAST] to REAL
** <tr><td> BLOB <td> TEXT <td> Add a zero terminator if needed ** <tr><td> BLOB <td> TEXT <td> [CAST] to TEXT, ensure zero terminator
** </table> ** </table>
** </blockquote>)^ ** </blockquote>)^
** **
@ -5590,7 +5594,8 @@ SQLITE_API unsigned int sqlite3_value_subtype(sqlite3_value*);
** object D and returns a pointer to that copy. ^The [sqlite3_value] returned ** object D and returns a pointer to that copy. ^The [sqlite3_value] returned
** is a [protected sqlite3_value] object even if the input is not. ** is a [protected sqlite3_value] object even if the input is not.
** ^The sqlite3_value_dup(V) interface returns NULL if V is NULL or if a ** ^The sqlite3_value_dup(V) interface returns NULL if V is NULL or if a
** memory allocation fails. ** memory allocation fails. ^If V is a [pointer value], then the result
** of sqlite3_value_dup(V) is a NULL value.
** **
** ^The sqlite3_value_free(V) interface frees an [sqlite3_value] object ** ^The sqlite3_value_free(V) interface frees an [sqlite3_value] object
** previously obtained from [sqlite3_value_dup()]. ^If V is a NULL pointer ** previously obtained from [sqlite3_value_dup()]. ^If V is a NULL pointer
@ -6272,6 +6277,28 @@ SQLITE_API int sqlite3_get_autocommit(sqlite3*);
*/ */
SQLITE_API sqlite3 *sqlite3_db_handle(sqlite3_stmt*); SQLITE_API sqlite3 *sqlite3_db_handle(sqlite3_stmt*);
/*
** CAPI3REF: Return The Schema Name For A Database Connection
** METHOD: sqlite3
**
** ^The sqlite3_db_name(D,N) interface returns a pointer to the schema name
** for the N-th database on database connection D, or a NULL pointer of N is
** out of range. An N value of 0 means the main database file. An N of 1 is
** the "temp" schema. Larger values of N correspond to various ATTACH-ed
** databases.
**
** Space to hold the string that is returned by sqlite3_db_name() is managed
** by SQLite itself. The string might be deallocated by any operation that
** changes the schema, including [ATTACH] or [DETACH] or calls to
** [sqlite3_serialize()] or [sqlite3_deserialize()], even operations that
** occur on a different thread. Applications that need to
** remember the string long-term should make their own copy. Applications that
** are accessing the same database connection simultaneously on multiple
** threads should mutex-protect calls to this API and should make their own
** private copy of the result prior to releasing the mutex.
*/
SQLITE_API const char *sqlite3_db_name(sqlite3 *db, int N);
/* /*
** CAPI3REF: Return The Filename For A Database Connection ** CAPI3REF: Return The Filename For A Database Connection
** METHOD: sqlite3 ** METHOD: sqlite3
@ -9551,8 +9578,8 @@ SQLITE_API SQLITE_EXPERIMENTAL const char *sqlite3_vtab_collation(sqlite3_index_
** of a [virtual table] implementation. The result of calling this ** of a [virtual table] implementation. The result of calling this
** interface from outside of xBestIndex() is undefined and probably harmful. ** interface from outside of xBestIndex() is undefined and probably harmful.
** **
** ^The sqlite3_vtab_distinct() interface returns an integer that is ** ^The sqlite3_vtab_distinct() interface returns an integer between 0 and
** either 0, 1, or 2. The integer returned by sqlite3_vtab_distinct() ** 3. The integer returned by sqlite3_vtab_distinct()
** gives the virtual table additional information about how the query ** gives the virtual table additional information about how the query
** planner wants the output to be ordered. As long as the virtual table ** planner wants the output to be ordered. As long as the virtual table
** can meet the ordering requirements of the query planner, it may set ** can meet the ordering requirements of the query planner, it may set
@ -9584,6 +9611,13 @@ SQLITE_API SQLITE_EXPERIMENTAL const char *sqlite3_vtab_collation(sqlite3_index_
** that have the same value for all columns identified by "aOrderBy". ** that have the same value for all columns identified by "aOrderBy".
** ^However omitting the extra rows is optional. ** ^However omitting the extra rows is optional.
** This mode is used for a DISTINCT query. ** This mode is used for a DISTINCT query.
** <li value="3"><p>
** ^(If the sqlite3_vtab_distinct() interface returns 3, that means
** that the query planner needs only distinct rows but it does need the
** rows to be sorted.)^ ^The virtual table implementation is free to omit
** rows that are identical in all aOrderBy columns, if it wants to, but
** it is not required to omit any rows. This mode is used for queries
** that have both DISTINCT and ORDER BY clauses.
** </ol> ** </ol>
** **
** ^For the purposes of comparing virtual table output values to see if the ** ^For the purposes of comparing virtual table output values to see if the
@ -9768,7 +9802,7 @@ SQLITE_API int sqlite3_vtab_in_next(sqlite3_value *pVal, sqlite3_value **ppOut);
** ^When xBestIndex returns, the sqlite3_value object returned by ** ^When xBestIndex returns, the sqlite3_value object returned by
** sqlite3_vtab_rhs_value() is automatically deallocated. ** sqlite3_vtab_rhs_value() is automatically deallocated.
** **
** The "_rhs_" in the name of this routine is an appreviation for ** The "_rhs_" in the name of this routine is an abbreviation for
** "Right-Hand Side". ** "Right-Hand Side".
*/ */
SQLITE_API int sqlite3_vtab_rhs_value(sqlite3_index_info*, int, sqlite3_value **ppVal); SQLITE_API int sqlite3_vtab_rhs_value(sqlite3_index_info*, int, sqlite3_value **ppVal);

View file

@ -1,13 +0,0 @@
// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build sqlite_json sqlite_json1 json1
package sqlite3
/*
#cgo CFLAGS: -DSQLITE_ENABLE_JSON1
*/
import "C"

View file

@ -33,7 +33,7 @@ import (
// The callback is passed a SQLitePreUpdateData struct with the data for // The callback is passed a SQLitePreUpdateData struct with the data for
// the update, as well as methods for fetching copies of impacted data. // the update, as well as methods for fetching copies of impacted data.
// //
// If there is an existing update hook for this connection, it will be // If there is an existing preupdate hook for this connection, it will be
// removed. If callback is nil the existing hook (if any) will be removed // removed. If callback is nil the existing hook (if any) will be removed
// without creating a new one. // without creating a new one.
func (c *SQLiteConn) RegisterPreUpdateHook(callback func(SQLitePreUpdateData)) { func (c *SQLiteConn) RegisterPreUpdateHook(callback func(SQLitePreUpdateData)) {

View file

@ -13,7 +13,7 @@ package sqlite3
// The callback is passed a SQLitePreUpdateData struct with the data for // The callback is passed a SQLitePreUpdateData struct with the data for
// the update, as well as methods for fetching copies of impacted data. // the update, as well as methods for fetching copies of impacted data.
// //
// If there is an existing update hook for this connection, it will be // If there is an existing preupdate hook for this connection, it will be
// removed. If callback is nil the existing hook (if any) will be removed // removed. If callback is nil the existing hook (if any) will be removed
// without creating a new one. // without creating a new one.
func (c *SQLiteConn) RegisterPreUpdateHook(callback func(SQLitePreUpdateData)) { func (c *SQLiteConn) RegisterPreUpdateHook(callback func(SQLitePreUpdateData)) {

View file

@ -356,6 +356,12 @@ struct sqlite3_api_routines {
int (*vtab_in)(sqlite3_index_info*,int,int); int (*vtab_in)(sqlite3_index_info*,int,int);
int (*vtab_in_first)(sqlite3_value*,sqlite3_value**); int (*vtab_in_first)(sqlite3_value*,sqlite3_value**);
int (*vtab_in_next)(sqlite3_value*,sqlite3_value**); int (*vtab_in_next)(sqlite3_value*,sqlite3_value**);
/* Version 3.39.0 and later */
int (*deserialize)(sqlite3*,const char*,unsigned char*,
sqlite3_int64,sqlite3_int64,unsigned);
unsigned char *(*serialize)(sqlite3*,const char *,sqlite3_int64*,
unsigned int);
const char *(*db_name)(sqlite3*,int);
}; };
/* /*
@ -674,6 +680,12 @@ typedef int (*sqlite3_loadext_entry)(
#define sqlite3_vtab_in sqlite3_api->vtab_in #define sqlite3_vtab_in sqlite3_api->vtab_in
#define sqlite3_vtab_in_first sqlite3_api->vtab_in_first #define sqlite3_vtab_in_first sqlite3_api->vtab_in_first
#define sqlite3_vtab_in_next sqlite3_api->vtab_in_next #define sqlite3_vtab_in_next sqlite3_api->vtab_in_next
/* Version 3.39.0 and later */
#ifndef SQLITE_OMIT_DESERIALIZE
#define sqlite3_deserialize sqlite3_api->deserialize
#define sqlite3_serialize sqlite3_api->serialize
#endif
#define sqlite3_db_name sqlite3_api->db_name
#endif /* !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION) */ #endif /* !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION) */
#if !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION) #if !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION)

3
vendor/golang.org/x/crypto/AUTHORS generated vendored
View file

@ -1,3 +0,0 @@
# This source code refers to The Go Authors for copyright purposes.
# The master list of authors is in the main Go distribution,
# visible at https://tip.golang.org/AUTHORS.

View file

@ -1,3 +0,0 @@
# This source code was written by the Go contributors.
# The master list of contributors is in the main Go distribution,
# visible at https://tip.golang.org/CONTRIBUTORS.

View file

@ -12,7 +12,7 @@ import (
"errors" "errors"
"math/bits" "math/bits"
"golang.org/x/crypto/internal/subtle" "golang.org/x/crypto/internal/alias"
) )
const ( const (
@ -189,7 +189,7 @@ func (s *Cipher) XORKeyStream(dst, src []byte) {
panic("chacha20: output smaller than input") panic("chacha20: output smaller than input")
} }
dst = dst[:len(src)] dst = dst[:len(src)]
if subtle.InexactOverlap(dst, src) { if alias.InexactOverlap(dst, src) {
panic("chacha20: invalid buffer overlap") panic("chacha20: invalid buffer overlap")
} }

View file

@ -15,6 +15,7 @@ const bufSize = 256
// xorKeyStreamVX is an assembly implementation of XORKeyStream. It must only // xorKeyStreamVX is an assembly implementation of XORKeyStream. It must only
// be called when the vector facility is available. Implementation in asm_s390x.s. // be called when the vector facility is available. Implementation in asm_s390x.s.
//
//go:noescape //go:noescape
func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32) func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32)

View file

@ -9,7 +9,8 @@ package curve25519 // import "golang.org/x/crypto/curve25519"
import ( import (
"crypto/subtle" "crypto/subtle"
"fmt" "errors"
"strconv"
"golang.org/x/crypto/curve25519/internal/field" "golang.org/x/crypto/curve25519/internal/field"
) )
@ -124,10 +125,10 @@ func X25519(scalar, point []byte) ([]byte, error) {
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) { func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
var in [32]byte var in [32]byte
if l := len(scalar); l != 32 { if l := len(scalar); l != 32 {
return nil, fmt.Errorf("bad scalar length: %d, expected %d", l, 32) return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
} }
if l := len(point); l != 32 { if l := len(point); l != 32 {
return nil, fmt.Errorf("bad point length: %d, expected %d", l, 32) return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
} }
copy(in[:], scalar) copy(in[:], scalar)
if &point[0] == &Basepoint[0] { if &point[0] == &Basepoint[0] {
@ -138,7 +139,7 @@ func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
copy(base[:], point) copy(base[:], point)
ScalarMult(dst, &in, &base) ScalarMult(dst, &in, &base)
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 { if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
return nil, fmt.Errorf("bad input point: low order point") return nil, errors.New("bad input point: low order point")
} }
} }
return dst[:], nil return dst[:], nil

View file

@ -1,13 +1,16 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT. // Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego // +build amd64,gc,!purego
package field package field
// feMul sets out = a * b. It works like feMulGeneric. // feMul sets out = a * b. It works like feMulGeneric.
//
//go:noescape //go:noescape
func feMul(out *Element, a *Element, b *Element) func feMul(out *Element, a *Element, b *Element)
// feSquare sets out = a * a. It works like feSquareGeneric. // feSquare sets out = a * a. It works like feSquareGeneric.
//
//go:noescape //go:noescape
func feSquare(out *Element, a *Element) func feSquare(out *Element, a *Element)

View file

@ -5,9 +5,8 @@
//go:build !purego //go:build !purego
// +build !purego // +build !purego
// Package subtle implements functions that are often useful in cryptographic // Package alias implements memory aliasing tests.
// code but require careful thought to use correctly. package alias
package subtle // import "golang.org/x/crypto/internal/subtle"
import "unsafe" import "unsafe"

View file

@ -5,9 +5,8 @@
//go:build purego //go:build purego
// +build purego // +build purego
// Package subtle implements functions that are often useful in cryptographic // Package alias implements memory aliasing tests.
// code but require careful thought to use correctly. package alias
package subtle // import "golang.org/x/crypto/internal/subtle"
// This is the Google App Engine standard variant based on reflect // This is the Google App Engine standard variant based on reflect
// because the unsafe package and cgo are disallowed. // because the unsafe package and cgo are disallowed.

View file

@ -136,7 +136,7 @@ func shiftRightBy2(a uint128) uint128 {
// updateGeneric absorbs msg into the state.h accumulator. For each chunk m of // updateGeneric absorbs msg into the state.h accumulator. For each chunk m of
// 128 bits of message, it computes // 128 bits of message, it computes
// //
// h₊ = (h + m) * r mod 2¹³⁰ - 5 // h₊ = (h + m) * r mod 2¹³⁰ - 5
// //
// If the msg length is not a multiple of TagSize, it assumes the last // If the msg length is not a multiple of TagSize, it assumes the last
// incomplete chunk is the final one. // incomplete chunk is the final one.
@ -278,8 +278,7 @@ const (
// finalize completes the modular reduction of h and computes // finalize completes the modular reduction of h and computes
// //
// out = h + s mod 2¹²⁸ // out = h + s mod 2¹²⁸
//
func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) { func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
h0, h1, h2 := h[0], h[1], h[2] h0, h1, h2 := h[0], h[1], h[2]

View file

@ -14,6 +14,7 @@ import (
// updateVX is an assembly implementation of Poly1305 that uses vector // updateVX is an assembly implementation of Poly1305 that uses vector
// instructions. It must only be called if the vector facility (vx) is // instructions. It must only be called if the vector facility (vx) is
// available. // available.
//
//go:noescape //go:noescape
func updateVX(state *macState, msg []byte) func updateVX(state *macState, msg []byte)

View file

@ -35,8 +35,8 @@ This package is interoperable with NaCl: https://nacl.cr.yp.to/secretbox.html.
package secretbox // import "golang.org/x/crypto/nacl/secretbox" package secretbox // import "golang.org/x/crypto/nacl/secretbox"
import ( import (
"golang.org/x/crypto/internal/alias"
"golang.org/x/crypto/internal/poly1305" "golang.org/x/crypto/internal/poly1305"
"golang.org/x/crypto/internal/subtle"
"golang.org/x/crypto/salsa20/salsa" "golang.org/x/crypto/salsa20/salsa"
) )
@ -88,7 +88,7 @@ func Seal(out, message []byte, nonce *[24]byte, key *[32]byte) []byte {
copy(poly1305Key[:], firstBlock[:]) copy(poly1305Key[:], firstBlock[:])
ret, out := sliceForAppend(out, len(message)+poly1305.TagSize) ret, out := sliceForAppend(out, len(message)+poly1305.TagSize)
if subtle.AnyOverlap(out, message) { if alias.AnyOverlap(out, message) {
panic("nacl: invalid buffer overlap") panic("nacl: invalid buffer overlap")
} }
@ -147,7 +147,7 @@ func Open(out, box []byte, nonce *[24]byte, key *[32]byte) ([]byte, bool) {
} }
ret, out := sliceForAppend(out, len(box)-Overhead) ret, out := sliceForAppend(out, len(box)-Overhead)
if subtle.AnyOverlap(out, box) { if alias.AnyOverlap(out, box) {
panic("nacl: invalid buffer overlap") panic("nacl: invalid buffer overlap")
} }

View file

@ -23,12 +23,14 @@ import (
// A Block represents an OpenPGP armored structure. // A Block represents an OpenPGP armored structure.
// //
// The encoded form is: // The encoded form is:
// -----BEGIN Type-----
// Headers
// //
// base64-encoded Bytes // -----BEGIN Type-----
// '=' base64 encoded checksum // Headers
// -----END Type----- //
// base64-encoded Bytes
// '=' base64 encoded checksum
// -----END Type-----
//
// where Headers is a possibly empty sequence of Key: Value lines. // where Headers is a possibly empty sequence of Key: Value lines.
// //
// Since the armored data can be very large, this package presents a streaming // Since the armored data can be very large, this package presents a streaming

View file

@ -96,7 +96,8 @@ func (l *lineBreaker) Close() (err error) {
// trailer. // trailer.
// //
// It's built into a stack of io.Writers: // It's built into a stack of io.Writers:
// encoding -> base64 encoder -> lineBreaker -> out //
// encoding -> base64 encoder -> lineBreaker -> out
type encoding struct { type encoding struct {
out io.Writer out io.Writer
breaker *lineBreaker breaker *lineBreaker

View file

@ -77,8 +77,8 @@ func Encrypt(random io.Reader, pub *PublicKey, msg []byte) (c1, c2 *big.Int, err
// returns the plaintext of the message. An error can result only if the // returns the plaintext of the message. An error can result only if the
// ciphertext is invalid. Users should keep in mind that this is a padding // ciphertext is invalid. Users should keep in mind that this is a padding
// oracle and thus, if exposed to an adaptive chosen ciphertext attack, can // oracle and thus, if exposed to an adaptive chosen ciphertext attack, can
// be used to break the cryptosystem. See ``Chosen Ciphertext Attacks // be used to break the cryptosystem. See Chosen Ciphertext Attacks
// Against Protocols Based on the RSA Encryption Standard PKCS #1'', Daniel // Against Protocols Based on the RSA Encryption Standard PKCS #1, Daniel
// Bleichenbacher, Advances in Cryptology (Crypto '98), // Bleichenbacher, Advances in Cryptology (Crypto '98),
func Decrypt(priv *PrivateKey, c1, c2 *big.Int) (msg []byte, err error) { func Decrypt(priv *PrivateKey, c1, c2 *big.Int) (msg []byte, err error) {
s := new(big.Int).Exp(c1, priv.X, priv.P) s := new(big.Int).Exp(c1, priv.X, priv.P)

View file

@ -7,7 +7,6 @@ package packet
import ( import (
"bytes" "bytes"
"io" "io"
"io/ioutil"
"golang.org/x/crypto/openpgp/errors" "golang.org/x/crypto/openpgp/errors"
) )
@ -26,7 +25,7 @@ type OpaquePacket struct {
} }
func (op *OpaquePacket) parse(r io.Reader) (err error) { func (op *OpaquePacket) parse(r io.Reader) (err error) {
op.Contents, err = ioutil.ReadAll(r) op.Contents, err = io.ReadAll(r)
return return
} }

View file

@ -13,7 +13,6 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"io" "io"
"io/ioutil"
"math/big" "math/big"
"strconv" "strconv"
"time" "time"
@ -133,7 +132,7 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
} }
} }
pk.encryptedData, err = ioutil.ReadAll(r) pk.encryptedData, err = io.ReadAll(r)
if err != nil { if err != nil {
return return
} }

View file

@ -236,7 +236,7 @@ func (w *seMDCWriter) Close() (err error) {
return w.w.Close() return w.w.Close()
} }
// noOpCloser is like an ioutil.NopCloser, but for an io.Writer. // noOpCloser is like an io.NopCloser, but for an io.Writer.
type noOpCloser struct { type noOpCloser struct {
w io.Writer w io.Writer
} }

View file

@ -9,7 +9,6 @@ import (
"image" "image"
"image/jpeg" "image/jpeg"
"io" "io"
"io/ioutil"
) )
const UserAttrImageSubpacket = 1 const UserAttrImageSubpacket = 1
@ -56,7 +55,7 @@ func NewUserAttribute(contents ...*OpaqueSubpacket) *UserAttribute {
func (uat *UserAttribute) parse(r io.Reader) (err error) { func (uat *UserAttribute) parse(r io.Reader) (err error) {
// RFC 4880, section 5.13 // RFC 4880, section 5.13
b, err := ioutil.ReadAll(r) b, err := io.ReadAll(r)
if err != nil { if err != nil {
return return
} }

View file

@ -6,7 +6,6 @@ package packet
import ( import (
"io" "io"
"io/ioutil"
"strings" "strings"
) )
@ -66,7 +65,7 @@ func NewUserId(name, comment, email string) *UserId {
func (uid *UserId) parse(r io.Reader) (err error) { func (uid *UserId) parse(r io.Reader) (err error) {
// RFC 4880, section 5.11 // RFC 4880, section 5.11
b, err := ioutil.ReadAll(r) b, err := io.ReadAll(r)
if err != nil { if err != nil {
return return
} }

View file

@ -402,7 +402,7 @@ func (s signatureWriter) Close() error {
return s.encryptedData.Close() return s.encryptedData.Close()
} }
// noOpCloser is like an ioutil.NopCloser, but for an io.Writer. // noOpCloser is like an io.NopCloser, but for an io.Writer.
// TODO: we have two of these in OpenPGP packages alone. This probably needs // TODO: we have two of these in OpenPGP packages alone. This probably needs
// to be promoted somewhere more common. // to be promoted somewhere more common.
type noOpCloser struct { type noOpCloser struct {

View file

@ -251,7 +251,7 @@ type algorithmOpenSSHCertSigner struct {
// private key is held by signer. It returns an error if the public key in cert // private key is held by signer. It returns an error if the public key in cert
// doesn't match the key used by signer. // doesn't match the key used by signer.
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
return nil, errors.New("ssh: signer and cert have different public key") return nil, errors.New("ssh: signer and cert have different public key")
} }
@ -460,6 +460,8 @@ func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
// certKeyAlgoNames is a mapping from known certificate algorithm names to the // certKeyAlgoNames is a mapping from known certificate algorithm names to the
// corresponding public key signature algorithm. // corresponding public key signature algorithm.
//
// This map must be kept in sync with the one in agent/client.go.
var certKeyAlgoNames = map[string]string{ var certKeyAlgoNames = map[string]string{
CertAlgoRSAv01: KeyAlgoRSA, CertAlgoRSAv01: KeyAlgoRSA,
CertAlgoRSASHA256v01: KeyAlgoRSASHA256, CertAlgoRSASHA256v01: KeyAlgoRSASHA256,

View file

@ -15,7 +15,6 @@ import (
"fmt" "fmt"
"hash" "hash"
"io" "io"
"io/ioutil"
"golang.org/x/crypto/chacha20" "golang.org/x/crypto/chacha20"
"golang.org/x/crypto/internal/poly1305" "golang.org/x/crypto/internal/poly1305"
@ -497,7 +496,7 @@ func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error)
// data, to make distinguishing between // data, to make distinguishing between
// failing MAC and failing length check more // failing MAC and failing length check more
// difficult. // difficult.
io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) io.CopyN(io.Discard, r, int64(c.oracleCamouflage))
} }
} }
return p, err return p, err
@ -640,7 +639,7 @@ const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com // chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
// AEAD, which is described here: // AEAD, which is described here:
// //
// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00 // https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
// //
// the methods here also implement padding, which RFC4253 Section 6 // the methods here also implement padding, which RFC4253 Section 6
// also requires of stream ciphers. // also requires of stream ciphers.

View file

@ -12,8 +12,9 @@ the multiplexed nature of SSH is exposed to users that wish to support
others. others.
References: References:
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
This package does not fall under the stability promise of the Go language itself, This package does not fall under the stability promise of the Go language itself,
so its API may be changed when pressing needs arise. so its API may be changed when pressing needs arise.

View file

@ -13,7 +13,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"sync" "sync"
) )
@ -124,7 +123,7 @@ type Session struct {
// output and error. // output and error.
// //
// If either is nil, Run connects the corresponding file // If either is nil, Run connects the corresponding file
// descriptor to an instance of ioutil.Discard. There is a // descriptor to an instance of io.Discard. There is a
// fixed amount of buffering that is shared for the two streams. // fixed amount of buffering that is shared for the two streams.
// If either blocks it may eventually cause the remote // If either blocks it may eventually cause the remote
// command to block. // command to block.
@ -506,7 +505,7 @@ func (s *Session) stdout() {
return return
} }
if s.Stdout == nil { if s.Stdout == nil {
s.Stdout = ioutil.Discard s.Stdout = io.Discard
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stdout, s.ch) _, err := io.Copy(s.Stdout, s.ch)
@ -519,7 +518,7 @@ func (s *Session) stderr() {
return return
} }
if s.Stderr == nil { if s.Stderr == nil {
s.Stderr = ioutil.Discard s.Stderr = io.Discard
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stderr, s.ch.Stderr()) _, err := io.Copy(s.Stderr, s.ch.Stderr())

21
vendor/gorm.io/datatypes/License generated vendored Normal file
View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

198
vendor/gorm.io/datatypes/README.md generated vendored Normal file
View file

@ -0,0 +1,198 @@
# GORM Data Types
## JSON
sqlite, mysql, postgres supported
```go
import "gorm.io/datatypes"
type UserWithJSON struct {
gorm.Model
Name string
Attributes datatypes.JSON
}
DB.Create(&User{
Name: "json-1",
Attributes: datatypes.JSON([]byte(`{"name": "jinzhu", "age": 18, "tags": ["tag1", "tag2"], "orgs": {"orga": "orga"}}`)),
}
// Check JSON has keys
datatypes.JSONQuery("attributes").HasKey(value, keys...)
db.Find(&user, datatypes.JSONQuery("attributes").HasKey("role"))
db.Find(&user, datatypes.JSONQuery("attributes").HasKey("orgs", "orga"))
// MySQL
// SELECT * FROM `users` WHERE JSON_EXTRACT(`attributes`, '$.role') IS NOT NULL
// SELECT * FROM `users` WHERE JSON_EXTRACT(`attributes`, '$.orgs.orga') IS NOT NULL
// PostgreSQL
// SELECT * FROM "user" WHERE "attributes"::jsonb ? 'role'
// SELECT * FROM "user" WHERE "attributes"::jsonb -> 'orgs' ? 'orga'
// Check JSON extract value from keys equal to value
datatypes.JSONQuery("attributes").Equals(value, keys...)
DB.First(&user, datatypes.JSONQuery("attributes").Equals("jinzhu", "name"))
DB.First(&user, datatypes.JSONQuery("attributes").Equals("orgb", "orgs", "orgb"))
// MySQL
// SELECT * FROM `user` WHERE JSON_EXTRACT(`attributes`, '$.name') = "jinzhu"
// SELECT * FROM `user` WHERE JSON_EXTRACT(`attributes`, '$.orgs.orgb') = "orgb"
// PostgreSQL
// SELECT * FROM "user" WHERE json_extract_path_text("attributes"::json,'name') = 'jinzhu'
// SELECT * FROM "user" WHERE json_extract_path_text("attributes"::json,'orgs','orgb') = 'orgb'
```
NOTE: SQlite need to build with `json1` tag, e.g: `go build --tags json1`, refer https://github.com/mattn/go-sqlite3#usage
## Date
```go
import "gorm.io/datatypes"
type UserWithDate struct {
gorm.Model
Name string
Date datatypes.Date
}
user := UserWithDate{Name: "jinzhu", Date: datatypes.Date(time.Now())}
DB.Create(&user)
// INSERT INTO `user_with_dates` (`name`,`date`) VALUES ("jinzhu","2020-07-17 00:00:00")
DB.First(&result, "name = ? AND date = ?", "jinzhu", datatypes.Date(curTime))
// SELECT * FROM user_with_dates WHERE name = "jinzhu" AND date = "2020-07-17 00:00:00" ORDER BY `user_with_dates`.`id` LIMIT 1
```
## Time
MySQL, PostgreSQL, SQLite, SQLServer are supported.
Time with nanoseconds is supported for some databases which support for time with fractional second scale.
```go
import "gorm.io/datatypes"
type UserWithTime struct {
gorm.Model
Name string
Time datatypes.Time
}
user := UserWithTime{Name: "jinzhu", Time: datatypes.NewTime(1, 2, 3, 0)}
DB.Create(&user)
// INSERT INTO `user_with_times` (`name`,`time`) VALUES ("jinzhu","01:02:03")
DB.First(&result, "name = ? AND time = ?", "jinzhu", datatypes.NewTime(1, 2, 3, 0))
// SELECT * FROM user_with_times WHERE name = "jinzhu" AND time = "01:02:03" ORDER BY `user_with_times`.`id` LIMIT 1
```
NOTE: If the current using database is SQLite, the field column type is defined as `TEXT` type
when GORM AutoMigrate because SQLite doesn't have time type.
## JSON_SET
sqlite, mysql supported
```go
import (
"gorm.io/datatypes"
"gorm.io/gorm"
)
type UserWithJSON struct {
gorm.Model
Name string
Attributes datatypes.JSON
}
DB.Create(&UserWithJSON{
Name: "json-1",
Attributes: datatypes.JSON([]byte(`{"name": "json-1", "age": 18, "tags": ["tag1", "tag2"], "orgs": {"orga": "orga"}}`)),
})
type User struct {
Name string
Age int
}
friend := User{
Name: "Bob",
Age: 21,
}
// Set fields of JSON column
datatypes.JSONSet("attributes").Set("age", 20).Set("tags[0]", "tag2").Set("orgs.orga", "orgb")
DB.Model(&UserWithJSON{}).Where("name = ?", "json-1").UpdateColumn("attributes", datatypes.JSONSet("attributes").Set("age", 20).Set("tags[0]", "tag3").Set("orgs.orga", "orgb"))
DB.Model(&UserWithJSON{}).Where("name = ?", "json-1").UpdateColumn("attributes", datatypes.JSONSet("attributes").Set("phones", []string{"10085", "10086"}))
DB.Model(&UserWithJSON{}).Where("name = ?", "json-1").UpdateColumn("attributes", datatypes.JSONSet("attributes").Set("phones", gorm.Expr("CAST(? AS JSON)", `["10085", "10086"]`)))
DB.Model(&UserWithJSON{}).Where("name = ?", "json-1").UpdateColumn("attributes", datatypes.JSONSet("attributes").Set("friend", friend))
// MySQL
// UPDATE `user_with_jsons` SET `attributes` = JSON_SET(`attributes`, '$.tags[0]', 'tag3', '$.orgs.orga', 'orgb', '$.age', 20) WHERE name = 'json-1'
// UPDATE `user_with_jsons` SET `attributes` = JSON_SET(`attributes`, '$.phones', CAST('["10085", "10086"]' AS JSON)) WHERE name = 'json-1'
// UPDATE `user_with_jsons` SET `attributes` = JSON_SET(`attributes`, '$.phones', CAST('["10085", "10086"]' AS JSON)) WHERE name = 'json-1'
// UPDATE `user_with_jsons` SET `attributes` = JSON_SET(`attributes`, '$.friend', CAST('{"Name": "Bob", "Age": 21}' AS JSON)) WHERE name = 'json-1'
```
NOTE: MariaDB does not support CAST(? AS JSON).
## JSONType[T]
sqlite, mysql, postgres supported
```go
import "gorm.io/datatypes"
type Attribute struct {
Sex int
Age int
Orgs map[string]string
Tags []string
Admin bool
Role string
}
type UserWithJSON struct {
gorm.Model
Name string
Attributes datatypes.JSONType[Attribute]
}
var user = UserWithJSON{
Name: "hello"
Attributes: datatypes.JSONType[Attribute]{
Data: Attribute{
Age: 18,
Sex: 1,
Orgs: map[string]string{"orga": "orga"},
Tags: []string{"tag1", "tag2", "tag3"},
},
},
}
// Create
DB.Create(&user)
// First
var result UserWithJSON
DB.First(&result, user.ID)
// Update
jsonMap = UserWithJSON{
Attributes: datatypes.JSONType[Attribute]{
Data: Attribute{
Age: 18,
Sex: 1,
Orgs: map[string]string{"orga": "orga"},
Tags: []string{"tag1", "tag2", "tag3"},
},
},
}
DB.Model(&user).Updates(jsonMap)
```
NOTE: it's not support json query

42
vendor/gorm.io/datatypes/date.go generated vendored Normal file
View file

@ -0,0 +1,42 @@
package datatypes
import (
"database/sql"
"database/sql/driver"
"time"
)
type Date time.Time
func (date *Date) Scan(value interface{}) (err error) {
nullTime := &sql.NullTime{}
err = nullTime.Scan(value)
*date = Date(nullTime.Time)
return
}
func (date Date) Value() (driver.Value, error) {
y, m, d := time.Time(date).Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.Time(date).Location()), nil
}
// GormDataType gorm common data type
func (date Date) GormDataType() string {
return "date"
}
func (date Date) GobEncode() ([]byte, error) {
return time.Time(date).GobEncode()
}
func (date *Date) GobDecode(b []byte) error {
return (*time.Time)(date).GobDecode(b)
}
func (date Date) MarshalJSON() ([]byte, error) {
return time.Time(date).MarshalJSON()
}
func (date *Date) UnmarshalJSON(b []byte) error {
return (*time.Time)(date).UnmarshalJSON(b)
}

374
vendor/gorm.io/datatypes/json.go generated vendored Normal file
View file

@ -0,0 +1,374 @@
package datatypes
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
// JSON defined JSON data type, need to implements driver.Valuer, sql.Scanner interface
type JSON json.RawMessage
// Value return json value, implement driver.Valuer interface
func (j JSON) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
}
return string(j), nil
}
// Scan scan value into Jsonb, implements sql.Scanner interface
func (j *JSON) Scan(value interface{}) error {
if value == nil {
*j = JSON("null")
return nil
}
var bytes []byte
switch v := value.(type) {
case []byte:
if len(v) > 0 {
bytes = make([]byte, len(v))
copy(bytes, v)
}
case string:
bytes = []byte(v)
default:
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
result := json.RawMessage(bytes)
*j = JSON(result)
return nil
}
// MarshalJSON to output non base64 encoded []byte
func (j JSON) MarshalJSON() ([]byte, error) {
return json.RawMessage(j).MarshalJSON()
}
// UnmarshalJSON to deserialize []byte
func (j *JSON) UnmarshalJSON(b []byte) error {
result := json.RawMessage{}
err := result.UnmarshalJSON(b)
*j = JSON(result)
return err
}
func (j JSON) String() string {
return string(j)
}
// GormDataType gorm common data type
func (JSON) GormDataType() string {
return "json"
}
// GormDBDataType gorm db data type
func (JSON) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
case "mysql":
return "JSON"
case "postgres":
return "JSONB"
}
return ""
}
func (js JSON) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
if len(js) == 0 {
return gorm.Expr("NULL")
}
data, _ := js.MarshalJSON()
switch db.Dialector.Name() {
case "mysql":
if v, ok := db.Dialector.(*mysql.Dialector); ok && !strings.Contains(v.ServerVersion, "MariaDB") {
return gorm.Expr("CAST(? AS JSON)", string(data))
}
}
return gorm.Expr("?", string(data))
}
// JSONQueryExpression json query expression, implements clause.Expression interface to use as querier
type JSONQueryExpression struct {
column string
keys []string
hasKeys bool
equals bool
equalsValue interface{}
extract bool
path string
}
// JSONQuery query column as json
func JSONQuery(column string) *JSONQueryExpression {
return &JSONQueryExpression{column: column}
}
// Extract extract json with path
func (jsonQuery *JSONQueryExpression) Extract(path string) *JSONQueryExpression {
jsonQuery.extract = true
jsonQuery.path = path
return jsonQuery
}
// HasKey returns clause.Expression
func (jsonQuery *JSONQueryExpression) HasKey(keys ...string) *JSONQueryExpression {
jsonQuery.keys = keys
jsonQuery.hasKeys = true
return jsonQuery
}
// Keys returns clause.Expression
func (jsonQuery *JSONQueryExpression) Equals(value interface{}, keys ...string) *JSONQueryExpression {
jsonQuery.keys = keys
jsonQuery.equals = true
jsonQuery.equalsValue = value
return jsonQuery
}
// Build implements clause.Expression
func (jsonQuery *JSONQueryExpression) Build(builder clause.Builder) {
if stmt, ok := builder.(*gorm.Statement); ok {
switch stmt.Dialector.Name() {
case "mysql", "sqlite":
switch {
case jsonQuery.extract:
builder.WriteString("JSON_EXTRACT(")
builder.WriteQuoted(jsonQuery.column)
builder.WriteByte(',')
builder.AddVar(stmt, jsonQuery.path)
builder.WriteString(")")
case jsonQuery.hasKeys:
if len(jsonQuery.keys) > 0 {
builder.WriteString("JSON_EXTRACT(")
builder.WriteQuoted(jsonQuery.column)
builder.WriteByte(',')
builder.AddVar(stmt, jsonQueryJoin(jsonQuery.keys))
builder.WriteString(") IS NOT NULL")
}
case jsonQuery.equals:
if len(jsonQuery.keys) > 0 {
builder.WriteString("JSON_EXTRACT(")
builder.WriteQuoted(jsonQuery.column)
builder.WriteByte(',')
builder.AddVar(stmt, jsonQueryJoin(jsonQuery.keys))
builder.WriteString(") = ")
if value, ok := jsonQuery.equalsValue.(bool); ok {
builder.WriteString(strconv.FormatBool(value))
} else {
stmt.AddVar(builder, jsonQuery.equalsValue)
}
}
}
case "postgres":
switch {
case jsonQuery.hasKeys:
if len(jsonQuery.keys) > 0 {
stmt.WriteQuoted(jsonQuery.column)
stmt.WriteString("::jsonb")
for _, key := range jsonQuery.keys[0 : len(jsonQuery.keys)-1] {
stmt.WriteString(" -> ")
stmt.AddVar(builder, key)
}
stmt.WriteString(" ? ")
stmt.AddVar(builder, jsonQuery.keys[len(jsonQuery.keys)-1])
}
case jsonQuery.equals:
if len(jsonQuery.keys) > 0 {
builder.WriteString(fmt.Sprintf("json_extract_path_text(%v::json,", stmt.Quote(jsonQuery.column)))
for idx, key := range jsonQuery.keys {
if idx > 0 {
builder.WriteByte(',')
}
stmt.AddVar(builder, key)
}
builder.WriteString(") = ")
if _, ok := jsonQuery.equalsValue.(string); ok {
stmt.AddVar(builder, jsonQuery.equalsValue)
} else {
stmt.AddVar(builder, fmt.Sprint(jsonQuery.equalsValue))
}
}
}
}
}
}
// JSONOverlapsExpression JSON_OVERLAPS expression, implements clause.Expression interface to use as querier
type JSONOverlapsExpression struct {
column clause.Expression
val string
}
// JSONOverlaps query column as json
func JSONOverlaps(column clause.Expression, value string) *JSONOverlapsExpression {
return &JSONOverlapsExpression{
column: column,
val: value,
}
}
// Build implements clause.Expression
// only mysql support JSON_OVERLAPS
func (json *JSONOverlapsExpression) Build(builder clause.Builder) {
if stmt, ok := builder.(*gorm.Statement); ok {
switch stmt.Dialector.Name() {
case "mysql":
builder.WriteString("JSON_OVERLAPS(")
json.column.Build(builder)
builder.WriteString(",")
builder.AddVar(stmt, json.val)
builder.WriteString(")")
}
}
}
type columnExpression string
func Column(col string) columnExpression {
return columnExpression(col)
}
func (col columnExpression) Build(builder clause.Builder) {
if stmt, ok := builder.(*gorm.Statement); ok {
switch stmt.Dialector.Name() {
case "mysql", "sqlite", "postgres":
builder.WriteString(stmt.Quote(string(col)))
}
}
}
const prefix = "$."
func jsonQueryJoin(keys []string) string {
if len(keys) == 1 {
return prefix + keys[0]
}
n := len(prefix)
n += len(keys) - 1
for i := 0; i < len(keys); i++ {
n += len(keys[i])
}
var b strings.Builder
b.Grow(n)
b.WriteString(prefix)
b.WriteString(keys[0])
for _, key := range keys[1:] {
b.WriteString(".")
b.WriteString(key)
}
return b.String()
}
// JSONSetExpression json set expression, implements clause.Expression interface to use as updater
type JSONSetExpression struct {
column string
path2value map[string]interface{}
mutex sync.RWMutex
}
// JSONSet update fields of json column
func JSONSet(column string) *JSONSetExpression {
return &JSONSetExpression{column: column, path2value: make(map[string]interface{})}
}
// Set return clause.Expression
func (jsonSet *JSONSetExpression) Set(path string, value interface{}) *JSONSetExpression {
jsonSet.mutex.Lock()
jsonSet.path2value[path] = value
jsonSet.mutex.Unlock()
return jsonSet
}
// Build implements clause.Expression
// only support mysql and sqlite
func (jsonSet *JSONSetExpression) Build(builder clause.Builder) {
if stmt, ok := builder.(*gorm.Statement); ok {
switch stmt.Dialector.Name() {
case "mysql":
var isMariaDB bool
if v, ok := stmt.Dialector.(*mysql.Dialector); ok {
isMariaDB = strings.Contains(v.ServerVersion, "MariaDB")
}
builder.WriteString("JSON_SET(")
builder.WriteQuoted(jsonSet.column)
for path, value := range jsonSet.path2value {
builder.WriteByte(',')
builder.AddVar(stmt, prefix+path)
builder.WriteByte(',')
if _, ok := value.(clause.Expression); ok {
stmt.AddVar(builder, value)
continue
}
rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
switch rv.Kind() {
case reflect.Slice, reflect.Array, reflect.Struct, reflect.Map:
b, _ := json.Marshal(value)
if isMariaDB {
stmt.AddVar(builder, string(b))
break
}
stmt.AddVar(builder, gorm.Expr("CAST(? AS JSON)", string(b)))
default:
stmt.AddVar(builder, value)
}
}
builder.WriteString(")")
case "sqlite":
builder.WriteString("JSON_SET(")
builder.WriteQuoted(jsonSet.column)
for path, value := range jsonSet.path2value {
builder.WriteByte(',')
builder.AddVar(stmt, prefix+path)
builder.WriteByte(',')
if _, ok := value.(clause.Expression); ok {
stmt.AddVar(builder, value)
continue
}
rv := reflect.ValueOf(value)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
switch rv.Kind() {
case reflect.Slice, reflect.Array, reflect.Struct, reflect.Map:
b, _ := json.Marshal(value)
stmt.AddVar(builder, gorm.Expr("JSON(?)", string(b)))
default:
stmt.AddVar(builder, value)
}
}
builder.WriteString(")")
}
}
}

96
vendor/gorm.io/datatypes/json_map.go generated vendored Normal file
View file

@ -0,0 +1,96 @@
package datatypes
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
// JSONMap defined JSON data type, need to implements driver.Valuer, sql.Scanner interface
type JSONMap map[string]interface{}
// Value return json value, implement driver.Valuer interface
func (m JSONMap) Value() (driver.Value, error) {
if m == nil {
return nil, nil
}
ba, err := m.MarshalJSON()
return string(ba), err
}
// Scan scan value into Jsonb, implements sql.Scanner interface
func (m *JSONMap) Scan(val interface{}) error {
if val == nil {
*m = make(JSONMap)
return nil
}
var ba []byte
switch v := val.(type) {
case []byte:
ba = v
case string:
ba = []byte(v)
default:
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", val))
}
t := map[string]interface{}{}
err := json.Unmarshal(ba, &t)
*m = t
return err
}
// MarshalJSON to output non base64 encoded []byte
func (m JSONMap) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
t := (map[string]interface{})(m)
return json.Marshal(t)
}
// UnmarshalJSON to deserialize []byte
func (m *JSONMap) UnmarshalJSON(b []byte) error {
t := map[string]interface{}{}
err := json.Unmarshal(b, &t)
*m = JSONMap(t)
return err
}
// GormDataType gorm common data type
func (m JSONMap) GormDataType() string {
return "jsonmap"
}
// GormDBDataType gorm db data type
func (JSONMap) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
case "mysql":
return "JSON"
case "postgres":
return "JSONB"
case "sqlserver":
return "NVARCHAR(MAX)"
}
return ""
}
func (jm JSONMap) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
data, _ := jm.MarshalJSON()
switch db.Dialector.Name() {
case "mysql":
if v, ok := db.Dialector.(*mysql.Dialector); ok && !strings.Contains(v.ServerVersion, "MariaDB") {
return gorm.Expr("CAST(? AS JSON)", string(data))
}
}
return gorm.Expr("?", string(data))
}

80
vendor/gorm.io/datatypes/json_type.go generated vendored Normal file
View file

@ -0,0 +1,80 @@
package datatypes
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
// JSONType give a generic data type for json encoded data.
type JSONType[T any] struct {
Data T
}
// Value return json value, implement driver.Valuer interface
func (j JSONType[T]) Value() (driver.Value, error) {
return json.Marshal(j.Data)
}
// Scan scan value into JSONType[T], implements sql.Scanner interface
func (j *JSONType[T]) Scan(value interface{}) error {
var bytes []byte
switch v := value.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
return json.Unmarshal(bytes, &j.Data)
}
// MarshalJSON to output non base64 encoded []byte
func (j JSONType[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(j.Data)
}
// UnmarshalJSON to deserialize []byte
func (j *JSONType[T]) UnmarshalJSON(b []byte) error {
return json.Unmarshal(b, &j.Data)
}
// GormDataType gorm common data type
func (JSONType[T]) GormDataType() string {
return "json"
}
// GormDBDataType gorm db data type
func (JSONType[T]) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
case "mysql":
return "JSON"
case "postgres":
return "JSONB"
}
return ""
}
func (js JSONType[T]) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
data, _ := js.MarshalJSON()
switch db.Dialector.Name() {
case "mysql":
if v, ok := db.Dialector.(*mysql.Dialector); ok && !strings.Contains(v.ServerVersion, "MariaDB") {
return gorm.Expr("CAST(? AS JSON)", string(data))
}
}
return gorm.Expr("?", string(data))
}

10
vendor/gorm.io/datatypes/test_all.sh generated vendored Normal file
View file

@ -0,0 +1,10 @@
#!/bin/bash -e
dialects=("postgres" "postgres_simple" "mysql" "mssql" "sqlite")
for dialect in "${dialects[@]}" ; do
if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ]
then
GORM_DIALECT=${dialect} go test --tags "json1"
fi
done

123
vendor/gorm.io/datatypes/time.go generated vendored Normal file
View file

@ -0,0 +1,123 @@
package datatypes
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
// Time is time data type.
type Time time.Duration
// NewTime is a constructor for Time and returns new Time.
func NewTime(hour, min, sec, nsec int) Time {
return newTime(hour, min, sec, nsec)
}
func newTime(hour, min, sec, nsec int) Time {
return Time(
time.Duration(hour)*time.Hour +
time.Duration(min)*time.Minute +
time.Duration(sec)*time.Second +
time.Duration(nsec)*time.Nanosecond,
)
}
// GormDataType returns gorm common data type. This type is used for the field's column type.
func (Time) GormDataType() string {
return "time"
}
// GormDBDataType returns gorm DB data type based on the current using database.
func (Time) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "mysql":
return "TIME"
case "postgres":
return "TIME"
case "sqlserver":
return "TIME"
case "sqlite":
return "TEXT"
default:
return ""
}
}
// Scan implements sql.Scanner interface and scans value into Time,
func (t *Time) Scan(src interface{}) error {
switch v := src.(type) {
case []byte:
t.setFromString(string(v))
case string:
t.setFromString(v)
case time.Time:
t.setFromTime(v)
default:
return errors.New(fmt.Sprintf("failed to scan value: %v", v))
}
return nil
}
func (t *Time) setFromString(str string) {
var h, m, s, n int
fmt.Sscanf(str, "%02d:%02d:%02d.%09d", &h, &m, &s, &n)
*t = newTime(h, m, s, n)
}
func (t *Time) setFromTime(src time.Time) {
*t = newTime(src.Hour(), src.Minute(), src.Second(), src.Nanosecond())
}
// Value implements driver.Valuer interface and returns string format of Time.
func (t Time) Value() (driver.Value, error) {
return t.String(), nil
}
// String implements fmt.Stringer interface.
func (t Time) String() string {
if nsec := t.nanoseconds(); nsec > 0 {
return fmt.Sprintf("%02d:%02d:%02d.%09d", t.hours(), t.minutes(), t.seconds(), nsec)
} else {
// omit nanoseconds unless any value is specified
return fmt.Sprintf("%02d:%02d:%02d", t.hours(), t.minutes(), t.seconds())
}
}
func (t Time) hours() int {
return int(time.Duration(t).Truncate(time.Hour).Hours())
}
func (t Time) minutes() int {
return int((time.Duration(t) % time.Hour).Truncate(time.Minute).Minutes())
}
func (t Time) seconds() int {
return int((time.Duration(t) % time.Minute).Truncate(time.Second).Seconds())
}
func (t Time) nanoseconds() int {
return int((time.Duration(t) % time.Second).Nanoseconds())
}
// MarshalJSON implements json.Marshaler to convert Time to json serialization.
func (t Time) MarshalJSON() ([]byte, error) {
return json.Marshal(t.String())
}
// UnmarshalJSON implements json.Unmarshaler to deserialize json data.
func (t *Time) UnmarshalJSON(data []byte) error {
// ignore null
if string(data) == "null" {
return nil
}
t.setFromString(strings.Trim(string(data), `"`))
return nil
}

66
vendor/gorm.io/datatypes/url.go generated vendored Normal file
View file

@ -0,0 +1,66 @@
package datatypes
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
type URL url.URL
func (u URL) Value() (driver.Value, error) {
return u.String(), nil
}
func (u *URL) Scan(value interface{}) error {
var us string
switch v := value.(type) {
case []byte:
us = string(v)
case string:
us = v
default:
return errors.New(fmt.Sprint("Failed to parse URL:", value))
}
uu, err := url.Parse(us)
if err != nil {
return err
}
*u = URL(*uu)
return nil
}
func (URL) GormDataType() string {
return "url"
}
func (URL) GormDBDataType(db *gorm.DB, field *schema.Field) string {
return "TEXT"
}
func (u *URL) String() string {
return (*url.URL)(u).String()
}
func (u URL) MarshalJSON() ([]byte, error) {
return json.Marshal(u.String())
}
func (u *URL) UnmarshalJSON(data []byte) error {
// ignore null
if string(data) == "null" {
return nil
}
uu, err := url.Parse(strings.Trim(string(data), `"'`))
if err != nil {
return err
}
*u = URL(*uu)
return nil
}

6
vendor/gorm.io/driver/mysql/.gitignore generated vendored Normal file
View file

@ -0,0 +1,6 @@
TODO*
documents
coverage.txt
_book
.idea
vendor

View file

@ -11,6 +11,26 @@ import (
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
const indexSql = `
SELECT
TABLE_NAME,
COLUMN_NAME,
INDEX_NAME,
NON_UNIQUE
FROM
information_schema.STATISTICS
WHERE
TABLE_SCHEMA = ?
AND TABLE_NAME = ?
ORDER BY
INDEX_NAME,
SEQ_IN_INDEX`
var typeAliasMap = map[string][]string{
"bool": {"tinyint"},
"tinyint": {"bool"},
}
type Migrator struct { type Migrator struct {
migrator.Migrator migrator.Migrator
Dialector Dialector
@ -143,9 +163,9 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0) columnTypes := make([]gorm.ColumnType, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error { err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
var ( var (
currentDatabase = m.DB.Migrator().CurrentDatabase() currentDatabase, table = m.CurrentSchema(stmt, stmt.Table)
columnTypeSQL = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale " columnTypeSQL = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale "
rows, err = m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() rows, err = m.DB.Session(&gorm.Session{}).Table(table).Limit(1).Rows()
) )
if err != nil { if err != nil {
@ -163,7 +183,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
} }
columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ORDINAL_POSITION" columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ORDINAL_POSITION"
columns, rowErr := m.DB.Raw(columnTypeSQL, currentDatabase, stmt.Table).Rows() columns, rowErr := m.DB.Table(table).Raw(columnTypeSQL, currentDatabase, table).Rows()
if rowErr != nil { if rowErr != nil {
return rowErr return rowErr
} }
@ -203,6 +223,13 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
} }
column.DefaultValueValue.String = strings.Trim(column.DefaultValueValue.String, "'") column.DefaultValueValue.String = strings.Trim(column.DefaultValueValue.String, "'")
if m.Dialector.DontSupportNullAsDefaultValue {
// rewrite mariadb default value like other version
if column.DefaultValueValue.Valid && column.DefaultValueValue.String == "NULL" {
column.DefaultValueValue.Valid = false
column.DefaultValueValue.String = ""
}
}
if datetimePrecision.Valid { if datetimePrecision.Valid {
column.DecimalSizeValue = datetimePrecision column.DecimalSizeValue = datetimePrecision
@ -227,7 +254,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
func (m Migrator) CurrentDatabase() (name string) { func (m Migrator) CurrentDatabase() (name string) {
baseName := m.Migrator.CurrentDatabase() baseName := m.Migrator.CurrentDatabase()
m.DB.Raw( m.DB.Raw(
"SELECT SCHEMA_NAME from Information_schema.SCHEMATA where SCHEMA_NAME LIKE ? ORDER BY SCHEMA_NAME=? DESC limit 1", "SELECT SCHEMA_NAME from Information_schema.SCHEMATA where SCHEMA_NAME LIKE ? ORDER BY SCHEMA_NAME=? DESC,SCHEMA_NAME limit 1",
baseName+"%", baseName).Scan(&name) baseName+"%", baseName).Scan(&name)
return return
} }
@ -237,3 +264,66 @@ func (m Migrator) GetTables() (tableList []string, err error) {
Scan(&tableList).Error Scan(&tableList).Error
return return
} }
func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
indexes := make([]gorm.Index, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
result := make([]*Index, 0)
schema, table := m.CurrentSchema(stmt, stmt.Table)
scanErr := m.DB.Table(table).Raw(indexSql, schema, table).Scan(&result).Error
if scanErr != nil {
return scanErr
}
indexMap := groupByIndexName(result)
for _, idx := range indexMap {
tempIdx := &migrator.Index{
TableName: idx[0].TableName,
NameValue: idx[0].IndexName,
PrimaryKeyValue: sql.NullBool{
Bool: idx[0].IndexName == "PRIMARY",
Valid: true,
},
UniqueValue: sql.NullBool{
Bool: idx[0].NonUnique == 0,
Valid: true,
},
}
for _, x := range idx {
tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName)
}
indexes = append(indexes, tempIdx)
}
return nil
})
return indexes, err
}
// Index table index info
type Index struct {
TableName string `gorm:"column:TABLE_NAME"`
ColumnName string `gorm:"column:COLUMN_NAME"`
IndexName string `gorm:"column:INDEX_NAME"`
NonUnique int32 `gorm:"column:NON_UNIQUE"`
}
func groupByIndexName(indexList []*Index) map[string][]*Index {
columnIndexMap := make(map[string][]*Index, len(indexList))
for _, idx := range indexList {
columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx)
}
return columnIndexMap
}
func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (string, string) {
if tables := strings.Split(table, `.`); len(tables) == 2 {
return tables[0], tables[1]
}
m.DB = m.DB.Table(table)
return m.CurrentDatabase(), table
}
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return typeAliasMap[databaseTypeName]
}

177
vendor/gorm.io/driver/mysql/mysql.go generated vendored
View file

@ -5,30 +5,37 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"math" "math"
"regexp"
"strconv"
"strings" "strings"
"time" "time"
_ "github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/callbacks" "gorm.io/gorm/callbacks"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"gorm.io/gorm/migrator" "gorm.io/gorm/migrator"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
type Config struct { type Config struct {
DriverName string DriverName string
ServerVersion string ServerVersion string
DSN string DSN string
Conn gorm.ConnPool DSNConfig *mysql.Config
SkipInitializeWithVersion bool Conn gorm.ConnPool
DefaultStringSize uint SkipInitializeWithVersion bool
DefaultDatetimePrecision *int DefaultStringSize uint
DisableDatetimePrecision bool DefaultDatetimePrecision *int
DontSupportRenameIndex bool DisableWithReturning bool
DontSupportRenameColumn bool DisableDatetimePrecision bool
DontSupportForShareClause bool DontSupportRenameIndex bool
DontSupportRenameColumn bool
DontSupportForShareClause bool
DontSupportNullAsDefaultValue bool
} }
type Dialector struct { type Dialector struct {
@ -49,7 +56,8 @@ var (
) )
func Open(dsn string) gorm.Dialector { func Open(dsn string) gorm.Dialector {
return &Dialector{Config: &Config{DSN: dsn}} dsnConf, _ := mysql.ParseDSN(dsn)
return &Dialector{Config: &Config{DSN: dsn, DSNConfig: dsnConf}}
} }
func New(config Config) gorm.Dialector { func New(config Config) gorm.Dialector {
@ -69,30 +77,20 @@ func (dialector Dialector) NowFunc(n int) func() time.Time {
} }
func (dialector Dialector) Apply(config *gorm.Config) error { func (dialector Dialector) Apply(config *gorm.Config) error {
if config.NowFunc == nil { if config.NowFunc != nil {
if dialector.DefaultDatetimePrecision == nil { return nil
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
}
// while maintaining the readability of the code, separate the business logic from
// the general part and leave it to the function to do it here.
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
} }
if dialector.DefaultDatetimePrecision == nil {
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
}
// while maintaining the readability of the code, separate the business logic from
// the general part and leave it to the function to do it here.
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
return nil return nil
} }
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
ctx := context.Background()
// register callbacks
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: CreateClauses,
QueryClauses: QueryClauses,
UpdateClauses: UpdateClauses,
DeleteClauses: DeleteClauses,
})
if dialector.DriverName == "" { if dialector.DriverName == "" {
dialector.DriverName = "mysql" dialector.DriverName = "mysql"
} }
@ -110,8 +108,9 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
} }
} }
withReturning := false
if !dialector.Config.SkipInitializeWithVersion { if !dialector.Config.SkipInitializeWithVersion {
err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&dialector.ServerVersion) err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion)
if err != nil { if err != nil {
return err return err
} }
@ -120,6 +119,8 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
dialector.Config.DontSupportRenameIndex = true dialector.Config.DontSupportRenameIndex = true
dialector.Config.DontSupportRenameColumn = true dialector.Config.DontSupportRenameColumn = true
dialector.Config.DontSupportForShareClause = true dialector.Config.DontSupportForShareClause = true
dialector.Config.DontSupportNullAsDefaultValue = true
withReturning = checkVersion(dialector.ServerVersion, "10.5")
} else if strings.HasPrefix(dialector.ServerVersion, "5.6.") { } else if strings.HasPrefix(dialector.ServerVersion, "5.6.") {
dialector.Config.DontSupportRenameIndex = true dialector.Config.DontSupportRenameIndex = true
dialector.Config.DontSupportRenameColumn = true dialector.Config.DontSupportRenameColumn = true
@ -135,6 +136,32 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
} }
} }
// register callbacks
callbackConfig := &callbacks.Config{
CreateClauses: CreateClauses,
QueryClauses: QueryClauses,
UpdateClauses: UpdateClauses,
DeleteClauses: DeleteClauses,
}
if !dialector.Config.DisableWithReturning && withReturning {
callbackConfig.LastInsertIDReversed = true
if !utils.Contains(callbackConfig.CreateClauses, "RETURNING") {
callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
}
if !utils.Contains(callbackConfig.UpdateClauses, "RETURNING") {
callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
}
if !utils.Contains(callbackConfig.DeleteClauses, "RETURNING") {
callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
}
}
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
for k, v := range dialector.ClauseBuilders() { for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v db.ClauseBuilders[k] = v
} }
@ -146,7 +173,7 @@ const (
ClauseOnConflict = "ON CONFLICT" ClauseOnConflict = "ON CONFLICT"
// ClauseValues for clause.ClauseBuilder VALUES key // ClauseValues for clause.ClauseBuilder VALUES key
ClauseValues = "VALUES" ClauseValues = "VALUES"
// ClauseValues for clause.ClauseBuilder FOR key // ClauseFor for clause.ClauseBuilder FOR key
ClauseFor = "FOR" ClauseFor = "FOR"
) )
@ -174,6 +201,8 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
if column.Name != "" { if column.Name != "" {
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
} }
builder.(*gorm.Statement).AddClause(onConflict)
} }
} }
@ -284,7 +313,23 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`') writer.WriteByte('`')
} }
type localTimeInterface interface {
In(loc *time.Location) time.Time
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string { func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
if dialector.DSNConfig != nil && dialector.DSNConfig.Loc == time.Local {
for i, v := range vars {
if p, ok := v.(localTimeInterface); ok {
func(i int, t localTimeInterface) {
defer func() {
recover()
}()
vars[i] = t.In(time.Local)
}(i, p)
}
}
}
return logger.ExplainSQL(sql, nil, `'`, vars...) return logger.ExplainSQL(sql, nil, `'`, vars...)
} }
@ -302,9 +347,9 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
return dialector.getSchemaTimeType(field) return dialector.getSchemaTimeType(field)
case schema.Bytes: case schema.Bytes:
return dialector.getSchemaBytesType(field) return dialector.getSchemaBytesType(field)
default:
return dialector.getSchemaCustomType(field)
} }
return string(field.DataType)
} }
func (dialector Dialector) getSchemaFloatType(field *schema.Field) string { func (dialector Dialector) getSchemaFloatType(field *schema.Field) string {
@ -345,11 +390,11 @@ func (dialector Dialector) getSchemaStringType(field *schema.Field) string {
} }
func (dialector Dialector) getSchemaTimeType(field *schema.Field) string { func (dialector Dialector) getSchemaTimeType(field *schema.Field) string {
precision := ""
if !dialector.DisableDatetimePrecision && field.Precision == 0 { if !dialector.DisableDatetimePrecision && field.Precision == 0 {
field.Precision = *dialector.DefaultDatetimePrecision field.Precision = *dialector.DefaultDatetimePrecision
} }
var precision string
if field.Precision > 0 { if field.Precision > 0 {
precision = fmt.Sprintf("(%d)", field.Precision) precision = fmt.Sprintf("(%d)", field.Precision)
} }
@ -373,23 +418,37 @@ func (dialector Dialector) getSchemaBytesType(field *schema.Field) string {
} }
func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string { func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string {
sqlType := "bigint" constraint := func(sqlType string) string {
if field.DataType == schema.Uint {
sqlType += " unsigned"
}
if field.NotNull {
sqlType += " NOT NULL"
}
if field.AutoIncrement {
sqlType += " AUTO_INCREMENT"
}
return sqlType
}
switch { switch {
case field.Size <= 8: case field.Size <= 8:
sqlType = "tinyint" return constraint("tinyint")
case field.Size <= 16: case field.Size <= 16:
sqlType = "smallint" return constraint("smallint")
case field.Size <= 24: case field.Size <= 24:
sqlType = "mediumint" return constraint("mediumint")
case field.Size <= 32: case field.Size <= 32:
sqlType = "int" return constraint("int")
default:
return constraint("bigint")
} }
}
if field.DataType == schema.Uint { func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
sqlType += " unsigned" sqlType := string(field.DataType)
}
if field.AutoIncrement { if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), " auto_increment") {
sqlType += " AUTO_INCREMENT" sqlType += " AUTO_INCREMENT"
} }
@ -403,3 +462,31 @@ func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error { func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error
} }
// checkVersion newer or equal returns true, old returns false
func checkVersion(newVersion, oldVersion string) bool {
if newVersion == oldVersion {
return true
}
var (
versionTrimmerRegexp = regexp.MustCompile(`^(\d+).*$`)
newVersions = strings.Split(newVersion, ".")
oldVersions = strings.Split(oldVersion, ".")
)
for idx, nv := range newVersions {
if len(oldVersions) <= idx {
return true
}
nvi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(nv, "$1"))
ovi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(oldVersions[idx], "$1"))
if nvi == ovi {
continue
}
return nvi > ovi
}
return false
}

View file

@ -15,3 +15,16 @@ db, err := gorm.Open(sqlite.Open("gorm.db"), &gorm.Config{})
``` ```
Checkout [https://gorm.io](https://gorm.io) for details. Checkout [https://gorm.io](https://gorm.io) for details.
### Pure go Sqlite Driver
checkout [https://github.com/glebarez/sqlite](https://github.com/glebarez/sqlite) for details
```go
import (
"github.com/glebarez/sqlite"
"gorm.io/gorm"
)
db, err := gorm.Open(sqlite.Open("gorm.db"), &gorm.Config{})
```

View file

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
"strconv"
"strings" "strings"
"gorm.io/gorm/migrator" "gorm.io/gorm/migrator"
@ -12,12 +13,13 @@ import (
var ( var (
sqliteSeparator = "`|\"|'|\t" sqliteSeparator = "`|\"|'|\t"
indexRegexp = regexp.MustCompile(fmt.Sprintf("CREATE(?: UNIQUE)? INDEX [%v]?[\\w\\d]+[%v]? ON (.*)$", sqliteSeparator, sqliteSeparator)) indexRegexp = regexp.MustCompile(fmt.Sprintf("(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\\w\\d-]+[%v]? ON (.*)$", sqliteSeparator, sqliteSeparator))
tableRegexp = regexp.MustCompile(fmt.Sprintf("(?i)(CREATE TABLE [%v]?[\\w\\d]+[%v]?)(?: \\((.*)\\))?", sqliteSeparator, sqliteSeparator)) tableRegexp = regexp.MustCompile(fmt.Sprintf("(?is)(CREATE TABLE [%v]?[\\w\\d-]+[%v]?)(?: \\((.*)\\))?", sqliteSeparator, sqliteSeparator))
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator)) separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
columnsRegexp = regexp.MustCompile(fmt.Sprintf("\\([%v]?([\\w\\d]+)[%v]?(?:,[%v]?([\\w\\d]+)[%v]){0,}\\)", sqliteSeparator, sqliteSeparator, sqliteSeparator, sqliteSeparator)) columnsRegexp = regexp.MustCompile(fmt.Sprintf("\\([%v]?([\\w\\d]+)[%v]?(?:,[%v]?([\\w\\d]+)[%v]){0,}\\)", sqliteSeparator, sqliteSeparator, sqliteSeparator, sqliteSeparator))
columnRegexp = regexp.MustCompile(fmt.Sprintf("^[%v]?([\\w\\d]+)[%v]?\\s+([\\w\\(\\)\\d]+)(.*)$", sqliteSeparator, sqliteSeparator)) columnRegexp = regexp.MustCompile(fmt.Sprintf("^[%v]?([\\w\\d]+)[%v]?\\s+([\\w\\(\\)\\d]+)(.*)$", sqliteSeparator, sqliteSeparator))
defaultValueRegexp = regexp.MustCompile("(?i) DEFAULT \\(?(.+)?\\)?( |COLLATE|GENERATED|$)") defaultValueRegexp = regexp.MustCompile("(?i) DEFAULT \\(?(.+)?\\)?( |COLLATE|GENERATED|$)")
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
) )
type ddl struct { type ddl struct {
@ -37,16 +39,17 @@ func parseDDL(strs ...string) (*ddl, error) {
quote rune quote rune
buf string buf string
) )
ddlBodyRunesLen := len(ddlBodyRunes)
result.head = sections[1] result.head = sections[1]
for idx := 0; idx < len(ddlBodyRunes); idx++ { for idx := 0; idx < ddlBodyRunesLen; idx++ {
var ( var (
next rune = 0 next rune = 0
c = ddlBodyRunes[idx] c = ddlBodyRunes[idx]
) )
if idx+1 < len(ddlBody) { if idx+1 < ddlBodyRunesLen {
next = []rune(ddlBody)[idx+1] next = ddlBodyRunes[idx+1]
} }
if sc := string(c); separatorRegexp.MatchString(sc) { if sc := string(c); separatorRegexp.MatchString(sc) {
@ -115,7 +118,7 @@ func parseDDL(strs ...string) (*ddl, error) {
PrimaryKeyValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Valid: true}, NullableValue: sql.NullBool{Valid: true},
DefaultValueValue: sql.NullString{Valid: true}, DefaultValueValue: sql.NullString{Valid: false},
} }
matchUpper := strings.ToUpper(matches[3]) matchUpper := strings.ToUpper(matches[3])
@ -134,6 +137,14 @@ func parseDDL(strs ...string) (*ddl, error) {
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true} columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
} }
// data type length
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
if len(matches) == 1 && len(matches[0]) == 2 {
size, _ := strconv.Atoi(matches[0][1])
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
}
result.columns = append(result.columns, columnType) result.columns = append(result.columns, columnType)
} }
} }
@ -205,7 +216,8 @@ func (d *ddl) getColumns() []string {
fUpper := strings.ToUpper(f) fUpper := strings.ToUpper(f)
if strings.HasPrefix(fUpper, "PRIMARY KEY") || if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
strings.HasPrefix(fUpper, "CHECK") || strings.HasPrefix(fUpper, "CHECK") ||
strings.HasPrefix(fUpper, "CONSTRAINT") { strings.HasPrefix(fUpper, "CONSTRAINT") ||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
continue continue
} }

View file

@ -80,12 +80,17 @@ func (m Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error { return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
if field := stmt.Schema.LookUpField(name); field != nil { if field := stmt.Schema.LookUpField(name); field != nil {
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?,") // lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?,", field.DBName)) createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
if createSQL == rawDDL {
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
}
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
} }

View file

@ -97,12 +97,13 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
}, },
"LIMIT": func(c clause.Clause, builder clause.Builder) { "LIMIT": func(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok { if limit, ok := c.Expression.(clause.Limit); ok {
if limit.Limit > 0 || limit.Offset > 0 { var lmt = -1
if limit.Limit <= 0 { if limit.Limit != nil && *limit.Limit >= 0 {
limit.Limit = -1 lmt = *limit.Limit
} }
if lmt >= 0 || limit.Offset > 0 {
builder.WriteString("LIMIT ") builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(limit.Limit)) builder.WriteString(strconv.Itoa(lmt))
} }
if limit.Offset > 0 { if limit.Offset > 0 {
builder.WriteString(" OFFSET ") builder.WriteString(" OFFSET ")

3
vendor/gorm.io/gorm/.gitignore generated vendored
View file

@ -3,4 +3,5 @@ documents
coverage.txt coverage.txt
_book _book
.idea .idea
vendor vendor
.vscode

9
vendor/gorm.io/gorm/.golangci.yml generated vendored
View file

@ -9,3 +9,12 @@ linters:
- prealloc - prealloc
- unconvert - unconvert
- unparam - unparam
- goimports
- whitespace
linters-settings:
whitespace:
multi-func: true
goimports:
local-prefixes: gorm.io/gorm

8
vendor/gorm.io/gorm/README.md generated vendored
View file

@ -30,12 +30,18 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Getting Started ## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io) * GORM Guides [https://gorm.io](https://gorm.io)
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) * Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing ## Contributing
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
## Contributors
Thank you for contributing to the GORM framework!
[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors)
## License ## License
© Jinzhu, 2013~time.Now © Jinzhu, 2013~time.Now

4
vendor/gorm.io/gorm/association.go generated vendored
View file

@ -507,7 +507,9 @@ func (association *Association) buildCondition() *DB {
joinStmt.AddClause(queryClause) joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE") joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
} }
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{

8
vendor/gorm.io/gorm/callbacks.go generated vendored
View file

@ -246,7 +246,13 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
sortCallback func(*callback) error sortCallback func(*callback) error
) )
sort.Slice(cs, func(i, j int) bool { sort.Slice(cs, func(i, j int) bool {
return cs[j].before == "*" || cs[j].after == "*" if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
}) })
for _, c := range cs { for _, c := range cs {

View file

@ -206,9 +206,12 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
} }
} }
cacheKey := utils.ToStringKey(relPrimaryValues) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr { if isPtr {
elems = reflect.Append(elems, elem) elems = reflect.Append(elems, elem)
} else { } else {
@ -253,6 +256,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType = reflect.PtrTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{} objs := []reflect.Value{}
@ -272,19 +276,34 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
joins = reflect.Append(joins, joinValue) joins = reflect.Append(joins, joinValue)
} }
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
if !isPtr {
objs = append(objs, v) elem = elem.Addr()
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
} }
objs = append(objs, v)
elems = reflect.Append(elems, elem)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem)
}
} }
} }
} }
@ -304,7 +323,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
// optimize elems of reflect value length // optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 { if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v { if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems, selectColumns, restricted, nil) saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
} }
for i := 0; i < elemLen; i++ { for i := 0; i < elemLen; i++ {

View file

@ -125,7 +125,7 @@ func checkMissingWhereConditions(db *gorm.DB) {
type visitMap = map[reflect.Value]bool type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded // Check if circular values, return true if loaded
func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
v = v.Elem() v = v.Elem()
} }
@ -134,17 +134,17 @@ func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
loaded = true loaded = true
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(vistMap, v.Index(i)) { if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false loaded = false
} }
} }
case reflect.Struct, reflect.Interface: case reflect.Struct, reflect.Interface:
if v.CanAddr() { if v.CanAddr() {
p := v.Addr() p := v.Addr()
if _, ok := (*vistMap)[p]; ok { if _, ok := (*visitMap)[p]; ok {
return true return true
} }
(*vistMap)[p] = true (*visitMap)[p] = true
} }
} }

Some files were not shown because too many files have changed in this diff Show more