Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding spec.metadata.pemContents to MySQL Binding #3620

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
9 changes: 9 additions & 0 deletions bindings/mysql/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ metadata:
description: "Path to the PEM file. Used with SSL connection"
example: '"path/to/pem/file"'
type: string
- name: pemContents
required: false
description: "Base64-encoded PEM file contents. Used with SSL connection. Supersedes pemPath if both provided."
example: '"-----BEGIN CERTIFICATE-----
MIIFaDCCBFCgAwIBAgISESHkvZFwK9Qz0KsXD3x8p44aMA0GCSqGSIb3DQEBCwUA
...
bml6YXRpb252YWxzaGEyZzIuY3JsMIGgBggrBgEFBQcBAQSBkzCBkDBNBggrBgEF
-----END CERTIFICATE-----"'
type: string
- name: maxIdleConns
required: false
description: "The max idle connections. Integer greater than 0"
Expand Down
107 changes: 89 additions & 18 deletions bindings/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import (
"database/sql"
"database/sql/driver"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"os"
"reflect"
"strconv"
"strings"
"sync/atomic"
"time"

Expand Down Expand Up @@ -80,6 +82,10 @@ type mysqlMetadata struct {
// PemPath is the path to the pem file to connect to MySQL over SSL.
PemPath string `mapstructure:"pemPath"`

// PemContents is the contents of the pem file to connect to MySQL over SSL.
// PemContents supersedes PemPath if both are provided.
PemContents string `mapstructure:"pemContents"`

// MaxIdleConns is the maximum number of connections in the idle connection pool.
MaxIdleConns int `mapstructure:"maxIdleConns"`

Expand Down Expand Up @@ -117,7 +123,34 @@ func (m *Mysql) Init(ctx context.Context, md bindings.Metadata) error {
return errors.New("missing MySql connection string")
}

m.db, err = initDB(meta.URL, meta.PemPath)
var pemContents []byte

// meta.PemContents supersedes meta.PemPath if both are provided.
if meta.PemContents != "" {
// Reformat the PEM to standard format
meta.PemContents = reformatPEM(meta.PemContents)
pemContents = []byte(meta.PemContents)
} else if meta.PemPath != "" {
pemContents, err = os.ReadFile(meta.PemPath)
if err != nil {
return fmt.Errorf("unable to read PEM file: %w", err)
}
}

// Decode PEM contents and parse certificate to ensure it's valid.
if len(pemContents) != 0 {
block, _ := pem.Decode(pemContents)
if block == nil {
return errors.New("failed to decode PEM")
}

_, err = x509.ParseCertificate(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse PEM contents: %w", err)
}
}

m.db, err = initDB(meta.URL, pemContents)
if err != nil {
return err
}
Expand Down Expand Up @@ -234,7 +267,9 @@ func (m *Mysql) Close() error {
}

if m.db != nil {
m.db.Close()
if err := m.db.Close(); err != nil {
m.logger.Warnf("error closing DB: %v", err)
}
m.db = nil
}

Expand All @@ -246,7 +281,12 @@ func (m *Mysql) query(ctx context.Context, sql string, params ...any) ([]byte, e
if err != nil {
return nil, fmt.Errorf("error executing query: %w", err)
}
defer rows.Close()

defer func() {
if err = rows.Close(); err != nil {
m.logger.Warnf("error closing rows: %v", err)
}
}()

result, err := m.jsonify(rows)
if err != nil {
Expand All @@ -265,26 +305,22 @@ func (m *Mysql) exec(ctx context.Context, sql string, params ...any) (int64, err
return res.RowsAffected()
}

func initDB(url, pemPath string) (*sql.DB, error) {
conf, err := mysql.ParseDSN(url)
if err != nil {
return nil, fmt.Errorf("illegal Data Source Name (DSN) specified by %s", connectionURLKey)
}

if pemPath != "" {
var pem []byte
func initDB(url string, pemContents []byte) (*sql.DB, error) {
// We need to register the custom TLS config before parsing the DSN if the user
// has provided a PEM file. DSN parsing will fail if the user has provided a PEM
// file, but the custom TLS config requested (i.e., "custom") is not registered.
if len(pemContents) != 0 {
// Create an empty root cert pool. We will append the PEM contents to this pool.
rootCertPool := x509.NewCertPool()
pem, err = os.ReadFile(pemPath)
if err != nil {
return nil, fmt.Errorf("error reading PEM file from %s: %w", pemPath, err)
}

ok := rootCertPool.AppendCertsFromPEM(pem)
ok := rootCertPool.AppendCertsFromPEM(pemContents)
if !ok {
return nil, errors.New("failed to append PEM")
}

err = mysql.RegisterTLSConfig("custom", &tls.Config{
// Register TLS config with the name "custom". The url must end with &tls=custom
// to use this custom TLS config.
err := mysql.RegisterTLSConfig("custom", &tls.Config{
RootCAs: rootCertPool,
MinVersion: tls.VersionTLS12,
})
Expand All @@ -293,6 +329,12 @@ func initDB(url, pemPath string) (*sql.DB, error) {
}
}

// Parse the DSN to get the connection configuration.
conf, err := mysql.ParseDSN(url)
if err != nil {
return nil, fmt.Errorf("illegal Data Source Name (DSN) specified by %s", connectionURLKey)
}

// Required to correctly parse time columns
// See: https://stackoverflow.com/a/45040724
conf.ParseTime = true
Expand All @@ -306,6 +348,32 @@ func initDB(url, pemPath string) (*sql.DB, error) {
return db, nil
}

// Helper function to reformat a single-line PEM into standard PEM format
func reformatPEM(pemStr string) string {
// Ensure headers and footers are on their own lines
pemStr = strings.ReplaceAll(pemStr, "-----BEGIN CERTIFICATE-----", "\n-----BEGIN CERTIFICATE-----\n")
pemStr = strings.ReplaceAll(pemStr, "-----END CERTIFICATE-----", "\n-----END CERTIFICATE-----")

// Split into base64-encoded content and reformat into 64-character lines
lines := strings.Split(pemStr, "\n")
if len(lines) >= 3 {
encodedContent := lines[1]
lines[1] = strings.Join(chunkString(encodedContent, 64), "\n")
}
return strings.Join(lines, "\n")
}

// Helper function to split a string into chunks of a given size
func chunkString(s string, chunkSize int) []string {
var chunks []string
for len(s) > chunkSize {
chunks = append(chunks, s[:chunkSize])
s = s[chunkSize:]
}
chunks = append(chunks, s)
return chunks
}

func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) {
columnTypes, err := rows.ColumnTypes()
if err != nil {
Expand Down Expand Up @@ -373,6 +441,9 @@ func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []any) map[string]
// GetComponentMetadata returns the metadata of the component.
func (m *Mysql) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
metadataStruct := mysqlMetadata{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.BindingType)
if err := metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.BindingType); err != nil {
m.logger.Warnf("error retrieving metadata info: %v", err)
}

return
}
6 changes: 5 additions & 1 deletion bindings/mysql/mysql_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ func TestMysqlIntegration(t *testing.T) {
err := b.Init(context.Background(), m)
require.NoError(t, err)

defer b.Close()
defer func() {
if err = b.Close(); err != nil {
t.Errorf("failed to close database: %s", err)
}
}()

t.Run("Invoke create table", func(t *testing.T) {
res, err := b.Invoke(context.Background(), &bindings.InvokeRequest{
Expand Down
Loading