Skip to content

Commit

Permalink
feat(mysql): address review comments on validations
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Feb 20, 2025
1 parent 4a1fe95 commit 8a15869
Showing 1 changed file with 90 additions and 68 deletions.
158 changes: 90 additions & 68 deletions flow/connectors/mysql/validate.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
package connmysql


import (
"context"
"database/sql"
"errors"
"fmt"
"strings"

"github.com/go-mysql-org/go-mysql/mysql"
_ "github.com/go-sql-driver/mysql"

"github.com/PeerDB-io/peerdb/flow/connectors/utils"
"github.com/PeerDB-io/peerdb/flow/generated/protos"
"github.com/PeerDB-io/peerdb/flow/shared"
)

type MySQLConnector struct {
conn *sql.DB
config *MySQLConfig
}

func (c *MySQLConnector) CheckSourceTables(ctx context.Context, tableNames []*utils.SchemaTable) error {
if c.conn == nil {
return errors.New("check tables: conn is nil")
Expand All @@ -30,7 +22,7 @@ func (c *MySQLConnector) CheckSourceTables(ctx context.Context, tableNames []*ut
query := fmt.Sprintf("SELECT 1 FROM `%s`.`%s` LIMIT 1", parsedTable.Schema, parsedTable.Table)
_, err := c.conn.QueryContext(ctx, query)
if err != nil {
return fmt.Errorf("error checking table %s.%s: %v", parsedTable.Schema, parsedTable.Table, err)
return fmt.Errorf("error checking table %s.%s: %w", parsedTable.Schema, parsedTable.Table, err)
}
}
return nil
Expand All @@ -41,40 +33,41 @@ func (c *MySQLConnector) CheckReplicationPermissions(ctx context.Context) error
return errors.New("check replication permissions: conn is nil")
}

var replicationPrivilege string
err := c.conn.QueryRowContext(ctx, "SHOW GRANTS FOR CURRENT_USER()").Scan(&replicationPrivilege)
rows, err := c.Execute(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return fmt.Errorf("failed to check replication privileges: %v", err)
return fmt.Errorf("failed to check replication privileges: %w", err)
}

if !strings.Contains(replicationPrivilege, "REPLICATION SLAVE") && !strings.Contains(replicationPrivilege, "REPLICATION CLIENT") {
return errors.New("MySQL user does not have replication privileges")
for _, row := range rows {
if grant, ok := row[0].(string); ok {
if strings.Contains(grant, "REPLICATION SLAVE") || strings.Contains(grant, "REPLICATION CLIENT") {
return nil
}
}
}

return nil
return errors.New("MySQL user does not have replication privileges")
}

func (c *MySQLConnector) CheckReplicationConnectivity(ctx context.Context) error {
if c.conn == nil {
return errors.New("check replication connectivity: conn is nil")
}

var masterLogFile string
var masterLogPos int

err := c.conn.QueryRowContext(ctx, "SHOW MASTER STATUS").Scan(&masterLogFile, &masterLogPos)
rows, err := c.Execute(ctx, "SHOW MASTER STATUS")
if err != nil {
// Handle case where SHOW MASTER STATUS returns no rows (binary logging disabled)
if errors.Is(err, sql.ErrNoRows) {
return errors.New("binary logging is disabled on this MySQL server")
}
return fmt.Errorf("failed to check replication status: %v", err)
return fmt.Errorf("failed to check replication status: %w", err)
}
if len(rows) == 0 {
return errors.New("binary logging is disabled on this MySQL server")
}

masterLogFile, _ := rows[0][0].(string)
masterLogPos, _ := rows[0][1].(int64)

// Additional validation: Check if the values are valid
if masterLogFile == "" || masterLogPos <= 0 {
return errors.New("invalid replication status: missing log file or position")
}
}

return nil
}
Expand All @@ -84,52 +77,83 @@ func (c *MySQLConnector) CheckBinlogSettings(ctx context.Context) error {
return errors.New("check binlog settings: conn is nil")
}

// Check binlog_expire_logs_seconds
var expireSeconds int
err := c.conn.QueryRowContext(ctx, "SELECT @@binlog_expire_logs_seconds").Scan(&expireSeconds)
rows, err := c.Execute(ctx, "SELECT @@binlog_expire_logs_seconds")
if err != nil {
return fmt.Errorf("failed to retrieve binlog_expire_logs_seconds: %v", err)
return fmt.Errorf("failed to retrieve binlog_expire_logs_seconds: %w", err)
}
defer rows.Close()

var expireSeconds int
if rows.Next() {
if err := rows.Scan(&expireSeconds); err != nil {
return fmt.Errorf("failed to scan binlog_expire_logs_seconds: %w", err)
}
}
if expireSeconds <= 86400 {
return errors.New("binlog_expire_logs_seconds is too low. Must be greater than 1 day")
}

// Check binlog_format
var binlogFormat string
err = c.conn.QueryRowContext(ctx, "SELECT @@binlog_format").Scan(&binlogFormat)
if err != nil {
return fmt.Errorf("failed to retrieve binlog_format: %v", err)
rows, err := c.Execute(ctx, "SELECT @@binlog_format")
if err != nil || len(rows) == 0 {
return fmt.Errorf("failed to retrieve binlog_format: %w", err)
}
binlogFormat, _ := rows[0][0].(string)
if binlogFormat != "ROW" {
return errors.New("binlog_format must be set to 'ROW'")
}

// Check binlog_row_metadata
var binlogRowMetadata string
err = c.conn.QueryRowContext(ctx, "SELECT @@binlog_row_metadata").Scan(&binlogRowMetadata)
if err != nil {
return fmt.Errorf("failed to retrieve binlog_row_metadata: %v", err)
}
if binlogRowMetadata != "FULL" {
return errors.New("binlog_row_metadata must be set to 'FULL' for column exclusion support")
}
}

