Skip to content

Commit

Permalink
feat(catalog): Add pagination for list table operation across differe…
Browse files Browse the repository at this point in the history
…nt catalog types (#306)

## **Goal**
To support pagination for ListTables method. Similar to `ListViews`
operation in PR #290. We need
to change the method interface in `catalog.go` , cascading changes into
`glue.go` and `sql.go`. I'm not sure about the pagination in `sql.go` so
I write a wrapper to convert the existing function to iter.Seq2[] type

## **TODO**:
- [ ] Get alignment on the wrapper of `ListTables` method in `sql.go`
- [ ] Add more test for sql.go and glue.go once the first point is
aligned

---------

Signed-off-by: dttung2905 <[email protected]>
  • Loading branch information
dttung2905 authored Feb 20, 2025
1 parent 281e62d commit 6c4e87b
Show file tree
Hide file tree
Showing 8 changed files with 508 additions and 72 deletions.
3 changes: 2 additions & 1 deletion catalog/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"context"
"errors"
"fmt"
"iter"
"maps"
"strings"

Expand Down Expand Up @@ -79,7 +80,7 @@ type Catalog interface {
CommitTable(context.Context, *table.Table, []table.Requirement, []table.Update) (table.Metadata, string, error)
// ListTables returns a list of table identifiers in the catalog, with the returned
// identifiers containing the information required to load the table via that catalog.
ListTables(ctx context.Context, namespace table.Identifier) ([]table.Identifier, error)
ListTables(ctx context.Context, namespace table.Identifier) iter.Seq2[table.Identifier, error]
// LoadTable loads a table from the catalog and returns a Table with the metadata.
LoadTable(ctx context.Context, identifier table.Identifier, props iceberg.Properties) (*table.Table, error)
// DropTable tells the catalog to drop the table entirely.
Expand Down
46 changes: 23 additions & 23 deletions catalog/glue/glue.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
"iter"
"strconv"
_ "unsafe"

Expand Down Expand Up @@ -155,33 +156,34 @@ func NewCatalog(opts ...Option) *Catalog {
// ListTables returns a list of Iceberg tables in the given Glue database.
//
// The namespace should just contain the Glue database name.
func (c *Catalog) ListTables(ctx context.Context, namespace table.Identifier) ([]table.Identifier, error) {
database, err := identifierToGlueDatabase(namespace)
if err != nil {
return nil, err
}

params := &glue.GetTablesInput{CatalogId: c.catalogId, DatabaseName: aws.String(database)}

var icebergTables []table.Identifier

for {
tblsRes, err := c.glueSvc.GetTables(ctx, params)
func (c *Catalog) ListTables(ctx context.Context, namespace table.Identifier) iter.Seq2[table.Identifier, error] {
return func(yield func(table.Identifier, error) bool) {
database, err := identifierToGlueDatabase(namespace)
if err != nil {
return nil, fmt.Errorf("failed to list tables in namespace %s: %w", database, err)
yield(table.Identifier{}, err)
return
}

icebergTables = append(icebergTables,
filterTableListByType(database, tblsRes.TableList, glueTypeIceberg)...)
paginator := glue.NewGetTablesPaginator(c.glueSvc, &glue.GetTablesInput{
CatalogId: c.catalogId,
DatabaseName: aws.String(database),
})

for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
yield(table.Identifier{}, fmt.Errorf("failed to list tables in namespace %s: %w", database, err))
return
}

if tblsRes.NextToken == nil {
break
icebergTables := filterTableListByType(database, page.TableList, glueTypeIceberg)
for _, tbl := range icebergTables {
if !yield(tbl, nil) {
return
}
}
}

params.NextToken = tblsRes.NextToken
}

return icebergTables, nil
}

// LoadTable loads a table from the catalog table details.
Expand Down Expand Up @@ -543,14 +545,12 @@ func DatabaseIdentifier(database string) table.Identifier {

func filterTableListByType(database string, tableList []types.Table, tableType string) []table.Identifier {
var filtered []table.Identifier

for _, tbl := range tableList {
if tbl.Parameters[tableTypePropsKey] != tableType {
continue
}
filtered = append(filtered, TableIdentifier(database, aws.ToString(tbl.Name)))
}

return filtered
}

Expand Down
179 changes: 168 additions & 11 deletions catalog/glue/glue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ package glue
import (
"context"
"errors"
"fmt"
"os"
"testing"

"github.com/apache/iceberg-go/catalog"
"github.com/apache/iceberg-go/table"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/glue"
Expand Down Expand Up @@ -81,13 +83,42 @@ func (m *mockGlueClient) UpdateDatabase(ctx context.Context, params *glue.Update
return args.Get(0).(*glue.UpdateDatabaseOutput), args.Error(1)
}

var testIcebergGlueTable = types.Table{
var testIcebergGlueTable1 = types.Table{
Name: aws.String("test_table"),
Parameters: map[string]string{
tableTypePropsKey: "ICEBERG",
metadataLocationPropsKey: "s3://test-bucket/test_table/metadata/abc123-123.metadata.json",
},
}
var testIcebergGlueTable2 = types.Table{
Name: aws.String("test_table2"),
Parameters: map[string]string{
tableTypePropsKey: "ICEBERG",
metadataLocationPropsKey: "s3://test-bucket/test_table/metadata/abc456-456.metadata.json",
},
}

var testIcebergGlueTable3 = types.Table{
Name: aws.String("test_table3"),
Parameters: map[string]string{
tableTypePropsKey: "ICEBERG",
metadataLocationPropsKey: "s3://test-bucket/test_table/metadata/abc789-789.metadata.json",
},
}
var testIcebergGlueTable4 = types.Table{
Name: aws.String("test_table4"),
Parameters: map[string]string{
tableTypePropsKey: "ICEBERG",
metadataLocationPropsKey: "s3://test-bucket/test_table/metadata/abc123-789.metadata.json",
},
}
var testIcebergGlueTable5 = types.Table{
Name: aws.String("test_table5"),
Parameters: map[string]string{
tableTypePropsKey: "ICEBERG",
metadataLocationPropsKey: "s3://test-bucket/test_table/metadata/abc12345-789.metadata.json",
},
}

var testNonIcebergGlueTable = types.Table{
Name: aws.String("other_table"),
Expand All @@ -104,7 +135,7 @@ func TestGlueGetTable(t *testing.T) {
mockGlueSvc.On("GetTable", mock.Anything, &glue.GetTableInput{
DatabaseName: aws.String("test_database"),
Name: aws.String("test_table"),
}, mock.Anything).Return(&glue.GetTableOutput{Table: &testIcebergGlueTable}, nil)
}, mock.Anything).Return(&glue.GetTableOutput{Table: &testIcebergGlueTable1}, nil)

glueCatalog := &Catalog{
glueSvc: mockGlueSvc,
Expand All @@ -123,19 +154,136 @@ func TestGlueListTables(t *testing.T) {
mockGlueSvc.On("GetTables", mock.Anything, &glue.GetTablesInput{
DatabaseName: aws.String("test_database"),
}, mock.Anything).Return(&glue.GetTablesOutput{
TableList: []types.Table{testIcebergGlueTable, testNonIcebergGlueTable},
TableList: []types.Table{testIcebergGlueTable1, testNonIcebergGlueTable},
}, nil).Once()

glueCatalog := &Catalog{
glueSvc: mockGlueSvc,
}

tables, err := glueCatalog.ListTables(context.TODO(), DatabaseIdentifier("test_database"))
assert.NoError(err)
assert.Len(tables, 1)
assert.Equal([]string{"test_database", "test_table"}, tables[0])
var lastErr error
tbls := make([]table.Identifier, 0)
iter := glueCatalog.ListTables(context.TODO(), DatabaseIdentifier("test_database"))

for tbl, err := range iter {
tbls = append(tbls, tbl)
if err != nil {
lastErr = err
}
}
assert.NoError(lastErr)
assert.Len(tbls, 1)
assert.Equal([]string{"test_database", "test_table"}, tbls[0])
}

func TestGlueListTablesPagination(t *testing.T) {
assert := require.New(t)

mockGlueSvc := &mockGlueClient{}

// First page
mockGlueSvc.On("GetTables", mock.Anything, &glue.GetTablesInput{
DatabaseName: aws.String("test_database"),
}, mock.Anything).Return(&glue.GetTablesOutput{
TableList: []types.Table{
testIcebergGlueTable1,
testIcebergGlueTable2,
},
NextToken: aws.String("token1"),
}, nil).Once()

// Second page
mockGlueSvc.On("GetTables", mock.Anything, &glue.GetTablesInput{
DatabaseName: aws.String("test_database"),
NextToken: aws.String("token1"),
}, mock.Anything).Return(&glue.GetTablesOutput{
TableList: []types.Table{
testIcebergGlueTable3,
testIcebergGlueTable4,
},
NextToken: aws.String("token2"),
}, nil).Once()

// Third page
mockGlueSvc.On("GetTables", mock.Anything, &glue.GetTablesInput{
DatabaseName: aws.String("test_database"),
NextToken: aws.String("token2"),
}, mock.Anything).Return(&glue.GetTablesOutput{
TableList: []types.Table{
testIcebergGlueTable5,
testNonIcebergGlueTable,
},
}, nil).Once()

glueCatalog := &Catalog{
glueSvc: mockGlueSvc,
}

var lastErr error
tbls := make([]table.Identifier, 0)
iter := glueCatalog.ListTables(context.TODO(), DatabaseIdentifier("test_database"))

for tbl, err := range iter {
tbls = append(tbls, tbl)
if err != nil {
lastErr = err
}
}

assert.NoError(lastErr)
assert.Len(tbls, 5) // Only Iceberg tables should be included
assert.Equal([]string{"test_database", "test_table"}, tbls[0])
assert.Equal([]string{"test_database", "test_table2"}, tbls[1])
assert.Equal([]string{"test_database", "test_table3"}, tbls[2])
assert.Equal([]string{"test_database", "test_table4"}, tbls[3])
assert.Equal([]string{"test_database", "test_table5"}, tbls[4])

mockGlueSvc.AssertExpectations(t)
}

func TestGlueListTablesError(t *testing.T) {
assert := require.New(t)

mockGlueSvc := &mockGlueClient{}

// First page succeeds
mockGlueSvc.On("GetTables", mock.Anything, &glue.GetTablesInput{
DatabaseName: aws.String("test_database"),
}, mock.Anything).Return(&glue.GetTablesOutput{
TableList: []types.Table{
testIcebergGlueTable1,
},
NextToken: aws.String("token1"),
}, nil).Once()

mockGlueSvc.On("GetTables", mock.Anything, &glue.GetTablesInput{
DatabaseName: aws.String("test_database"),
NextToken: aws.String("token1"),
}, mock.Anything).Return(&glue.GetTablesOutput{}, fmt.Errorf("token expired")).Once()

glueCatalog := &Catalog{
glueSvc: mockGlueSvc,
}

var lastErr error
tbls := make([]table.Identifier, 0)
iter := glueCatalog.ListTables(context.TODO(), DatabaseIdentifier("test_database"))

for tbl, err := range iter {
if err != nil {
lastErr = err
break
}
tbls = append(tbls, tbl)
}

assert.Error(lastErr)
assert.Contains(lastErr.Error(), "token expired")
assert.Len(tbls, 1)
assert.Equal([]string{"test_database", "test_table"}, tbls[0])

mockGlueSvc.AssertExpectations(t)
}
func TestGlueListNamespaces(t *testing.T) {
assert := require.New(t)

Expand Down Expand Up @@ -175,7 +323,7 @@ func TestGlueDropTable(t *testing.T) {
DatabaseName: aws.String("test_database"),
Name: aws.String("test_table"),
}, mock.Anything).Return(&glue.GetTableOutput{
Table: &testIcebergGlueTable,
Table: &testIcebergGlueTable1,
}, nil).Once()

mockGlueSvc.On("DeleteTable", mock.Anything, &glue.DeleteTableInput{
Expand Down Expand Up @@ -534,9 +682,18 @@ func TestGlueListTablesIntegration(t *testing.T) {

catalog := NewCatalog(WithAwsConfig(awscfg))

tables, err := catalog.ListTables(context.TODO(), DatabaseIdentifier(os.Getenv("TEST_DATABASE_NAME")))
assert.NoError(err)
assert.Equal([]string{os.Getenv("TEST_DATABASE_NAME"), os.Getenv("TEST_TABLE_NAME")}, tables[1])
iter := catalog.ListTables(context.TODO(), DatabaseIdentifier(os.Getenv("TEST_DATABASE_NAME")))
var lastErr error
tbls := make([]table.Identifier, 0)
for tbl, err := range iter {
tbls = append(tbls, tbl)
if err != nil {
lastErr = err
}
}

assert.NoError(lastErr)
assert.Equal([]string{os.Getenv("TEST_DATABASE_NAME"), os.Getenv("TEST_TABLE_NAME")}, tbls[1])
}

func TestGlueLoadTableIntegration(t *testing.T) {
Expand Down
Loading

0 comments on commit 6c4e87b

Please sign in to comment.