diff --git a/cdc/entry/mounter_test.go b/cdc/entry/mounter_test.go index 7dc12279dcb..8416a462714 100644 --- a/cdc/entry/mounter_test.go +++ b/cdc/entry/mounter_test.go @@ -1133,13 +1133,13 @@ func TestNewDMRowChange(t *testing.T) { cdcTableInfo := model.WrapTableInfo(0, "test", 0, originTI) cols := []*model.Column{ { - Name: "id", Type: 3, Charset: "binary", Flag: 65, Value: 1, Default: nil, + Name: "id", Type: 3, Charset: "binary", Flag: 65, Value: 1, }, { - Name: "a1", Type: 3, Charset: "binary", Flag: 51, Value: 1, Default: nil, + Name: "a1", Type: 3, Charset: "binary", Flag: 51, Value: 1, }, { - Name: "a3", Type: 3, Charset: "binary", Flag: 51, Value: 2, Default: nil, + Name: "a3", Type: 3, Charset: "binary", Flag: 51, Value: 2, }, } recoveredTI := model.BuildTiDBTableInfo(cols, cdcTableInfo.IndexColumnsOffset) diff --git a/cdc/model/codec/codec.go b/cdc/model/codec/codec.go new file mode 100644 index 00000000000..a361faecfaf --- /dev/null +++ b/cdc/model/codec/codec.go @@ -0,0 +1,228 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package codec + +import ( + "encoding/binary" + + timodel "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tiflow/cdc/model" + codecv1 "github.com/pingcap/tiflow/cdc/model/codec/v1" + "github.com/tinylib/msgp/msgp" +) + +const ( + v1HeaderLength int = 4 + versionPrefixLength int = 2 + versionFieldLength int = 2 + + latestVersion uint16 = 2 +) + +// NOTE: why we need this? +// +// Before this logic is introduced, redo log is encoded into byte slice without a version field. +// This makes it hard to extend in the future. +// However, in the old format (i.e. v1 format), the first 5 bytes are always same, which can be +// confirmed in v1/codec_gen.go. So we reuse those bytes, and add a version field in them. +var ( + versionPrefix = [versionPrefixLength]byte{0xff, 0xff} +) + +func postUnmarshal(r *model.RedoLog) { + workaroundColumn := func(c *model.Column, redoC *model.RedoColumn) { + c.Flag = model.ColumnFlagType(redoC.Flag) + if redoC.ValueIsEmptyBytes { + c.Value = []byte{} + } else { + c.Value = redoC.Value + } + } + + if r.RedoRow.Row != nil { + row := r.RedoRow.Row + for i, c := range row.Columns { + if c != nil { + workaroundColumn(c, &r.RedoRow.Columns[i]) + } + } + for i, c := range row.PreColumns { + if c != nil { + workaroundColumn(c, &r.RedoRow.PreColumns[i]) + } + } + r.RedoRow.Columns = nil + r.RedoRow.PreColumns = nil + } + if r.RedoDDL.DDL != nil { + r.RedoDDL.DDL.Type = timodel.ActionType(r.RedoDDL.Type) + r.RedoDDL.DDL.TableInfo = &model.TableInfo{ + TableName: r.RedoDDL.TableName, + } + } +} + +func preMarshal(r *model.RedoLog) { + // Workaround empty byte slice for msgp#247 + workaroundColumn := func(redoC *model.RedoColumn) { + switch v := redoC.Value.(type) { + case []byte: + if len(v) == 0 { + redoC.ValueIsEmptyBytes = true + } + } + } + + if r.RedoRow.Row != nil { + row := r.RedoRow.Row + r.RedoRow.Columns = make([]model.RedoColumn, 0, len(row.Columns)) + r.RedoRow.PreColumns = make([]model.RedoColumn, 0, len(row.PreColumns)) + for _, c := range row.Columns { + redoC := model.RedoColumn{} + if c != nil { + redoC.Value = c.Value + redoC.Flag = uint64(c.Flag) + workaroundColumn(&redoC) + } + r.RedoRow.Columns = append(r.RedoRow.Columns, redoC) + } + for _, c := range row.PreColumns { + redoC := model.RedoColumn{} + if c != nil { + redoC.Value = c.Value + redoC.Flag = uint64(c.Flag) + workaroundColumn(&redoC) + } + r.RedoRow.PreColumns = append(r.RedoRow.PreColumns, redoC) + } + } + if r.RedoDDL.DDL != nil { + r.RedoDDL.Type = byte(r.RedoDDL.DDL.Type) + if r.RedoDDL.DDL.TableInfo != nil { + r.RedoDDL.TableName = r.RedoDDL.DDL.TableInfo.TableName + } + } +} + +// UnmarshalRedoLog unmarshals a RedoLog from the given byte slice. +func UnmarshalRedoLog(bts []byte) (r *model.RedoLog, o []byte, err error) { + if len(bts) < versionPrefixLength { + err = msgp.ErrShortBytes + return + } + + shouldBeV1 := false + for i := 0; i < versionPrefixLength; i++ { + if bts[i] != versionPrefix[i] { + shouldBeV1 = true + break + } + } + if shouldBeV1 { + var rv1 *codecv1.RedoLog = new(codecv1.RedoLog) + if o, err = rv1.UnmarshalMsg(bts); err != nil { + return + } + codecv1.PostUnmarshal(rv1) + r = redoLogFromV1(rv1) + } else { + bts = bts[versionPrefixLength:] + version, bts := decodeVersion(bts) + if version == latestVersion { + r = new(model.RedoLog) + if o, err = r.UnmarshalMsg(bts); err != nil { + return + } + postUnmarshal(r) + } else { + panic("unsupported codec version") + } + } + return +} + +// MarshalRedoLog marshals a RedoLog into bytes. +func MarshalRedoLog(r *model.RedoLog, b []byte) (o []byte, err error) { + preMarshal(r) + b = append(b, versionPrefix[:]...) + b = binary.BigEndian.AppendUint16(b, latestVersion) + o, err = r.MarshalMsg(b) + return +} + +// MarshalDDLAsRedoLog converts a DDLEvent into RedoLog, and then marshals it. +func MarshalDDLAsRedoLog(d *model.DDLEvent, b []byte) (o []byte, err error) { + log := &model.RedoLog{ + RedoDDL: model.RedoDDLEvent{DDL: d}, + Type: model.RedoLogTypeDDL, + } + return MarshalRedoLog(log, b) +} + +func decodeVersion(bts []byte) (uint16, []byte) { + version := binary.BigEndian.Uint16(bts[0:versionFieldLength]) + return version, bts[versionFieldLength:] +} + +func redoLogFromV1(rv1 *codecv1.RedoLog) (r *model.RedoLog) { + r = &model.RedoLog{Type: (model.RedoLogType)(rv1.Type)} + if rv1.RedoRow != nil && rv1.RedoRow.Row != nil { + r.RedoRow.Row = &model.RowChangedEventInRedoLog{ + StartTs: rv1.RedoRow.Row.StartTs, + CommitTs: rv1.RedoRow.Row.CommitTs, + Table: tableNameFromV1(rv1.RedoRow.Row.Table), + Columns: make([]*model.Column, 0, len(rv1.RedoRow.Row.Columns)), + PreColumns: make([]*model.Column, 0, len(rv1.RedoRow.Row.PreColumns)), + IndexColumns: rv1.RedoRow.Row.IndexColumns, + } + for _, c := range rv1.RedoRow.Row.Columns { + r.RedoRow.Row.Columns = append(r.RedoRow.Row.Columns, columnFromV1(c)) + } + for _, c := range rv1.RedoRow.Row.PreColumns { + r.RedoRow.Row.PreColumns = append(r.RedoRow.Row.PreColumns, columnFromV1(c)) + } + } + if rv1.RedoDDL != nil && rv1.RedoDDL.DDL != nil { + r.RedoDDL.DDL = &model.DDLEvent{ + StartTs: rv1.RedoDDL.DDL.StartTs, + CommitTs: rv1.RedoDDL.DDL.CommitTs, + Query: rv1.RedoDDL.DDL.Query, + TableInfo: rv1.RedoDDL.DDL.TableInfo, + PreTableInfo: rv1.RedoDDL.DDL.PreTableInfo, + Type: rv1.RedoDDL.DDL.Type, + } + r.RedoDDL.DDL.Done.Store(rv1.RedoDDL.DDL.Done) + } + return +} + +func tableNameFromV1(t *codecv1.TableName) *model.TableName { + return &model.TableName{ + Schema: t.Schema, + Table: t.Table, + TableID: t.TableID, + IsPartition: t.IsPartition, + } +} + +func columnFromV1(c *codecv1.Column) *model.Column { + return &model.Column{ + Name: c.Name, + Type: c.Type, + Charset: c.Charset, + Flag: c.Flag, + Value: c.Value, + ApproximateBytes: c.ApproximateBytes, + } +} diff --git a/cdc/model/schema_storage.go b/cdc/model/schema_storage.go index 61bb78a7445..ec53a5da126 100644 --- a/cdc/model/schema_storage.go +++ b/cdc/model/schema_storage.go @@ -17,11 +17,19 @@ import ( "fmt" "github.com/pingcap/log" +<<<<<<< HEAD "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/types" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/util/rowcodec" +======= + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/types" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/util/rowcodec" +>>>>>>> 600286c56d (sink(ticdc): fix incorrect `default` field (#12038)) "go.uber.org/zap" ) @@ -342,5 +350,74 @@ func (ti *TableInfo) IsIndexUnique(indexInfo *model.IndexInfo) bool { // Clone clones the TableInfo func (ti *TableInfo) Clone() *TableInfo { +<<<<<<< HEAD return WrapTableInfo(ti.SchemaID, ti.TableName.Schema, ti.TableInfoVersion, ti.TableInfo.Clone()) +======= + return WrapTableInfo(ti.SchemaID, ti.TableName.Schema, ti.Version, ti.TableInfo.Clone()) +} + +// GetIndex return the corresponding index by the given name. +func (ti *TableInfo) GetIndex(name string) *model.IndexInfo { + for _, index := range ti.Indices { + if index != nil && index.Name.O == name { + return index + } + } + return nil +} + +// IndexByName returns the index columns and offsets of the corresponding index by name +func (ti *TableInfo) IndexByName(name string) ([]string, []int, bool) { + index := ti.GetIndex(name) + if index == nil { + return nil, nil, false + } + names := make([]string, 0, len(index.Columns)) + offset := make([]int, 0, len(index.Columns)) + for _, col := range index.Columns { + names = append(names, col.Name.O) + offset = append(offset, col.Offset) + } + return names, offset, true +} + +// OffsetsByNames returns the column offsets of the corresponding columns by names +// If any column does not exist, return false +func (ti *TableInfo) OffsetsByNames(names []string) ([]int, bool) { + // todo: optimize it + columnOffsets := make(map[string]int, len(ti.Columns)) + for _, col := range ti.Columns { + if col != nil { + columnOffsets[col.Name.O] = col.Offset + } + } + + result := make([]int, 0, len(names)) + for _, col := range names { + offset, ok := columnOffsets[col] + if !ok { + return nil, false + } + result = append(result, offset) + } + + return result, true +} + +// GetPrimaryKeyColumnNames returns the primary key column names +func (ti *TableInfo) GetPrimaryKeyColumnNames() []string { + var result []string + if ti.PKIsHandle { + result = append(result, ti.GetPkColInfo().Name.O) + return result + } + + indexInfo := ti.GetPrimaryKey() + if indexInfo != nil { + for _, col := range indexInfo.Columns { + result = append(result, col.Name.O) + } + } + return result +>>>>>>> 600286c56d (sink(ticdc): fix incorrect `default` field (#12038)) } diff --git a/cdc/model/sink.go b/cdc/model/sink.go index 2e18678ee16..e16b1db8f3b 100644 --- a/cdc/model/sink.go +++ b/cdc/model/sink.go @@ -291,6 +291,57 @@ func (r *RowChangedEvent) IsUpdate() bool { return len(r.PreColumns) != 0 && len(r.Columns) != 0 } +<<<<<<< HEAD +======= +func columnData2Column(col *ColumnData, tableInfo *TableInfo) *Column { + colID := col.ColumnID + offset, ok := tableInfo.columnsOffset[colID] + if !ok { + log.Panic("invalid column id", + zap.Int64("columnID", colID), + zap.Any("tableInfo", tableInfo)) + } + colInfo := tableInfo.Columns[offset] + return &Column{ + Name: colInfo.Name.O, + Type: colInfo.GetType(), + Charset: colInfo.GetCharset(), + Collation: colInfo.GetCollate(), + Flag: *tableInfo.ColumnsFlag[colID], + Value: col.Value, + } +} + +func columnDatas2Columns(cols []*ColumnData, tableInfo *TableInfo) []*Column { + if cols == nil { + return nil + } + columns := make([]*Column, len(cols)) + nilColumnNum := 0 + for i, colData := range cols { + if colData == nil { + nilColumnNum++ + continue + } + columns[i] = columnData2Column(colData, tableInfo) + } + log.Debug("meet nil column data", + zap.Any("nilColumnNum", nilColumnNum), + zap.Any("tableInfo", tableInfo)) + return columns +} + +// GetColumns returns the columns of the event +func (r *RowChangedEvent) GetColumns() []*Column { + return columnDatas2Columns(r.Columns, r.TableInfo) +} + +// GetPreColumns returns the pre columns of the event +func (r *RowChangedEvent) GetPreColumns() []*Column { + return columnDatas2Columns(r.PreColumns, r.TableInfo) +} + +>>>>>>> 600286c56d (sink(ticdc): fix incorrect `default` field (#12038)) // PrimaryKeyColumnNames return all primary key's name func (r *RowChangedEvent) PrimaryKeyColumnNames() []string { var result []string @@ -423,12 +474,21 @@ func (r *RowChangedEvent) ApproximateBytes() int { // Column represents a column value in row changed event type Column struct { +<<<<<<< HEAD Name string `json:"name" msg:"name"` Type byte `json:"type" msg:"type"` Charset string `json:"charset" msg:"charset"` Flag ColumnFlagType `json:"flag" msg:"-"` Value interface{} `json:"value" msg:"value"` Default interface{} `json:"default" msg:"-"` +======= + Name string `msg:"name"` + Type byte `msg:"type"` + Charset string `msg:"charset"` + Collation string `msg:"collation"` + Flag ColumnFlagType `msg:"-"` + Value interface{} `msg:"-"` +>>>>>>> 600286c56d (sink(ticdc): fix incorrect `default` field (#12038)) // ApproximateBytes is approximate bytes consumed by the column. ApproximateBytes int `json:"-"` @@ -745,3 +805,104 @@ func (t *SingleTableTxn) Append(row *RowChangedEvent) { } t.Rows = append(t.Rows, row) } +<<<<<<< HEAD +======= + +// TopicPartitionKey contains the topic and partition key of the message. +type TopicPartitionKey struct { + Topic string + Partition int32 + PartitionKey string + TotalPartition int32 +} + +// ColumnDataX is like ColumnData, but contains more informations. +// +//msgp:ignore RowChangedEvent +type ColumnDataX struct { + *ColumnData + flag *ColumnFlagType + info *model.ColumnInfo +} + +// GetColumnDataX encapsures ColumnData to ColumnDataX. +func GetColumnDataX(col *ColumnData, tb *TableInfo) ColumnDataX { + x := ColumnDataX{ColumnData: col} + if x.ColumnData != nil { + x.flag = tb.ColumnsFlag[col.ColumnID] + x.info = tb.Columns[tb.columnsOffset[col.ColumnID]] + } + return x +} + +// GetName returns name. +func (x ColumnDataX) GetName() string { + return x.info.Name.O +} + +// GetType returns type. +func (x ColumnDataX) GetType() byte { + return x.info.GetType() +} + +// GetCharset returns charset. +func (x ColumnDataX) GetCharset() string { + return x.info.GetCharset() +} + +// GetCollation returns collation. +func (x ColumnDataX) GetCollation() string { + return x.info.GetCollate() +} + +// GetFlag returns flag. +func (x ColumnDataX) GetFlag() ColumnFlagType { + return *x.flag +} + +// GetDefaultValue return default value. +func (x ColumnDataX) GetDefaultValue() interface{} { + return x.info.GetDefaultValue() +} + +// GetColumnInfo returns column info. +func (x ColumnDataX) GetColumnInfo() *model.ColumnInfo { + return x.info +} + +// Columns2ColumnDataForTest is for tests. +func Columns2ColumnDataForTest(columns []*Column) ([]*ColumnData, *TableInfo) { + info := &TableInfo{ + TableInfo: &model.TableInfo{ + Columns: make([]*model.ColumnInfo, len(columns)), + }, + ColumnsFlag: make(map[int64]*ColumnFlagType, len(columns)), + columnsOffset: make(map[int64]int), + } + colDatas := make([]*ColumnData, 0, len(columns)) + + for i, column := range columns { + var columnID int64 = int64(i) + info.columnsOffset[columnID] = i + + info.Columns[i] = &model.ColumnInfo{} + info.Columns[i].Name.O = column.Name + info.Columns[i].SetType(column.Type) + info.Columns[i].SetCharset(column.Charset) + info.Columns[i].SetCollate(column.Collation) + + info.ColumnsFlag[columnID] = new(ColumnFlagType) + *info.ColumnsFlag[columnID] = column.Flag + + colDatas = append(colDatas, &ColumnData{ColumnID: columnID, Value: column.Value}) + } + + return colDatas, info +} + +// Column2ColumnDataXForTest is for tests. +func Column2ColumnDataXForTest(column *Column) ColumnDataX { + datas, info := Columns2ColumnDataForTest([]*Column{column}) + return GetColumnDataX(datas[0], info) +} +>>>>>>> 600286c56d (sink(ticdc): fix incorrect `default` field (#12038)) diff --git a/pkg/sink/cloudstorage/table_definition.go b/pkg/sink/cloudstorage/table_definition.go new file mode 100644 index 00000000000..d87c5304179 --- /dev/null +++ b/pkg/sink/cloudstorage/table_definition.go @@ -0,0 +1,340 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. +package cloudstorage + +import ( + "encoding/json" + "sort" + "strconv" + "strings" + + "github.com/pingcap/log" + timodel "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/charset" + pmodel "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/types" + "github.com/pingcap/tiflow/cdc/model" + "github.com/pingcap/tiflow/pkg/errors" + "github.com/pingcap/tiflow/pkg/hash" + "go.uber.org/zap" +) + +const ( + defaultTableDefinitionVersion = 1 + marshalPrefix = "" + marshalIndent = " " +) + +// TableCol denotes the column info for a table definition. +type TableCol struct { + ID string `json:"ColumnId,omitempty"` + Name string `json:"ColumnName" ` + Tp string `json:"ColumnType"` + Default interface{} `json:"ColumnDefault,omitempty"` + Precision string `json:"ColumnPrecision,omitempty"` + Scale string `json:"ColumnScale,omitempty"` + Nullable string `json:"ColumnNullable,omitempty"` + IsPK string `json:"ColumnIsPk,omitempty"` +} + +// FromTiColumnInfo converts from TiDB ColumnInfo to TableCol. +func (t *TableCol) FromTiColumnInfo(col *timodel.ColumnInfo, outputColumnID bool) { + defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(col.GetType()) + isDecimalNotDefault := col.GetDecimal() != defaultDecimal && + col.GetDecimal() != 0 && + col.GetDecimal() != types.UnspecifiedLength + + displayFlen, displayDecimal := col.GetFlen(), col.GetDecimal() + if displayFlen == types.UnspecifiedLength { + displayFlen = defaultFlen + } + if displayDecimal == types.UnspecifiedLength { + displayDecimal = defaultDecimal + } + + if outputColumnID { + t.ID = strconv.FormatInt(col.ID, 10) + } + t.Name = col.Name.O + t.Tp = strings.ToUpper(types.TypeToStr(col.GetType(), col.GetCharset())) + if mysql.HasUnsignedFlag(col.GetFlag()) { + t.Tp += " UNSIGNED" + } + if mysql.HasPriKeyFlag(col.GetFlag()) { + t.IsPK = "true" + } + if mysql.HasNotNullFlag(col.GetFlag()) { + t.Nullable = "false" + } + t.Default = col.GetDefaultValue() + + switch col.GetType() { + case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration: + if isDecimalNotDefault { + t.Scale = strconv.Itoa(displayDecimal) + } + case mysql.TypeDouble, mysql.TypeFloat: + t.Precision = strconv.Itoa(displayFlen) + if isDecimalNotDefault { + t.Scale = strconv.Itoa(displayDecimal) + } + case mysql.TypeNewDecimal: + t.Precision = strconv.Itoa(displayFlen) + t.Scale = strconv.Itoa(displayDecimal) + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, + mysql.TypeBit, mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeBlob, + mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + t.Precision = strconv.Itoa(displayFlen) + case mysql.TypeYear: + t.Precision = strconv.Itoa(displayFlen) + } +} + +// ToTiColumnInfo converts from TableCol to TiDB ColumnInfo. +func (t *TableCol) ToTiColumnInfo(colID int64) (*timodel.ColumnInfo, error) { + col := new(timodel.ColumnInfo) + + if t.ID != "" { + var err error + col.ID, err = strconv.ParseInt(t.ID, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + } + + col.ID = colID + col.Name = pmodel.NewCIStr(t.Name) + tp := types.StrToType(strings.ToLower(strings.TrimSuffix(t.Tp, " UNSIGNED"))) + col.FieldType = *types.NewFieldType(tp) + if strings.Contains(t.Tp, "UNSIGNED") { + col.AddFlag(mysql.UnsignedFlag) + } + if t.IsPK == "true" { + col.AddFlag(mysql.PriKeyFlag) + } + if t.Nullable == "false" { + col.AddFlag(mysql.NotNullFlag) + } + col.DefaultValue = t.Default + if strings.Contains(t.Tp, "BLOB") || strings.Contains(t.Tp, "BINARY") { + col.SetCharset(charset.CharsetBin) + } else { + col.SetCharset(charset.CharsetUTF8MB4) + } + setFlen := func(precision string) error { + if len(precision) > 0 { + flen, err := strconv.Atoi(precision) + if err != nil { + return errors.Trace(err) + } + col.SetFlen(flen) + } + return nil + } + setDecimal := func(scale string) error { + if len(scale) > 0 { + decimal, err := strconv.Atoi(scale) + if err != nil { + return errors.Trace(err) + } + col.SetDecimal(decimal) + } + return nil + } + switch col.GetType() { + case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration: + err := setDecimal(t.Scale) + if err != nil { + return nil, errors.Trace(err) + } + case mysql.TypeDouble, mysql.TypeFloat, mysql.TypeNewDecimal: + err := setFlen(t.Precision) + if err != nil { + return nil, errors.Trace(err) + } + err = setDecimal(t.Scale) + if err != nil { + return nil, errors.Trace(err) + } + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, + mysql.TypeBit, mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeBlob, + mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeYear: + err := setFlen(t.Precision) + if err != nil { + return nil, errors.Trace(err) + } + } + + return col, nil +} + +// TableDefinition is the detailed table definition used for cloud storage sink. +// TODO: find a better name for this struct. +type TableDefinition struct { + Table string `json:"Table"` + Schema string `json:"Schema"` + Version uint64 `json:"Version"` + TableVersion uint64 `json:"TableVersion"` + Query string `json:"Query"` + Type timodel.ActionType `json:"Type"` + Columns []TableCol `json:"TableColumns"` + TotalColumns int `json:"TableColumnsTotal"` +} + +// tableDefWithoutQuery is the table definition without query, which ignores the +// Query, Type and TableVersion field. +type tableDefWithoutQuery struct { + Table string `json:"Table"` + Schema string `json:"Schema"` + Version uint64 `json:"Version"` + Columns []TableCol `json:"TableColumns"` + TotalColumns int `json:"TableColumnsTotal"` +} + +// FromDDLEvent converts from DDLEvent to TableDefinition. +func (t *TableDefinition) FromDDLEvent(event *model.DDLEvent, outputColumnID bool) { + if event.CommitTs != event.TableInfo.Version { + log.Panic("commit ts and table info version should be equal", + zap.Any("event", event), zap.Any("tableInfo", event.TableInfo), + ) + } + t.FromTableInfo(event.TableInfo, event.TableInfo.Version, outputColumnID) + t.Query = event.Query + t.Type = event.Type +} + +// ToDDLEvent converts from TableDefinition to DDLEvent. +func (t *TableDefinition) ToDDLEvent() (*model.DDLEvent, error) { + tableInfo, err := t.ToTableInfo() + if err != nil { + return nil, err + } + + return &model.DDLEvent{ + TableInfo: tableInfo, + CommitTs: t.TableVersion, + Type: t.Type, + Query: t.Query, + }, nil +} + +// FromTableInfo converts from TableInfo to TableDefinition. +func (t *TableDefinition) FromTableInfo( + info *model.TableInfo, tableInfoVersion model.Ts, outputColumnID bool, +) { + t.Version = defaultTableDefinitionVersion + t.TableVersion = tableInfoVersion + + t.Schema = info.TableName.Schema + if info.TableInfo == nil { + return + } + t.Table = info.TableName.Table + t.TotalColumns = len(info.Columns) + for _, col := range info.Columns { + var tableCol TableCol + tableCol.FromTiColumnInfo(col, outputColumnID) + t.Columns = append(t.Columns, tableCol) + } +} + +// ToTableInfo converts from TableDefinition to DDLEvent. +func (t *TableDefinition) ToTableInfo() (*model.TableInfo, error) { + tidbTableInfo := &timodel.TableInfo{ + Name: pmodel.NewCIStr(t.Table), + } + nextMockID := int64(100) // 100 is an arbitrary number + for _, col := range t.Columns { + tiCol, err := col.ToTiColumnInfo(nextMockID) + if err != nil { + return nil, err + } + if mysql.HasPriKeyFlag(tiCol.GetFlag()) { + // use PKIsHandle to make sure that the primary keys can be detected by `WrapTableInfo` + tidbTableInfo.PKIsHandle = true + } + tidbTableInfo.Columns = append(tidbTableInfo.Columns, tiCol) + nextMockID += 1 + } + info := model.WrapTableInfo(100, t.Schema, 100, tidbTableInfo) + + return info, nil +} + +// IsTableSchema returns whether the TableDefinition is a table schema. +func (t *TableDefinition) IsTableSchema() bool { + if len(t.Columns) != t.TotalColumns { + log.Panic("invalid table definition", zap.Any("tableDef", t)) + } + return t.TotalColumns != 0 +} + +// MarshalWithQuery marshals TableDefinition with Query field. +func (t *TableDefinition) MarshalWithQuery() ([]byte, error) { + data, err := json.MarshalIndent(t, marshalPrefix, marshalIndent) + if err != nil { + return nil, errors.WrapError(errors.ErrMarshalFailed, err) + } + return data, nil +} + +// marshalWithoutQuery marshals TableDefinition without Query field. +func (t *TableDefinition) marshalWithoutQuery() ([]byte, error) { + // sort columns by name + sortedColumns := make([]TableCol, len(t.Columns)) + copy(sortedColumns, t.Columns) + sort.Slice(sortedColumns, func(i, j int) bool { + return sortedColumns[i].Name < sortedColumns[j].Name + }) + + defWithoutQuery := tableDefWithoutQuery{ + Table: t.Table, + Schema: t.Schema, + Columns: sortedColumns, + TotalColumns: t.TotalColumns, + } + + data, err := json.MarshalIndent(defWithoutQuery, marshalPrefix, marshalIndent) + if err != nil { + return nil, errors.WrapError(errors.ErrMarshalFailed, err) + } + return data, nil +} + +// Sum32 returns the 32-bits hash value of TableDefinition. +func (t *TableDefinition) Sum32(hasher *hash.PositionInertia) (uint32, error) { + if hasher == nil { + hasher = hash.NewPositionInertia() + } + hasher.Reset() + data, err := t.marshalWithoutQuery() + if err != nil { + return 0, err + } + + hasher.Write(data) + return hasher.Sum32(), nil +} + +// GenerateSchemaFilePath generates the schema file path for TableDefinition. +func (t *TableDefinition) GenerateSchemaFilePath() (string, error) { + checksum, err := t.Sum32(nil) + if err != nil { + return "", err + } + if !t.IsTableSchema() && t.Table != "" { + log.Panic("invalid table definition", zap.Any("tableDef", t)) + } + return generateSchemaFilePath(t.Schema, t.Table, t.TableVersion, checksum), nil +} diff --git a/pkg/sink/codec/debezium/codec.go b/pkg/sink/codec/debezium/codec.go new file mode 100644 index 00000000000..164fd06b9c1 --- /dev/null +++ b/pkg/sink/codec/debezium/codec.go @@ -0,0 +1,1699 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package debezium + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" + "time" + + "github.com/pingcap/log" + timodel "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tiflow/cdc/model" + cerror "github.com/pingcap/tiflow/pkg/errors" + "github.com/pingcap/tiflow/pkg/sink/codec/common" + "github.com/pingcap/tiflow/pkg/sink/codec/internal" + "github.com/pingcap/tiflow/pkg/util" + "github.com/tikv/client-go/v2/oracle" + "go.uber.org/zap" +) + +type dbzCodec struct { + config *common.Config + clusterID string + nowFunc func() time.Time +} + +func (c *dbzCodec) writeDebeziumFieldValues( + writer *util.JSONWriter, + fieldName string, + cols []*model.ColumnData, + tableInfo *model.TableInfo, +) error { + var err error + colInfos := tableInfo.GetColInfosForRowChangedEvent() + writer.WriteObjectField(fieldName, func() { + for i, col := range cols { + colx := model.GetColumnDataX(col, tableInfo) + err = c.writeDebeziumFieldValue(writer, colx, colInfos[i].Ft) + if err != nil { + log.Error("write Debezium field value meet error", zap.Error(err)) + break + } + } + }) + return err +} + +func (c *dbzCodec) writeDebeziumFieldSchema( + writer *util.JSONWriter, + col model.ColumnDataX, + ft *types.FieldType, +) { + switch col.GetType() { + case mysql.TypeBit: + n := ft.GetFlen() + var v uint64 + var err error + if col.GetDefaultValue() != nil { + val, ok := col.GetDefaultValue().(string) + if !ok { + return + } + v, err = strconv.ParseUint(parseBit(val, n), 2, 64) + if err != nil { + return + } + } + if n == 1 { + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "boolean") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + writer.WriteBoolField("default", v != 0) // bool + } + }) + } else { + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "bytes") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("name", "io.debezium.data.Bits") + writer.WriteIntField("version", 1) + writer.WriteObjectField("parameters", func() { + writer.WriteStringField("length", fmt.Sprintf("%d", n)) + }) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + c.writeBinaryField(writer, "default", getBitFromUint64(n, v)) // binary + } + }) + } + + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, + mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + writer.WriteAnyField("default", col.GetDefaultValue()) + } + }) + + case mysql.TypeEnum: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("name", "io.debezium.data.Enum") + writer.WriteIntField("version", 1) + writer.WriteObjectField("parameters", func() { + elems := ft.GetElems() + parameters := make([]string, 0, len(elems)) + for _, ele := range elems { + parameters = append(parameters, common.EscapeEnumAndSetOptions(ele)) + } + writer.WriteStringField("allowed", strings.Join(parameters, ",")) + }) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + writer.WriteAnyField("default", col.GetDefaultValue()) + } + }) + + case mysql.TypeSet: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("name", "io.debezium.data.EnumSet") + writer.WriteIntField("version", 1) + writer.WriteObjectField("parameters", func() { + writer.WriteStringField("allowed", strings.Join(ft.GetElems(), ",")) + }) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + writer.WriteAnyField("default", col.GetDefaultValue()) + } + }) + + case mysql.TypeDate, mysql.TypeNewDate: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int32") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("name", "io.debezium.time.Date") + writer.WriteIntField("version", 1) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + t, err := time.Parse("2006-01-02", v) + if err != nil { + // For example, time may be invalid like 1000-00-00 + // return nil, nil + if mysql.HasNotNullFlag(ft.GetFlag()) { + writer.WriteInt64Field("default", 0) + } + return + } + year := t.Year() + if year < 70 { + // treats "0018" as 2018 + t = t.AddDate(2000, 0, 0) + } else if year < 100 { + // treats "0099" as 1999 + t = t.AddDate(1900, 0, 0) + } + writer.WriteInt64Field("default", t.UTC().Unix()/60/60/24) + } + }) + + case mysql.TypeDatetime: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int64") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + if ft.GetDecimal() <= 3 { + writer.WriteStringField("name", "io.debezium.time.Timestamp") + } else { + writer.WriteStringField("name", "io.debezium.time.MicroTimestamp") + } + writer.WriteIntField("version", 1) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + if v == "CURRENT_TIMESTAMP" { + writer.WriteInt64Field("default", 0) + return + } + t, err := types.StrToDateTime(types.DefaultStmtNoWarningContext, v, ft.GetDecimal()) + if err != nil { + writer.WriteInt64Field("default", 0) + return + } + gt, err := t.GoTime(time.UTC) + if err != nil { + if mysql.HasNotNullFlag(ft.GetFlag()) { + writer.WriteInt64Field("default", 0) + } + return + } + year := gt.Year() + if year < 70 { + // treats "0018" as 2018 + gt = gt.AddDate(2000, 0, 0) + } else if year < 100 { + // treats "0099" as 1999 + gt = gt.AddDate(1900, 0, 0) + } + if ft.GetDecimal() <= 3 { + writer.WriteInt64Field("default", gt.UnixMilli()) + } else { + writer.WriteInt64Field("default", gt.UnixMicro()) + } + } + }) + + case mysql.TypeTimestamp: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("name", "io.debezium.time.ZonedTimestamp") + writer.WriteIntField("version", 1) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + if v == "CURRENT_TIMESTAMP" { + if mysql.HasNotNullFlag(ft.GetFlag()) { + writer.WriteStringField("default", "1970-01-01T00:00:00Z") + } + return + } + t, err := types.StrToDateTime(types.DefaultStmtNoWarningContext, v, ft.GetDecimal()) + if err != nil { + writer.WriteInt64Field("default", 0) + return + } + if t.Compare(types.MinTimestamp) < 0 { + if mysql.HasNotNullFlag(ft.GetFlag()) { + writer.WriteStringField("default", "1970-01-01T00:00:00Z") + } + return + } + gt, err := t.GoTime(time.UTC) + if err != nil { + writer.WriteInt64Field("default", 0) + return + } + str := gt.Format("2006-01-02T15:04:05") + fsp := ft.GetDecimal() + if fsp > 0 { + tmp := fmt.Sprintf(".%06d", gt.Nanosecond()/1000) + str = str + tmp[:1+fsp] + } + str += "Z" + writer.WriteStringField("default", str) + } + }) + + case mysql.TypeDuration: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int64") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("name", "io.debezium.time.MicroTime") + writer.WriteIntField("version", 1) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + d, _, _, err := types.StrToDuration(types.DefaultStmtNoWarningContext.WithLocation(c.config.TimeZone), v, ft.GetDecimal()) + if err != nil { + return + } + writer.WriteInt64Field("default", d.Microseconds()) + } + }) + + case mysql.TypeJSON: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("name", "io.debezium.data.Json") + writer.WriteIntField("version", 1) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + writer.WriteAnyField("default", col.GetDefaultValue()) + } + }) + + case mysql.TypeTiny: // TINYINT + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int16") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return + } + writer.WriteFloat64Field("default", floatV) + } + }) + + case mysql.TypeShort: // SMALLINT + writer.WriteObjectElement(func() { + if mysql.HasUnsignedFlag(ft.GetFlag()) { + writer.WriteStringField("type", "int32") + } else { + writer.WriteStringField("type", "int16") + } + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return + } + writer.WriteFloat64Field("default", floatV) + } + }) + + case mysql.TypeInt24: // MEDIUMINT + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int32") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return + } + writer.WriteFloat64Field("default", floatV) + } + }) + + case mysql.TypeLong: // INT + writer.WriteObjectElement(func() { + if col.GetFlag().IsUnsigned() { + writer.WriteStringField("type", "int64") + } else { + writer.WriteStringField("type", "int32") + } + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return + } + writer.WriteFloat64Field("default", floatV) + } + }) + + case mysql.TypeLonglong: // BIGINT + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int64") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return + } + writer.WriteFloat64Field("default", floatV) + } + }) + + case mysql.TypeFloat: + writer.WriteObjectElement(func() { + if ft.GetDecimal() != -1 { + writer.WriteStringField("type", "double") + } else { + writer.WriteStringField("type", "float") + } + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return + } + writer.WriteFloat64Field("default", floatV) + } + }) + + case mysql.TypeDouble, mysql.TypeNewDecimal: + // https://dev.mysql.com/doc/refman/8.4/en/numeric-types.html + // MySQL also treats REAL as a synonym for DOUBLE PRECISION (a nonstandard variation), unless the REAL_AS_FLOAT SQL mode is enabled. + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "double") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return + } + writer.WriteFloat64Field("default", floatV) + } + }) + + case mysql.TypeYear: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int32") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("name", "io.debezium.time.Year") + writer.WriteIntField("version", 1) + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + v, ok := col.GetDefaultValue().(string) + if !ok { + return + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return + } + if floatV < 70 { + // treats "DEFAULT 1" as 2001 + floatV += 2000 + } else if floatV < 100 { + // treats "DEFAULT 99" as 1999 + floatV += 1900 + } + writer.WriteFloat64Field("default", floatV) + } + }) + + case mysql.TypeTiDBVectorFloat32: + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", !mysql.HasNotNullFlag(ft.GetFlag())) + writer.WriteStringField("name", "io.debezium.data.TiDBVectorFloat32") + writer.WriteStringField("field", col.GetName()) + if col.GetDefaultValue() != nil { + writer.WriteAnyField("default", col.GetDefaultValue()) + } + }) + + default: + log.Warn( + "meet unsupported field type", + zap.Any("fieldType", col.GetType()), + zap.Any("column", col.GetName()), + ) + } +} + +// See https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-data-types +// +//revive:disable indent-error-flow +func (c *dbzCodec) writeDebeziumFieldValue( + writer *util.JSONWriter, + col model.ColumnDataX, + ft *types.FieldType, +) error { + value := col.Value + if value == nil { + value = col.GetDefaultValue() + } + if value == nil { + writer.WriteNullField(col.GetName()) + return nil + } + switch col.GetType() { + case mysql.TypeBit: + n := ft.GetFlen() + var v uint64 + switch val := value.(type) { + case uint64: + v = val + case string: + hexValue, err := strconv.ParseUint(parseBit(val, n), 2, 64) + if err != nil { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type string for bit column %s, error:%s", + col.GetName(), err.Error()) + } + v = hexValue + default: + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for bit column %s", + col.Value, + col.GetName()) + } + // Debezium behavior: + // BIT(1) → BOOLEAN + // BIT(>1) → BYTES The byte[] contains the bits in little-endian form and is sized to + // contain the specified number of bits. + if n == 1 { + writer.WriteBoolField(col.GetName(), v != 0) + return nil + } else { + c.writeBinaryField(writer, col.GetName(), getBitFromUint64(n, v)) + return nil + } + + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, + mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: + isBinary := col.GetFlag().IsBinary() + switch v := value.(type) { + case []byte: + if !isBinary { + writer.WriteStringField(col.GetName(), common.UnsafeBytesToString(v)) + } else { + c.writeBinaryField(writer, col.GetName(), v) + } + case string: + if isBinary { + c.writeBinaryField(writer, col.GetName(), common.UnsafeStringToBytes(v)) + } + writer.WriteStringField(col.GetName(), v) + } + return nil + + case mysql.TypeEnum: + v, ok := value.(uint64) + if !ok { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for enum column %s", + value, + col.GetName()) + } + enumVar, err := types.ParseEnumValue(ft.GetElems(), v) + if err != nil { + // Invalid enum value inserted in non-strict mode. + writer.WriteStringField(col.GetName(), "") + return nil + } + writer.WriteStringField(col.GetName(), enumVar.Name) + return nil + + case mysql.TypeSet: + v, ok := value.(uint64) + if !ok { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for set column %s", + value, + col.GetName()) + } + setVar, err := types.ParseSetValue(ft.GetElems(), v) + if err != nil { + // Invalid enum value inserted in non-strict mode. + writer.WriteStringField(col.GetName(), "") + return nil + } + writer.WriteStringField(col.GetName(), setVar.Name) + return nil + + case mysql.TypeNewDecimal: + v, ok := value.(string) + if !ok { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for decimal column %s", + value, + col.GetName()) + } + floatV, err := strconv.ParseFloat(v, 64) + if err != nil { + return cerror.WrapError( + cerror.ErrDebeziumEncodeFailed, + err) + } + writer.WriteFloat64Field(col.GetName(), floatV) + return nil + + case mysql.TypeDate, mysql.TypeNewDate: + v, ok := value.(string) + if !ok { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for date column %s", + value, + col.GetName()) + } + t, err := time.Parse("2006-01-02", v) + if err != nil { + // For example, time may be invalid like 1000-00-00 + // return nil, nil + if mysql.HasNotNullFlag(ft.GetFlag()) { + writer.WriteInt64Field(col.GetName(), 0) + } else { + writer.WriteNullField(col.GetName()) + } + return nil + } + year := t.Year() + if year < 70 { + // treats "0018" as 2018 + t = t.AddDate(2000, 0, 0) + } else if year < 100 { + // treats "0099" as 1999 + t = t.AddDate(1900, 0, 0) + } + + writer.WriteInt64Field(col.GetName(), t.UTC().Unix()/60/60/24) + return nil + + case mysql.TypeDatetime: + // Debezium behavior from doc: + // > Such columns are converted into epoch milliseconds or microseconds based on the + // > column's precision by using UTC. + v, ok := value.(string) + if !ok { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for datetime column %s", + value, + col.GetName()) + } + if v == "CURRENT_TIMESTAMP" { + writer.WriteInt64Field(col.GetName(), 0) + return nil + } + t, err := types.StrToDateTime(types.DefaultStmtNoWarningContext, v, ft.GetDecimal()) + if err != nil { + return cerror.WrapError( + cerror.ErrDebeziumEncodeFailed, + err) + } + gt, err := t.GoTime(time.UTC) + if err != nil { + if mysql.HasNotNullFlag(ft.GetFlag()) { + writer.WriteInt64Field(col.GetName(), 0) + } else { + writer.WriteNullField(col.GetName()) + } + return nil + } + year := gt.Year() + if year < 70 { + // treats "0018" as 2018 + gt = gt.AddDate(2000, 0, 0) + } else if year < 100 { + // treats "0099" as 1999 + gt = gt.AddDate(1900, 0, 0) + } + if ft.GetDecimal() <= 3 { + writer.WriteInt64Field(col.GetName(), gt.UnixMilli()) + } else { + writer.WriteInt64Field(col.GetName(), gt.UnixMicro()) + } + return nil + + case mysql.TypeTimestamp: + // Debezium behavior from doc: + // > The TIMESTAMP type represents a timestamp without time zone information. + // > It is converted by MySQL from the server (or session's) current time zone into UTC + // > when writing and from UTC into the server (or session's) current time zone when reading + // > back the value. + // > Such columns are converted into an equivalent io.debezium.time.ZonedTimestamp in UTC + // > based on the server (or session's) current time zone. The time zone will be queried from + // > the server by default. If this fails, it must be specified explicitly by the database + // > connectionTimeZone MySQL configuration option. + v, ok := value.(string) + if !ok { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for timestamp column %s", + value, + col.GetName()) + } + if v == "CURRENT_TIMESTAMP" { + if mysql.HasNotNullFlag(ft.GetFlag()) { + writer.WriteStringField(col.GetName(), "1970-01-01T00:00:00Z") + } else { + writer.WriteNullField(col.GetName()) + } + return nil + } + t, err := types.StrToDateTime(types.DefaultStmtNoWarningContext.WithLocation(c.config.TimeZone), v, ft.GetDecimal()) + if err != nil { + return cerror.WrapError( + cerror.ErrDebeziumEncodeFailed, + err) + } + if t.Compare(types.MinTimestamp) < 0 { + if col.Value == nil { + writer.WriteNullField(col.GetName()) + } else { + writer.WriteStringField(col.GetName(), "1970-01-01T00:00:00Z") + } + return nil + } + gt, err := t.GoTime(c.config.TimeZone) + if err != nil { + return cerror.WrapError( + cerror.ErrDebeziumEncodeFailed, + err) + } + str := gt.UTC().Format("2006-01-02T15:04:05") + fsp := ft.GetDecimal() + if fsp > 0 { + tmp := fmt.Sprintf(".%06d", gt.Nanosecond()/1000) + str = str + tmp[:1+fsp] + } + str += "Z" + writer.WriteStringField(col.GetName(), str) + return nil + + case mysql.TypeDuration: + // Debezium behavior from doc: + // > Represents the time value in microseconds and does not include + // > time zone information. MySQL allows M to be in the range of 0-6. + v, ok := value.(string) + if !ok { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for time column %s", + value, + col.GetName()) + } + d, _, _, err := types.StrToDuration(types.DefaultStmtNoWarningContext.WithLocation(c.config.TimeZone), v, ft.GetDecimal()) + if err != nil { + return cerror.WrapError( + cerror.ErrDebeziumEncodeFailed, + err) + } + writer.WriteInt64Field(col.GetName(), d.Microseconds()) + return nil + + case mysql.TypeLonglong, mysql.TypeLong, mysql.TypeInt24, mysql.TypeShort, mysql.TypeTiny: + // Note: Although Debezium's doc claims to use INT32 for INT, but it + // actually uses INT64. Debezium also uses INT32 for SMALLINT. + isUnsigned := col.GetFlag().IsUnsigned() + maxValue := types.GetMaxValue(ft) + minValue := types.GetMinValue(ft) + switch v := value.(type) { + case uint64: + if !isUnsigned { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for unsigned int column %s", + value, + col.GetName()) + } + if ft.GetType() == mysql.TypeLonglong && v == maxValue.GetUint64() || v > maxValue.GetUint64() { + writer.WriteAnyField(col.GetName(), -1) + } else { + writer.WriteInt64Field(col.GetName(), int64(v)) + } + case int64: + if isUnsigned { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for int column %s", + value, + col.GetName()) + } + if v < minValue.GetInt64() || v > maxValue.GetInt64() { + writer.WriteAnyField(col.GetName(), -1) + } else { + writer.WriteInt64Field(col.GetName(), v) + } + case string: + if isUnsigned { + t, err := strconv.ParseUint(v, 10, 64) + if err != nil { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type string for unsigned int column %s", + col.GetName()) + } + if ft.GetType() == mysql.TypeLonglong && t == maxValue.GetUint64() || t > maxValue.GetUint64() { + writer.WriteAnyField(col.GetName(), -1) + } else { + writer.WriteInt64Field(col.GetName(), int64(t)) + } + } else { + t, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type string for int column %s", + col.GetName()) + } + if t < minValue.GetInt64() || t > maxValue.GetInt64() { + writer.WriteAnyField(col.GetName(), -1) + } else { + writer.WriteInt64Field(col.GetName(), t) + } + } + } + return nil + + case mysql.TypeDouble, mysql.TypeFloat: + if v, ok := value.(string); ok { + val, err := strconv.ParseFloat(v, 64) + if err != nil { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type string for int column %s", + col.GetName()) + } + writer.WriteFloat64Field(col.GetName(), val) + } else { + writer.WriteAnyField(col.GetName(), value) + } + return nil + + case mysql.TypeTiDBVectorFloat32: + v, ok := value.(types.VectorFloat32) + if !ok { + return cerror.ErrDebeziumEncodeFailed.GenWithStack( + "unexpected column value type %T for unsigned vector column %s", + value, + col.GetName()) + } + writer.WriteStringField(col.GetName(), v.String()) + return nil + } + + writer.WriteAnyField(col.GetName(), value) + return nil +} + +func (c *dbzCodec) writeBinaryField(writer *util.JSONWriter, fieldName string, value []byte) { + // TODO: Deal with different binary output later. + writer.WriteBase64StringField(fieldName, value) +} + +func (c *dbzCodec) writeSourceSchema(writer *util.JSONWriter) { + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "struct") + writer.WriteArrayField("fields", func() { + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "version") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "connector") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "name") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int64") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "ts_ms") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", true) + writer.WriteStringField("name", "io.debezium.data.Enum") + writer.WriteIntField("version", 1) + writer.WriteObjectField("parameters", func() { + writer.WriteStringField("allowed", "true,last,false,incremental") + }) + writer.WriteStringField("default", "false") + writer.WriteStringField("field", "snapshot") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "db") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", true) + writer.WriteStringField("field", "sequence") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", true) + writer.WriteStringField("field", "table") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int64") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "server_id") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", true) + writer.WriteStringField("field", "gtid") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "file") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int64") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "pos") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int32") + writer.WriteBoolField("optional", false) + writer.WriteStringField("field", "row") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "int64") + writer.WriteBoolField("optional", true) + writer.WriteStringField("field", "thread") + }) + writer.WriteObjectElement(func() { + writer.WriteStringField("type", "string") + writer.WriteBoolField("optional", true) + writer.WriteStringField("field", "query") + }) + }) + writer.WriteBoolField("optional", false) + writer.WriteStringField("name", "io.debezium.connector.mysql.Source") + writer.WriteStringField("field", "source") + }) +} + +// EncodeKey encode RowChangedEvent into key message +func (c *dbzCodec) EncodeKey( + e *model.RowChangedEvent, + dest io.Writer, +) error { + // schema field describes the structure of the primary key, or the unique key if the table does not have a primary key, for the table that was changed. + // see https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-events + colDataXs, colInfos := e.HandleKeyColDataXInfos() + jWriter := util.BorrowJSONWriter(dest) + defer util.ReturnJSONWriter(jWriter) + + var err error + jWriter.WriteObject(func() { + jWriter.WriteObjectField("payload", func() { + for i, col := range colDataXs { + err = c.writeDebeziumFieldValue(jWriter, col, colInfos[i].Ft) + } + }) + if !c.config.DebeziumDisableSchema { + jWriter.WriteObjectField("schema", func() { + jWriter.WriteStringField("type", "struct") + jWriter.WriteStringField("name", + fmt.Sprintf("%s.Key", getSchemaTopicName(c.clusterID, e.TableInfo.GetSchemaName(), e.TableInfo.GetTableName()))) + jWriter.WriteBoolField("optional", false) + jWriter.WriteArrayField("fields", func() { + for i, col := range colDataXs { + c.writeDebeziumFieldSchema(jWriter, col, colInfos[i].Ft) + } + }) + }) + } + }) + return err +} + +// EncodeValue encode RowChangedEvent into value message +func (c *dbzCodec) EncodeValue( + e *model.RowChangedEvent, + dest io.Writer, +) error { + jWriter := util.BorrowJSONWriter(dest) + defer util.ReturnJSONWriter(jWriter) + + commitTime := oracle.GetTimeFromTS(e.CommitTs) + + var err error + + jWriter.WriteObject(func() { + jWriter.WriteObjectField("payload", func() { + jWriter.WriteObjectField("source", func() { + jWriter.WriteStringField("version", "2.4.0.Final") + jWriter.WriteStringField("connector", "TiCDC") + jWriter.WriteStringField("name", c.clusterID) + // ts_ms: In the source object, ts_ms indicates the time that the change was made in the database. + // https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-create-events + jWriter.WriteInt64Field("ts_ms", commitTime.UnixMilli()) + // snapshot field is a string of true,last,false,incremental + jWriter.WriteStringField("snapshot", "false") + jWriter.WriteStringField("db", e.TableInfo.GetSchemaName()) + jWriter.WriteStringField("table", e.TableInfo.GetTableName()) + jWriter.WriteInt64Field("server_id", 0) + jWriter.WriteNullField("gtid") + jWriter.WriteStringField("file", "") + jWriter.WriteInt64Field("pos", 0) + jWriter.WriteInt64Field("row", 0) + jWriter.WriteInt64Field("thread", 0) + jWriter.WriteNullField("query") + + // The followings are TiDB extended fields + jWriter.WriteUint64Field("commit_ts", e.CommitTs) + jWriter.WriteStringField("cluster_id", c.clusterID) + }) + + // ts_ms: displays the time at which the connector processed the event + // https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-create-events + jWriter.WriteInt64Field("ts_ms", c.nowFunc().UnixMilli()) + jWriter.WriteNullField("transaction") + if e.IsInsert() { + // op: Mandatory string that describes the type of operation that caused the connector to generate the event. + // Valid values are: + // c = create + // u = update + // d = delete + // r = read (applies to only snapshots) + // https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-create-events + jWriter.WriteStringField("op", "c") + + // before: An optional field that specifies the state of the row before the event occurred. + // When the op field is c for create, the before field is null since this change event is for new content. + // In a delete event value, the before field contains the values that were in the row before + // it was deleted with the database commit. + jWriter.WriteNullField("before") + + // after: An optional field that specifies the state of the row after the event occurred. + // Optional field that specifies the state of the row after the event occurred. + // In a delete event value, the after field is null, signifying that the row no longer exists. + err = c.writeDebeziumFieldValues(jWriter, "after", e.Columns, e.TableInfo) + } else if e.IsDelete() { + jWriter.WriteStringField("op", "d") + jWriter.WriteNullField("after") + err = c.writeDebeziumFieldValues(jWriter, "before", e.PreColumns, e.TableInfo) + } else if e.IsUpdate() { + jWriter.WriteStringField("op", "u") + if c.config.DebeziumOutputOldValue { + err = c.writeDebeziumFieldValues(jWriter, "before", e.PreColumns, e.TableInfo) + } + if err == nil { + err = c.writeDebeziumFieldValues(jWriter, "after", e.Columns, e.TableInfo) + } + } + }) + + if !c.config.DebeziumDisableSchema { + jWriter.WriteObjectField("schema", func() { + jWriter.WriteStringField("type", "struct") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("name", + fmt.Sprintf("%s.Envelope", getSchemaTopicName(c.clusterID, e.TableInfo.GetSchemaName(), e.TableInfo.GetTableName()))) + jWriter.WriteIntField("version", 1) + jWriter.WriteArrayField("fields", func() { + // schema is the same for `before` and `after`. So we build a new buffer to + // build the JSON, so that content can be reused. + var fieldsJSON string + { + fieldsBuf := &bytes.Buffer{} + fieldsWriter := util.BorrowJSONWriter(fieldsBuf) + var validCols []*model.ColumnData + if e.IsInsert() { + validCols = e.Columns + } else if e.IsDelete() { + validCols = e.PreColumns + } else if e.IsUpdate() { + validCols = e.Columns + } + colInfos := e.TableInfo.GetColInfosForRowChangedEvent() + for i, col := range validCols { + colx := model.GetColumnDataX(col, e.TableInfo) + c.writeDebeziumFieldSchema(fieldsWriter, colx, colInfos[i].Ft) + } + if e.TableInfo.HasVirtualColumns() { + for _, colInfo := range e.TableInfo.Columns { + if model.IsColCDCVisible(colInfo) { + continue + } + data := &model.ColumnData{ColumnID: colInfo.ID} + colx := model.GetColumnDataX(data, e.TableInfo) + c.writeDebeziumFieldSchema(fieldsWriter, colx, &colInfo.FieldType) + } + } + util.ReturnJSONWriter(fieldsWriter) + fieldsJSON = fieldsBuf.String() + } + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "struct") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("name", + fmt.Sprintf("%s.Value", getSchemaTopicName(c.clusterID, e.TableInfo.GetSchemaName(), e.TableInfo.GetTableName()))) + jWriter.WriteStringField("field", "before") + jWriter.WriteArrayField("fields", func() { + jWriter.WriteRaw(fieldsJSON) + }) + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "struct") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("name", + fmt.Sprintf("%s.Value", getSchemaTopicName(c.clusterID, e.TableInfo.GetSchemaName(), e.TableInfo.GetTableName()))) + jWriter.WriteStringField("field", "after") + jWriter.WriteArrayField("fields", func() { + jWriter.WriteRaw(fieldsJSON) + }) + }) + c.writeSourceSchema(jWriter) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "string") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "op") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "int64") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("field", "ts_ms") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "struct") + jWriter.WriteArrayField("fields", func() { + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "string") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "id") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "int64") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "total_order") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "int64") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "data_collection_order") + }) + }) + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("name", "event.block") + jWriter.WriteIntField("version", 1) + jWriter.WriteStringField("field", "transaction") + }) + }) + }) + } + }) + return err +} + +// EncodeDDLEvent encode DDLEvent into debezium change event +func (c *dbzCodec) EncodeDDLEvent( + e *model.DDLEvent, + keyDest io.Writer, + dest io.Writer, +) error { + keyJWriter := util.BorrowJSONWriter(keyDest) + jWriter := util.BorrowJSONWriter(dest) + defer util.ReturnJSONWriter(keyJWriter) + defer util.ReturnJSONWriter(jWriter) + + commitTime := oracle.GetTimeFromTS(e.CommitTs) + var changeType string + // refer to: https://docs.pingcap.com/tidb/dev/mysql-compatibility#ddl-operations + switch e.Type { + case timodel.ActionCreateSchema, + timodel.ActionCreateTable, + timodel.ActionCreateView: + changeType = "CREATE" + case timodel.ActionAddColumn, + timodel.ActionModifyColumn, + timodel.ActionDropColumn, + timodel.ActionMultiSchemaChange, + timodel.ActionAddTablePartition, + timodel.ActionRemovePartitioning, + timodel.ActionReorganizePartition, + timodel.ActionExchangeTablePartition, + timodel.ActionAlterTablePartitioning, + timodel.ActionTruncateTablePartition, + timodel.ActionDropTablePartition, + timodel.ActionRebaseAutoID, + timodel.ActionSetDefaultValue, + timodel.ActionModifyTableComment, + timodel.ActionModifyTableCharsetAndCollate, + timodel.ActionModifySchemaCharsetAndCollate, + timodel.ActionAddIndex, + timodel.ActionAlterIndexVisibility, + timodel.ActionRenameIndex, + timodel.ActionRenameTable, + timodel.ActionRecoverTable, + timodel.ActionAddPrimaryKey, + timodel.ActionDropPrimaryKey, + timodel.ActionAlterTTLInfo, + timodel.ActionAlterTTLRemove: + changeType = "ALTER" + case timodel.ActionDropSchema, + timodel.ActionDropTable, + timodel.ActionDropIndex, + timodel.ActionDropView: + changeType = "DROP" + default: + return cerror.ErrDDLUnsupportType.GenWithStackByArgs(e.Type, e.Query) + } + + var err error + dbName, tableName := getDBTableName(e) + // message key + keyJWriter.WriteObject(func() { + keyJWriter.WriteObjectField("payload", func() { + if e.Type == timodel.ActionDropTable { + keyJWriter.WriteStringField("databaseName", e.PreTableInfo.GetSchemaName()) + } else { + keyJWriter.WriteStringField("databaseName", dbName) + } + }) + if !c.config.DebeziumDisableSchema { + keyJWriter.WriteObjectField("schema", func() { + keyJWriter.WriteStringField("type", "struct") + keyJWriter.WriteStringField("name", "io.debezium.connector.mysql.SchemaChangeKey") + keyJWriter.WriteBoolField("optional", false) + keyJWriter.WriteIntField("version", 1) + keyJWriter.WriteArrayField("fields", func() { + keyJWriter.WriteObjectElement(func() { + keyJWriter.WriteStringField("field", "databaseName") + keyJWriter.WriteBoolField("optional", false) + keyJWriter.WriteStringField("type", "string") + }) + }) + }) + } + }) + + // message value + jWriter.WriteObject(func() { + jWriter.WriteObjectField("payload", func() { + jWriter.WriteObjectField("source", func() { + jWriter.WriteStringField("version", "2.4.0.Final") + jWriter.WriteStringField("connector", "TiCDC") + jWriter.WriteStringField("name", c.clusterID) + jWriter.WriteInt64Field("ts_ms", commitTime.UnixMilli()) + jWriter.WriteStringField("snapshot", "false") + if e.TableInfo == nil { + jWriter.WriteStringField("db", "") + jWriter.WriteStringField("table", "") + } else { + jWriter.WriteStringField("db", dbName) + jWriter.WriteStringField("table", tableName) + } + jWriter.WriteInt64Field("server_id", 0) + jWriter.WriteNullField("gtid") + jWriter.WriteStringField("file", "") + jWriter.WriteInt64Field("pos", 0) + jWriter.WriteInt64Field("row", 0) + jWriter.WriteInt64Field("thread", 0) + jWriter.WriteNullField("query") + + // The followings are TiDB extended fields + jWriter.WriteUint64Field("commit_ts", e.CommitTs) + jWriter.WriteStringField("cluster_id", c.clusterID) + }) + jWriter.WriteInt64Field("ts_ms", c.nowFunc().UnixMilli()) + + if e.Type == timodel.ActionDropTable { + jWriter.WriteStringField("databaseName", e.PreTableInfo.GetSchemaName()) + } else { + jWriter.WriteStringField("databaseName", dbName) + } + jWriter.WriteNullField("schemaName") + jWriter.WriteStringField("ddl", e.Query) + jWriter.WriteArrayField("tableChanges", func() { + // return early if there is no table changes + if tableName == "" { + return + } + jWriter.WriteObjectElement(func() { + // Describes the kind of change. The value is one of the following: + // CREATE: Table created. + // ALTER: Table modified. + // DROP: Table deleted. + jWriter.WriteStringField("type", changeType) + // In the case of a table rename, this identifier is a concatenation of , table names. + if e.Type == timodel.ActionRenameTable { + jWriter.WriteStringField("id", fmt.Sprintf("\"%s\".\"%s\",\"%s\".\"%s\"", + e.PreTableInfo.GetSchemaName(), + e.PreTableInfo.GetTableName(), + dbName, + tableName)) + } else { + jWriter.WriteStringField("id", fmt.Sprintf("\"%s\".\"%s\"", + dbName, + tableName)) + } + // return early if there is no table info + if e.Type == timodel.ActionDropTable { + jWriter.WriteNullField("table") + return + } + jWriter.WriteObjectField("table", func() { + jWriter.WriteStringField("defaultCharsetName", e.TableInfo.Charset) + jWriter.WriteArrayField("primaryKeyColumnNames", func() { + for _, pk := range e.TableInfo.GetPrimaryKeyColumnNames() { + jWriter.WriteStringElement(pk) + } + }) + jWriter.WriteArrayField("columns", func() { + parseColumns(e.Query, e.TableInfo.Columns) + for pos, col := range e.TableInfo.Columns { + if col.Hidden { + continue + } + jWriter.WriteObjectElement(func() { + flag := col.GetFlag() + jdbcType := internal.MySQLType2JdbcType(col.GetType(), mysql.HasBinaryFlag(flag)) + expression, name := getExpressionAndName(col.FieldType) + jWriter.WriteStringField("name", col.Name.O) + jWriter.WriteIntField("jdbcType", int(jdbcType)) + jWriter.WriteNullField("nativeType") + if col.Comment != "" { + jWriter.WriteStringField("comment", col.Comment) + } else { + jWriter.WriteNullField("comment") + } + if col.DefaultValue == nil { + jWriter.WriteNullField("defaultValueExpression") + } else { + v, ok := col.DefaultValue.(string) + if ok { + if strings.ToUpper(v) == "CURRENT_TIMESTAMP" { + // https://debezium.io/documentation/reference/3.0/connectors/mysql.html#mysql-temporal-types + jWriter.WriteAnyField("defaultValueExpression", "1970-01-01 00:00:00") + } else if v == "" { + jWriter.WriteNullField("defaultValueExpression") + } else if col.DefaultValueBit != nil { + jWriter.WriteStringField("defaultValueExpression", parseBit(v, col.GetFlen())) + } else { + jWriter.WriteStringField("defaultValueExpression", v) + } + } else { + jWriter.WriteAnyField("defaultValueExpression", col.DefaultValue) + } + } + elems := col.GetElems() + if len(elems) != 0 { + // Format is ENUM ('e1', 'e2') or SET ('e1', 'e2') + jWriter.WriteArrayField("enumValues", func() { + for _, ele := range elems { + jWriter.WriteStringElement(fmt.Sprintf("'%s'", ele)) + } + }) + } else { + jWriter.WriteNullField("enumValues") + } + + jWriter.WriteStringField("typeName", name) + jWriter.WriteStringField("typeExpression", expression) + + charsetName := getCharset(col.FieldType) + if charsetName != "" { + jWriter.WriteStringField("charsetName", charsetName) + } else { + jWriter.WriteNullField("charsetName") + } + + length := getLen(col.FieldType) + if length != -1 { + jWriter.WriteIntField("length", length) + } else { + jWriter.WriteNullField("length") + } + + scale := getScale(col.FieldType) + if scale != -1 { + jWriter.WriteFloat64Field("scale", scale) + } else { + jWriter.WriteNullField("scale") + } + jWriter.WriteIntField("position", pos+1) + jWriter.WriteBoolField("optional", !mysql.HasNotNullFlag(flag)) + + updateNowWithTimestamp := mysql.HasOnUpdateNowFlag(flag) && jdbcType == internal.JavaSQLTypeTIMESTAMP_WITH_TIMEZONE + autoIncrementFlag := mysql.HasAutoIncrementFlag(flag) || updateNowWithTimestamp + + jWriter.WriteBoolField("autoIncremented", autoIncrementFlag) + jWriter.WriteBoolField("generated", autoIncrementFlag) + }) + } + }) + jWriter.WriteNullField("comment") + }) + }) + }) + }) + + if !c.config.DebeziumDisableSchema { + jWriter.WriteObjectField("schema", func() { + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "struct") + jWriter.WriteIntField("version", 1) + jWriter.WriteStringField("name", "io.debezium.connector.mysql.SchemaChangeValue") + jWriter.WriteArrayField("fields", func() { + c.writeSourceSchema(jWriter) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "ts_ms") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "int64") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "databaseName") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "schemaName") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "ddl") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "tableChanges") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "array") + jWriter.WriteObjectField("items", func() { + jWriter.WriteStringField("name", "io.debezium.connector.schema.Change") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "struct") + jWriter.WriteIntField("version", 1) + jWriter.WriteArrayField("fields", func() { + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "type") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "id") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "table") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "struct") + jWriter.WriteStringField("name", "io.debezium.connector.schema.Table") + jWriter.WriteIntField("version", 1) + jWriter.WriteArrayField("fields", func() { + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "defaultCharsetName") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "primaryKeyColumnNames") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "array") + jWriter.WriteObjectField("items", func() { + jWriter.WriteStringField("type", "string") + jWriter.WriteBoolField("optional", false) + }) + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "columns") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "array") + jWriter.WriteObjectField("items", func() { + jWriter.WriteStringField("name", "io.debezium.connector.schema.Column") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "struct") + jWriter.WriteIntField("version", 1) + jWriter.WriteArrayField("fields", func() { + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "name") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "jdbcType") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "int32") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "nativeType") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "int32") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "typeName") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "typeExpression") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "charsetName") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "length") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "int32") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "scale") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "int32") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "position") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("type", "int32") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "optional") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "boolean") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "autoIncremented") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "boolean") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "generated") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "boolean") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "comment") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "defaultValueExpression") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "string") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "enumValues") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "array") + jWriter.WriteObjectField("items", func() { + jWriter.WriteStringField("type", "string") + jWriter.WriteBoolField("optional", false) + }) + }) + }) + }) + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("field", "comment") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("type", "string") + }) + }) + }) + }) + }) + }) + }) + }) + } + }) + return err +} + +// EncodeCheckpointEvent encode checkpointTs into debezium change event +func (c *dbzCodec) EncodeCheckpointEvent( + ts uint64, + keyDest io.Writer, + dest io.Writer, +) error { + keyJWriter := util.BorrowJSONWriter(keyDest) + jWriter := util.BorrowJSONWriter(dest) + defer util.ReturnJSONWriter(keyJWriter) + defer util.ReturnJSONWriter(jWriter) + commitTime := oracle.GetTimeFromTS(ts) + var err error + // message key + keyJWriter.WriteObject(func() { + keyJWriter.WriteObjectField("payload", func() {}) + if !c.config.DebeziumDisableSchema { + keyJWriter.WriteObjectField("schema", func() { + keyJWriter.WriteStringField("type", "struct") + keyJWriter.WriteStringField("name", + fmt.Sprintf("%s.%s.Key", common.SanitizeName(c.clusterID), "watermark")) + keyJWriter.WriteBoolField("optional", false) + keyJWriter.WriteArrayField("fields", func() { + }) + }) + } + }) + // message value + jWriter.WriteObject(func() { + jWriter.WriteObjectField("payload", func() { + jWriter.WriteObjectField("source", func() { + jWriter.WriteStringField("version", "2.4.0.Final") + jWriter.WriteStringField("connector", "TiCDC") + jWriter.WriteStringField("name", c.clusterID) + // ts_ms: In the source object, ts_ms indicates the time that the change was made in the database. + // https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-create-events + jWriter.WriteInt64Field("ts_ms", commitTime.UnixMilli()) + // snapshot field is a string of true,last,false,incremental + jWriter.WriteStringField("snapshot", "false") + jWriter.WriteStringField("db", "") + jWriter.WriteStringField("table", "") + jWriter.WriteInt64Field("server_id", 0) + jWriter.WriteNullField("gtid") + jWriter.WriteStringField("file", "") + jWriter.WriteInt64Field("pos", 0) + jWriter.WriteInt64Field("row", 0) + jWriter.WriteInt64Field("thread", 0) + jWriter.WriteNullField("query") + + // The followings are TiDB extended fields + jWriter.WriteUint64Field("commit_ts", ts) + jWriter.WriteStringField("cluster_id", c.clusterID) + }) + + // ts_ms: displays the time at which the connector processed the event + // https://debezium.io/documentation/reference/stable/connectors/mysql.html#mysql-create-events + jWriter.WriteInt64Field("ts_ms", c.nowFunc().UnixMilli()) + jWriter.WriteNullField("transaction") + jWriter.WriteStringField("op", "m") + }) + + if !c.config.DebeziumDisableSchema { + jWriter.WriteObjectField("schema", func() { + jWriter.WriteStringField("type", "struct") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("name", + fmt.Sprintf("%s.%s.Envelope", common.SanitizeName(c.clusterID), "watermark")) + jWriter.WriteIntField("version", 1) + jWriter.WriteArrayField("fields", func() { + c.writeSourceSchema(jWriter) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "string") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "op") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "int64") + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("field", "ts_ms") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "struct") + jWriter.WriteArrayField("fields", func() { + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "string") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "id") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "int64") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "total_order") + }) + jWriter.WriteObjectElement(func() { + jWriter.WriteStringField("type", "int64") + jWriter.WriteBoolField("optional", false) + jWriter.WriteStringField("field", "data_collection_order") + }) + }) + jWriter.WriteBoolField("optional", true) + jWriter.WriteStringField("name", "event.block") + jWriter.WriteIntField("version", 1) + jWriter.WriteStringField("field", "transaction") + }) + }) + }) + } + }) + return err +} diff --git a/pkg/sink/codec/debezium/helper.go b/pkg/sink/codec/debezium/helper.go new file mode 100644 index 00000000000..9029f49a4b9 --- /dev/null +++ b/pkg/sink/codec/debezium/helper.go @@ -0,0 +1,252 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package debezium + +import ( + "encoding/binary" + "fmt" + "strings" + + "github.com/pingcap/log" + timodel "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + pmodel "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + "github.com/pingcap/tiflow/cdc/model" + "github.com/pingcap/tiflow/pkg/sink/codec/common" + "go.uber.org/zap" +) + +type visiter struct { + columnsMap map[pmodel.CIStr]*timodel.ColumnInfo +} + +func (v *visiter) Enter(n ast.Node) (node ast.Node, skipChildren bool) { + return n, false +} + +func (v *visiter) Leave(n ast.Node) (node ast.Node, ok bool) { + switch col := n.(type) { + case *ast.ColumnDef: + c := v.columnsMap[col.Name.Name] + if col.Tp != nil { + parseType(c, col) + } + c.Comment = "" // disable comment + } + return n, true +} + +func extractValue(expr ast.ExprNode) any { + switch v := expr.(type) { + case *driver.ValueExpr: + return fmt.Sprintf("%v", v.GetValue()) + case *ast.FuncCallExpr: + return v.FnName.String() + } + return nil +} + +func parseType(c *timodel.ColumnInfo, col *ast.ColumnDef) { + ft := col.Tp + switch ft.GetType() { + case mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp, mysql.TypeYear: + if ft.GetType() == mysql.TypeYear { + c.SetFlen(ft.GetFlen()) + } else { + c.SetDecimal(ft.GetDecimal()) + } + parseOptions(col.Options, c) + default: + } +} + +func parseOptions(options []*ast.ColumnOption, c *timodel.ColumnInfo) { + for _, option := range options { + switch option.Tp { + case ast.ColumnOptionDefaultValue: + defaultValue := extractValue(option.Expr) + if defaultValue == nil { + continue + } + if err := c.SetDefaultValue(defaultValue); err != nil { + log.Error("failed to set default value") + } + } + } +} + +func parseColumns(sql string, columns []*timodel.ColumnInfo) { + p := parser.New() + stmt, err := p.ParseOneStmt(sql, mysql.DefaultCharset, mysql.DefaultCollationName) + if err != nil { + log.Error("format query parse one stmt failed", zap.Error(err)) + } + + columnsMap := make(map[pmodel.CIStr]*timodel.ColumnInfo, len(columns)) + for _, col := range columns { + columnsMap[col.Name] = col + } + stmt.Accept(&visiter{columnsMap: columnsMap}) +} + +func parseBit(s string, n int) string { + var result string + if len(s) > 0 { + // Leading zeros may be omitted + result = fmt.Sprintf("%0*b", n%8, s[0]) + } + for i := 1; i < len(s); i++ { + result += fmt.Sprintf("%08b", s[i]) + } + return result +} + +func getCharset(ft types.FieldType) string { + if ft.GetCharset() == "binary" { + return "" + } + switch ft.GetType() { + case mysql.TypeTimestamp, mysql.TypeDuration, mysql.TypeNewDecimal, mysql.TypeString, mysql.TypeVarchar, + mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeEnum, mysql.TypeSet: + return ft.GetCharset() + } + return "" +} + +func getLen(ft types.FieldType) int { + defaultFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(ft.GetType()) + decimal := ft.GetDecimal() + flen := ft.GetFlen() + switch ft.GetType() { + case mysql.TypeTimestamp, mysql.TypeDuration, mysql.TypeDatetime: + return decimal + case mysql.TypeBit, mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTiDBVectorFloat32, + mysql.TypeLonglong, mysql.TypeFloat, mysql.TypeDouble: + if flen != defaultFlen { + return flen + } + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong: + if mysql.HasUnsignedFlag(ft.GetFlag()) { + defaultFlen -= 1 + } + if ft.GetType() == mysql.TypeTiny && mysql.HasZerofillFlag(ft.GetFlag()) { + defaultFlen += 1 + } + if flen != defaultFlen { + return flen + } + case mysql.TypeYear, mysql.TypeNewDecimal: + return flen + case mysql.TypeSet: + return 2*len(ft.GetElems()) - 1 + case mysql.TypeEnum: + return 1 + } + return -1 +} + +func getScale(ft types.FieldType) float64 { + switch ft.GetType() { + case mysql.TypeNewDecimal, mysql.TypeFloat, mysql.TypeDouble: + return float64(ft.GetDecimal()) + } + return -1 +} + +func getSuffix(ft types.FieldType) string { + suffix := "" + decimal := ft.GetDecimal() + flen := ft.GetFlen() + defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(ft.GetType()) + isDecimalNotDefault := decimal != defaultDecimal && decimal != 0 && decimal != -1 + + // displayFlen and displayDecimal are flen and decimal values with `-1` substituted with default value. + displayFlen, displayDecimal := flen, decimal + if displayFlen == -1 { + displayFlen = defaultFlen + } + if displayDecimal == -1 { + displayDecimal = defaultDecimal + } + + switch ft.GetType() { + case mysql.TypeDouble: + // 1. flen Not Default, decimal Not Default -> Valid + // 2. flen Not Default, decimal Default (-1) -> Invalid + // 3. flen Default, decimal Not Default -> Valid + // 4. flen Default, decimal Default -> Valid (hide)W + if isDecimalNotDefault { + suffix = fmt.Sprintf("(%d,%d)", displayFlen, displayDecimal) + } + case mysql.TypeNewDecimal: + suffix = fmt.Sprintf("(%d,%d)", displayFlen, displayDecimal) + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString: + if !mysql.HasBinaryFlag(ft.GetFlag()) && displayFlen != 1 { + suffix = fmt.Sprintf("(%d)", displayFlen) + } + case mysql.TypeYear: + suffix = fmt.Sprintf("(%d)", flen) + case mysql.TypeTiDBVectorFloat32: + if flen != -1 { + suffix = fmt.Sprintf("(%d)", flen) + } + case mysql.TypeNull: + suffix = "(0)" + } + return suffix +} + +func getExpressionAndName(ft types.FieldType) (string, string) { + prefix := strings.ToUpper(types.TypeToStr(ft.GetType(), ft.GetCharset())) + switch ft.GetType() { + case mysql.TypeYear, mysql.TypeBit, mysql.TypeVarchar, mysql.TypeString, mysql.TypeNewDecimal: + return prefix, prefix + } + cs := prefix + getSuffix(ft) + var suf string + if mysql.HasZerofillFlag(ft.GetFlag()) { + suf = " UNSIGNED ZEROFILL" + } else if mysql.HasUnsignedFlag(ft.GetFlag()) { + suf = " UNSIGNED" + } + return cs + suf, prefix + suf +} + +func getBitFromUint64(n int, v uint64) []byte { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], v) + numBytes := n / 8 + if n%8 != 0 { + numBytes += 1 + } + return buf[:numBytes] +} + +func getDBTableName(e *model.DDLEvent) (string, string) { + if e.TableInfo == nil { + return "", "" + } + return e.TableInfo.GetSchemaName(), e.TableInfo.GetTableName() +} + +func getSchemaTopicName(namespace string, schema string, table string) string { + return fmt.Sprintf("%s.%s.%s", + common.SanitizeName(namespace), + common.SanitizeName(schema), + common.SanitizeTopicName(table)) +} diff --git a/pkg/sink/codec/debezium/helper_test.go b/pkg/sink/codec/debezium/helper_test.go new file mode 100644 index 00000000000..3bd8461c40e --- /dev/null +++ b/pkg/sink/codec/debezium/helper_test.go @@ -0,0 +1,67 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package debezium + +import ( + "testing" + + timodel "github.com/pingcap/tidb/pkg/meta/model" + pmodel "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/stretchr/testify/require" +) + +func TestGetColumns(t *testing.T) { + sql := "CREATE TABLE test (id INT PRIMARY KEY, val1 datetime default current_timestamp, val2 time(2) default 0, val3 timestamp(3) default now(), val4 YEAR(4) default 1970 comment 'first');" + columnInfos := []*timodel.ColumnInfo{ + { + Name: pmodel.NewCIStr("id"), + FieldType: *types.NewFieldType(mysql.TypeLong), + }, + { + Name: pmodel.NewCIStr("val1"), + FieldType: *types.NewFieldType(mysql.TypeDatetime), + }, + { + Name: pmodel.NewCIStr("val2"), + FieldType: *types.NewFieldType(mysql.TypeDuration), + }, + { + Name: pmodel.NewCIStr("val3"), + FieldType: *types.NewFieldType(mysql.TypeTimestamp), + }, + { + Name: pmodel.NewCIStr("val4"), + FieldType: *types.NewFieldType(mysql.TypeYear), + }, + } + parseColumns(sql, columnInfos) + require.Equal(t, columnInfos[1].GetDefaultValue(), "CURRENT_TIMESTAMP") + require.Equal(t, columnInfos[2].GetDecimal(), 2) + require.Equal(t, columnInfos[2].GetDefaultValue(), "0") + require.Equal(t, columnInfos[3].GetDecimal(), 3) + require.Equal(t, columnInfos[3].GetDefaultValue(), "CURRENT_TIMESTAMP") + require.Equal(t, columnInfos[4].GetFlen(), 4) + require.Equal(t, columnInfos[4].GetDefaultValue(), "1970") + require.Equal(t, columnInfos[4].Comment, "") +} + +func TestGetSchemaTopicName(t *testing.T) { + namespace := "default" + schema := "1A.B" + table := "columnNameWith中文" + name := getSchemaTopicName(namespace, schema, table) + require.Equal(t, name, "default._1A_B.columnNameWith__") +} diff --git a/pkg/sink/codec/simple/avro.go b/pkg/sink/codec/simple/avro.go new file mode 100644 index 00000000000..8fd93b61a6e --- /dev/null +++ b/pkg/sink/codec/simple/avro.go @@ -0,0 +1,571 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package simple + +import ( + "sort" + "sync" + "time" + + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/types" + "github.com/pingcap/tiflow/cdc/model" +) + +func newTableSchemaMap(tableInfo *model.TableInfo) interface{} { + pkInIndexes := false + indexesSchema := make([]interface{}, 0, len(tableInfo.Indices)) + for _, idx := range tableInfo.Indices { + index := map[string]interface{}{ + "name": idx.Name.O, + "unique": idx.Unique, + "primary": idx.Primary, + "nullable": false, + } + columns := make([]string, 0, len(idx.Columns)) + for _, col := range idx.Columns { + columns = append(columns, col.Name.O) + colInfo := tableInfo.Columns[col.Offset] + // An index is not null when all columns of are not null + if !mysql.HasNotNullFlag(colInfo.GetFlag()) { + index["nullable"] = true + } + } + index["columns"] = columns + if idx.Primary { + pkInIndexes = true + } + indexesSchema = append(indexesSchema, index) + } + + // sometimes the primary key is not in the index, we need to find it manually. + if !pkInIndexes { + pkColumns := tableInfo.GetPrimaryKeyColumnNames() + if len(pkColumns) != 0 { + index := map[string]interface{}{ + "name": "primary", + "nullable": false, + "primary": true, + "unique": true, + "columns": pkColumns, + } + indexesSchema = append(indexesSchema, index) + } + } + + sort.SliceStable(tableInfo.Columns, func(i, j int) bool { + return tableInfo.Columns[i].ID < tableInfo.Columns[j].ID + }) + + columnsSchema := make([]interface{}, 0, len(tableInfo.Columns)) + for _, col := range tableInfo.Columns { + mysqlType := map[string]interface{}{ + "mysqlType": types.TypeToStr(col.GetType(), col.GetCharset()), + "charset": col.GetCharset(), + "collate": col.GetCollate(), + "length": col.GetFlen(), + } + + switch col.GetType() { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, + mysql.TypeFloat, mysql.TypeDouble, mysql.TypeBit, mysql.TypeYear: + mysqlType["unsigned"] = map[string]interface{}{ + "boolean": mysql.HasUnsignedFlag(col.GetFlag()), + } + mysqlType["zerofill"] = map[string]interface{}{ + "boolean": mysql.HasZerofillFlag(col.GetFlag()), + } + case mysql.TypeEnum, mysql.TypeSet: + mysqlType["elements"] = map[string]interface{}{ + "array": col.GetElems(), + } + case mysql.TypeNewDecimal: + mysqlType["decimal"] = map[string]interface{}{ + "int": col.GetDecimal(), + } + mysqlType["unsigned"] = map[string]interface{}{ + "boolean": mysql.HasUnsignedFlag(col.GetFlag()), + } + mysqlType["zerofill"] = map[string]interface{}{ + "boolean": mysql.HasZerofillFlag(col.GetFlag()), + } + default: + } + + column := map[string]interface{}{ + "name": col.Name.O, + "dataType": mysqlType, + "nullable": !mysql.HasNotNullFlag(col.GetFlag()), + "default": nil, + } + defaultValue := col.GetDefaultValue() + if defaultValue != nil { + // according to TiDB source code, the default value is converted to string if not nil. + column["default"] = map[string]interface{}{ + "string": defaultValue, + } + } + + columnsSchema = append(columnsSchema, column) + } + + result := map[string]interface{}{ + "database": tableInfo.TableName.Schema, + "table": tableInfo.TableName.Table, + "tableID": tableInfo.ID, + "version": int64(tableInfo.UpdateTS), + "columns": columnsSchema, + "indexes": indexesSchema, + } + + return result +} + +func newResolvedMessageMap(ts uint64) map[string]interface{} { + watermark := map[string]interface{}{ + "version": defaultVersion, + "type": string(MessageTypeWatermark), + "commitTs": int64(ts), + "buildTs": time.Now().UnixMilli(), + } + watermark = map[string]interface{}{ + "com.pingcap.simple.avro.Watermark": watermark, + } + + payload := map[string]interface{}{ + "type": string(MessageTypeWatermark), + "payload": watermark, + } + + return map[string]interface{}{ + "com.pingcap.simple.avro.Message": payload, + } +} + +func newBootstrapMessageMap(tableInfo *model.TableInfo) map[string]interface{} { + m := map[string]interface{}{ + "version": defaultVersion, + "type": string(MessageTypeBootstrap), + "tableSchema": newTableSchemaMap(tableInfo), + "buildTs": time.Now().UnixMilli(), + } + + m = map[string]interface{}{ + "com.pingcap.simple.avro.Bootstrap": m, + } + + payload := map[string]interface{}{ + "type": string(MessageTypeBootstrap), + "payload": m, + } + + return map[string]interface{}{ + "com.pingcap.simple.avro.Message": payload, + } +} + +func newDDLMessageMap(ddl *model.DDLEvent) map[string]interface{} { + result := map[string]interface{}{ + "version": defaultVersion, + "type": string(getDDLType(ddl.Type)), + "sql": ddl.Query, + "commitTs": int64(ddl.CommitTs), + "buildTs": time.Now().UnixMilli(), + } + + if ddl.TableInfo != nil && ddl.TableInfo.TableInfo != nil { + tableSchema := newTableSchemaMap(ddl.TableInfo) + result["tableSchema"] = map[string]interface{}{ + "com.pingcap.simple.avro.TableSchema": tableSchema, + } + } + if ddl.PreTableInfo != nil && ddl.PreTableInfo.TableInfo != nil { + tableSchema := newTableSchemaMap(ddl.PreTableInfo) + result["preTableSchema"] = map[string]interface{}{ + "com.pingcap.simple.avro.TableSchema": tableSchema, + } + } + + result = map[string]interface{}{ + "com.pingcap.simple.avro.DDL": result, + } + payload := map[string]interface{}{ + "type": string(MessageTypeDDL), + "payload": result, + } + return map[string]interface{}{ + "com.pingcap.simple.avro.Message": payload, + } +} + +var ( + // genericMapPool return holder for each column and checksum + genericMapPool = sync.Pool{ + New: func() any { + return make(map[string]interface{}) + }, + } + // rowMapPool return map for each row + rowMapPool = sync.Pool{ + New: func() any { + return make(map[string]interface{}) + }, + } + + dmlMessagePayloadPool = sync.Pool{ + New: func() any { + return make(map[string]interface{}) + }, + } + + // dmlMessagePool return a map for the dml message + dmlMessagePool = sync.Pool{ + New: func() any { + return make(map[string]interface{}) + }, + } + + messageHolderPool = sync.Pool{ + New: func() any { + return make(map[string]interface{}) + }, + } +) + +func (a *avroMarshaller) newDMLMessageMap( + event *model.RowChangedEvent, + onlyHandleKey bool, + claimCheckFileName string, +) map[string]interface{} { + dmlMessagePayload := dmlMessagePayloadPool.Get().(map[string]interface{}) + dmlMessagePayload["version"] = defaultVersion + dmlMessagePayload["database"] = event.TableInfo.GetSchemaName() + dmlMessagePayload["table"] = event.TableInfo.GetTableName() + dmlMessagePayload["tableID"] = event.GetTableID() + dmlMessagePayload["commitTs"] = int64(event.CommitTs) + dmlMessagePayload["buildTs"] = time.Now().UnixMilli() + dmlMessagePayload["schemaVersion"] = int64(event.TableInfo.UpdateTS) + + if !a.config.LargeMessageHandle.Disabled() && onlyHandleKey { + dmlMessagePayload["handleKeyOnly"] = map[string]interface{}{ + "boolean": true, + } + } + + if a.config.LargeMessageHandle.EnableClaimCheck() && claimCheckFileName != "" { + dmlMessagePayload["claimCheckLocation"] = map[string]interface{}{ + "string": claimCheckFileName, + } + } + + if a.config.EnableRowChecksum && event.Checksum != nil { + cc := map[string]interface{}{ + "version": event.Checksum.Version, + "corrupted": event.Checksum.Corrupted, + "current": int64(event.Checksum.Current), + "previous": int64(event.Checksum.Previous), + } + + holder := genericMapPool.Get().(map[string]interface{}) + holder["com.pingcap.simple.avro.Checksum"] = cc + dmlMessagePayload["checksum"] = holder + } + + if event.IsInsert() { + data := a.collectColumns(event.Columns, event.TableInfo, onlyHandleKey) + dmlMessagePayload["data"] = data + dmlMessagePayload["type"] = string(DMLTypeInsert) + } else if event.IsDelete() { + old := a.collectColumns(event.PreColumns, event.TableInfo, onlyHandleKey) + dmlMessagePayload["old"] = old + dmlMessagePayload["type"] = string(DMLTypeDelete) + } else if event.IsUpdate() { + data := a.collectColumns(event.Columns, event.TableInfo, onlyHandleKey) + dmlMessagePayload["data"] = data + old := a.collectColumns(event.PreColumns, event.TableInfo, onlyHandleKey) + dmlMessagePayload["old"] = old + dmlMessagePayload["type"] = string(DMLTypeUpdate) + } + + dmlMessagePayload = map[string]interface{}{ + "com.pingcap.simple.avro.DML": dmlMessagePayload, + } + + dmlMessage := dmlMessagePool.Get().(map[string]interface{}) + dmlMessage["type"] = string(MessageTypeDML) + dmlMessage["payload"] = dmlMessagePayload + + messageHolder := messageHolderPool.Get().(map[string]interface{}) + messageHolder["com.pingcap.simple.avro.Message"] = dmlMessage + + return messageHolder +} + +func recycleMap(m map[string]interface{}) { + dmlMessage := m["com.pingcap.simple.avro.Message"].(map[string]interface{}) + dml := dmlMessage["payload"].(map[string]interface{})["com.pingcap.simple.avro.DML"].(map[string]interface{}) + + checksum := dml["checksum"] + if checksum != nil { + checksum := checksum.(map[string]interface{}) + clear(checksum) + genericMapPool.Put(checksum) + } + + dataMap := dml["data"] + if dataMap != nil { + dataMap := dataMap.(map[string]interface{})["map"].(map[string]interface{}) + for _, col := range dataMap { + colMap := col.(map[string]interface{}) + clear(colMap) + genericMapPool.Put(col) + } + clear(dataMap) + rowMapPool.Put(dataMap) + } + + oldDataMap := dml["old"] + if oldDataMap != nil { + oldDataMap := oldDataMap.(map[string]interface{})["map"].(map[string]interface{}) + for _, col := range oldDataMap { + colMap := col.(map[string]interface{}) + clear(colMap) + genericMapPool.Put(col) + } + clear(oldDataMap) + rowMapPool.Put(oldDataMap) + } + + clear(dml) + dmlMessagePayloadPool.Put(dml) + + clear(dmlMessage) + dmlMessagePool.Put(dmlMessage) + + clear(m) + messageHolderPool.Put(m) +} + +func (a *avroMarshaller) collectColumns( + columns []*model.ColumnData, tableInfo *model.TableInfo, onlyHandleKey bool, +) map[string]interface{} { + result := rowMapPool.Get().(map[string]interface{}) + for _, col := range columns { + if col != nil { + colFlag := tableInfo.ForceGetColumnFlagType(col.ColumnID) + if onlyHandleKey && !colFlag.IsHandleKey() { + continue + } + colInfo := tableInfo.ForceGetColumnInfo(col.ColumnID) + value, avroType := a.encodeValue4Avro(col.Value, &colInfo.FieldType) + holder := genericMapPool.Get().(map[string]interface{}) + holder[avroType] = value + result[colInfo.Name.O] = holder + } + } + return map[string]interface{}{ + "map": result, + } +} + +func newTableSchemaFromAvroNative(native map[string]interface{}) *TableSchema { + rawColumns := native["columns"].([]interface{}) + columns := make([]*columnSchema, 0, len(rawColumns)) + for _, raw := range rawColumns { + raw := raw.(map[string]interface{}) + rawDataType := raw["dataType"].(map[string]interface{}) + + var ( + decimal int + elements []string + unsigned bool + zerofill bool + ) + + if rawDataType["elements"] != nil { + rawElements := rawDataType["elements"].(map[string]interface{})["array"].([]interface{}) + for _, rawElement := range rawElements { + elements = append(elements, rawElement.(string)) + } + } + if rawDataType["decimal"] != nil { + decimal = int(rawDataType["decimal"].(map[string]interface{})["int"].(int32)) + } + if rawDataType["unsigned"] != nil { + unsigned = rawDataType["unsigned"].(map[string]interface{})["boolean"].(bool) + } + if rawDataType["zerofill"] != nil { + zerofill = rawDataType["zerofill"].(map[string]interface{})["boolean"].(bool) + } + + dt := dataType{ + MySQLType: rawDataType["mysqlType"].(string), + Charset: rawDataType["charset"].(string), + Collate: rawDataType["collate"].(string), + Length: int(rawDataType["length"].(int64)), + Decimal: decimal, + Elements: elements, + Unsigned: unsigned, + Zerofill: zerofill, + } + + var defaultValue interface{} + rawDefault := raw["default"] + switch v := rawDefault.(type) { + case nil: + case map[string]interface{}: + defaultValue = v["string"].(string) + } + + column := &columnSchema{ + Name: raw["name"].(string), + Nullable: raw["nullable"].(bool), + Default: defaultValue, + DataType: dt, + } + columns = append(columns, column) + } + + rawIndexes := native["indexes"].([]interface{}) + indexes := make([]*IndexSchema, 0, len(rawIndexes)) + for _, raw := range rawIndexes { + raw := raw.(map[string]interface{}) + rawColumns := raw["columns"].([]interface{}) + keyColumns := make([]string, 0, len(rawColumns)) + for _, rawColumn := range rawColumns { + keyColumns = append(keyColumns, rawColumn.(string)) + } + index := &IndexSchema{ + Name: raw["name"].(string), + Unique: raw["unique"].(bool), + Primary: raw["primary"].(bool), + Nullable: raw["nullable"].(bool), + Columns: keyColumns, + } + indexes = append(indexes, index) + } + return &TableSchema{ + Schema: native["database"].(string), + Table: native["table"].(string), + TableID: native["tableID"].(int64), + Version: uint64(native["version"].(int64)), + Columns: columns, + Indexes: indexes, + } +} + +func newMessageFromAvroNative(native interface{}, m *message) { + rawValues := native.(map[string]interface{})["com.pingcap.simple.avro.Message"].(map[string]interface{}) + rawPayload := rawValues["payload"].(map[string]interface{}) + + rawMessage := rawPayload["com.pingcap.simple.avro.Watermark"] + if rawMessage != nil { + rawValues = rawMessage.(map[string]interface{}) + m.Version = int(rawValues["version"].(int32)) + m.Type = MessageTypeWatermark + m.CommitTs = uint64(rawValues["commitTs"].(int64)) + m.BuildTs = rawValues["buildTs"].(int64) + return + } + + rawMessage = rawPayload["com.pingcap.simple.avro.Bootstrap"] + if rawMessage != nil { + rawValues = rawMessage.(map[string]interface{}) + m.Version = int(rawValues["version"].(int32)) + m.Type = MessageTypeBootstrap + m.BuildTs = rawValues["buildTs"].(int64) + m.TableSchema = newTableSchemaFromAvroNative(rawValues["tableSchema"].(map[string]interface{})) + return + } + + rawMessage = rawPayload["com.pingcap.simple.avro.DDL"] + if rawMessage != nil { + rawValues = rawMessage.(map[string]interface{}) + m.Version = int(rawValues["version"].(int32)) + m.Type = MessageType(rawValues["type"].(string)) + m.SQL = rawValues["sql"].(string) + m.CommitTs = uint64(rawValues["commitTs"].(int64)) + m.BuildTs = rawValues["buildTs"].(int64) + + rawTableSchemaValues := rawValues["tableSchema"] + if rawTableSchemaValues != nil { + rawTableSchema := rawTableSchemaValues.(map[string]interface{}) + rawTableSchema = rawTableSchema["com.pingcap.simple.avro.TableSchema"].(map[string]interface{}) + m.TableSchema = newTableSchemaFromAvroNative(rawTableSchema) + } + + rawPreTableSchemaValue := rawValues["preTableSchema"] + if rawPreTableSchemaValue != nil { + rawPreTableSchema := rawPreTableSchemaValue.(map[string]interface{}) + rawPreTableSchema = rawPreTableSchema["com.pingcap.simple.avro.TableSchema"].(map[string]interface{}) + m.PreTableSchema = newTableSchemaFromAvroNative(rawPreTableSchema) + } + return + } + + rawValues = rawPayload["com.pingcap.simple.avro.DML"].(map[string]interface{}) + m.Type = MessageType(rawValues["type"].(string)) + m.Version = int(rawValues["version"].(int32)) + m.CommitTs = uint64(rawValues["commitTs"].(int64)) + m.BuildTs = rawValues["buildTs"].(int64) + m.Schema = rawValues["database"].(string) + m.Table = rawValues["table"].(string) + m.TableID = rawValues["tableID"].(int64) + m.SchemaVersion = uint64(rawValues["schemaVersion"].(int64)) + + if rawValues["handleKeyOnly"] != nil { + m.HandleKeyOnly = rawValues["handleKeyOnly"].(map[string]interface{})["boolean"].(bool) + } + if rawValues["claimCheckLocation"] != nil { + m.ClaimCheckLocation = rawValues["claimCheckLocation"].(map[string]interface{})["string"].(string) + } + + m.Checksum = newChecksum(rawValues) + m.Data = newDataMap(rawValues["data"]) + m.Old = newDataMap(rawValues["old"]) +} + +func newChecksum(raw map[string]interface{}) *checksum { + rawValue := raw["checksum"] + if rawValue == nil { + return nil + } + rawChecksum := rawValue.(map[string]interface{}) + rawChecksum = rawChecksum["com.pingcap.simple.avro.Checksum"].(map[string]interface{}) + return &checksum{ + Version: int(rawChecksum["version"].(int32)), + Corrupted: rawChecksum["corrupted"].(bool), + Current: uint32(rawChecksum["current"].(int64)), + Previous: uint32(rawChecksum["previous"].(int64)), + } +} + +func newDataMap(rawValues interface{}) map[string]interface{} { + if rawValues == nil { + return nil + } + data := make(map[string]interface{}) + rawDataMap := rawValues.(map[string]interface{})["map"].(map[string]interface{}) + for key, value := range rawDataMap { + if value == nil { + data[key] = nil + continue + } + valueMap := value.(map[string]interface{}) + for _, v := range valueMap { + data[key] = v + } + } + return data +} diff --git a/pkg/sink/codec/simple/encoder_test.go b/pkg/sink/codec/simple/encoder_test.go new file mode 100644 index 00000000000..d76c6ec25e7 --- /dev/null +++ b/pkg/sink/codec/simple/encoder_test.go @@ -0,0 +1,2015 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package simple + +import ( + "context" + "database/sql/driver" + "fmt" + "math/rand" + "sort" + "strconv" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/golang/mock/gomock" + timodel "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tiflow/cdc/entry" + "github.com/pingcap/tiflow/cdc/model" + "github.com/pingcap/tiflow/pkg/compression" + "github.com/pingcap/tiflow/pkg/config" + "github.com/pingcap/tiflow/pkg/errors" + "github.com/pingcap/tiflow/pkg/integrity" + "github.com/pingcap/tiflow/pkg/sink/codec/common" + mock_simple "github.com/pingcap/tiflow/pkg/sink/codec/simple/mock" + "github.com/pingcap/tiflow/pkg/sink/codec/utils" + "github.com/stretchr/testify/require" +) + +func TestEncodeCheckpoint(t *testing.T) { + t.Parallel() + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + + for _, compressionType := range []string{ + compression.None, + compression.Snappy, + compression.LZ4, + } { + codecConfig.LargeMessageHandle.LargeMessageHandleCompression = compressionType + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + checkpoint := 446266400629063682 + m, err := enc.EncodeCheckpointEvent(uint64(checkpoint)) + require.NoError(t, err) + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeResolved, messageType) + require.NotEqual(t, 0, dec.msg.BuildTs) + + ts, err := dec.NextResolvedEvent() + require.NoError(t, err) + require.Equal(t, uint64(checkpoint), ts) + } + } +} + +func TestEncodeDMLEnableChecksum(t *testing.T) { + replicaConfig := config.GetDefaultReplicaConfig() + replicaConfig.Integrity.IntegrityCheckLevel = integrity.CheckLevelCorrectness + createTableDDL, _, updateEvent, _ := utils.NewLargeEvent4Test(t, replicaConfig) + rand.New(rand.NewSource(time.Now().Unix())).Shuffle(len(createTableDDL.TableInfo.Columns), func(i, j int) { + createTableDDL.TableInfo.Columns[i], createTableDDL.TableInfo.Columns[j] = createTableDDL.TableInfo.Columns[j], createTableDDL.TableInfo.Columns[i] + }) + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + codecConfig.EnableRowChecksum = true + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + for _, compressionType := range []string{ + compression.None, + compression.Snappy, + compression.LZ4, + } { + codecConfig.LargeMessageHandle.LargeMessageHandleCompression = compressionType + + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + m, err := enc.EncodeDDLEvent(createTableDDL) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + + decodedDDL, err := dec.NextDDLEvent() + require.NoError(t, err) + + originFlags := createTableDDL.TableInfo.ColumnsFlag + obtainedFlags := decodedDDL.TableInfo.ColumnsFlag + + for colID, expected := range originFlags { + name := createTableDDL.TableInfo.ForceGetColumnName(colID) + actualID := decodedDDL.TableInfo.ForceGetColumnIDByName(name) + actual := obtainedFlags[actualID] + require.Equal(t, expected, actual) + } + + err = enc.AppendRowChangedEvent(ctx, "", updateEvent, func() {}) + require.NoError(t, err) + + messages := enc.Build() + require.Len(t, messages, 1) + + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + + decodedRow, err := dec.NextRowChangedEvent() + require.NoError(t, err) + require.Equal(t, updateEvent.Checksum.Current, decodedRow.Checksum.Current) + require.Equal(t, updateEvent.Checksum.Previous, decodedRow.Checksum.Previous) + require.False(t, decodedRow.Checksum.Corrupted) + } + } + + // tamper the checksum, to test error case + updateEvent.Checksum.Current = 1 + updateEvent.Checksum.Previous = 2 + + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + m, err := enc.EncodeDDLEvent(createTableDDL) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + + _, err = dec.NextDDLEvent() + require.NoError(t, err) + + err = enc.AppendRowChangedEvent(ctx, "", updateEvent, func() {}) + require.NoError(t, err) + + messages := enc.Build() + require.Len(t, messages, 1) + + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + + decodedRow, err := dec.NextRowChangedEvent() + require.Error(t, err) + require.Nil(t, decodedRow) +} + +func TestE2EPartitionTable(t *testing.T) { + helper := entry.NewSchemaTestHelper(t) + defer helper.Close() + + helper.Tk().MustExec("use test") + + createPartitionTableDDL := helper.DDL2Event(`create table test.t(a int primary key, b int) partition by range (a) ( + partition p0 values less than (10), + partition p1 values less than (20), + partition p2 values less than MAXVALUE)`) + require.NotNil(t, createPartitionTableDDL) + + insertEvent := helper.DML2Event(`insert into test.t values (1, 1)`, "test", "t", "p0") + require.NotNil(t, insertEvent) + + insertEvent1 := helper.DML2Event(`insert into test.t values (11, 11)`, "test", "t", "p1") + require.NotNil(t, insertEvent1) + + insertEvent2 := helper.DML2Event(`insert into test.t values (21, 21)`, "test", "t", "p2") + require.NotNil(t, insertEvent2) + + events := []*model.RowChangedEvent{insertEvent, insertEvent1, insertEvent2} + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + builder, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := builder.Build() + + decoder, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + message, err := enc.EncodeDDLEvent(createPartitionTableDDL) + require.NoError(t, err) + + err = decoder.AddKeyValue(message.Key, message.Value) + require.NoError(t, err) + tp, hasNext, err := decoder.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, tp) + + decodedDDL, err := decoder.NextDDLEvent() + require.NoError(t, err) + require.NotNil(t, decodedDDL) + + for _, event := range events { + err = enc.AppendRowChangedEvent(ctx, "", event, nil) + require.NoError(t, err) + message := enc.Build()[0] + + err = decoder.AddKeyValue(message.Key, message.Value) + require.NoError(t, err) + tp, hasNext, err := decoder.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, tp) + + decodedEvent, err := decoder.NextRowChangedEvent() + require.NoError(t, err) + // table id should be set to the partition table id, the PhysicalTableID + require.Equal(t, decodedEvent.GetTableID(), event.GetTableID()) + } + } +} + +func TestEncodeDDLSequence(t *testing.T) { + helper := entry.NewSchemaTestHelper(t) + defer helper.Close() + + dropDBEvent := helper.DDL2Event(`DROP DATABASE IF EXISTS test`) + createDBDDLEvent := helper.DDL2Event(`CREATE DATABASE IF NOT EXISTS test`) + helper.Tk().MustExec("use test") + + createTableDDLEvent := helper.DDL2Event("CREATE TABLE `TBL1` (`id` INT PRIMARY KEY AUTO_INCREMENT,`value` VARCHAR(255),`payload` VARCHAR(2000),`a` INT)") + + addColumnDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` ADD COLUMN `nn` INT") + + dropColumnDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` DROP COLUMN `nn`") + + changeColumnDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` CHANGE COLUMN `value` `value2` VARCHAR(512)") + + modifyColumnDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` MODIFY COLUMN `value2` VARCHAR(512) FIRST") + + setDefaultDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` ALTER COLUMN `payload` SET DEFAULT _UTF8MB4'a'") + + dropDefaultDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` ALTER COLUMN `payload` DROP DEFAULT") + + autoIncrementDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` AUTO_INCREMENT = 5") + + modifyColumnNullDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` MODIFY COLUMN `a` INT NULL") + + modifyColumnNotNullDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` MODIFY COLUMN `a` INT NOT NULL") + + addIndexDDLEvent := helper.DDL2Event("CREATE INDEX `idx_a` ON `TBL1` (`a`)") + + renameIndexDDLEvent := helper.DDL2Event("ALTER TABLE `TBL1` RENAME INDEX `idx_a` TO `new_idx_a`") + + indexVisibilityDDLEvent := helper.DDL2Event("ALTER TABLE TBL1 ALTER INDEX `new_idx_a` INVISIBLE") + + dropIndexDDLEvent := helper.DDL2Event("DROP INDEX `new_idx_a` ON `TBL1`") + + truncateTableDDLEvent := helper.DDL2Event("TRUNCATE TABLE TBL1") + + multiSchemaChangeDDLEvent := helper.DDL2Event("ALTER TABLE TBL1 ADD COLUMN `new_col` INT, ADD INDEX `idx_new_col` (`a`)") + + multiSchemaChangeDropDDLEvent := helper.DDL2Event("ALTER TABLE TBL1 DROP COLUMN `new_col`, DROP INDEX `idx_new_col`") + + renameTableDDLEvent := helper.DDL2Event("RENAME TABLE TBL1 TO TBL2") + + helper.Tk().MustExec("set @@tidb_allow_remove_auto_inc = 1") + renameColumnDDLEvent := helper.DDL2Event("ALTER TABLE TBL2 CHANGE COLUMN `id` `id2` INT") + + partitionTableDDLEvent := helper.DDL2Event("ALTER TABLE TBL2 PARTITION BY RANGE (id2) (PARTITION p0 VALUES LESS THAN (10), PARTITION p1 VALUES LESS THAN (20))") + + addPartitionDDLEvent := helper.DDL2Event("ALTER TABLE TBL2 ADD PARTITION (PARTITION p2 VALUES LESS THAN (30))") + + dropPartitionDDLEvent := helper.DDL2Event("ALTER TABLE TBL2 DROP PARTITION p2") + + truncatePartitionDDLevent := helper.DDL2Event("ALTER TABLE TBL2 TRUNCATE PARTITION p1") + + reorganizePartitionDDLEvent := helper.DDL2Event("ALTER TABLE TBL2 REORGANIZE PARTITION p1 INTO (PARTITION p3 VALUES LESS THAN (40))") + + removePartitionDDLEvent := helper.DDL2Event("ALTER TABLE TBL2 REMOVE PARTITIONING") + + alterCharsetCollateDDLEvent := helper.DDL2Event("ALTER TABLE TBL2 CHARACTER SET = utf8mb4 COLLATE = utf8mb4_bin") + + dropTableDDLEvent := helper.DDL2Event("DROP TABLE TBL2") + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + for _, compressionType := range []string{ + compression.None, + compression.Snappy, + compression.LZ4, + } { + codecConfig.LargeMessageHandle.LargeMessageHandleCompression = compressionType + + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + + enc := b.Build() + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + m, err := enc.EncodeDDLEvent(dropDBEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + require.Equal(t, DDLTypeQuery, dec.msg.Type) + + _, err = dec.NextDDLEvent() + require.NoError(t, err) + + m, err = enc.EncodeDDLEvent(createDBDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + require.Equal(t, DDLTypeQuery, dec.msg.Type) + + _, err = dec.NextDDLEvent() + require.NoError(t, err) + + m, err = enc.EncodeDDLEvent(createTableDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + require.Equal(t, DDLTypeCreate, dec.msg.Type) + + event, err := dec.NextDDLEvent() + require.NoError(t, err) + require.Len(t, event.TableInfo.Indices, 1) + require.Len(t, event.TableInfo.Columns, 4) + + m, err = enc.EncodeDDLEvent(addColumnDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Len(t, event.TableInfo.Indices, 1) + require.Len(t, event.TableInfo.Columns, 5) + + m, err = enc.EncodeDDLEvent(dropColumnDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Len(t, event.TableInfo.Indices, 1) + require.Len(t, event.TableInfo.Columns, 4) + + m, err = enc.EncodeDDLEvent(changeColumnDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Len(t, event.TableInfo.Indices, 1) + require.Len(t, event.TableInfo.Columns, 4) + + m, err = enc.EncodeDDLEvent(modifyColumnDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices), string(format), compressionType) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(setDefaultDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + for _, col := range event.TableInfo.Columns { + if col.Name.O == "payload" { + require.Equal(t, "a", col.DefaultValue) + } + } + + m, err = enc.EncodeDDLEvent(dropDefaultDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + for _, col := range event.TableInfo.Columns { + if col.Name.O == "payload" { + require.Nil(t, col.DefaultValue) + } + } + + m, err = enc.EncodeDDLEvent(autoIncrementDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(modifyColumnNullDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + for _, col := range event.TableInfo.Columns { + if col.Name.O == "a" { + require.True(t, !mysql.HasNotNullFlag(col.GetFlag())) + } + } + + m, err = enc.EncodeDDLEvent(modifyColumnNotNullDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + for _, col := range event.TableInfo.Columns { + if col.Name.O == "a" { + require.True(t, mysql.HasNotNullFlag(col.GetFlag())) + } + } + + m, err = enc.EncodeDDLEvent(addIndexDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeCIndex, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 2, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(renameIndexDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 2, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + hasNewIndex := false + noOldIndex := true + for _, index := range event.TableInfo.Indices { + if index.Name.O == "new_idx_a" { + hasNewIndex = true + } + if index.Name.O == "idx_a" { + noOldIndex = false + } + } + require.True(t, hasNewIndex) + require.True(t, noOldIndex) + + m, err = enc.EncodeDDLEvent(indexVisibilityDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 2, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(dropIndexDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeDIndex, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(truncateTableDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeTruncate, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(multiSchemaChangeDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 2, len(event.TableInfo.Indices)) + require.Equal(t, 5, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(multiSchemaChangeDropDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(renameTableDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeRename, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(renameColumnDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(partitionTableDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(addPartitionDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(dropPartitionDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(truncatePartitionDDLevent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(reorganizePartitionDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(removePartitionDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(alterCharsetCollateDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeAlter, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + + m, err = enc.EncodeDDLEvent(dropTableDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + _, _, err = dec.HasNext() + require.NoError(t, err) + require.Equal(t, DDLTypeErase, dec.msg.Type) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, 1, len(event.TableInfo.Indices)) + require.Equal(t, 4, len(event.TableInfo.Columns)) + } + } +} + +func TestEncodeDDLEvent(t *testing.T) { + replicaConfig := config.GetDefaultReplicaConfig() + replicaConfig.Integrity.IntegrityCheckLevel = integrity.CheckLevelCorrectness + helper := entry.NewSchemaTestHelperWithReplicaConfig(t, replicaConfig) + defer helper.Close() + + createTableSQL := `create table test.t(id int primary key, name varchar(255) not null, gender enum('male', 'female'), email varchar(255) null, key idx_name_email(name, email))` + createTableDDLEvent := helper.DDL2Event(createTableSQL) + + insertEvent := helper.DML2Event(`insert into test.t values (1, "jack", "male", "jack@abc.com")`, "test", "t") + + renameTableDDLEvent := helper.DDL2Event(`rename table test.t to test.abc`) + + insertEvent2 := helper.DML2Event(`insert into test.abc values (2, "anna", "female", "anna@abc.com")`, "test", "abc") + helper.Tk().MustExec("drop table test.abc") + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + codecConfig.EnableRowChecksum = true + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + for _, compressionType := range []string{ + compression.None, + compression.Snappy, + compression.LZ4, + } { + codecConfig.LargeMessageHandle.LargeMessageHandleCompression = compressionType + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + m, err := enc.EncodeDDLEvent(createTableDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + require.NotEqual(t, 0, dec.msg.BuildTs) + require.True(t, dec.msg.TableSchema.Indexes[0].Nullable) + + columnSchemas := dec.msg.TableSchema.Columns + sortedColumns := make([]*timodel.ColumnInfo, len(createTableDDLEvent.TableInfo.Columns)) + copy(sortedColumns, createTableDDLEvent.TableInfo.Columns) + sort.Slice(sortedColumns, func(i, j int) bool { + return sortedColumns[i].ID < sortedColumns[j].ID + }) + + for idx, column := range sortedColumns { + require.Equal(t, column.Name.O, columnSchemas[idx].Name) + } + + event, err := dec.NextDDLEvent() + + require.NoError(t, err) + require.Equal(t, createTableDDLEvent.TableInfo.TableName.TableID, event.TableInfo.TableName.TableID) + require.Equal(t, createTableDDLEvent.CommitTs, event.CommitTs) + + // because we don't we don't set startTs in the encoded message, + // so the startTs is equal to commitTs + + require.Equal(t, createTableDDLEvent.CommitTs, event.StartTs) + require.Equal(t, createTableDDLEvent.Query, event.Query) + require.Equal(t, len(createTableDDLEvent.TableInfo.Columns), len(event.TableInfo.Columns)) + require.Equal(t, 2, len(event.TableInfo.Indices)) + require.Nil(t, event.PreTableInfo) + + item := dec.memo.Read(createTableDDLEvent.TableInfo.TableName.Schema, + createTableDDLEvent.TableInfo.TableName.Table, createTableDDLEvent.TableInfo.UpdateTS) + require.NotNil(t, item) + + err = enc.AppendRowChangedEvent(ctx, "", insertEvent, func() {}) + require.NoError(t, err) + + messages := enc.Build() + require.Len(t, messages, 1) + + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + require.NotEqual(t, 0, dec.msg.BuildTs) + + decodedRow, err := dec.NextRowChangedEvent() + require.NoError(t, err) + require.Equal(t, decodedRow.CommitTs, insertEvent.CommitTs) + require.Equal(t, decodedRow.TableInfo.GetSchemaName(), insertEvent.TableInfo.GetSchemaName()) + require.Equal(t, decodedRow.TableInfo.GetTableName(), insertEvent.TableInfo.GetTableName()) + require.Nil(t, decodedRow.PreColumns) + + m, err = enc.EncodeDDLEvent(renameTableDDLEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + require.NotEqual(t, 0, dec.msg.BuildTs) + + event, err = dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, renameTableDDLEvent.TableInfo.TableName.TableID, event.TableInfo.TableName.TableID) + require.Equal(t, renameTableDDLEvent.CommitTs, event.CommitTs) + // because we don't we don't set startTs in the encoded message, + // so the startTs is equal to commitTs + require.Equal(t, renameTableDDLEvent.CommitTs, event.StartTs) + require.Equal(t, renameTableDDLEvent.Query, event.Query) + require.Equal(t, len(renameTableDDLEvent.TableInfo.Columns), len(event.TableInfo.Columns)) + require.Equal(t, len(renameTableDDLEvent.TableInfo.Indices)+1, len(event.TableInfo.Indices)) + require.NotNil(t, event.PreTableInfo) + + item = dec.memo.Read(renameTableDDLEvent.TableInfo.TableName.Schema, + renameTableDDLEvent.TableInfo.TableName.Table, renameTableDDLEvent.TableInfo.UpdateTS) + require.NotNil(t, item) + + err = enc.AppendRowChangedEvent(context.Background(), "", insertEvent2, func() {}) + require.NoError(t, err) + + messages = enc.Build() + require.Len(t, messages, 1) + + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + require.NotEqual(t, 0, dec.msg.BuildTs) + + decodedRow, err = dec.NextRowChangedEvent() + require.NoError(t, err) + require.Equal(t, insertEvent2.CommitTs, decodedRow.CommitTs) + require.Equal(t, insertEvent2.TableInfo.GetSchemaName(), decodedRow.TableInfo.GetSchemaName()) + require.Equal(t, insertEvent2.TableInfo.GetTableName(), decodedRow.TableInfo.GetTableName()) + require.Nil(t, decodedRow.PreColumns) + } + } +} + +func TestColumnFlags(t *testing.T) { + helper := entry.NewSchemaTestHelper(t) + defer helper.Close() + + createTableDDL := `create table test.t( + a bigint(20) unsigned not null, + b bigint(20) default null, + c varbinary(767) default null, + d int(11) unsigned not null, + e int(11) default null, + primary key (a), + key idx_c(c), + key idx_b(b), + unique key idx_de(d, e))` + createTableDDLEvent := helper.DDL2Event(createTableDDL) + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + m, err := enc.EncodeDDLEvent(createTableDDLEvent) + require.NoError(t, err) + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + + decodedDDLEvent, err := dec.NextDDLEvent() + require.NoError(t, err) + + originFlags := createTableDDLEvent.TableInfo.ColumnsFlag + obtainedFlags := decodedDDLEvent.TableInfo.ColumnsFlag + + for colID, expected := range originFlags { + name := createTableDDLEvent.TableInfo.ForceGetColumnName(colID) + actualID := decodedDDLEvent.TableInfo.ForceGetColumnIDByName(name) + actual := obtainedFlags[actualID] + require.Equal(t, expected, actual) + } + } +} + +func TestEncodeIntegerTypes(t *testing.T) { + replicaConfig := config.GetDefaultReplicaConfig() + replicaConfig.Integrity.IntegrityCheckLevel = integrity.CheckLevelCorrectness + helper := entry.NewSchemaTestHelperWithReplicaConfig(t, replicaConfig) + defer helper.Close() + + createTableDDL := `create table test.t( + id int primary key auto_increment, + a tinyint, b tinyint unsigned, + c smallint, d smallint unsigned, + e mediumint, f mediumint unsigned, + g int, h int unsigned, + i bigint, j bigint unsigned)` + ddlEvent := helper.DDL2Event(createTableDDL) + + sql := `insert into test.t values( + 1, + -128, 0, + -32768, 0, + -8388608, 0, + -2147483648, 0, + -9223372036854775808, 0)` + minValues := helper.DML2Event(sql, "test", "t") + + sql = `insert into test.t values ( + 2, + 127, 255, + 32767, 65535, + 8388607, 16777215, + 2147483647, 4294967295, + 9223372036854775807, 18446744073709551615)` + maxValues := helper.DML2Event(sql, "test", "t") + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + codecConfig.EnableRowChecksum = true + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + m, err := enc.EncodeDDLEvent(ddlEvent) + require.NoError(t, err) + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + + _, err = dec.NextDDLEvent() + require.NoError(t, err) + + for _, event := range []*model.RowChangedEvent{ + minValues, + maxValues, + } { + err = enc.AppendRowChangedEvent(ctx, "", event, func() {}) + require.NoError(t, err) + + messages := enc.Build() + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + + decodedRow, err := dec.NextRowChangedEvent() + require.NoError(t, err) + require.Equal(t, decodedRow.CommitTs, event.CommitTs) + + decodedColumns := make(map[string]*model.ColumnData, len(decodedRow.Columns)) + for _, column := range decodedRow.Columns { + colName := decodedRow.TableInfo.ForceGetColumnName(column.ColumnID) + decodedColumns[colName] = column + } + + for _, expected := range event.Columns { + colName := event.TableInfo.ForceGetColumnName(expected.ColumnID) + decoded, ok := decodedColumns[colName] + require.True(t, ok) + require.EqualValues(t, expected.Value, decoded.Value) + } + } + } +} + +func TestEncoderOtherTypes(t *testing.T) { + helper := entry.NewSchemaTestHelper(t) + defer helper.Close() + + sql := `create table test.t( + a int primary key auto_increment, + b enum('a', 'b', 'c'), + c set('a', 'b', 'c'), + d bit(64), + e json)` + ddlEvent := helper.DDL2Event(sql) + + sql = `insert into test.t() values (1, 'a', 'a,b', b'1000001', '{ + "key1": "value1", + "key2": "value2" + }');` + row := helper.DML2Event(sql, "test", "t") + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + m, err := enc.EncodeDDLEvent(ddlEvent) + require.NoError(t, err) + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + + _, err = dec.NextDDLEvent() + require.NoError(t, err) + + err = enc.AppendRowChangedEvent(ctx, "", row, func() {}) + require.NoError(t, err) + + messages := enc.Build() + require.Len(t, messages, 1) + + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + + decodedRow, err := dec.NextRowChangedEvent() + require.NoError(t, err) + + decodedColumns := make(map[string]*model.ColumnData, len(decodedRow.Columns)) + for _, column := range decodedRow.Columns { + colName := decodedRow.TableInfo.ForceGetColumnName(column.ColumnID) + decodedColumns[colName] = column + } + for _, expected := range row.Columns { + colName := row.TableInfo.ForceGetColumnName(expected.ColumnID) + decoded, ok := decodedColumns[colName] + require.True(t, ok) + require.EqualValues(t, expected.Value, decoded.Value) + } + } +} + +func TestE2EPartitionTableDMLBeforeDDL(t *testing.T) { + helper := entry.NewSchemaTestHelper(t) + defer helper.Close() + + helper.Tk().MustExec("use test") + + createPartitionTableDDL := helper.DDL2Event(`create table test.t(a int primary key, b int) partition by range (a) ( + partition p0 values less than (10), + partition p1 values less than (20), + partition p2 values less than MAXVALUE)`) + require.NotNil(t, createPartitionTableDDL) + + insertEvent := helper.DML2Event(`insert into test.t values (1, 1)`, "test", "t", "p0") + require.NotNil(t, insertEvent) + + insertEvent1 := helper.DML2Event(`insert into test.t values (11, 11)`, "test", "t", "p1") + require.NotNil(t, insertEvent1) + + insertEvent2 := helper.DML2Event(`insert into test.t values (21, 21)`, "test", "t", "p2") + require.NotNil(t, insertEvent2) + + events := []*model.RowChangedEvent{insertEvent, insertEvent1, insertEvent2} + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + builder, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + + enc := builder.Build() + + decoder, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + codecConfig.EncodingFormat = format + for _, event := range events { + err = enc.AppendRowChangedEvent(ctx, "", event, nil) + require.NoError(t, err) + message := enc.Build()[0] + + err = decoder.AddKeyValue(message.Key, message.Value) + require.NoError(t, err) + tp, hasNext, err := decoder.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, tp) + + decodedEvent, err := decoder.NextRowChangedEvent() + require.NoError(t, err) + require.Nil(t, decodedEvent) + } + + message, err := enc.EncodeDDLEvent(createPartitionTableDDL) + require.NoError(t, err) + + err = decoder.AddKeyValue(message.Key, message.Value) + require.NoError(t, err) + tp, hasNext, err := decoder.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, tp) + + decodedDDL, err := decoder.NextDDLEvent() + require.NoError(t, err) + require.NotNil(t, decodedDDL) + + cachedEvents := decoder.GetCachedEvents() + for idx, decodedRow := range cachedEvents { + require.NotNil(t, decodedRow) + require.NotNil(t, decodedRow.TableInfo) + require.Equal(t, decodedRow.GetTableID(), events[idx].GetTableID()) + } + } +} + +func TestEncodeDMLBeforeDDL(t *testing.T) { + helper := entry.NewSchemaTestHelper(t) + defer helper.Close() + + sql := `create table test.t(a int primary key, b int)` + ddlEvent := helper.DDL2Event(sql) + + sql = `insert into test.t values (1, 2)` + row := helper.DML2Event(sql, "test", "t") + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + err = enc.AppendRowChangedEvent(ctx, "", row, func() {}) + require.NoError(t, err) + + messages := enc.Build() + require.Len(t, messages, 1) + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + + decodedRow, err := dec.NextRowChangedEvent() + require.NoError(t, err) + require.Nil(t, decodedRow) + + m, err := enc.EncodeDDLEvent(ddlEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + + event, err := dec.NextDDLEvent() + require.NoError(t, err) + require.NotNil(t, event) + + cachedEvents := dec.GetCachedEvents() + for _, decodedRow = range cachedEvents { + require.NotNil(t, decodedRow) + require.NotNil(t, decodedRow.TableInfo) + require.Equal(t, decodedRow.TableInfo.ID, event.TableInfo.ID) + } +} + +func TestEncodeBootstrapEvent(t *testing.T) { + helper := entry.NewSchemaTestHelper(t) + defer helper.Close() + + sql := `create table test.t( + id int, + name varchar(255) not null, + age int, + email varchar(255) not null, + primary key(id, name), + key idx_name_email(name, email))` + ddlEvent := helper.DDL2Event(sql) + ddlEvent.IsBootstrap = true + + sql = `insert into test.t values (1, "jack", 23, "jack@abc.com")` + row := helper.DML2Event(sql, "test", "t") + + helper.Tk().MustExec("drop table test.t") + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + for _, compressionType := range []string{ + compression.None, + compression.Snappy, + compression.LZ4, + } { + codecConfig.LargeMessageHandle.LargeMessageHandleCompression = compressionType + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + m, err := enc.EncodeDDLEvent(ddlEvent) + require.NoError(t, err) + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + require.NotEqual(t, 0, dec.msg.BuildTs) + + event, err := dec.NextDDLEvent() + require.NoError(t, err) + require.Equal(t, ddlEvent.TableInfo.TableName.TableID, event.TableInfo.TableName.TableID) + // Bootstrap event doesn't have query + require.Equal(t, "", event.Query) + require.Equal(t, len(ddlEvent.TableInfo.Columns), len(event.TableInfo.Columns)) + require.Equal(t, len(ddlEvent.TableInfo.Indices), len(event.TableInfo.Indices)) + + item := dec.memo.Read(ddlEvent.TableInfo.TableName.Schema, + ddlEvent.TableInfo.TableName.Table, ddlEvent.TableInfo.UpdateTS) + require.NotNil(t, item) + + err = enc.AppendRowChangedEvent(context.Background(), "", row, func() {}) + require.NoError(t, err) + + messages := enc.Build() + require.Len(t, messages, 1) + + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + require.NotEqual(t, 0, dec.msg.BuildTs) + + decodedRow, err := dec.NextRowChangedEvent() + require.NoError(t, err) + require.Equal(t, decodedRow.CommitTs, row.CommitTs) + require.Equal(t, decodedRow.TableInfo.GetSchemaName(), row.TableInfo.GetSchemaName()) + require.Equal(t, decodedRow.TableInfo.GetTableName(), row.TableInfo.GetTableName()) + require.Nil(t, decodedRow.PreColumns) + } + } +} + +func TestEncodeLargeEventsNormal(t *testing.T) { + ddlEvent, insertEvent, updateEvent, deleteEvent := utils.NewLargeEvent4Test(t, config.GetDefaultReplicaConfig()) + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + for _, compressionType := range []string{ + compression.None, + compression.Snappy, + compression.LZ4, + } { + codecConfig.LargeMessageHandle.LargeMessageHandleCompression = compressionType + + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + m, err := enc.EncodeDDLEvent(ddlEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + + obtainedDDL, err := dec.NextDDLEvent() + require.NoError(t, err) + require.NotNil(t, obtainedDDL) + + obtainedDefaultValues := make(map[string]interface{}, len(obtainedDDL.TableInfo.Columns)) + for _, col := range obtainedDDL.TableInfo.Columns { + obtainedDefaultValues[col.Name.O] = col.GetDefaultValue() + switch col.GetType() { + case mysql.TypeFloat, mysql.TypeDouble: + require.Equal(t, 0, col.GetDecimal()) + default: + } + } + for _, col := range ddlEvent.TableInfo.Columns { + expected := col.GetDefaultValue() + obtained := obtainedDefaultValues[col.Name.O] + require.Equal(t, expected, obtained) + } + + for _, event := range []*model.RowChangedEvent{insertEvent, updateEvent, deleteEvent} { + err = enc.AppendRowChangedEvent(ctx, "", event, func() {}) + require.NoError(t, err) + + messages := enc.Build() + require.Len(t, messages, 1) + + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + + if event.IsDelete() { + require.Equal(t, dec.msg.Type, DMLTypeDelete) + } else if event.IsUpdate() { + require.Equal(t, dec.msg.Type, DMLTypeUpdate) + } else { + require.Equal(t, dec.msg.Type, DMLTypeInsert) + } + + decodedRow, err := dec.NextRowChangedEvent() + require.NoError(t, err) + + require.Equal(t, decodedRow.CommitTs, event.CommitTs) + require.Equal(t, decodedRow.TableInfo.GetSchemaName(), event.TableInfo.GetSchemaName()) + require.Equal(t, decodedRow.TableInfo.GetTableName(), event.TableInfo.GetTableName()) + require.Equal(t, decodedRow.GetTableID(), event.GetTableID()) + + decodedColumns := make(map[string]*model.ColumnData, len(decodedRow.Columns)) + for _, column := range decodedRow.Columns { + colName := decodedRow.TableInfo.ForceGetColumnName(column.ColumnID) + decodedColumns[colName] = column + } + for _, col := range event.Columns { + colName := event.TableInfo.ForceGetColumnName(col.ColumnID) + decoded, ok := decodedColumns[colName] + require.True(t, ok) + switch v := col.Value.(type) { + case types.VectorFloat32: + require.EqualValues(t, v.String(), decoded.Value) + default: + require.EqualValues(t, v, decoded.Value) + } + } + + decodedPreviousColumns := make(map[string]*model.ColumnData, len(decodedRow.PreColumns)) + for _, column := range decodedRow.PreColumns { + colName := decodedRow.TableInfo.ForceGetColumnName(column.ColumnID) + decodedPreviousColumns[colName] = column + } + for _, col := range event.PreColumns { + colName := event.TableInfo.ForceGetColumnName(col.ColumnID) + decoded, ok := decodedPreviousColumns[colName] + require.True(t, ok) + switch v := col.Value.(type) { + case types.VectorFloat32: + require.EqualValues(t, v.String(), decoded.Value) + default: + require.EqualValues(t, v, decoded.Value) + } + } + } + } + } +} + +func TestDDLMessageTooLarge(t *testing.T) { + ddlEvent, _, _, _ := utils.NewLargeEvent4Test(t, config.GetDefaultReplicaConfig()) + + codecConfig := common.NewConfig(config.ProtocolSimple) + codecConfig.MaxMessageBytes = 100 + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + b, err := NewBuilder(context.Background(), codecConfig) + require.NoError(t, err) + enc := b.Build() + + _, err = enc.EncodeDDLEvent(ddlEvent) + require.ErrorIs(t, err, errors.ErrMessageTooLarge) + } +} + +func TestDMLMessageTooLarge(t *testing.T) { + _, insertEvent, _, _ := utils.NewLargeEvent4Test(t, config.GetDefaultReplicaConfig()) + + codecConfig := common.NewConfig(config.ProtocolSimple) + codecConfig.MaxMessageBytes = 50 + + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + + for _, handle := range []string{ + config.LargeMessageHandleOptionNone, + config.LargeMessageHandleOptionHandleKeyOnly, + config.LargeMessageHandleOptionClaimCheck, + } { + codecConfig.LargeMessageHandle.LargeMessageHandleOption = handle + if handle == config.LargeMessageHandleOptionClaimCheck { + codecConfig.LargeMessageHandle.ClaimCheckStorageURI = "file:///tmp/simple-claim-check" + } + b, err := NewBuilder(context.Background(), codecConfig) + require.NoError(t, err) + enc := b.Build() + + err = enc.AppendRowChangedEvent(context.Background(), "", insertEvent, func() {}) + require.ErrorIs(t, err, errors.ErrMessageTooLarge, string(format), handle) + } + } +} + +func TestLargerMessageHandleClaimCheck(t *testing.T) { + ddlEvent, _, updateEvent, _ := utils.NewLargeEvent4Test(t, config.GetDefaultReplicaConfig()) + + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + codecConfig.LargeMessageHandle.LargeMessageHandleOption = config.LargeMessageHandleOptionClaimCheck + + codecConfig.LargeMessageHandle.ClaimCheckStorageURI = "unsupported:///" + b, err := NewBuilder(ctx, codecConfig) + require.Error(t, err) + require.Nil(t, b) + + badDec, err := NewDecoder(ctx, codecConfig, nil) + require.Error(t, err) + require.Nil(t, badDec) + + codecConfig.LargeMessageHandle.ClaimCheckStorageURI = "file:///tmp/simple-claim-check" + for _, rawValue := range []bool{false, true} { + codecConfig.LargeMessageHandle.ClaimCheckRawValue = rawValue + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatAvro, + common.EncodingFormatJSON, + } { + codecConfig.EncodingFormat = format + for _, compressionType := range []string{ + compression.None, + compression.Snappy, + compression.LZ4, + } { + codecConfig.MaxMessageBytes = config.DefaultMaxMessageBytes + codecConfig.LargeMessageHandle.LargeMessageHandleCompression = compressionType + + b, err = NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + m, err := enc.EncodeDDLEvent(ddlEvent) + require.NoError(t, err) + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + + _, err = dec.NextDDLEvent() + require.NoError(t, err) + + enc.(*encoder).config.MaxMessageBytes = 500 + err = enc.AppendRowChangedEvent(ctx, "", updateEvent, func() {}) + require.NoError(t, err) + + claimCheckLocationM := enc.Build()[0] + + dec.config.MaxMessageBytes = 500 + err = dec.AddKeyValue(claimCheckLocationM.Key, claimCheckLocationM.Value) + require.NoError(t, err) + + messageType, hasNext, err = dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + require.NotEqual(t, "", dec.msg.ClaimCheckLocation) + + decodedRow, err := dec.NextRowChangedEvent() + require.NoError(t, err) + + require.Equal(t, decodedRow.CommitTs, updateEvent.CommitTs) + require.Equal(t, decodedRow.TableInfo.GetSchemaName(), updateEvent.TableInfo.GetSchemaName()) + require.Equal(t, decodedRow.TableInfo.GetTableName(), updateEvent.TableInfo.GetTableName()) + + decodedColumns := make(map[string]*model.ColumnData, len(decodedRow.Columns)) + for _, column := range decodedRow.Columns { + colName := decodedRow.TableInfo.ForceGetColumnName(column.ColumnID) + decodedColumns[colName] = column + } + for _, col := range updateEvent.Columns { + colName := updateEvent.TableInfo.ForceGetColumnName(col.ColumnID) + decoded, ok := decodedColumns[colName] + require.True(t, ok) + switch v := col.Value.(type) { + case types.VectorFloat32: + require.EqualValues(t, v.String(), decoded.Value, colName) + default: + require.EqualValues(t, v, decoded.Value, colName) + } + } + + for _, column := range decodedRow.PreColumns { + colName := decodedRow.TableInfo.ForceGetColumnName(column.ColumnID) + decodedColumns[colName] = column + } + for _, col := range updateEvent.PreColumns { + colName := updateEvent.TableInfo.ForceGetColumnName(col.ColumnID) + decoded, ok := decodedColumns[colName] + require.True(t, ok) + switch v := col.Value.(type) { + case types.VectorFloat32: + require.EqualValues(t, v.String(), decoded.Value, colName) + default: + require.EqualValues(t, v, decoded.Value, colName) + } + } + } + } + } +} + +func TestLargeMessageHandleKeyOnly(t *testing.T) { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + mock.MatchExpectationsInOrder(false) + require.NoError(t, err) + + ddlEvent, insertEvent, updateEvent, deleteEvent := utils.NewLargeEvent4Test(t, config.GetDefaultReplicaConfig()) + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + codecConfig.LargeMessageHandle.LargeMessageHandleOption = config.LargeMessageHandleOptionHandleKeyOnly + + badDec, err := NewDecoder(ctx, codecConfig, nil) + require.Error(t, err) + require.Nil(t, badDec) + + events := []*model.RowChangedEvent{ + insertEvent, + updateEvent, + deleteEvent, + } + + for _, format := range []common.EncodingFormatType{ + common.EncodingFormatJSON, + common.EncodingFormatAvro, + } { + codecConfig.EncodingFormat = format + for _, compressionType := range []string{ + compression.None, + compression.Snappy, + compression.LZ4, + } { + codecConfig.MaxMessageBytes = config.DefaultMaxMessageBytes + codecConfig.LargeMessageHandle.LargeMessageHandleCompression = compressionType + + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + dec, err := NewDecoder(ctx, codecConfig, db) + require.NoError(t, err) + + enc.(*encoder).config.MaxMessageBytes = 500 + dec.config.MaxMessageBytes = 500 + for _, event = range events { + err = enc.AppendRowChangedEvent(ctx, "", event, func() {}) + require.NoError(t, err) + + messages := enc.Build() + require.Len(t, messages, 1) + + err = dec.AddKeyValue(messages[0].Key, messages[0].Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeRow, messageType) + require.True(t, dec.msg.HandleKeyOnly) + + obtainedValues := make(map[string]interface{}, len(dec.msg.Data)) + for name, value := range dec.msg.Data { + obtainedValues[name] = value + } + for _, col := range event.Columns { + colName := event.TableInfo.ForceGetColumnName(col.ColumnID) + colFlag := event.TableInfo.ForceGetColumnFlagType(col.ColumnID) + if colFlag.IsHandleKey() { + require.Contains(t, dec.msg.Data, colName) + obtained := obtainedValues[colName] + switch v := obtained.(type) { + case string: + var err error + obtained, err = strconv.ParseInt(v, 10, 64) + require.NoError(t, err) + } + require.EqualValues(t, col.Value, obtained) + } else { + require.NotContains(t, dec.msg.Data, colName) + } + } + + clear(obtainedValues) + for name, value := range dec.msg.Old { + obtainedValues[name] = value + } + for _, col := range event.PreColumns { + colName := event.TableInfo.ForceGetColumnName(col.ColumnID) + colFlag := event.TableInfo.ForceGetColumnFlagType(col.ColumnID) + if colFlag.IsHandleKey() { + require.Contains(t, dec.msg.Old, colName) + obtained := obtainedValues[colName] + switch v := obtained.(type) { + case string: + var err error + obtained, err = strconv.ParseInt(v, 10, 64) + require.NoError(t, err) + } + require.EqualValues(t, col.Value, obtained) + } else { + require.NotContains(t, dec.msg.Data, colName) + } + } + + decodedRow, err := dec.NextRowChangedEvent() + require.NoError(t, err) + require.Nil(t, decodedRow) + } + + enc.(*encoder).config.MaxMessageBytes = config.DefaultMaxMessageBytes + dec.config.MaxMessageBytes = config.DefaultMaxMessageBytes + m, err := enc.EncodeDDLEvent(ddlEvent) + require.NoError(t, err) + + err = dec.AddKeyValue(m.Key, m.Value) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.NoError(t, err) + require.True(t, hasNext) + require.Equal(t, model.MessageTypeDDL, messageType) + + for _, event = range events { + mock.ExpectQuery("SELECT @@global.time_zone"). + WillReturnRows(mock.NewRows([]string{""}).AddRow("SYSTEM")) + + query := fmt.Sprintf("set @@tidb_snapshot=%v", event.CommitTs) + mock.ExpectExec(query).WillReturnResult(driver.ResultNoRows) + + query = fmt.Sprintf("set @@tidb_snapshot=%v", event.CommitTs-1) + mock.ExpectExec(query).WillReturnResult(driver.ResultNoRows) + + names, values := utils.LargeColumnKeyValues() + mock.ExpectQuery("select * from test.t where tu1 = 127"). + WillReturnRows(mock.NewRows(names).AddRow(values...)) + + mock.ExpectQuery("select * from test.t where tu1 = 127"). + WillReturnRows(mock.NewRows(names).AddRow(values...)) + + } + _, err = dec.NextDDLEvent() + require.NoError(t, err) + + decodedRows := dec.GetCachedEvents() + for idx, decodedRow := range decodedRows { + event := events[idx] + + require.Equal(t, decodedRow.CommitTs, event.CommitTs) + require.Equal(t, decodedRow.TableInfo.GetSchemaName(), event.TableInfo.GetSchemaName()) + require.Equal(t, decodedRow.TableInfo.GetTableName(), event.TableInfo.GetTableName()) + + decodedColumns := make(map[string]*model.ColumnData, len(decodedRow.Columns)) + for _, column := range decodedRow.Columns { + colName := decodedRow.TableInfo.ForceGetColumnName(column.ColumnID) + decodedColumns[colName] = column + } + for _, col := range event.Columns { + colName := event.TableInfo.ForceGetColumnName(col.ColumnID) + decoded, ok := decodedColumns[colName] + require.True(t, ok) + colInfo := event.TableInfo.ForceGetColumnFlagType(col.ColumnID) + if colInfo.IsBinary() { + switch v := col.Value.(type) { + case []byte: + length := len(decoded.Value.([]uint8)) + require.Equal(t, v[:length], decoded.Value, colName) + case types.VectorFloat32: + require.Equal(t, v.String(), decoded.Value, colName) + default: + require.Equal(t, col.Value, decoded.Value, colName) + } + } else { + switch v := col.Value.(type) { + case []byte: + require.Equal(t, string(v), decoded.Value, colName) + default: + require.Equal(t, v, decoded.Value, colName) + } + } + } + + clear(decodedColumns) + for _, column := range decodedRow.PreColumns { + colName := decodedRow.TableInfo.ForceGetColumnName(column.ColumnID) + decodedColumns[colName] = column + } + for _, col := range event.PreColumns { + colName := event.TableInfo.ForceGetColumnName(col.ColumnID) + decoded, ok := decodedColumns[colName] + require.True(t, ok) + colInfo := event.TableInfo.ForceGetColumnFlagType(col.ColumnID) + if colInfo.IsBinary() { + switch v := col.Value.(type) { + case []byte: + length := len(decoded.Value.([]uint8)) + require.Equal(t, v[:length], decoded.Value, colName) + case types.VectorFloat32: + require.Equal(t, v.String(), decoded.Value, colName) + default: + require.Equal(t, col.Value, decoded.Value, colName) + } + } else { + switch v := col.Value.(type) { + case types.VectorFloat32: + require.Equal(t, v.String(), decoded.Value, colName) + case []byte: + require.Equal(t, string(v), decoded.Value, colName) + default: + require.Equal(t, v, decoded.Value, colName) + } + } + } + } + } + } +} + +func TestDecoder(t *testing.T) { + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + decoder, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + require.NotNil(t, decoder) + + messageType, hasNext, err := decoder.HasNext() + require.NoError(t, err) + require.False(t, hasNext) + require.Equal(t, model.MessageTypeUnknown, messageType) + + ddl, err := decoder.NextDDLEvent() + require.ErrorIs(t, err, errors.ErrCodecDecode) + require.Nil(t, ddl) + + decoder.msg = new(message) + checkpoint, err := decoder.NextResolvedEvent() + require.ErrorIs(t, err, errors.ErrCodecDecode) + require.Equal(t, uint64(0), checkpoint) + + event, err := decoder.NextRowChangedEvent() + require.ErrorIs(t, err, errors.ErrCodecDecode) + require.Nil(t, event) + + decoder.value = []byte("invalid") + err = decoder.AddKeyValue(nil, nil) + require.ErrorIs(t, err, errors.ErrCodecDecode) +} + +func TestMarshallerError(t *testing.T) { + ctx := context.Background() + codecConfig := common.NewConfig(config.ProtocolSimple) + + b, err := NewBuilder(ctx, codecConfig) + require.NoError(t, err) + enc := b.Build() + + mockMarshaller := mock_simple.NewMockmarshaller(gomock.NewController(t)) + enc.(*encoder).marshaller = mockMarshaller + + mockMarshaller.EXPECT().MarshalCheckpoint(gomock.Any()).Return(nil, errors.ErrEncodeFailed) + _, err = enc.EncodeCheckpointEvent(123) + require.ErrorIs(t, err, errors.ErrEncodeFailed) + + mockMarshaller.EXPECT().MarshalDDLEvent(gomock.Any()).Return(nil, errors.ErrEncodeFailed) + _, err = enc.EncodeDDLEvent(&model.DDLEvent{}) + require.ErrorIs(t, err, errors.ErrEncodeFailed) + + mockMarshaller.EXPECT().MarshalRowChangedEvent(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.ErrEncodeFailed) + err = enc.AppendRowChangedEvent(ctx, "", &model.RowChangedEvent{}, func() {}) + require.ErrorIs(t, err, errors.ErrEncodeFailed) + + dec, err := NewDecoder(ctx, codecConfig, nil) + require.NoError(t, err) + dec.marshaller = mockMarshaller + + mockMarshaller.EXPECT().Unmarshal(gomock.Any(), gomock.Any()).Return(errors.ErrDecodeFailed) + err = dec.AddKeyValue([]byte("key"), []byte("value")) + require.NoError(t, err) + + messageType, hasNext, err := dec.HasNext() + require.ErrorIs(t, err, errors.ErrDecodeFailed) + require.False(t, hasNext) + require.Equal(t, model.MessageTypeUnknown, messageType) +} diff --git a/pkg/sink/codec/simple/message.go b/pkg/sink/codec/simple/message.go new file mode 100644 index 00000000000..6ab764864ac --- /dev/null +++ b/pkg/sink/codec/simple/message.go @@ -0,0 +1,824 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package simple + +import ( + "database/sql" + "encoding/base64" + "fmt" + "sort" + "strconv" + "time" + + "github.com/pingcap/log" + timodel "github.com/pingcap/tidb/pkg/meta/model" + pmodel "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/types" + tiTypes "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tiflow/cdc/model" + cerror "github.com/pingcap/tiflow/pkg/errors" + "github.com/pingcap/tiflow/pkg/integrity" + "github.com/pingcap/tiflow/pkg/sink/codec/common" + "github.com/pingcap/tiflow/pkg/sink/codec/utils" + "go.uber.org/zap" +) + +const ( + defaultVersion = 1 +) + +// MessageType is the type of the message. +type MessageType string + +const ( + // MessageTypeWatermark is the type of the watermark event. + MessageTypeWatermark MessageType = "WATERMARK" + // MessageTypeBootstrap is the type of the bootstrap event. + MessageTypeBootstrap MessageType = "BOOTSTRAP" + // MessageTypeDDL is the type of the ddl event. + MessageTypeDDL MessageType = "DDL" + // MessageTypeDML is the type of the row event. + MessageTypeDML MessageType = "DML" +) + +// DML Message types +const ( + // DMLTypeInsert is the type of the insert event. + DMLTypeInsert MessageType = "INSERT" + // DMLTypeUpdate is the type of the update event. + DMLTypeUpdate MessageType = "UPDATE" + // DMLTypeDelete is the type of the delete event. + DMLTypeDelete MessageType = "DELETE" +) + +// DDL message types +const ( + DDLTypeCreate MessageType = "CREATE" + DDLTypeRename MessageType = "RENAME" + DDLTypeCIndex MessageType = "CINDEX" + DDLTypeDIndex MessageType = "DINDEX" + DDLTypeErase MessageType = "ERASE" + DDLTypeTruncate MessageType = "TRUNCATE" + DDLTypeAlter MessageType = "ALTER" + DDLTypeQuery MessageType = "QUERY" +) + +func getDDLType(t timodel.ActionType) MessageType { + switch t { + case timodel.ActionCreateTable: + return DDLTypeCreate + case timodel.ActionRenameTable, timodel.ActionRenameTables: + return DDLTypeRename + case timodel.ActionAddIndex, timodel.ActionAddForeignKey, timodel.ActionAddPrimaryKey: + return DDLTypeCIndex + case timodel.ActionDropIndex, timodel.ActionDropForeignKey, timodel.ActionDropPrimaryKey: + return DDLTypeDIndex + case timodel.ActionDropTable: + return DDLTypeErase + case timodel.ActionTruncateTable: + return DDLTypeTruncate + case timodel.ActionAddColumn, timodel.ActionDropColumn, timodel.ActionModifyColumn, timodel.ActionRebaseAutoID, + timodel.ActionSetDefaultValue, timodel.ActionModifyTableComment, timodel.ActionRenameIndex, timodel.ActionAddTablePartition, + timodel.ActionDropTablePartition, timodel.ActionModifyTableCharsetAndCollate, timodel.ActionTruncateTablePartition, + timodel.ActionAlterIndexVisibility, timodel.ActionMultiSchemaChange, timodel.ActionReorganizePartition, + timodel.ActionAlterTablePartitioning, timodel.ActionRemovePartitioning: + return DDLTypeAlter + default: + return DDLTypeQuery + } +} + +// columnSchema is the schema of the column. +type columnSchema struct { + Name string `json:"name"` + DataType dataType `json:"dataType"` + Nullable bool `json:"nullable"` + Default interface{} `json:"default"` +} + +type dataType struct { + // MySQLType represent the basic mysql type + MySQLType string `json:"mysqlType"` + + Charset string `json:"charset"` + Collate string `json:"collate"` + + // length represent size of bytes of the field + Length int `json:"length,omitempty"` + // Decimal represent decimal length of the field + Decimal int `json:"decimal,omitempty"` + // Elements represent the element list for enum and set type. + Elements []string `json:"elements,omitempty"` + + Unsigned bool `json:"unsigned,omitempty"` + Zerofill bool `json:"zerofill,omitempty"` +} + +// newColumnSchema converts from TiDB ColumnInfo to columnSchema. +func newColumnSchema(col *timodel.ColumnInfo) *columnSchema { + tp := dataType{ + MySQLType: types.TypeToStr(col.GetType(), col.GetCharset()), + Charset: col.GetCharset(), + Collate: col.GetCollate(), + Length: col.GetFlen(), + Elements: col.GetElems(), + Unsigned: mysql.HasUnsignedFlag(col.GetFlag()), + Zerofill: mysql.HasZerofillFlag(col.GetFlag()), + } + + switch col.GetType() { + // Float and Double decimal is always -1, do not encode it into the schema. + case mysql.TypeFloat, mysql.TypeDouble: + default: + tp.Decimal = col.GetDecimal() + } + + defaultValue := col.GetDefaultValue() + if defaultValue != nil && col.GetType() == mysql.TypeBit { + defaultValue = common.MustBinaryLiteralToInt([]byte(defaultValue.(string))) + } + return &columnSchema{ + Name: col.Name.O, + DataType: tp, + Nullable: !mysql.HasNotNullFlag(col.GetFlag()), + Default: defaultValue, + } +} + +// newTiColumnInfo uses columnSchema and IndexSchema to construct a tidb column info. +func newTiColumnInfo( + column *columnSchema, colID int64, indexes []*IndexSchema, +) *timodel.ColumnInfo { + col := new(timodel.ColumnInfo) + col.ID = colID + col.Name = pmodel.NewCIStr(column.Name) + + col.FieldType = *types.NewFieldType(types.StrToType(column.DataType.MySQLType)) + col.SetCharset(column.DataType.Charset) + col.SetCollate(column.DataType.Collate) + if column.DataType.Unsigned { + col.AddFlag(mysql.UnsignedFlag) + } + if column.DataType.Zerofill { + col.AddFlag(mysql.ZerofillFlag) + } + col.SetFlen(column.DataType.Length) + col.SetDecimal(column.DataType.Decimal) + col.SetElems(column.DataType.Elements) + + if utils.IsBinaryMySQLType(column.DataType.MySQLType) { + col.AddFlag(mysql.BinaryFlag) + } + + if !column.Nullable { + col.AddFlag(mysql.NotNullFlag) + } + + defaultValue := column.Default + if defaultValue != nil && col.GetType() == mysql.TypeBit { + switch v := defaultValue.(type) { + case float64: + byteSize := (col.GetFlen() + 7) >> 3 + defaultValue = tiTypes.NewBinaryLiteralFromUint(uint64(v), byteSize) + defaultValue = defaultValue.(tiTypes.BinaryLiteral).ToString() + default: + } + } + + for _, index := range indexes { + for _, name := range index.Columns { + if name == column.Name { + if index.Primary { + col.AddFlag(mysql.PriKeyFlag) + } else if index.Unique { + col.AddFlag(mysql.UniqueKeyFlag) + } else { + col.AddFlag(mysql.MultipleKeyFlag) + } + } + } + } + + err := col.SetDefaultValue(defaultValue) + if err != nil { + log.Panic("set default value failed", zap.Any("column", col), zap.Any("default", defaultValue)) + } + return col +} + +// IndexSchema is the schema of the index. +type IndexSchema struct { + Name string `json:"name"` + Unique bool `json:"unique"` + Primary bool `json:"primary"` + Nullable bool `json:"nullable"` + Columns []string `json:"columns"` +} + +// newIndexSchema converts from TiDB IndexInfo to IndexSchema. +func newIndexSchema(index *timodel.IndexInfo, columns []*timodel.ColumnInfo) *IndexSchema { + indexSchema := &IndexSchema{ + Name: index.Name.O, + Unique: index.Unique, + Primary: index.Primary, + } + for _, col := range index.Columns { + indexSchema.Columns = append(indexSchema.Columns, col.Name.O) + colInfo := columns[col.Offset] + // An index is not null when all columns of are not null + if !mysql.HasNotNullFlag(colInfo.GetFlag()) { + indexSchema.Nullable = true + } + } + return indexSchema +} + +// newTiIndexInfo convert IndexSchema to a tidb index info. +func newTiIndexInfo(indexSchema *IndexSchema, columns []*timodel.ColumnInfo, indexID int64) *timodel.IndexInfo { + indexColumns := make([]*timodel.IndexColumn, len(indexSchema.Columns)) + for i, col := range indexSchema.Columns { + var offset int + for idx, column := range columns { + if column.Name.O == col { + offset = idx + break + } + } + indexColumns[i] = &timodel.IndexColumn{ + Name: pmodel.NewCIStr(col), + Offset: offset, + } + } + + return &timodel.IndexInfo{ + ID: indexID, + Name: pmodel.NewCIStr(indexSchema.Name), + Columns: indexColumns, + Unique: indexSchema.Unique, + Primary: indexSchema.Primary, + } +} + +// TableSchema is the schema of the table. +type TableSchema struct { + Schema string `json:"schema"` + Table string `json:"table"` + TableID int64 `json:"tableID"` + Version uint64 `json:"version"` + Columns []*columnSchema `json:"columns"` + Indexes []*IndexSchema `json:"indexes"` +} + +func newTableSchema(tableInfo *model.TableInfo) *TableSchema { + pkInIndexes := false + indexes := make([]*IndexSchema, 0, len(tableInfo.Indices)) + for _, idx := range tableInfo.Indices { + index := newIndexSchema(idx, tableInfo.Columns) + if index.Primary { + pkInIndexes = true + } + indexes = append(indexes, index) + } + + // sometimes the primary key is not in the index, we need to find it manually. + if !pkInIndexes { + pkColumns := tableInfo.GetPrimaryKeyColumnNames() + if len(pkColumns) != 0 { + index := &IndexSchema{ + Name: "primary", + Nullable: false, + Primary: true, + Unique: true, + Columns: pkColumns, + } + indexes = append(indexes, index) + } + } + + sort.SliceStable(tableInfo.Columns, func(i, j int) bool { + return tableInfo.Columns[i].ID < tableInfo.Columns[j].ID + }) + + columns := make([]*columnSchema, 0, len(tableInfo.Columns)) + for _, col := range tableInfo.Columns { + colSchema := newColumnSchema(col) + columns = append(columns, colSchema) + } + + return &TableSchema{ + Schema: tableInfo.TableName.Schema, + Table: tableInfo.TableName.Table, + TableID: tableInfo.ID, + Version: tableInfo.UpdateTS, + Columns: columns, + Indexes: indexes, + } +} + +// newTableInfo converts from TableSchema to TableInfo. +func newTableInfo(m *TableSchema) *model.TableInfo { + var ( + database string + schemaVersion uint64 + ) + + tidbTableInfo := &timodel.TableInfo{} + if m != nil { + database = m.Schema + schemaVersion = m.Version + + tidbTableInfo.ID = m.TableID + tidbTableInfo.Name = pmodel.NewCIStr(m.Table) + tidbTableInfo.UpdateTS = m.Version + + nextMockID := int64(100) + for _, col := range m.Columns { + tiCol := newTiColumnInfo(col, nextMockID, m.Indexes) + nextMockID += 100 + tidbTableInfo.Columns = append(tidbTableInfo.Columns, tiCol) + } + + mockIndexID := int64(1) + for _, idx := range m.Indexes { + index := newTiIndexInfo(idx, tidbTableInfo.Columns, mockIndexID) + tidbTableInfo.Indices = append(tidbTableInfo.Indices, index) + mockIndexID += 1 + } + } + return model.WrapTableInfo(100, database, schemaVersion, tidbTableInfo) +} + +// newDDLEvent converts from message to DDLEvent. +func newDDLEvent(msg *message) *model.DDLEvent { + var ( + tableInfo *model.TableInfo + preTableInfo *model.TableInfo + ) + + tableInfo = newTableInfo(msg.TableSchema) + if msg.PreTableSchema != nil { + preTableInfo = newTableInfo(msg.PreTableSchema) + } + return &model.DDLEvent{ + StartTs: msg.CommitTs, + CommitTs: msg.CommitTs, + TableInfo: tableInfo, + PreTableInfo: preTableInfo, + Query: msg.SQL, + } +} + +// buildRowChangedEvent converts from message to RowChangedEvent. +func buildRowChangedEvent( + msg *message, tableInfo *model.TableInfo, enableRowChecksum bool, db *sql.DB, +) (*model.RowChangedEvent, error) { + result := &model.RowChangedEvent{ + CommitTs: msg.CommitTs, + PhysicalTableID: msg.TableID, + TableInfo: tableInfo, + Columns: decodeColumns(msg.Data, tableInfo), + PreColumns: decodeColumns(msg.Old, tableInfo), + } + + if enableRowChecksum && msg.Checksum != nil { + result.Checksum = &integrity.Checksum{ + Current: msg.Checksum.Current, + Previous: msg.Checksum.Previous, + Corrupted: msg.Checksum.Corrupted, + Version: msg.Checksum.Version, + } + + err := common.VerifyChecksum(result, db) + if err != nil || msg.Checksum.Corrupted { + log.Warn("consumer detect checksum corrupted", + zap.String("schema", msg.Schema), zap.String("table", msg.Table), zap.Error(err)) + return nil, cerror.ErrDecodeFailed.GenWithStackByArgs("checksum corrupted") + + } + } + + for _, col := range result.Columns { + adjustTimestampValue(col, tableInfo.ForceGetColumnInfo(col.ColumnID).FieldType) + } + for _, col := range result.PreColumns { + adjustTimestampValue(col, tableInfo.ForceGetColumnInfo(col.ColumnID).FieldType) + } + + return result, nil +} + +func adjustTimestampValue(column *model.ColumnData, flag types.FieldType) { + if flag.GetType() != mysql.TypeTimestamp { + return + } + if column.Value != nil { + var ts string + switch v := column.Value.(type) { + case map[string]string: + ts = v["value"] + case map[string]interface{}: + ts = v["value"].(string) + } + column.Value = ts + } +} + +func decodeColumns( + rawData map[string]interface{}, tableInfo *model.TableInfo, +) []*model.ColumnData { + if rawData == nil { + return nil + } + result := make([]*model.ColumnData, 0, len(tableInfo.Columns)) + for _, info := range tableInfo.Columns { + value, ok := rawData[info.Name.O] + if !ok { + log.Warn("cannot found the value for the column, "+ + "it must be a generated column and TiCDC does not replicate generated column value", + zap.String("column", info.Name.O)) + continue + } + columnID := tableInfo.ForceGetColumnIDByName(info.Name.O) + col := decodeColumn(value, columnID, &info.FieldType) + if col == nil { + log.Panic("cannot decode column", + zap.String("name", info.Name.O), zap.Any("data", value)) + } + + result = append(result, col) + } + return result +} + +type checksum struct { + Version int `json:"version"` + Corrupted bool `json:"corrupted"` + Current uint32 `json:"current"` + Previous uint32 `json:"previous"` +} + +type message struct { + Version int `json:"version"` + // Schema and Table is empty for the resolved ts event. + Schema string `json:"database,omitempty"` + Table string `json:"table,omitempty"` + TableID int64 `json:"tableID,omitempty"` + Type MessageType `json:"type"` + // SQL is only for the DDL event. + SQL string `json:"sql,omitempty"` + CommitTs uint64 `json:"commitTs"` + BuildTs int64 `json:"buildTs"` + // SchemaVersion is for the DML event. + SchemaVersion uint64 `json:"schemaVersion,omitempty"` + + // ClaimCheckLocation is only for the DML event. + ClaimCheckLocation string `json:"claimCheckLocation,omitempty"` + // HandleKeyOnly is only for the DML event. + HandleKeyOnly bool `json:"handleKeyOnly,omitempty"` + + // E2E checksum related fields, only set when enable checksum functionality. + Checksum *checksum `json:"checksum,omitempty"` + + // Data is available for the Insert and Update event. + Data map[string]interface{} `json:"data,omitempty"` + // Old is available for the Update and Delete event. + Old map[string]interface{} `json:"old,omitempty"` + // TableSchema is for the DDL and Bootstrap event. + TableSchema *TableSchema `json:"tableSchema,omitempty"` + // PreTableSchema holds schema information before the DDL executed. + PreTableSchema *TableSchema `json:"preTableSchema,omitempty"` +} + +func newResolvedMessage(ts uint64) *message { + return &message{ + Version: defaultVersion, + Type: MessageTypeWatermark, + CommitTs: ts, + BuildTs: time.Now().UnixMilli(), + } +} + +func newBootstrapMessage(tableInfo *model.TableInfo) *message { + schema := newTableSchema(tableInfo) + msg := &message{ + Version: defaultVersion, + Type: MessageTypeBootstrap, + BuildTs: time.Now().UnixMilli(), + TableSchema: schema, + } + return msg +} + +func newDDLMessage(ddl *model.DDLEvent) *message { + var ( + schema *TableSchema + preSchema *TableSchema + ) + // the tableInfo maybe nil if the DDL is `drop database` + if ddl.TableInfo != nil && ddl.TableInfo.TableInfo != nil { + schema = newTableSchema(ddl.TableInfo) + } + // `PreTableInfo` may not exist for some DDL, such as `create table` + if ddl.PreTableInfo != nil && ddl.PreTableInfo.TableInfo != nil { + preSchema = newTableSchema(ddl.PreTableInfo) + } + msg := &message{ + Version: defaultVersion, + Type: getDDLType(ddl.Type), + CommitTs: ddl.CommitTs, + BuildTs: time.Now().UnixMilli(), + SQL: ddl.Query, + TableSchema: schema, + PreTableSchema: preSchema, + } + return msg +} + +func (a *jsonMarshaller) newDMLMessage( + event *model.RowChangedEvent, + onlyHandleKey bool, claimCheckFileName string, +) *message { + m := &message{ + Version: defaultVersion, + Schema: event.TableInfo.GetSchemaName(), + Table: event.TableInfo.GetTableName(), + TableID: event.GetTableID(), + CommitTs: event.CommitTs, + BuildTs: time.Now().UnixMilli(), + SchemaVersion: event.TableInfo.UpdateTS, + HandleKeyOnly: onlyHandleKey, + ClaimCheckLocation: claimCheckFileName, + } + if event.IsInsert() { + m.Type = DMLTypeInsert + m.Data = a.formatColumns(event.Columns, event.TableInfo, onlyHandleKey) + } else if event.IsDelete() { + m.Type = DMLTypeDelete + m.Old = a.formatColumns(event.PreColumns, event.TableInfo, onlyHandleKey) + } else if event.IsUpdate() { + m.Type = DMLTypeUpdate + m.Data = a.formatColumns(event.Columns, event.TableInfo, onlyHandleKey) + m.Old = a.formatColumns(event.PreColumns, event.TableInfo, onlyHandleKey) + } + if a.config.EnableRowChecksum && event.Checksum != nil { + m.Checksum = &checksum{ + Version: event.Checksum.Version, + Corrupted: event.Checksum.Corrupted, + Current: event.Checksum.Current, + Previous: event.Checksum.Previous, + } + } + + return m +} + +func (a *jsonMarshaller) formatColumns( + columns []*model.ColumnData, tableInfo *model.TableInfo, onlyHandleKey bool, +) map[string]interface{} { + result := make(map[string]interface{}, len(columns)) + colInfos := tableInfo.GetColInfosForRowChangedEvent() + for i, col := range columns { + if col != nil { + flag := tableInfo.ForceGetColumnFlagType(col.ColumnID) + if onlyHandleKey && !flag.IsHandleKey() { + continue + } + value := encodeValue(col.Value, colInfos[i].Ft, a.config.TimeZone.String()) + result[tableInfo.ForceGetColumnName(col.ColumnID)] = value + } + } + return result +} + +func (a *avroMarshaller) encodeValue4Avro( + value interface{}, ft *types.FieldType, +) (interface{}, string) { + if value == nil { + return nil, "null" + } + + switch ft.GetType() { + case mysql.TypeTimestamp: + return map[string]interface{}{ + "location": a.config.TimeZone.String(), + "value": value.(string), + }, "com.pingcap.simple.avro.Timestamp" + case mysql.TypeLonglong: + if mysql.HasUnsignedFlag(ft.GetFlag()) { + return map[string]interface{}{ + "value": int64(value.(uint64)), + }, "com.pingcap.simple.avro.UnsignedBigint" + } + } + + switch v := value.(type) { + case uint64: + return int64(v), "long" + case int64: + return v, "long" + case []byte: + if mysql.HasBinaryFlag(ft.GetFlag()) { + return v, "bytes" + } + return string(v), "string" + case float32: + return v, "float" + case float64: + return v, "double" + case string: + return v, "string" + case tiTypes.VectorFloat32: + return v.String(), "string" + default: + log.Panic("unexpected type for avro value", zap.Any("value", value)) + } + return value, "" +} + +func encodeValue( + value interface{}, ft *types.FieldType, location string, +) interface{} { + if value == nil { + return nil + } + var err error + switch ft.GetType() { + case mysql.TypeBit: + switch v := value.(type) { + case []uint8: + value = common.MustBinaryLiteralToInt(v) + default: + } + case mysql.TypeTimestamp: + var ts string + switch v := value.(type) { + case string: + ts = v + // the timestamp value maybe []uint8 if it's queried from upstream TiDB. + case []uint8: + ts = string(v) + } + return map[string]string{ + "location": location, + "value": ts, + } + case mysql.TypeEnum: + switch v := value.(type) { + case []uint8: + data := string(v) + var enum tiTypes.Enum + enum, err = tiTypes.ParseEnumName(ft.GetElems(), data, ft.GetCollate()) + value = enum.Value + } + case mysql.TypeSet: + switch v := value.(type) { + case []uint8: + data := string(v) + var set tiTypes.Set + set, err = tiTypes.ParseSetName(ft.GetElems(), data, ft.GetCollate()) + value = set.Value + } + default: + } + + if err != nil { + log.Panic("parse enum / set name failed", + zap.Any("elems", ft.GetElems()), zap.Any("name", value), zap.Error(err)) + } + + var result string + switch v := value.(type) { + case int64: + result = strconv.FormatInt(v, 10) + case uint64: + result = strconv.FormatUint(v, 10) + case float32: + result = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + result = strconv.FormatFloat(v, 'f', -1, 64) + case string: + result = v + case []byte: + if mysql.HasBinaryFlag(ft.GetFlag()) { + result = base64.StdEncoding.EncodeToString(v) + } else { + result = string(v) + } + case tiTypes.VectorFloat32: + result = v.String() + default: + result = fmt.Sprintf("%v", v) + } + + return result +} + +func decodeColumn(value interface{}, id int64, fieldType *types.FieldType) *model.ColumnData { + result := &model.ColumnData{ + ColumnID: id, + Value: value, + } + if value == nil { + return result + } + + var err error + if mysql.HasBinaryFlag(fieldType.GetFlag()) { + switch v := value.(type) { + case string: + value, err = base64.StdEncoding.DecodeString(v) + if err != nil { + return nil + } + default: + } + result.Value = value + return result + } + + switch fieldType.GetType() { + case mysql.TypeBit, mysql.TypeSet: + switch v := value.(type) { + // avro encoding, set is encoded as `int64`, bit encoded as `string` + // json encoding, set is encoded as `string`, bit encoded as `string` + case string: + value, err = strconv.ParseUint(v, 10, 64) + case int64: + value = uint64(v) + } + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong, mysql.TypeInt24: + switch v := value.(type) { + case string: + if mysql.HasUnsignedFlag(fieldType.GetFlag()) { + value, err = strconv.ParseUint(v, 10, 64) + } else { + value, err = strconv.ParseInt(v, 10, 64) + } + default: + value = v + } + case mysql.TypeYear: + switch v := value.(type) { + case string: + value, err = strconv.ParseInt(v, 10, 64) + default: + value = v + } + case mysql.TypeLonglong: + switch v := value.(type) { + case string: + if mysql.HasUnsignedFlag(fieldType.GetFlag()) { + value, err = strconv.ParseUint(v, 10, 64) + } else { + value, err = strconv.ParseInt(v, 10, 64) + } + case map[string]interface{}: + value = uint64(v["value"].(int64)) + default: + value = v + } + case mysql.TypeFloat: + switch v := value.(type) { + case string: + var val float64 + val, err = strconv.ParseFloat(v, 32) + value = float32(val) + default: + value = v + } + case mysql.TypeDouble: + switch v := value.(type) { + case string: + value, err = strconv.ParseFloat(v, 64) + default: + value = v + } + case mysql.TypeEnum: + // avro encoding, enum is encoded as `int64`, use it directly. + // json encoding, enum is encoded as `string` + switch v := value.(type) { + case string: + value, err = strconv.ParseUint(v, 10, 64) + } + default: + } + + if err != nil { + return nil + } + + result.Value = value + return result +}