diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8e30742 --- /dev/null +++ b/.gitignore @@ -0,0 +1,29 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# IDE +.idea/ +.idea + +# Mac +.DS_Store +*/.DS_Store +!sqldialects.xml + +# Dependencies +vendor + +# local files +/tmp + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out +out/* +.editorconfig diff --git a/go.mod b/go.mod index 27c0648..23de6fd 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/bitcomplete/sqltestutil +module github.com/mohammad-ahmadi-de/sqltestutil go 1.18 diff --git a/migration.go b/migration.go index 45a3a78..7ef088d 100644 --- a/migration.go +++ b/migration.go @@ -2,11 +2,10 @@ package sqltestutil import ( "context" + "database/sql/driver" "io/ioutil" "path/filepath" "sort" - - "github.com/jmoiron/sqlx" ) // RunMigrations reads all of the files matching *.up.sql in migrationDir and @@ -19,18 +18,30 @@ import ( // // Note that this function does not check whether the migration has already been // run. Its primary purpose is to initialize a test database. -func RunMigrations(ctx context.Context, db sqlx.ExecerContext, migrationDir string) error { +func RunMigrations(ctx context.Context, db driver.ExecerContext, migrationDir string, files ...string) error { filenames, err := filepath.Glob(filepath.Join(migrationDir, "*.up.sql")) if err != nil { return err } + var filter map[string]struct{} = nil + if len(files) > 0 { + filter = make(map[string]struct{}) + for i := range files { + filter[files[i]] = struct{}{} + } + } sort.Strings(filenames) for _, filename := range filenames { + if len(files) > 0 { + if _, exist := filter[filepath.Base(filename)]; !exist { + continue + } + } data, err := ioutil.ReadFile(filename) if err != nil { return err } - _, err = db.ExecContext(ctx, string(data)) + _, err = db.ExecContext(ctx, string(data), nil) if err != nil { return err } diff --git a/option.go b/option.go new file mode 100644 index 0000000..25c59c4 --- /dev/null +++ b/option.go @@ -0,0 +1,55 @@ +package sqltestutil + +import "fmt" + +type Option func(*PostgresContainer) + +func WithPassword(password string) Option { + return func(container *PostgresContainer) { + if len(password) == 0 { + panic("sqltestutil: password option can not be empty") + } + container.password = password + } +} +func WithUser(user string) Option { + return func(container *PostgresContainer) { + if len(user) == 0 { + panic("sqltestutil: user option can not be empty") + } + container.user = user + } +} +func WithPort(port uint16) Option { + return func(container *PostgresContainer) { + if port <= 1000 { + panic("sqltestutil: port option can not be less than 1000") + } + container.port = fmt.Sprint(port) + } +} + +func WithVersion(version string) Option { + return func(container *PostgresContainer) { + if len(version) == 0 { + panic("sqltestutil: version option can not be empty") + } + container.version = version + } +} +func WithDBName(dbName string) Option { + return func(container *PostgresContainer) { + if len(dbName) == 0 { + panic("sqltestutil: dbName option can not be empty") + } + container.dbName = dbName + } +} +func WithContainerName(containerName string) Option { + return func(container *PostgresContainer) { + if len(containerName) == 0 { + panic("sqltestutil: containerName option can not be empty") + } + container.containerName = containerName + } +} diff --git a/postgres_container.go b/postgres_container.go index 5f5473f..285409a 100644 --- a/postgres_container.go +++ b/postgres_container.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "errors" "fmt" + "github.com/docker/docker/api/types/filters" "io" "math/big" "net" @@ -19,9 +20,13 @@ import ( // PostgresContainer is a Docker container running Postgres. It can be used to // cheaply start a throwaway Postgres instance for testing. type PostgresContainer struct { - id string - password string - port string + id string + password string + user string + port string + dbName string + version string + containerName string } // StartPostgresContainer starts a new Postgres Docker container. The version @@ -58,20 +63,65 @@ type PostgresContainer struct { // } // // func TestExampleTestSuite(t *testing.T) { -// pg, _ := sqltestutil.StartPostgresContainer(context.Background(), "12") +// pg, _ := sqltestutil.StartPostgresContainer(context.Background(), WithVersion("12")) // defer pg.Shutdown(ctx) // suite.Run(t, &ExampleTestSuite{}) // } // // [1]: https://github.com/golang/go/issues/37206 // [2]: https://github.com/stretchr/testify -func StartPostgresContainer(ctx context.Context, version string) (*PostgresContainer, error) { +func StartPostgresContainer(ctx context.Context, options ...Option) (*PostgresContainer, error) { cli, err := client.NewClientWithOpts(client.FromEnv) if err != nil { panic(err) } defer cli.Close() - image := "postgres:" + version + + containerObj := &PostgresContainer{} + // + // apply options, if any. + // + for i := range options { + options[i](containerObj) + } + // + // set default values + // + if len(containerObj.password) == 0 { + password, err := randomPassword() + if err != nil { + return nil, err + } + containerObj.password = password + } + if len(containerObj.port) == 0 { + port, err := randomPort() + if err != nil { + return nil, err + } + containerObj.port = port + } + if len(containerObj.user) == 0 { + containerObj.user = "pgtest" + } + if len(containerObj.dbName) == 0 { + containerObj.dbName = "pgtest" + } + if len(containerObj.version) == 0 { + containerObj.version = "12" + } + if len(containerObj.containerName) == 0 { + containerObj.containerName = "sqltestutil" + } + // + // remove leaked containers + // + err = containerObj.fixContainerLeak(ctx) + if err != nil { + return nil, err + } + + image := "postgres:" + containerObj.version _, _, err = cli.ImageInspectWithRaw(ctx, image) if err != nil { _, notFound := err.(interface { @@ -91,20 +141,12 @@ func StartPostgresContainer(ctx context.Context, version string) (*PostgresConta } } - password, err := randomPassword() - if err != nil { - return nil, err - } - port, err := randomPort() - if err != nil { - return nil, err - } createResp, err := cli.ContainerCreate(ctx, &container.Config{ Image: image, Env: []string{ - "POSTGRES_DB=pgtest", - "POSTGRES_PASSWORD=" + password, - "POSTGRES_USER=pgtest", + "POSTGRES_DB=" + containerObj.dbName, + "POSTGRES_PASSWORD=" + containerObj.password, + "POSTGRES_USER=" + containerObj.user, }, Healthcheck: &container.HealthConfig{ Test: []string{"CMD-SHELL", "pg_isready -U pgtest"}, @@ -115,10 +157,10 @@ func StartPostgresContainer(ctx context.Context, version string) (*PostgresConta }, &container.HostConfig{ PortBindings: nat.PortMap{ "5432/tcp": []nat.PortBinding{ - {HostPort: port}, + {HostPort: containerObj.port}, }, }, - }, nil, nil, "") + }, nil, nil, containerObj.containerName) if err != nil { return nil, err } @@ -160,17 +202,38 @@ HealthCheck: time.Sleep(500 * time.Millisecond) } } - return &PostgresContainer{ - id: createResp.ID, - password: password, - port: port, - }, nil + containerObj.id = createResp.ID + + return containerObj, nil +} +func (c *PostgresContainer) fixContainerLeak(ctx context.Context) error { + cli, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + return err + } + defer cli.Close() + + data, err := cli.ContainerList(ctx, types.ContainerListOptions{All: true, Filters: filters.NewArgs(filters.Arg("name", c.containerName))}) + if err != nil { + return err + } + for i := range data { + err = cli.ContainerStop(ctx, data[i].ID, nil) + if err != nil { + return err + } + err = cli.ContainerRemove(ctx, data[i].ID, types.ContainerRemoveOptions{}) + if err != nil { + return err + } + } + return nil } // ConnectionString returns a connection URL string that can be used to connect // to the running Postgres container. func (c *PostgresContainer) ConnectionString() string { - return fmt.Sprintf("postgres://pgtest:%s@127.0.0.1:%s/pgtest", c.password, c.port) + return fmt.Sprintf("postgres://%s:%s@127.0.0.1:%s/%s", c.user, c.password, c.port, c.dbName) } // Shutdown cleans up the Postgres container by stopping and removing it. This diff --git a/postgres_container_test.go b/postgres_container_test.go new file mode 100644 index 0000000..44eeba4 --- /dev/null +++ b/postgres_container_test.go @@ -0,0 +1,15 @@ +package sqltestutil_test + +import ( + "context" + "github.com/mohammad-ahmadi-de/sqltestutil" + "testing" +) + +func TestStartPostgresContainer(t *testing.T) { + c, err := sqltestutil.StartPostgresContainer(context.Background(), sqltestutil.WithPort(5321)) + if err != nil { + t.Fatal(err) + } + c.Shutdown(context.Background()) +}