Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
cache: true

- name: Install golangci-lint
run: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
run: go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.62.2

- name: Run golangci-lint
run: make lint
Expand Down
15 changes: 2 additions & 13 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -351,23 +351,12 @@ issues:

exclude-dirs:
- .jetgen/
- terraform/
exclude-files:
- "_test.go"

exclude-rules:
- source: "(noinspection|TODO)"
linters: [ godot ]
- source: "//noinspection"
linters: [ gocritic ]
- path: "usr/local/Cellar/go/"
linters:
- typecheck
- path: "go/pkg/mod/github.com/go-jet/jet"
linters:
- typecheck
- path: "go/pkg/mod/github.com/stretchr/testify"
linters:
- typecheck
- path: "go/pkg/mod/github.com/lib/pq"
linters:
- typecheck
linters: [ gocritic ]
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ test:
go test ./... -v

lint:
golangci-lint run ./...
golangci-lint run --config .golangci.yml --verbose

# ---------------- Golang Utils End -----------------------------------------

Expand Down
53 changes: 37 additions & 16 deletions cmd/lambda/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package main

import (
"database/sql"
"log/slog"
"os"

Expand All @@ -19,44 +20,64 @@ import (
_ "github.com/lib/pq"
)

func init() {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo}))
func initDB() (*sql.DB, error) {
db, err := store.NewDB()
if err != nil {
return nil, err
}
defer db.Close()
return db, nil
}

func initLogger() *slog.Logger {
var logger *slog.Logger
if os.Getenv("DEBUG") == "true" {
logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
} else {
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo}))
}

slog.SetDefault(logger)
return logger
}

func main() {
logger := initLogger()
err := godotenv.Load()
if err != nil {
logger.Info("no .env file")
}
}

func main() {
slog.Info("connecting to database")
db, err := store.NewDB()
logger.Info("connecting to database")
db, err := initDB()
if err != nil {
slog.Error("failed to connect to database", "error", err)
logger.Error("failed to connect to database", "error", err)
os.Exit(1)
}
defer db.Close()
slog.Info("connected to database")
logger.Info("connected to database")

cs := store.NewCategoryStore(db)

if os.Getenv("ENV") == "dev" {
const (
trainRatio = 60
validateRatio = 20
testRatio = 20
)
tCfg := transform.Config{
Version: "v1",
Shuffle: true,
TrainRatio: 60,
ValidateRatio: 20,
TestRatio: 20,
TrainRatio: trainRatio,
ValidateRatio: validateRatio,
TestRatio: testRatio,
}
t := transform.NewTransform(cs, tCfg)
t := transform.NewTransform(logger, cs, tCfg)
if err = t.GenerateDataset(); err != nil {
slog.Error("failed to generate dataset", "error", err)
logger.Error("failed to generate dataset", "error", err)
os.Exit(1)
}
} else {
lh := lambdahandler.NewHandler(cs)
lh := lambdahandler.NewHandler(logger, cs)
lambda.Start(lh.HandleSQSEvent)
}

}
22 changes: 12 additions & 10 deletions internal/lambdahandler/lambdahandler.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Package lambdahandler provides handlers for AWS Lambda functions, specifically designed to process SQS events.
// It integrates with the `store` and `transform` packages to generate datasets based on configurations received through SQS messages.
// The package handles unmarshalling of SQS messages into configuration objects, triggers dataset generation, and manages errors during these processes.
// It's designed to be used in a serverless architecture where an AWS Lambda function is triggered by SQS events to perform data transformation tasks.
// It integrates with the `store` and `transform` packages to generate datasets
// based on configurations received through SQS messages. The package handles unmarshalling of SQS messages
// into configuration objects, triggers dataset generation, and manages errors during these processes.
// It's designed to be used in a serverless architecture where an AWS Lambda function is triggered by
// SQS events to perform data transformation tasks.
package lambdahandler

