diff --git a/typesense/collections.go b/typesense/collections.go index 279114e..27a34b6 100644 --- a/typesense/collections.go +++ b/typesense/collections.go @@ -9,6 +9,7 @@ import ( // CollectionsInterface is a type for Collections API operations type CollectionsInterface interface { Create(schema *api.CollectionSchema) (*api.CollectionResponse, error) + CreateCollectionFromStruct(structData interface{}) (*api.CollectionResponse, error) Retrieve() ([]*api.CollectionResponse, error) } @@ -39,3 +40,23 @@ func (c *collections) Retrieve() ([]*api.CollectionResponse, error) { } return *response.JSON200, nil } + +// CreateCollectionFromStruct creates a Typesense collection from a Go struct. +func (c *collections) CreateCollectionFromStruct(structData interface{}) (*api.CollectionResponse, error) { + // Generate Typesense schema from the Go struct + schema, err := CreateSchemaFromGoStruct(structData) + if err != nil { + return nil, err + } + + // Use the generated schema to create a collection in Typesense + response, err := c.apiClient.CreateCollectionWithResponse(context.Background(), + api.CreateCollectionJSONRequestBody(*schema)) + if err != nil { + return nil, err + } + if response.JSON201 == nil { + return nil, &HTTPError{Status: response.StatusCode(), Body: response.Body} + } + return response.JSON201, nil +} diff --git a/typesense/collections_test.go b/typesense/collections_test.go index 36a360a..7a15fe8 100644 --- a/typesense/collections_test.go +++ b/typesense/collections_test.go @@ -50,6 +50,45 @@ func createNewCollection(name string) *api.CollectionResponse { } } +type MockStruct struct { + Field1 string `typesense:"string"` +} + +func (m MockStruct) CollectionName() string { + return "custom_collection_name" +} + +func TestCreateSchemaFromGoStruct(t *testing.T) { + mockStruct := MockStruct{Field1: "Test"} + + schema, err := CreateSchemaFromGoStruct(mockStruct) + assert.NoError(t, err) + assert.NotNil(t, schema) + assert.Equal(t, "custom_collection_name", schema.Name) +} + +func TestCreateCollectionFromStruct(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAPIClient := mocks.NewMockAPIClientInterface(ctrl) + client := NewClient(WithAPIClient(mockAPIClient)) + mockStruct := MockStruct{Field1: "Test"} + + schema, _ := CreateSchemaFromGoStruct(mockStruct) + expectedResult := &api.CollectionResponse{Name: "custom_collection_name", NumDocuments: pointer.Int64(0)} + + mockAPIClient.EXPECT(). + CreateCollectionWithResponse(gomock.Not(gomock.Nil()), api.CreateCollectionJSONRequestBody(*schema)). + Return(&api.CreateCollectionResponse{JSON201: expectedResult}, nil). + Times(1) + + result, err := client.Collections().CreateCollectionFromStruct(mockStruct) + + assert.NoError(t, err) + assert.Equal(t, expectedResult, result) +} + func TestCollectionCreate(t *testing.T) { newSchema := createNewSchema("companies") expectedResult := createNewCollection("companies") diff --git a/typesense/struct_parser.go b/typesense/struct_parser.go new file mode 100644 index 0000000..1c973a8 --- /dev/null +++ b/typesense/struct_parser.go @@ -0,0 +1,75 @@ +package typesense + +import ( + "errors" + "reflect" + "strings" + + "github.com/typesense/typesense-go/typesense/api" +) + +// CollectionNamer is an interface that provides a method to get the collection name. +type CollectionNamer interface { + CollectionName() string +} + +// CreateSchemaFromGoStruct takes a Go struct and generates a Typesense CollectionSchema. +// If the struct implements the CollectionNamer interface, its CollectionName method is used to get the collection name. +func CreateSchemaFromGoStruct(structData interface{}) (*api.CollectionSchema, error) { + t := reflect.TypeOf(structData) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + var collectionName string + if namer, ok := structData.(CollectionNamer); ok { + collectionName = namer.CollectionName() + } else { + collectionName = t.Name() + } + + fields := make([]api.Field, 0) + var defaultSortingField *string + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + tagValue, ok := field.Tag.Lookup("typesense") + if !ok || tagValue == "-" { + continue + } + + fieldType := field.Type.String() + if fieldType == "uuid.UUID" { + fieldType = "string" + } + + tagParts := strings.Split(tagValue, ",") + facetValue := false // Default facet value + typesenseField := api.Field{ + Name: tagParts[0], + Type: fieldType, + Facet: &facetValue, // Initially false + } + + for _, tagPart := range tagParts { + tagPartTrimmed := strings.TrimSpace(tagPart) + if tagPartTrimmed == "defaultSort" { + if defaultSortingField != nil { + return nil, errors.New("multiple fields marked with 'defaultSort' tag") + } + defaultSortingField = &field.Name + } else if tagPartTrimmed == "facet" { + facetValue = true + typesenseField.Facet = &facetValue + } + } + + fields = append(fields, typesenseField) + } + + return &api.CollectionSchema{ + Name: collectionName, + Fields: fields, + DefaultSortingField: defaultSortingField, + }, nil +}