// Check binlog_row_metadata
// rows, err := c.Execute(ctx, "SELECT @@binlog_row_metadata")
// if err != nil {
// return fmt.Errorf("failed to retrieve binlog_row_metadata: %w", err)
// }
// defer rows.Close()

// var binlogRowMetadata string
// if rows.Next() {
// if err := rows.Scan(&binlogRowMetadata); err != nil {
// return fmt.Errorf("failed to scan binlog_row_metadata: %w", err)
// }
// }

// if binlogRowMetadata != "FULL" {
// return errors.New("binlog_row_metadata must be set to 'FULL' for column exclusion support")
// }

// Check binlog_row_image
var binlogRowImage string
err = c.conn.QueryRowContext(ctx, "SELECT @@binlog_row_image").Scan(&binlogRowImage)
// rows, err := c.Execute(ctx, "SELECT @@binlog_row_image")
// if err != nil {
// return fmt.Errorf("failed to retrieve binlog_row_image: %w", err)
// }
// defer rows.Close()

// var binlogRowImage string
// if rows.Next() {
// if err := rows.Scan(&binlogRowImage); err != nil {
// return fmt.Errorf("failed to scan binlog_row_image: %w", err)
// }
// }

// if binlogRowImage != "FULL" {
// return errors.New("binlog_row_image must be set to 'FULL' (equivalent to PostgreSQL's REPLICA IDENTITY FULL)")
// }


// Check binlog_row_value_options
rows, err := c.Execute(ctx, "SELECT @@binlog_row_value_options")
if err != nil {
return fmt.Errorf("failed to retrieve binlog_row_image: %v", err)
}
if binlogRowImage != "FULL" {
return errors.New("binlog_row_image must be set to 'FULL' (equivalent to PostgreSQL's REPLICA IDENTITY FULL)")
return fmt.Errorf("failed to retrieve binlog_row_value_options: %w", err)
}
defer rows.Close()

// Check binlog_row_value_options
var binlogRowValueOptions string
err = c.conn.QueryRowContext(ctx, "SELECT @@binlog_row_value_options").Scan(&binlogRowValueOptions)
if err != nil {
return fmt.Errorf("failed to retrieve binlog_row_value_options: %v", err)
if rows.Next() {
if err := rows.Scan(&binlogRowValueOptions); err != nil {
return fmt.Errorf("failed to scan binlog_row_value_options: %w", err)
}
}

if binlogRowValueOptions != "" {
return errors.New("binlog_row_value_options must be disabled to prevent JSON change deltas")
}
Expand All @@ -142,25 +166,25 @@ func (c *MySQLConnector) ValidateMirrorSource(ctx context.Context, cfg *protos.F
for _, tableMapping := range cfg.TableMappings {
parsedTable, parseErr := utils.ParseSchemaTable(tableMapping.SourceTableIdentifier)
if parseErr != nil {
return fmt.Errorf("invalid source table identifier: %s", parseErr)
return fmt.Errorf("invalid source table identifier: %w", parseErr)
}
sourceTables = append(sourceTables, parsedTable)
}

if err := c.CheckReplicationConnectivity(ctx); err != nil {
return fmt.Errorf("unable to establish replication connectivity: %v", err)
return fmt.Errorf("unable to establish replication connectivity: %w", err)
}

if err := c.CheckReplicationPermissions(ctx); err != nil {
return fmt.Errorf("failed to check replication permissions: %v", err)
return fmt.Errorf("failed to check replication permissions: %w", err)
}

if err := c.CheckSourceTables(ctx, sourceTables); err != nil {
return fmt.Errorf("provided source tables invalidated: %v", err)
return fmt.Errorf("provided source tables invalidated: %w", err)
}

if err := c.CheckBinlogSettings(ctx); err != nil {
return fmt.Errorf("binlog configuration error: %v", err)
return fmt.Errorf("binlog configuration error: %w", err)
}

return nil
Expand All @@ -184,20 +208,18 @@ func (c *MySqlConnector) ValidateCheck(ctx context.Context) error {
}

if err := c.CheckReplicationConnectivity(ctx); err != nil {
return fmt.Errorf("unable to establish replication connectivity: %v", err)
return fmt.Errorf("unable to establish replication connectivity: %w", err)
}

if err := c.CheckReplicationPermissions(ctx); err != nil {
return fmt.Errorf("failed to check replication permissions: %v", err)
return fmt.Errorf("failed to check replication permissions: %w", err)
}

if err := c.CheckBinlogSettings(ctx); err != nil {
return fmt.Errorf("binlog configuration error: %v", err)
if c.config.Flavor == protos.MySqlFlavor_MYSQL_MYSQL {
if err := c.CheckBinlogSettings(ctx); err != nil {
return fmt.Errorf("binlog configuration error: %w", err)
}
}

return nil
}

func (c *MySqlConnector) ValidateMirrorSource(ctx context.Context, cfg *protos.FlowConnectionConfigs) error {
return nil
}

0 comments on commit 8a15869

Please sign in to comment.