import (
Expand Down Expand Up @@ -29,9 +31,9 @@ type Handler struct {
// NewHandler creates a new instance of Handler.
// It takes a CategoryStore instance as a dependency and initializes the logger with a component tag.
// Returns a pointer to the created Handler.
func NewHandler(cs *store.CategoryStore) *Handler {
func NewHandler(l *slog.Logger, cs *store.CategoryStore) *Handler {
return &Handler{
log: slog.With("component", "lambda"),
log: l.With("component", "lambda"),
cs: cs,
}
}
Expand All @@ -43,19 +45,19 @@ func NewHandler(cs *store.CategoryStore) *Handler {
// Returns an error if unmarshalling the configuration or generating the dataset fails.
func (h *Handler) HandleSQSEvent(ctx context.Context, sqsEvent events.SQSEvent) error {
for _, record := range sqsEvent.Records {
h.log.Info("processing message", "message_id", record.MessageId)
h.log.InfoContext(ctx, "processing message", "message_id", record.MessageId)
var cfg transform.Config
if err := json.Unmarshal([]byte(record.Body), &cfg); err != nil {
h.log.Error("failed to unmarshal config", "error", err)
h.log.ErrorContext(ctx, "failed to unmarshal config", "error", err)
return ErrUnmarshalConfig
}

t := transform.NewTransform(h.cs, cfg)
t := transform.NewTransform(h.log, h.cs, cfg)
if err := t.GenerateDataset(); err != nil {
h.log.Error("failed to generate dataset", "error", err)
h.log.ErrorContext(ctx, "failed to generate dataset", "error", err)
return err
}
h.log.Info("dataset generated successfully", "version", cfg.Version)
h.log.InfoContext(ctx, "dataset generated successfully", "version", cfg.Version)
}
return nil
}
7 changes: 4 additions & 3 deletions internal/store/category.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
// Package store provides data access objects and methods for interacting with the database.
// It includes functionalities for fetching and manipulating category data used in machine learning dataset generation.
//
//nolint:typecheck
package store

import (
"database/sql"

//nolint:revive,stylecheck // simulate SQL
. "github.com/go-jet/jet/v2/postgres"

"github.com/opplieam/bb-transform/.jetgen/postgres/public/model"
//nolint:revive,stylecheck // simulate SQL
. "github.com/opplieam/bb-transform/.jetgen/postgres/public/table"

"github.com/opplieam/bb-transform/internal/transform"
)

Expand Down
3 changes: 2 additions & 1 deletion internal/store/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ const (
DBMaxIdleConn int = 25
DBMaxIdleTime string = "15m"
DBDSN string = "BUYBETTER_DEV_SUPABASE_DSN"
DBCtxTimeout = 5 * time.Second
)

// NewDB creates a new database connection and configures it with given parameters.
func NewDB() (*sql.DB, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), DBCtxTimeout)
defer cancel()

// Use IPV4 for AWS lambda
Expand Down
33 changes: 20 additions & 13 deletions internal/transform/transform.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Package transform provides functionalities to transform raw category data into a structured dataset suitable for machine learning.
// It handles shuffling, splitting into train/validate/test sets, and database interactions.
// It utilizes a CategoryStorer interface to abstract the underlying data storage mechanism.
// Package transform provides functionalities to transform raw category data into a structured dataset
// suitable for machine learning. It handles shuffling, splitting into train/validate/test sets,
// and database interactions. It utilizes a CategoryStorer interface to abstract the underlying data storage mechanism.
// The package prepares data specifically for training, validating, and testing machine learning models,
// enabling the development of models that can predict or classify categories based on input features.
package transform
Expand All @@ -15,7 +15,8 @@ import (

// Config holds the configuration parameters for the transformation process.
// It includes settings for the dataset version, whether to shuffle the data for randomness,
// and the ratios for splitting the data into train, validate, and test sets, which are crucial for model training and evaluation.
// and the ratios for splitting the data into train, validate, and test sets,
// which are crucial for model training and evaluation.
type Config struct {
Version string `json:"version"`
Shuffle bool `json:"shuffle"`
Expand Down Expand Up @@ -54,9 +55,9 @@ type Transform struct {
// NewTransform creates a new Transform instance, responsible for orchestrating the dataset generation process.
// It initializes the logger with a "transform" component tag for tracking,
// sets the CategoryStorer for data access, and configures the transformation process using the provided Config.
func NewTransform(cs CategoryStorer, cfg Config) *Transform {
func NewTransform(l *slog.Logger, cs CategoryStorer, cfg Config) *Transform {
return &Transform{
log: slog.With("component", "transform"),
log: l.With("component", "transform"),
catStore: cs,
config: cfg,
}
Expand All @@ -65,10 +66,12 @@ func NewTransform(cs CategoryStorer, cfg Config) *Transform {
// GenerateDataset generates a dataset specifically designed for training and evaluating machine learning models.
// It retrieves original and matched categories from the CategoryStorer,
// optionally shuffles the matched categories to ensure a random distribution of data,
// splits the data into train, validate, and test sets according to the configured ratios, which is crucial for model training and performance assessment,
// splits the data into train, validate, and test sets according to the configured ratios,
// which is crucial for model training and performance assessment,
// cleans up previous datasets with the same version from the CategoryStorer to avoid data conflicts,
// and inserts the newly generated dataset into the CategoryStorer.
// The resulting dataset contains features (L1-L8) and labels (FullPathOut, NameOut) that can be used to train a model to predict category paths or names.
// The resulting dataset contains features (L1-L8) and labels (FullPathOut, NameOut)
// that can be used to train a model to predict category paths or names.
// Returns an error if any of the steps fail, using specific error variables for clarity.
func (t *Transform) GenerateDataset() error {
oCat, err := t.catStore.OriginalCategory()
Expand All @@ -87,25 +90,29 @@ func (t *Transform) GenerateDataset() error {
t.log.Info("shuffle matched category")

src := rand.NewSource(time.Now().UnixNano())
//nolint:gosec // No need to use secure random number generator
rng := rand.New(src)
rng.Shuffle(len(mCat), func(i, j int) {
mCat[i], mCat[j] = mCat[j], mCat[i]
})
}

const percentage = 100

// Calculate the number of samples for each label
totalSamples := len(mCat)
numTrain := int(float64(totalSamples) * (float64(t.config.TrainRatio) / 100))
numValidate := int(float64(totalSamples) * (float64(t.config.ValidateRatio) / 100))
numTrain := int(float64(totalSamples) * (float64(t.config.TrainRatio) / percentage))
numValidate := int(float64(totalSamples) * (float64(t.config.ValidateRatio) / percentage))

var dataset []model.CategoryDataset
var label string
for i, v := range mCat {
if i < numTrain {
switch {
case i < numTrain:
label = "train"
} else if i < numTrain+numValidate {
case i < numTrain+numValidate:
label = "validate"
} else {
default:
label = "test"
}

Expand Down
7 changes: 6 additions & 1 deletion internal/transform/transform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package transform

import (
"errors"
"log/slog"
"os"
"testing"

"github.com/opplieam/bb-transform/.jetgen/postgres/public/model"
Expand Down Expand Up @@ -107,10 +109,13 @@ func TestGenerateDataset(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
slog.SetDefault(logger)

mockStorer := NewMockCategoryStorer(t)
tt.mockBehavior(mockStorer)

tr := NewTransform(mockStorer, tt.cfg)
tr := NewTransform(logger, mockStorer, tt.cfg)
err := tr.GenerateDataset()

if tt.wantErr {
Expand Down
Empty file added readme.md
Empty file.
Loading