diff --git a/README.md b/README.md index 20ec0f2..fdfb7a8 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Features: New in v3: * Supports ILP over HTTP using the same client semantics +* Supports n-dimensional arrays of doubles for QuestDB servers 9.0.0 and up Documentation is available [here](https://pkg.go.dev/github.com/questdb/go-questdb-client/v3). @@ -99,6 +100,95 @@ HTTP is the recommended transport to use. To connect via TCP, set the configurat // ... ``` +## N-dimensional arrays + +QuestDB server version 9.0.0 and newer supports n-dimensional arrays of double precision floating point numbers. +The Go client provides several methods to send arrays to QuestDB: + +### 1D Arrays + +```go +// Send a 1D array of doubles +values1D := []float64{1.1, 2.2, 3.3, 4.4} +err = sender. + Table("measurements"). + Symbol("sensor", "temp_probe_1"). + Float64Array1DColumn("readings", values1D). + AtNow(ctx) +``` + +### 2D Arrays + +```go +// Send a 2D array of doubles (must be rectangular) +values2D := [][]float64{ + {1.1, 2.2, 3.3}, + {4.4, 5.5, 6.6}, + {7.7, 8.8, 9.9}, +} +err = sender. + Table("matrix_data"). + Symbol("experiment", "test_001"). + Float64Array2DColumn("matrix", values2D). + AtNow(ctx) +``` + +### 3D Arrays + +```go +// Send a 3D array of doubles (must be regular cuboid shape) +values3D := [][][]float64{ + {{1.0, 2.0}, {3.0, 4.0}}, + {{5.0, 6.0}, {7.0, 8.0}}, +} +err = sender. + Table("tensor_data"). + Symbol("model", "neural_net_v1"). + Float64Array3DColumn("weights", values3D). + AtNow(ctx) +``` + +### N-dimensional Arrays + +For higher dimensions, use the `NewNDArray` function: + +```go +// Create a 2x3x4 array +arr, err := qdb.NewNDArray[float64](2, 3, 4) +if err != nil { + log.Fatal(err) +} + +// Fill with values +arr.Fill(1.5) + +// Or set individual values +arr.Set([]uint{0, 1, 2}, 42.0) + +err = sender. + Table("ndarray_data"). + Symbol("dataset", "training_batch_1"). + Float64ArrayNDColumn("features", arr). + AtNow(ctx) +``` + +The array data is sent over a new protocol version (2) that is auto-negotiated +when using HTTP(s), or can be specified explicitly via the ``protocol_version=2`` +parameter when using TCP(s). + +We recommend using HTTP(s), but here is an TCP example, should you need it: + +```go +sender, err := qdb.NewLineSender(ctx, + qdb.WithTcp(), + qdb.WithProtocolVersion(qdb.ProtocolVersion2)) +``` + +When using ``protocol_version=2`` (with either TCP(s) or HTTP(s)), the sender +will now also serialize ``float64`` (double-precision) columns as binary. +You might see a performance uplift if this is a dominant data type in your +ingestion workload. + ## Pooled Line Senders **Warning: Experimental feature designed for use with HTTP senders ONLY** diff --git a/buffer.go b/buffer.go index c5df6f2..f4dd1bf 100644 --- a/buffer.go +++ b/buffer.go @@ -26,12 +26,14 @@ package questdb import ( "bytes" + "encoding/binary" "errors" "fmt" "math" "math/big" "strconv" "time" + "unsafe" ) // errInvalidMsg indicates a failed attempt to construct an ILP @@ -39,6 +41,54 @@ import ( // chars found in table or column name. var errInvalidMsg = errors.New("invalid message") +type binaryCode byte + +const ( + arrayCode binaryCode = 14 + float64Code binaryCode = 16 +) + +var isLittleEndian = func() bool { + var i int32 = 0x01020304 + return *(*byte)(unsafe.Pointer(&i)) == 0x04 +}() + +// MaxArrayElements defines the maximum total number of elements of Array +const MaxArrayElements = (1 << 28) - 1 + +// writeFloat64Data optimally writes float64 slice data to buffer +// Uses batch memory copy on little-endian machines for better performance +func (b *buffer) writeFloat64Data(data []float64) { + if isLittleEndian && len(data) > 0 { + b.Write(unsafe.Slice((*byte)(unsafe.Pointer(&data[0])), len(data)*8)) + } else { + bytes := make([]byte, 8) + for _, val := range data { + binary.LittleEndian.PutUint64(bytes[0:], math.Float64bits(val)) + b.Write(bytes) + } + } +} + +func (b *buffer) writeInt32(val int32) { + if isLittleEndian { + // On little-endian machines, we can directly write the uint32 as bytes + b.Write((*[4]byte)(unsafe.Pointer(&val))[:]) + } else { + // On big-endian machines, use the standard conversion + data := make([]byte, 4) + binary.LittleEndian.PutUint32(data, uint32(val)) + b.Write(data) + } +} + +type arrayElemType byte + +const ( + arrayElemDouble arrayElemType = 10 + arrayElemNull = 33 +) + // buffer is a wrapper on top of bytes.Buffer. It extends the // original struct with methods for writing int64 and float64 // numbers without unnecessary allocations. @@ -90,6 +140,12 @@ func (b *buffer) ClearLastErr() { b.lastErr = nil } +func (b *buffer) SetLastErr(err error) { + if b.lastErr == nil { + b.lastErr = err + } +} + func (b *buffer) writeInt(i int64) { // We need up to 20 bytes to fit an int64, including a sign. var a [20]byte @@ -393,8 +449,8 @@ func (b *buffer) resetMsgFlags() { b.hasFields = false } -func (b *buffer) Messages() string { - return b.String() +func (b *buffer) Messages() []byte { + return b.Buffer.Bytes() } func (b *buffer) Table(name string) *buffer { @@ -517,6 +573,206 @@ func (b *buffer) Float64Column(name string, val float64) *buffer { return b } +func (b *buffer) Float64ColumnBinary(name string, val float64) *buffer { + if !b.prepareForField() { + return b + } + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { + return b + } + b.WriteByte('=') + // binary format flag + b.WriteByte('=') + b.WriteByte(byte(float64Code)) + if isLittleEndian { + b.Write((*[8]byte)(unsafe.Pointer(&val))[:]) + } else { + data := make([]byte, 8) + binary.LittleEndian.PutUint64(data, math.Float64bits(val)) + b.Write(data) + } + b.hasFields = true + return b +} + +func (b *buffer) Float64Array1DColumn(name string, values []float64) *buffer { + if !b.prepareForField() { + return b + } + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { + return b + } + if values == nil { + b.writeNullArray() + return b + } + + dim1 := len(values) + if dim1 > MaxArrayElements { + b.lastErr = fmt.Errorf("array size %d exceeds maximum limit %d", dim1, MaxArrayElements) + return b + } + b.writeFloat64ArrayHeader(1) + + // Write shape + b.writeInt32(int32(dim1)) + + // Write values + if len(values) > 0 { + b.writeFloat64Data(values) + } + + b.hasFields = true + return b +} + +func (b *buffer) Float64Array2DColumn(name string, values [][]float64) *buffer { + if !b.prepareForField() { + return b + } + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { + return b + } + + if values == nil { + b.writeNullArray() + return b + } + + // Validate array shape + dim1 := len(values) + var dim2 int + if dim1 > 0 { + dim2 = len(values[0]) + totalElements := product([]uint{uint(dim1), uint(dim2)}) + if totalElements > MaxArrayElements { + b.lastErr = fmt.Errorf("array size %d exceeds maximum limit %d", totalElements, MaxArrayElements) + return b + } + for i, row := range values { + if len(row) != dim2 { + b.lastErr = fmt.Errorf("irregular 2D array shape: row %d has length %d, expected %d", i, len(row), dim2) + return b + } + } + } + + b.writeFloat64ArrayHeader(2) + + // Write shape + b.writeInt32(int32(dim1)) + b.writeInt32(int32(dim2)) + + // Write values + for _, row := range values { + if len(row) > 0 { + b.writeFloat64Data(row) + } + } + + b.hasFields = true + return b +} + +func (b *buffer) Float64Array3DColumn(name string, values [][][]float64) *buffer { + if !b.prepareForField() { + return b + } + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { + return b + } + + if values == nil { + b.writeNullArray() + return b + } + + // Validate array shape + dim1 := len(values) + var dim2, dim3 int + if dim1 > 0 { + dim2 = len(values[0]) + if dim2 > 0 { + dim3 = len(values[0][0]) + } + totalElements := product([]uint{uint(dim1), uint(dim2), uint(dim3)}) + if totalElements > MaxArrayElements { + b.lastErr = fmt.Errorf("array size %d exceeds maximum limit %d", totalElements, MaxArrayElements) + return b + } + + for i, level1 := range values { + if len(level1) != dim2 { + b.lastErr = fmt.Errorf("irregular 3D array shape: level1[%d] has length %d, expected %d", i, len(level1), dim2) + return b + } + for j, level2 := range level1 { + if len(level2) != dim3 { + b.lastErr = fmt.Errorf("irregular 3D array shape: level2[%d][%d] has length %d, expected %d", i, j, len(level2), dim3) + return b + } + } + } + } + + b.writeFloat64ArrayHeader(3) + + // Write shape + b.writeInt32(int32(dim1)) + b.writeInt32(int32(dim2)) + b.writeInt32(int32(dim3)) + + // Write values + for _, level1 := range values { + for _, level2 := range level1 { + if len(level2) > 0 { + b.writeFloat64Data(level2) + } + } + } + + b.hasFields = true + return b +} + +func (b *buffer) Float64ArrayNDColumn(name string, value *NdArray[float64]) *buffer { + if !b.prepareForField() { + return b + } + b.lastErr = b.writeColumnName(name) + if b.lastErr != nil { + return b + } + + if value == nil { + b.writeNullArray() + return b + } + + shape := value.Shape() + numDims := value.NDims() + // Write nDims + b.writeFloat64ArrayHeader(byte(numDims)) + + // Write shape + for _, dim := range shape { + b.writeInt32(int32(dim)) + } + + // Write data + data := value.Data() + if len(data) > 0 { + b.writeFloat64Data(data) + } + + b.hasFields = true + return b +} + func (b *buffer) StringColumn(name, val string) *buffer { if !b.prepareForField() { return b @@ -591,3 +847,19 @@ func (b *buffer) At(ts time.Time, sendTs bool) error { b.resetMsgFlags() return nil } + +func (b *buffer) writeFloat64ArrayHeader(dims byte) { + b.WriteByte('=') + b.WriteByte('=') + b.WriteByte(byte(arrayCode)) + b.WriteByte(byte(arrayElemDouble)) + b.WriteByte(dims) +} + +func (b *buffer) writeNullArray() { + b.WriteByte('=') + b.WriteByte('=') + b.WriteByte(byte(arrayCode)) + b.WriteByte(byte(arrayElemNull)) + b.hasFields = true +} diff --git a/buffer_test.go b/buffer_test.go index 515d9a9..02413f9 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -25,6 +25,7 @@ package questdb_test import ( + "encoding/binary" "math" "math/big" "strconv" @@ -46,7 +47,7 @@ func TestValidWrites(t *testing.T) { testCases := []struct { name string writerFn bufWriterFn - expectedLines []string + expectedLines [][]byte }{ { "multiple rows", @@ -61,9 +62,9 @@ func TestValidWrites(t *testing.T) { } return nil }, - []string{ - "my_test_table str_col=\"foo\",long_col=42i", - "my_test_table str_col=\"bar\",long_col=-42i 42000", + [][]byte{ + []byte("my_test_table str_col=\"foo\",long_col=42i"), + []byte("my_test_table str_col=\"bar\",long_col=-42i 42000"), }, }, { @@ -71,8 +72,8 @@ func TestValidWrites(t *testing.T) { func(s *qdb.Buffer) error { return s.Table("таблица").StringColumn("колонка", "значение").At(time.Time{}, false) }, - []string{ - "таблица колонка=\"значение\"", + [][]byte{ + []byte("таблица колонка=\"значение\""), }, }, } @@ -85,8 +86,15 @@ func TestValidWrites(t *testing.T) { assert.NoError(t, err) // Check the buffer - assert.Equal(t, strings.Join(tc.expectedLines, "\n")+"\n", buf.Messages()) - + var expectedLines []byte + for i, line := range tc.expectedLines { + if i != 0 { + expectedLines = append(expectedLines, '\n') + } + expectedLines = append(expectedLines, line...) + } + expectedLines = append(expectedLines, '\n') + assert.Equal(t, expectedLines, buf.Messages()) }) } } @@ -109,8 +117,8 @@ func TestTimestampSerialization(t *testing.T) { assert.NoError(t, err) // Check the buffer - expectedLines := []string{"my_test_table a_col=" + strconv.FormatInt(tc.val.UnixMicro(), 10) + "t"} - assert.Equal(t, strings.Join(expectedLines, "\n")+"\n", buf.Messages()) + expected := []byte("my_test_table a_col=" + strconv.FormatInt(tc.val.UnixMicro(), 10) + "t\n") + assert.Equal(t, expected, buf.Messages()) }) } } @@ -135,8 +143,8 @@ func TestInt64Serialization(t *testing.T) { assert.NoError(t, err) // Check the buffer - expectedLines := []string{"my_test_table a_col=" + strconv.FormatInt(tc.val, 10) + "i"} - assert.Equal(t, strings.Join(expectedLines, "\n")+"\n", buf.Messages()) + expected := []byte("my_test_table a_col=" + strconv.FormatInt(tc.val, 10) + "i\n") + assert.Equal(t, expected, buf.Messages()) }) } } @@ -163,8 +171,8 @@ func TestLong256Column(t *testing.T) { assert.NoError(t, err) // Check the buffer - expectedLines := []string{"my_test_table a_col=" + tc.expected + "i"} - assert.Equal(t, strings.Join(expectedLines, "\n")+"\n", buf.Messages()) + expected := []byte("my_test_table a_col=" + tc.expected + "i\n") + assert.Equal(t, expected, buf.Messages()) }) } } @@ -196,8 +204,8 @@ func TestFloat64Serialization(t *testing.T) { assert.NoError(t, err) // Check the buffer - expectedLines := []string{"my_test_table a_col=" + tc.expected} - assert.Equal(t, strings.Join(expectedLines, "\n")+"\n", buf.Messages()) + expected := []byte("my_test_table a_col=" + tc.expected + "\n") + assert.Equal(t, expected, buf.Messages()) }) } } @@ -428,8 +436,8 @@ func TestInvalidMessageGetsDiscarded(t *testing.T) { assert.Error(t, err) // The second message should be discarded. - expectedLines := []string{testTable + " foo=\"bar\""} - assert.Equal(t, strings.Join(expectedLines, "\n")+"\n", buf.Messages()) + expected := []byte(testTable + " foo=\"bar\"\n") + assert.Equal(t, expected, buf.Messages()) } func TestInvalidTableName(t *testing.T) { @@ -447,3 +455,369 @@ func TestInvalidColumnName(t *testing.T) { assert.ErrorContains(t, err, "column name contains an illegal char") assert.Empty(t, buf.Messages()) } + +func TestFloat64ColumnBinary(t *testing.T) { + testCases := []struct { + name string + val float64 + }{ + {"positive number", 42.3}, + {"negative number", -42.3}, + {"zero", 0.0}, + {"NaN", math.NaN()}, + {"positive infinity", math.Inf(1)}, + {"negative infinity", math.Inf(-1)}, + {"smallest value", math.SmallestNonzeroFloat64}, + {"max value", math.MaxFloat64}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := newTestBuffer() + err := buf.Table(testTable).Float64ColumnBinary("a_col", tc.val).At(time.Time{}, false) + assert.NoError(t, err) + assert.Equal(t, buf.Messages(), float64ToByte(testTable, "a_col", tc.val)) + }) + } +} + +func TestFloat64Array1DColumn(t *testing.T) { + testCases := []struct { + name string + values []float64 + }{ + {"single value", []float64{42.5}}, + {"multiple values", []float64{1.1, 2.2, 3.3}}, + {"empty array", []float64{}}, + {"with special values", []float64{math.NaN(), math.Inf(1), math.Inf(-1), 0.0}}, + {"null array", nil}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, littleEndian := range []bool{true, false} { + qdb.SetLittleEndian(littleEndian) + buf := newTestBuffer() + + err := buf.Table(testTable).Float64Array1DColumn("array_col", tc.values).At(time.Time{}, false) + assert.NoError(t, err) + assert.Equal(t, float641DArrayToByte(testTable, "array_col", tc.values), buf.Messages()) + } + }) + } +} + +func TestFloat64Array2DColumn(t *testing.T) { + testCases := []struct { + name string + values [][]float64 + }{ + {"2x2 array", [][]float64{{1.1, 2.2}, {3.3, 4.4}}}, + {"1x3 array", [][]float64{{1.0, 2.0, 3.0}}}, + {"3x1 array", [][]float64{{1.0}, {2.0}, {3.0}}}, + {"empty array", [][]float64{}}, + {"null array", nil}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, littleEndian := range []bool{true, false} { + qdb.SetLittleEndian(littleEndian) + buf := newTestBuffer() + err := buf.Table(testTable).Float64Array2DColumn("array_col", tc.values).At(time.Time{}, false) + assert.NoError(t, err) + assert.Equal(t, float642DArrayToByte(testTable, "array_col", tc.values), buf.Messages()) + } + }) + } +} + +func TestFloat64Array2DColumnIrregularShape(t *testing.T) { + buf := newTestBuffer() + + irregularArray := [][]float64{{1.0, 2.0}, {3.0, 4.0, 5.0}} + err := buf.Table(testTable).Float64Array2DColumn("array_col", irregularArray).At(time.Time{}, false) + + assert.ErrorContains(t, err, "irregular 2D array shape") + assert.Empty(t, buf.Messages()) +} + +func TestFloat64Array3DColumn(t *testing.T) { + testCases := []struct { + name string + values [][][]float64 + }{ + {"2x2x2 array", [][][]float64{{{1.1, 2.2}, {3.3, 4.4}}, {{5.5, 6.6}, {7.7, 8.8}}}}, + {"1x1x3 array", [][][]float64{{{1.0, 2.0, 3.0}}}}, + {"empty array", [][][]float64{}}, + {"null array", nil}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, littleEndian := range []bool{true, false} { + qdb.SetLittleEndian(littleEndian) + buf := newTestBuffer() + err := buf.Table(testTable).Float64Array3DColumn("array_col", tc.values).At(time.Time{}, false) + assert.NoError(t, err) + assert.Equal(t, float643DArrayToByte(testTable, "array_col", tc.values), buf.Messages()) + } + }) + } +} + +func TestFloat64Array3DColumnIrregularShape(t *testing.T) { + buf := newTestBuffer() + irregularArray := [][][]float64{{{1.0, 2.0}, {3.0}}, {{4.0, 5.0}, {6.0, 7.0}}} + err := buf.Table(testTable).Float64Array3DColumn("array_col", irregularArray).At(time.Time{}, false) + assert.ErrorContains(t, err, "irregular 3D array shape: level2[0][1] has length 1, expected 2") + assert.Empty(t, buf.Messages()) + irregularArray2 := [][][]float64{{{1.0, 2.0}, {3.0, 4.0}}, {{4.0, 5.0}}} + err = buf.Table(testTable).Float64Array3DColumn("array_col", irregularArray2).At(time.Time{}, false) + assert.ErrorContains(t, err, "irregular 3D array shape: level1[1] has length 1, expected 2") + assert.Empty(t, buf.Messages()) +} + +func TestFloat64ArrayNDColumn(t *testing.T) { + testCases := []struct { + name string + shape []uint + data []float64 + }{ + {"1D array", []uint{3}, []float64{1.0, 2.0, 3.0}}, + {"2D array", []uint{2, 2}, []float64{1.0, 2.0, 3.0, 4.0}}, + {"3D array", []uint{2, 1, 2}, []float64{1.0, 2.0, 3.0, 4.0}}, + {"null array", nil, nil}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, littleEndian := range []bool{true, false} { + qdb.SetLittleEndian(littleEndian) + buf := newTestBuffer() + var err error + if tc.data == nil { + err = buf.Table(testTable).Float64ArrayNDColumn("ndarray_col", nil).At(time.Time{}, false) + assert.NoError(t, err) + assert.Equal(t, float64NDArrayToByte(testTable, "ndarray_col", nil), buf.Messages()) + } else { + ndArray, err := qdb.NewNDArray[float64](tc.shape...) + assert.NoError(t, err) + for _, val := range tc.data { + ndArray.Append(val) + } + err = buf.Table(testTable).Float64ArrayNDColumn("ndarray_col", ndArray).At(time.Time{}, false) + assert.NoError(t, err) + assert.Equal(t, float64NDArrayToByte(testTable, "ndarray_col", ndArray), buf.Messages()) + } + } + }) + } +} + +func TestFloat64Array1DColumnExceedsMaxElements(t *testing.T) { + buf := newTestBuffer() + largeSize := qdb.MaxArrayElements + 1 + values := make([]float64, largeSize) + + err := buf.Table(testTable).Float64Array1DColumn("array_col", values).At(time.Time{}, false) + assert.ErrorContains(t, err, "array size 268435456 exceeds maximum limit 268435455") + assert.Empty(t, buf.Messages()) +} + +func TestFloat64Array2DColumnExceedsMaxElements(t *testing.T) { + buf := newTestBuffer() + dim1 := 65536 + dim2 := 4097 + values := make([][]float64, dim1) + for i := range values { + values[i] = make([]float64, dim2) + } + + err := buf.Table(testTable).Float64Array2DColumn("array_col", values).At(time.Time{}, false) + assert.ErrorContains(t, err, "array size 268500992 exceeds maximum limit 268435455") + assert.Empty(t, buf.Messages()) +} + +func TestFloat64Array3DColumnExceedsMaxElements(t *testing.T) { + buf := newTestBuffer() + dim1 := 1024 + dim2 := 1024 + dim3 := 256 + values := make([][][]float64, dim1) + for i := range values { + values[i] = make([][]float64, dim2) + for j := range values[i] { + values[i][j] = make([]float64, dim3) + } + } + + err := buf.Table(testTable).Float64Array3DColumn("array_col", values).At(time.Time{}, false) + assert.ErrorContains(t, err, "array size 268435456 exceeds maximum limit 268435455") + assert.Empty(t, buf.Messages()) +} + +func float64ToByte(table, col string, val float64) []byte { + buf := make([]byte, 0, 128) + buf = append(buf, ([]byte)(table)...) + buf = append(buf, ' ') + buf = append(buf, ([]byte)(col)...) + buf = append(buf, '=') + buf = append(buf, '=') + buf = append(buf, 16) + buf1 := make([]byte, 8) + binary.LittleEndian.PutUint64(buf1, math.Float64bits(val)) + buf = append(buf, buf1...) + buf = append(buf, '\n') + return buf +} + +func nullArrayToByte(table, col string) []byte { + buf := make([]byte, 0, 128) + buf = append(buf, ([]byte)(table)...) + buf = append(buf, ' ') + buf = append(buf, ([]byte)(col)...) + buf = append(buf, '=') + buf = append(buf, '=') + buf = append(buf, 14) + buf = append(buf, 33) + buf = append(buf, '\n') + return buf +} + +func float641DArrayToByte(table, col string, vals []float64) []byte { + if vals == nil { + return nullArrayToByte(table, col) + } + buf := make([]byte, 0, 128) + buf = append(buf, ([]byte)(table)...) + buf = append(buf, ' ') + buf = append(buf, ([]byte)(col)...) + buf = append(buf, '=') + buf = append(buf, '=') + buf = append(buf, 14) + buf = append(buf, 10) + buf = append(buf, 1) + + shapeData := make([]byte, 4) + binary.LittleEndian.PutUint32(shapeData, uint32(len(vals))) + buf = append(buf, shapeData...) + + // Write values + for _, val := range vals { + valData := make([]byte, 8) + binary.LittleEndian.PutUint64(valData, math.Float64bits(val)) + buf = append(buf, valData...) + } + buf = append(buf, '\n') + return buf +} + +func float642DArrayToByte(table, col string, vals [][]float64) []byte { + if vals == nil { + return nullArrayToByte(table, col) + } + buf := make([]byte, 0, 256) + buf = append(buf, ([]byte)(table)...) + buf = append(buf, ' ') + buf = append(buf, ([]byte)(col)...) + buf = append(buf, '=') + buf = append(buf, '=') + buf = append(buf, 14) + buf = append(buf, 10) + buf = append(buf, 2) + + dim1 := len(vals) + var dim2 int + if dim1 > 0 { + dim2 = len(vals[0]) + } + + shapeData := make([]byte, 8) + binary.LittleEndian.PutUint32(shapeData[:4], uint32(dim1)) + binary.LittleEndian.PutUint32(shapeData[4:8], uint32(dim2)) + buf = append(buf, shapeData...) + + for _, row := range vals { + for _, val := range row { + valData := make([]byte, 8) + binary.LittleEndian.PutUint64(valData, math.Float64bits(val)) + buf = append(buf, valData...) + } + } + buf = append(buf, '\n') + return buf +} + +func float643DArrayToByte(table, col string, vals [][][]float64) []byte { + if vals == nil { + return nullArrayToByte(table, col) + } + buf := make([]byte, 0, 512) + buf = append(buf, ([]byte)(table)...) + buf = append(buf, ' ') + buf = append(buf, ([]byte)(col)...) + buf = append(buf, '=') + buf = append(buf, '=') + buf = append(buf, 14) + buf = append(buf, 10) + buf = append(buf, 3) + + dim1 := len(vals) + var dim2, dim3 int + if dim1 > 0 { + dim2 = len(vals[0]) + if dim2 > 0 { + dim3 = len(vals[0][0]) + } + } + + shapeData := make([]byte, 12) + binary.LittleEndian.PutUint32(shapeData[:4], uint32(dim1)) + binary.LittleEndian.PutUint32(shapeData[4:8], uint32(dim2)) + binary.LittleEndian.PutUint32(shapeData[8:12], uint32(dim3)) + buf = append(buf, shapeData...) + + for _, level1 := range vals { + for _, level2 := range level1 { + for _, val := range level2 { + valData := make([]byte, 8) + binary.LittleEndian.PutUint64(valData, math.Float64bits(val)) + buf = append(buf, valData...) + } + } + } + buf = append(buf, '\n') + return buf +} + +func float64NDArrayToByte(table, col string, ndarray *qdb.NdArray[float64]) []byte { + if ndarray == nil { + return nullArrayToByte(table, col) + } + buf := make([]byte, 0, 512) + buf = append(buf, ([]byte)(table)...) + buf = append(buf, ' ') + buf = append(buf, ([]byte)(col)...) + buf = append(buf, '=') + buf = append(buf, '=') + buf = append(buf, 14) + buf = append(buf, 10) + buf = append(buf, byte(ndarray.NDims())) + + shape := ndarray.Shape() + for _, dim := range shape { + shapeData := make([]byte, 4) + binary.LittleEndian.PutUint32(shapeData, uint32(dim)) + buf = append(buf, shapeData...) + } + + data := ndarray.Data() + for _, val := range data { + valData := make([]byte, 8) + binary.LittleEndian.PutUint64(valData, math.Float64bits(val)) + buf = append(buf, valData...) + } + buf = append(buf, '\n') + return buf +} diff --git a/conf_parse.go b/conf_parse.go index 5337a14..7957dc5 100644 --- a/conf_parse.go +++ b/conf_parse.go @@ -162,6 +162,18 @@ func confFromStr(conf string) (*lineSenderConfig, error) { return nil, NewInvalidConfigStrError("tls_roots is not available in the go client") case "tls_roots_password": return nil, NewInvalidConfigStrError("tls_roots_password is not available in the go client") + case "protocol_version": + if v != "auto" { + version, err := strconv.Atoi(v) + if err != nil { + return nil, NewInvalidConfigStrError("invalid %s value, %q is not a valid int", k, v) + } + pVersion := protocolVersion(version) + if pVersion < ProtocolVersion1 || pVersion > ProtocolVersion2 { + return nil, NewInvalidConfigStrError("current client only supports protocol version 1 (text format for all datatypes), 2 (binary format for part datatypes) or explicitly unset") + } + senderConf.protocolVersion = pVersion + } default: return nil, NewInvalidConfigStrError("unsupported option %q", k) } diff --git a/conf_test.go b/conf_test.go index 4aab4a9..37018dd 100644 --- a/conf_test.go +++ b/conf_test.go @@ -201,6 +201,17 @@ func TestParserHappyCases(t *testing.T) { }, }, }, + { + name: "protocol version", + config: fmt.Sprintf("http::addr=%s;protocol_version=1;", addr), + expected: qdb.ConfigData{ + Schema: "http", + KeyValuePairs: map[string]string{ + "addr": addr, + "protocol_version": "1", + }, + }, + }, { name: "equal sign in password", config: fmt.Sprintf("http::addr=%s;username=%s;password=pass=word;", addr, user), @@ -288,15 +299,16 @@ type configTestCase struct { func TestHappyCasesFromConf(t *testing.T) { var ( - addr = "localhost:1111" - user = "test-user" - pass = "test-pass" - token = "test-token" - minThroughput = 999 - requestTimeout = time.Second * 88 - retryTimeout = time.Second * 99 - initBufSize = 256 - maxBufSize = 1024 + addr = "localhost:1111" + user = "test-user" + pass = "test-pass" + token = "test-token" + minThroughput = 999 + requestTimeout = time.Second * 88 + retryTimeout = time.Second * 99 + initBufSize = 256 + maxBufSize = 1024 + protocolVersion = qdb.ProtocolVersion2 ) testCases := []configTestCase{ @@ -382,6 +394,16 @@ func TestHappyCasesFromConf(t *testing.T) { qdb.WithMinThroughput(minThroughput), }, }, + { + name: "protocol_version", + config: fmt.Sprintf("http::addr=%s;protocol_version=%d;", + addr, protocolVersion), + expectedOpts: []qdb.LineSenderOption{ + qdb.WithHttp(), + qdb.WithAddress(addr), + qdb.WithProtocolVersion(protocolVersion), + }, + }, { name: "bearer token", config: fmt.Sprintf("http::addr=%s;token=%s", diff --git a/examples/from-conf/main.go b/examples/from-conf/main.go index 39ae094..08437b1 100644 --- a/examples/from-conf/main.go +++ b/examples/from-conf/main.go @@ -27,12 +27,21 @@ func main() { if err != nil { log.Fatal(err) } + // Prepare array data. + // QuestDB server version 9.0.0 or later is required for array support. + array, err := qdb.NewNDArray[float64](2, 3, 2) + if err != nil { + log.Fatal(err) + } + // Fill array with value 87.2 + array.Fill(87.2) err = sender. Table("trades"). Symbol("symbol", "ETH-USD"). Symbol("side", "sell"). Float64Column("price", 2615.54). Float64Column("amount", 0.00044). + Float64ArrayNDColumn("price_history", array). At(ctx, tradedTs) if err != nil { log.Fatal(err) @@ -42,12 +51,25 @@ func main() { if err != nil { log.Fatal(err) } + + // Reuse array by resetting index and appending new values sequentially + hasMore := true + array.ResetAppendIndex() + val := 200.0 + for hasMore { + hasMore, err = array.Append(val + 1) + val = val + 1 + if err != nil { + log.Fatal(err) + } + } err = sender. Table("trades"). Symbol("symbol", "BTC-USD"). Symbol("side", "sell"). Float64Column("price", 39269.98). Float64Column("amount", 0.001). + Float64ArrayNDColumn("price_history", array). At(ctx, tradedTs) if err != nil { log.Fatal(err) diff --git a/examples/http/basic/main.go b/examples/http/basic/main.go index 4d84f17..83c0370 100644 --- a/examples/http/basic/main.go +++ b/examples/http/basic/main.go @@ -28,17 +28,38 @@ func main() { if err != nil { log.Fatal(err) } + + // Prepare array data. + // QuestDB server version 9.0.0 or later is required for array support. + array, err := qdb.NewNDArray[float64](2, 3, 2) + if err != nil { + log.Fatal(err) + } + // Fill array with value 87.2 + array.Fill(87.2) err = sender. Table("trades"). Symbol("symbol", "ETH-USD"). Symbol("side", "sell"). Float64Column("price", 2615.54). Float64Column("amount", 0.00044). + Float64ArrayNDColumn("price_history", array). At(ctx, tradedTs) if err != nil { log.Fatal(err) } + // Reuse array by resetting index and appending new values sequentially + hasMore := true + array.ResetAppendIndex() + val := 200.0 + for hasMore { + hasMore, err = array.Append(val + 1) + val = val + 1 + if err != nil { + log.Fatal(err) + } + } tradedTs, err = time.Parse(time.RFC3339, "2022-08-06T15:04:06.987654Z") if err != nil { log.Fatal(err) @@ -49,6 +70,7 @@ func main() { Symbol("side", "sell"). Float64Column("price", 39269.98). Float64Column("amount", 0.001). + Float64ArrayNDColumn("price_history", array). At(ctx, tradedTs) if err != nil { log.Fatal(err) diff --git a/examples/tcp/basic/main.go b/examples/tcp/basic/main.go index 716c4de..5fdc632 100644 --- a/examples/tcp/basic/main.go +++ b/examples/tcp/basic/main.go @@ -11,7 +11,7 @@ import ( func main() { ctx := context.TODO() // Connect to QuestDB running on 127.0.0.1:9009 - sender, err := qdb.NewLineSender(ctx, qdb.WithTcp()) + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) if err != nil { log.Fatal(err) } @@ -23,17 +23,40 @@ func main() { if err != nil { log.Fatal(err) } + + // Prepare array data. + // QuestDB server version 9.0.0 or later is required for array support. + array, err := qdb.NewNDArray[float64](2, 3, 2) + if err != nil { + log.Fatal(err) + } + // Fill array with value 87.2 + array.Fill(87.2) + err = sender. Table("trades"). Symbol("symbol", "ETH-USD"). Symbol("side", "sell"). Float64Column("price", 2615.54). Float64Column("amount", 0.00044). + Float64ArrayNDColumn("price_history", array). At(ctx, tradedTs) if err != nil { log.Fatal(err) } + // Reuse array by resetting index and appending new values sequentially + hasMore := true + array.ResetAppendIndex() + val := 200.0 + for hasMore { + hasMore, err = array.Append(val + 1) + val = val + 1 + if err != nil { + log.Fatal(err) + } + } + tradedTs, err = time.Parse(time.RFC3339, "2022-08-06T15:04:06.987654Z") if err != nil { log.Fatal(err) @@ -44,6 +67,7 @@ func main() { Symbol("side", "sell"). Float64Column("price", 39269.98). Float64Column("amount", 0.001). + Float64ArrayNDColumn("price_history", array). At(ctx, tradedTs) if err != nil { log.Fatal(err) diff --git a/export_test.go b/export_test.go index eebd655..640c473 100644 --- a/export_test.go +++ b/export_test.go @@ -53,40 +53,86 @@ func ConfFromStr(conf string) (*LineSenderConfig, error) { return confFromStr(conf) } -func Messages(s LineSender) string { +func Messages(s LineSender) []byte { + if ps, ok := s.(*pooledSender); ok { + s = ps.wrapped + } if hs, ok := s.(*httpLineSender); ok { return hs.Messages() } + if hs, ok := s.(*httpLineSenderV2); ok { + return hs.Messages() + } if ts, ok := s.(*tcpLineSender); ok { return ts.Messages() } + if ts, ok := s.(*tcpLineSenderV2); ok { + return ts.Messages() + } panic("unexpected struct") } func MsgCount(s LineSender) int { if ps, ok := s.(*pooledSender); ok { - hs, _ := ps.wrapped.(*httpLineSender) - return hs.MsgCount() + s = ps } if hs, ok := s.(*httpLineSender); ok { return hs.MsgCount() } + if hs, ok := s.(*httpLineSenderV2); ok { + return hs.MsgCount() + } if ts, ok := s.(*tcpLineSender); ok { return ts.MsgCount() } + if ts, ok := s.(*tcpLineSenderV2); ok { + return ts.MsgCount() + } panic("unexpected struct") } func BufLen(s LineSender) int { + if ps, ok := s.(*pooledSender); ok { + s = ps + } if hs, ok := s.(*httpLineSender); ok { return hs.BufLen() } + if hs, ok := s.(*httpLineSenderV2); ok { + return hs.BufLen() + } if ts, ok := s.(*tcpLineSender); ok { return ts.BufLen() } + if ts, ok := s.(*tcpLineSenderV2); ok { + return ts.BufLen() + } + panic("unexpected struct") +} + +func ProtocolVersion(s LineSender) protocolVersion { + if ps, ok := s.(*pooledSender); ok { + s = ps + } + if _, ok := s.(*httpLineSender); ok { + return ProtocolVersion1 + } + if _, ok := s.(*httpLineSenderV2); ok { + return ProtocolVersion2 + } + if _, ok := s.(*tcpLineSender); ok { + return ProtocolVersion1 + } + if _, ok := s.(*tcpLineSenderV2); ok { + return ProtocolVersion2 + } panic("unexpected struct") } func NewLineSenderConfig(t SenderType) *LineSenderConfig { return newLineSenderConfig(t) } + +func SetLittleEndian(littleEndian bool) { + isLittleEndian = littleEndian +} diff --git a/http_sender.go b/http_sender.go index 2142961..69f22ff 100644 --- a/http_sender.go +++ b/http_sender.go @@ -29,6 +29,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "math/big" @@ -114,9 +115,12 @@ type httpLineSender struct { globalTransport *globalHttpTransport } -func newHttpLineSender(conf *lineSenderConfig) (*httpLineSender, error) { - var transport *http.Transport +type httpLineSenderV2 struct { + httpLineSender +} +func newHttpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, error) { + var transport *http.Transport s := &httpLineSender{ address: conf.address, minThroughputBytesPerSecond: conf.minThroughput, @@ -155,13 +159,29 @@ func newHttpLineSender(conf *lineSenderConfig) (*httpLineSender, error) { s.globalTransport.RegisterClient() } + // auto detect server line protocol version + pVersion := conf.protocolVersion + if pVersion == protocolVersionUnset { + var err error + pVersion, err = s.detectProtocolVersion(ctx, conf) + if err != nil { + return nil, err + } + } + s.uri = "http" if conf.tlsMode != tlsDisabled { s.uri += "s" } s.uri += fmt.Sprintf("://%s/write", s.address) - return s, nil + if pVersion == ProtocolVersion1 { + return s, nil + } else { + return &httpLineSenderV2{ + *s, + }, nil + } } func (s *httpLineSender) Flush(ctx context.Context) error { @@ -282,6 +302,26 @@ func (s *httpLineSender) BoolColumn(name string, val bool) LineSender { return s } +func (s *httpLineSender) Float64Array1DColumn(name string, values []float64) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support double-array")) + return s +} + +func (s *httpLineSender) Float64Array2DColumn(name string, values [][]float64) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support double-array")) + return s +} + +func (s *httpLineSender) Float64Array3DColumn(name string, values [][][]float64) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support double-array")) + return s +} + +func (s *httpLineSender) Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support double-array")) + return s +} + func (s *httpLineSender) Close(ctx context.Context) error { if s.closed { return errDoubleSenderClose @@ -399,6 +439,87 @@ func (s *httpLineSender) makeRequest(ctx context.Context) (bool, error) { } +func (s *httpLineSender) detectProtocolVersion(ctx context.Context, conf *lineSenderConfig) (protocolVersion, error) { + scheme := "http" + if conf.tlsMode != tlsDisabled { + scheme = "https" + } + settingsUri := fmt.Sprintf("%s://%s/settings", scheme, s.address) + + req, err := http.NewRequest(http.MethodGet, settingsUri, nil) + if err != nil { + return protocolVersionUnset, err + } + + reqCtx, cancel := context.WithTimeout(ctx, s.requestTimeout) + defer cancel() + req = req.WithContext(reqCtx) + + resp, err := s.client.Do(req) + if err != nil { + return protocolVersionUnset, err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case 404: + return ProtocolVersion1, nil + case 200: + return parseServerSettings(resp, conf) + default: + buf, _ := io.ReadAll(resp.Body) + if len(buf) > 1024 { + buf = append(buf[:1024], []byte("...")...) + } + return protocolVersionUnset, fmt.Errorf("failed to detect server line protocol version [http-status=%d, http-message=%s]", + resp.StatusCode, string(buf)) + } +} + +func parseServerSettings(resp *http.Response, conf *lineSenderConfig) (protocolVersion, error) { + buf, err := io.ReadAll(resp.Body) + if err != nil { + return protocolVersionUnset, fmt.Errorf("%d: %s", resp.StatusCode, resp.Status) + } + + var settings struct { + Config struct { + LineProtoSupportVersions []int `json:"line.proto.support.versions"` + MaxFileNameLength int `json:"cairo.max.file.name.length"` + } `json:"config"` + } + + if err := json.Unmarshal(buf, &settings); err != nil { + return ProtocolVersion1, nil + } + + // Update file name limit if provided by server + if settings.Config.MaxFileNameLength != 0 { + conf.fileNameLimit = settings.Config.MaxFileNameLength + } + + // Determine protocol version based on server support + versions := settings.Config.LineProtoSupportVersions + if len(versions) == 0 { + return ProtocolVersion1, nil + } + + hasProtocolVersion1 := false + for _, version := range versions { + if version == 2 { + return ProtocolVersion2, nil + } + if version == 1 { + hasProtocolVersion1 = true + } + } + if hasProtocolVersion1 { + return ProtocolVersion1, nil + } + + return protocolVersionUnset, errors.New("server does not support current client") +} + func isRetryableError(statusCode int) bool { switch statusCode { case 500, // Internal Server Error @@ -416,9 +537,9 @@ func isRetryableError(statusCode int) bool { } } -// Messages returns a copy of accumulated ILP messages that are not +// Messages returns the accumulated ILP messages that are not // flushed to the TCP connection yet. Useful for debugging purposes. -func (s *httpLineSender) Messages() string { +func (s *httpLineSender) Messages() []byte { return s.buf.Messages() } @@ -431,3 +552,63 @@ func (s *httpLineSender) MsgCount() int { func (s *httpLineSender) BufLen() int { return s.buf.Len() } + +func (s *httpLineSenderV2) Table(name string) LineSender { + s.buf.Table(name) + return s +} + +func (s *httpLineSenderV2) Symbol(name, val string) LineSender { + s.buf.Symbol(name, val) + return s +} + +func (s *httpLineSenderV2) Int64Column(name string, val int64) LineSender { + s.buf.Int64Column(name, val) + return s +} + +func (s *httpLineSenderV2) Long256Column(name string, val *big.Int) LineSender { + s.buf.Long256Column(name, val) + return s +} + +func (s *httpLineSenderV2) TimestampColumn(name string, ts time.Time) LineSender { + s.buf.TimestampColumn(name, ts) + return s +} + +func (s *httpLineSenderV2) StringColumn(name, val string) LineSender { + s.buf.StringColumn(name, val) + return s +} + +func (s *httpLineSenderV2) BoolColumn(name string, val bool) LineSender { + s.buf.BoolColumn(name, val) + return s +} + +func (s *httpLineSenderV2) Float64Column(name string, val float64) LineSender { + s.buf.Float64ColumnBinary(name, val) + return s +} + +func (s *httpLineSenderV2) Float64Array1DColumn(name string, values []float64) LineSender { + s.buf.Float64Array1DColumn(name, values) + return s +} + +func (s *httpLineSenderV2) Float64Array2DColumn(name string, values [][]float64) LineSender { + s.buf.Float64Array2DColumn(name, values) + return s +} + +func (s *httpLineSenderV2) Float64Array3DColumn(name string, values [][][]float64) LineSender { + s.buf.Float64Array3DColumn(name, values) + return s +} + +func (s *httpLineSenderV2) Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender { + s.buf.Float64ArrayNDColumn(name, values) + return s +} diff --git a/http_sender_test.go b/http_sender_test.go index 67d12ad..008ac6b 100644 --- a/http_sender_test.go +++ b/http_sender_test.go @@ -57,27 +57,27 @@ func TestHttpHappyCasesFromConf(t *testing.T) { testCases := []httpConfigTestCase{ { name: "request_timeout and retry_timeout milli conversion", - config: fmt.Sprintf("http::addr=%s;request_timeout=%d;retry_timeout=%d;", + config: fmt.Sprintf("http::addr=%s;request_timeout=%d;retry_timeout=%d;protocol_version=2;", addr, request_timeout.Milliseconds(), retry_timeout.Milliseconds()), }, { name: "pass before user", - config: fmt.Sprintf("http::addr=%s;password=%s;username=%s;", + config: fmt.Sprintf("http::addr=%s;password=%s;username=%s;protocol_version=2;", addr, pass, user), }, { name: "request_min_throughput", - config: fmt.Sprintf("http::addr=%s;request_min_throughput=%d;", + config: fmt.Sprintf("http::addr=%s;request_min_throughput=%d;protocol_version=2;", addr, min_throughput), }, { name: "bearer token", - config: fmt.Sprintf("http::addr=%s;token=%s;", + config: fmt.Sprintf("http::addr=%s;token=%s;protocol_version=2;", addr, token), }, { name: "auto flush", - config: fmt.Sprintf("http::addr=%s;auto_flush_rows=100;auto_flush_interval=1000;", + config: fmt.Sprintf("http::addr=%s;auto_flush_rows=100;auto_flush_interval=1000;protocol_version=2;", addr), }, } @@ -100,11 +100,11 @@ func TestHttpHappyCasesFromEnv(t *testing.T) { testCases := []httpConfigTestCase{ { name: "addr only", - config: fmt.Sprintf("http::addr=%s", addr), + config: fmt.Sprintf("http::addr=%s;protocol_version=1;", addr), }, { name: "auto flush", - config: fmt.Sprintf("http::addr=%s;auto_flush_rows=100;auto_flush_interval=1000;", + config: fmt.Sprintf("http::addr=%s;auto_flush_rows=100;auto_flush_interval=1000;protocol_version=2;", addr), }, } @@ -168,6 +168,11 @@ func TestHttpPathologicalCasesFromConf(t *testing.T) { config: "hTtp::addr=localhost:1234;", expectedErr: "invalid schema", }, + { + name: "protocol version", + config: "http::protocol_version=abc;", + expectedErr: "invalid protocol_version value", + }, } for _, tc := range testCases { @@ -729,13 +734,13 @@ func TestBufferClearAfterFlush(t *testing.T) { func TestCustomTransportAndTlsInit(t *testing.T) { ctx := context.Background() - s1, err := qdb.NewLineSender(ctx, qdb.WithHttp()) + s1, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) assert.NoError(t, err) - s2, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithTls()) + s2, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithTls(), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) assert.NoError(t, err) - s3, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithTlsInsecureSkipVerify()) + s3, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithTlsInsecureSkipVerify(), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) assert.NoError(t, err) transport := http.Transport{} @@ -744,6 +749,7 @@ func TestCustomTransportAndTlsInit(t *testing.T) { qdb.WithHttp(), qdb.WithHttpTransport(&transport), qdb.WithTls(), + qdb.WithProtocolVersion(qdb.ProtocolVersion2), ) assert.NoError(t, err) @@ -763,6 +769,143 @@ func TestCustomTransportAndTlsInit(t *testing.T) { assert.Equal(t, int64(0), qdb.GlobalTransport.ClientCount()) } +func TestAutoDetectProtocolVersionOldServer1(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", nil) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.Equal(t, qdb.ProtocolVersion(sender), qdb.ProtocolVersion1) +} + +func TestAutoDetectProtocolVersionOldServer2(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.Equal(t, qdb.ProtocolVersion(sender), qdb.ProtocolVersion1) +} + +func TestAutoDetectProtocolVersionOldServer3(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{1}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.Equal(t, qdb.ProtocolVersion(sender), qdb.ProtocolVersion1) +} + +func TestAutoDetectProtocolVersionNewServer1(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{1, 2}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.Equal(t, qdb.ProtocolVersion(sender), qdb.ProtocolVersion2) +} + +func TestAutoDetectProtocolVersionNewServer2(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.Equal(t, qdb.ProtocolVersion(sender), qdb.ProtocolVersion2) +} + +func TestAutoDetectProtocolVersionNewServer3(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{2, 3}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.Equal(t, qdb.ProtocolVersion(sender), qdb.ProtocolVersion2) +} + +func TestAutoDetectProtocolVersionNewServer4(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{3}) + assert.NoError(t, err) + defer srv.Close() + _, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.ErrorContains(t, err, "server does not support current client") +} + +func TestAutoDetectProtocolVersionError(t *testing.T) { + ctx := context.Background() + + srv, err := newTestHttpServerWithErrMsg(readAndDiscard, "Internal error") + assert.NoError(t, err) + defer srv.Close() + _, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.ErrorContains(t, err, "failed to detect server line protocol version [http-status=500, http-message={\"code\":\"500\",\"message\":\"Internal error\"}]") +} + +func TestSpecifyProtocolVersion(t *testing.T) { + ctx := context.Background() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{1, 2}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) + assert.Equal(t, qdb.ProtocolVersion(sender), qdb.ProtocolVersion1) +} + +func TestArrayColumnUnsupportedInHttpProtocolV1(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond)) + defer cancel() + + srv, err := newTestServerWithProtocol(readAndDiscard, "http", []int{1}) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + assert.NoError(t, err) + defer sender.Close(ctx) + + values1D := []float64{1.0, 2.0, 3.0, 4.0, 5.0} + values2D := [][]float64{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} + values3D := [][][]float64{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}} + arrayND, err := qdb.NewNDArray[float64](2, 2, 1, 2) + assert.NoError(t, err) + arrayND.Fill(11.0) + + err = sender. + Table(testTable). + Float64Array1DColumn("array_1d", values1D). + At(ctx, time.UnixMicro(1)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support double-array") + + err = sender. + Table(testTable). + Float64Array2DColumn("array_2d", values2D). + At(ctx, time.UnixMicro(2)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support double-array") + + err = sender. + Table(testTable). + Float64Array3DColumn("array_3d", values3D). + At(ctx, time.UnixMicro(3)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support double-array") + + err = sender. + Table(testTable). + Float64ArrayNDColumn("array_nd", arrayND). + At(ctx, time.UnixMicro(4)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support double-array") +} + func BenchmarkHttpLineSenderBatch1000(b *testing.B) { ctx := context.Background() @@ -773,6 +916,12 @@ func BenchmarkHttpLineSenderBatch1000(b *testing.B) { sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) assert.NoError(b, err) + values1D := []float64{1.0, 2.0, 3.0, 4.0, 5.0} + values2D := [][]float64{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} + values3D := [][][]float64{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}} + arrayND, _ := qdb.NewNDArray[float64](2, 3) + arrayND.Fill(10.0) + b.ResetTimer() for i := 0; i < b.N; i++ { for j := 0; j < 1000; j++ { @@ -784,6 +933,10 @@ func BenchmarkHttpLineSenderBatch1000(b *testing.B) { StringColumn("str_col", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua"). BoolColumn("bool_col", true). TimestampColumn("timestamp_col", time.UnixMicro(42)). + Float64Array1DColumn("array_1d", values1D). + Float64Array2DColumn("array_2d", values2D). + Float64Array3DColumn("array_3d", values3D). + Float64ArrayNDColumn("array_nd", arrayND). At(ctx, time.UnixMicro(int64(1000*i))) } sender.Flush(ctx) @@ -801,6 +954,12 @@ func BenchmarkHttpLineSenderNoFlush(b *testing.B) { sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) assert.NoError(b, err) + values1D := []float64{1.0, 2.0, 3.0, 4.0, 5.0} + values2D := [][]float64{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} + values3D := [][][]float64{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}} + arrayND, _ := qdb.NewNDArray[float64](2, 3) + arrayND.Fill(10) + b.ResetTimer() for i := 0; i < b.N; i++ { sender. @@ -811,6 +970,10 @@ func BenchmarkHttpLineSenderNoFlush(b *testing.B) { StringColumn("str_col", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua"). BoolColumn("bool_col", true). TimestampColumn("timestamp_col", time.UnixMicro(42)). + Float64Array1DColumn("array_1d", values1D). + Float64Array2DColumn("array_2d", values2D). + Float64Array3DColumn("array_3d", values3D). + Float64ArrayNDColumn("array_nd", arrayND). At(ctx, time.UnixMicro(int64(1000*i))) } sender.Flush(ctx) diff --git a/integration_test.go b/integration_test.go index 09df41d..a81d745 100644 --- a/integration_test.go +++ b/integration_test.go @@ -27,6 +27,7 @@ package questdb_test import ( "context" "fmt" + "math" "math/big" "path/filepath" "reflect" @@ -132,7 +133,7 @@ func setupQuestDB0(ctx context.Context, auth ilpAuthType, setupProxy bool) (*que return nil, err } req := testcontainers.ContainerRequest{ - Image: "questdb/questdb:7.4.2", + Image: "questdb/questdb:9.0.2", ExposedPorts: []string{"9000/tcp", "9009/tcp"}, WaitingFor: wait.ForHTTP("/").WithPort("9000"), Networks: []string{networkName}, @@ -298,7 +299,7 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { {"double_col", "DOUBLE"}, {"long_col", "LONG"}, {"long256_col", "LONG256"}, - {"str_col", "STRING"}, + {"str_col", "VARCHAR"}, {"bool_col", "BOOLEAN"}, {"timestamp_col", "TIMESTAMP"}, {"timestamp", "TIMESTAMP"}, @@ -312,10 +313,10 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { }, { "escaped chars", - "my-awesome_test 1=2.csv", + "m y-awesome_test 1=2.csv", func(s qdb.LineSender) error { return s. - Table("my-awesome_test 1=2.csv"). + Table("m y-awesome_test 1=2.csv"). Symbol("sym_name 1=2", "value 1,2=3\n4\r5\"6\\7"). StringColumn("str_name 1=2", "value 1,2=3\n4\r5\"6\\7"). At(ctx, time.UnixMicro(1)) @@ -323,7 +324,7 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { tableData{ Columns: []column{ {"sym_name 1=2", "SYMBOL"}, - {"str_name 1=2", "STRING"}, + {"str_name 1=2", "VARCHAR"}, {"timestamp", "TIMESTAMP"}, }, Dataset: [][]interface{}{ @@ -413,44 +414,139 @@ func (suite *integrationTestSuite) TestE2EValidWrites() { Count: 1, }, }, + { + "double array", + testTable, + func(s qdb.LineSender) error { + values1D := []float64{1.0, 2.0, 3.0, 4.0, 5.0, math.NaN()} + values2D := [][]float64{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}, {math.NaN(), math.NaN()}} + values3D := [][][]float64{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, math.NaN()}}} + arrayND, _ := qdb.NewNDArray[float64](2, 2, 1, 2) + arrayND.Fill(11.0) + arrayND.Set(math.NaN(), 1, 1, 0, 1) + + err := s. + Table(testTable). + Float64Array1DColumn("array_1d", values1D). + Float64Array2DColumn("array_2d", values2D). + Float64Array3DColumn("array_3d", values3D). + Float64ArrayNDColumn("array_nd", arrayND). + At(ctx, time.UnixMicro(1)) + if err != nil { + return err + } + // empty array + emptyNdArray, _ := qdb.NewNDArray[float64](2, 2, 0, 2) + err = s. + Table(testTable). + Float64Array1DColumn("array_1d", []float64{}). + Float64Array2DColumn("array_2d", [][]float64{{}}). + Float64Array3DColumn("array_3d", [][][]float64{{{}}}). + Float64ArrayNDColumn("array_nd", emptyNdArray). + At(ctx, time.UnixMicro(2)) + if err != nil { + return err + } + // null array + return s. + Table(testTable). + Float64Array1DColumn("array_1d", nil). + Float64Array2DColumn("array_2d", nil). + Float64Array3DColumn("array_3d", nil). + Float64ArrayNDColumn("array_nd", nil). + At(ctx, time.UnixMicro(3)) + }, + tableData{ + Columns: []column{ + {"array_1d", "ARRAY"}, + {"array_2d", "ARRAY"}, + {"array_3d", "ARRAY"}, + {"array_nd", "ARRAY"}, + {"timestamp", "TIMESTAMP"}, + }, + Dataset: [][]interface{}{ + { + []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5), nil}, + []interface{}{[]interface{}{float64(1), float64(2)}, []interface{}{float64(3), float64(4)}, []interface{}{float64(5), float64(6)}, []interface{}{nil, nil}}, + []interface{}{[]interface{}{[]interface{}{float64(1), float64(2)}, []interface{}{float64(3), float64(4)}}, []interface{}{[]interface{}{float64(5), float64(6)}, []interface{}{float64(7), nil}}}, + []interface{}{[]interface{}{[]interface{}{[]interface{}{float64(11), float64(11)}}, []interface{}{[]interface{}{float64(11), float64(11)}}}, []interface{}{[]interface{}{[]interface{}{float64(11), float64(11)}}, []interface{}{[]interface{}{float64(11), nil}}}}, + "1970-01-01T00:00:00.000001Z"}, + { + []interface{}{}, + []interface{}{}, + []interface{}{}, + []interface{}{}, + "1970-01-01T00:00:00.000002Z"}, + { + nil, + nil, + nil, + nil, + "1970-01-01T00:00:00.000003Z"}, + }, + Count: 3, + }, + }, } for _, tc := range testCases { for _, protocol := range []string{"tcp", "http"} { - suite.T().Run(fmt.Sprintf("%s: %s", tc.name, protocol), func(t *testing.T) { - var ( - sender qdb.LineSender - err error - ) - - questdbC, err := setupQuestDB(ctx, noAuth) - assert.NoError(t, err) - - switch protocol { - case "tcp": - sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress)) - assert.NoError(t, err) - case "http": - sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress)) + for _, pVersion := range []int{0, 1, 2} { + suite.T().Run(fmt.Sprintf("%s: %s", tc.name, protocol), func(t *testing.T) { + var ( + sender qdb.LineSender + err error + ) + + ignoreArray := false + questdbC, err := setupQuestDB(ctx, noAuth) assert.NoError(t, err) - default: - panic(protocol) - } - err = tc.writerFn(sender) - assert.NoError(t, err) + switch protocol { + case "tcp": + if pVersion == 0 { + sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress)) + ignoreArray = true + } else if pVersion == 1 { + sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) + ignoreArray = true + } else if pVersion == 2 { + sender, err = qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) + } + assert.NoError(t, err) + case "http": + if pVersion == 0 { + sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress)) + } else if pVersion == 1 { + sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) + ignoreArray = true + } else if pVersion == 2 { + sender, err = qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(questdbC.httpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) + } + assert.NoError(t, err) + default: + panic(protocol) + } + if ignoreArray && tc.name == "double array" { + return + } + + dropTable(t, tc.tableName, questdbC.httpAddress) + err = tc.writerFn(sender) + assert.NoError(t, err) - err = sender.Flush(ctx) - assert.NoError(t, err) + err = sender.Flush(ctx) + assert.NoError(t, err) - assert.Eventually(t, func() bool { - data := queryTableData(t, tc.tableName, questdbC.httpAddress) - return reflect.DeepEqual(tc.expected, data) - }, eventualDataTimeout, 100*time.Millisecond) + assert.Eventually(t, func() bool { + data := queryTableData(t, tc.tableName, questdbC.httpAddress) + return reflect.DeepEqual(tc.expected, data) + }, eventualDataTimeout, 100*time.Millisecond) - sender.Close(ctx) - questdbC.Stop(ctx) - }) + sender.Close(ctx) + questdbC.Stop(ctx) + }) + } } } } diff --git a/interop_test.go b/interop_test.go index 089c777..88f0ec0 100644 --- a/interop_test.go +++ b/interop_test.go @@ -75,7 +75,7 @@ func TestTcpClientInterop(t *testing.T) { srv, err := newTestTcpServer(sendToBackChannel) assert.NoError(t, err) - sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr())) + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) assert.NoError(t, err) sender.Table(tc.Table) @@ -129,7 +129,7 @@ func TestHttpClientInterop(t *testing.T) { srv, err := newTestHttpServer(sendToBackChannel) assert.NoError(t, err) - sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr())) + sender, err := qdb.NewLineSender(ctx, qdb.WithHttp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) assert.NoError(t, err) sender.Table(tc.Table) diff --git a/ndarray.go b/ndarray.go new file mode 100644 index 0000000..70040b7 --- /dev/null +++ b/ndarray.go @@ -0,0 +1,221 @@ +package questdb + +import ( + "errors" + "fmt" +) + +const ( + // MaxDimensions defines the maximum dims of NdArray + MaxDimensions = 32 +) + +// NdArrayElementType represents the constraint for numeric types that can be used in NdArray +type NdArrayElementType interface { + ~float64 +} + +// NdArray represents a generic n-dimensional array with shape validation. +// It's designed to be used with the [LineSender.Float64ArrayNDColumn] method for sending +// multi-dimensional arrays to QuestDB via the ILP protocol. +// +// NdArray instances are meant to be reused across multiple calls to the sender +// to avoid memory allocations. Use Append to populate data and +// ResetAppendIndex to reset the array for reuse after sending data. +// +// By default, all values in the array are initialized to zero. +type NdArray[T NdArrayElementType] struct { + data []T + shape []uint + appendIndex uint +} + +// NewNDArray creates a new NdArray with the specified shape. +// All elements are initialized to zero by default. +func NewNDArray[T NdArrayElementType](shape ...uint) (*NdArray[T], error) { + if err := validateShape(shape); err != nil { + return nil, fmt.Errorf("invalid shape: %w", err) + } + totalElements := product(shape) + data := make([]T, totalElements) + shapeSlice := make([]uint, len(shape)) + copy(shapeSlice, shape) + return &NdArray[T]{ + shape: shapeSlice, + data: data, + appendIndex: 0, + }, nil +} + +// Shape returns a copy of the array's shape +func (n *NdArray[T]) Shape() []uint { + shape := make([]uint, len(n.shape)) + copy(shape, n.shape) + return shape +} + +// NDims returns the number of dimensions +func (n *NdArray[T]) NDims() int { + return len(n.shape) +} + +// Size returns the total number of elements +func (n *NdArray[T]) Size() int { + return len(n.data) +} + +// Set sets a value at the specified multi-dimensional position +func (n *NdArray[T]) Set(v T, positions ...uint) error { + if len(positions) != n.NDims() { + return fmt.Errorf("position dimensions (%d) don't match array dimensions (%d)", len(positions), n.NDims()) + } + + index, err := n.positionsToIndex(positions) + if err != nil { + return err + } + + n.data[index] = v + return nil +} + +// Get retrieves a value at the specified multi-dimensional position +func (n *NdArray[T]) Get(positions ...uint) (T, error) { + var zero T + if len(positions) != n.NDims() { + return zero, fmt.Errorf("position dimensions (%d) don't match array dimensions (%d)", len(positions), n.NDims()) + } + + index, err := n.positionsToIndex(positions) + if err != nil { + return zero, err + } + + return n.data[index], nil +} + +// Reshape creates a new NdArray with a different shape but same data +func (n *NdArray[T]) Reshape(newShape ...uint) (*NdArray[T], error) { + if err := validateShape(newShape); err != nil { + return nil, fmt.Errorf("invalid new shape: %v", err) + } + + if uint(len(n.data)) != product(newShape) { + return nil, fmt.Errorf("new shape size (%d) doesn't match data size (%d)", + product(newShape), len(n.data)) + } + + // Create new array sharing the same data + newArray := &NdArray[T]{ + shape: make([]uint, len(newShape)), + data: n.data, // Share the same underlying data + } + copy(newArray.shape, newShape) + + return newArray, nil +} + +// Append adds a value to the array sequentially at the current append index. +// Returns true if there's more space for additional values, false if the array is now full. +// Use ResetAppendIndex() to reuse the array for multiple ILP messages. +// +// Example: +// +// arr, _ := NewNDArray[float64](2, 3) // 2x3 array (6 elements total) +// hasMore, _ := arr.Append(1.0) // hasMore = true, index now at 1 +// hasMore, _ = arr.Append(2.0) // hasMore = true, index now at 2 +// // ... append 4 more values +// hasMore, _ = arr.Append(6.0) // hasMore = false, array is full +// +// // To reuse the array: +// arr.ResetAppendIndex() +// arr.Append(10.0) // overwrites +// ... +func (n *NdArray[T]) Append(val T) (bool, error) { + if n.appendIndex >= uint(len(n.data)) { + return false, errors.New("array is full") + } + n.data[n.appendIndex] = val + n.appendIndex++ + return n.appendIndex < uint(len(n.data)), nil +} + +// ResetAppendIndex resets the append index to 0, allowing the NdArray to be reused +// for multiple append operations. This is useful for reusing arrays across multiple +// messages/rows ingestion without reallocating memory. +// +// Example: +// +// arr, _ := NewNDArray[float64](2) // 1D array with 3 elements +// arr.Append(2.0) +// arr.Append(3.0) // array is now full +// +// // sender.Float64ArrayNDColumn(arr) +// +// arr.ResetAppendIndex() // reset for reuse +// arr.Append(4.0) +// arr.Append(5.0) +func (n *NdArray[T]) ResetAppendIndex() { + n.appendIndex = 0 +} + +// Data returns the underlying data slice +func (n *NdArray[T]) Data() []T { + return n.data +} + +// Fill fills the entire array with the specified value +func (n *NdArray[T]) Fill(value T) { + for i := range n.data { + n.data[i] = value + } + n.appendIndex = uint(len(n.data)) // Mark as full +} + +func (n *NdArray[T]) positionsToIndex(positions []uint) (int, error) { + for i, pos := range positions { + if pos >= n.shape[i] { + return 0, fmt.Errorf("position[%d]=%d is out of bounds for dimension size %d", + i, pos, n.shape[i]) + } + } + + index := 0 + for i, pos := range positions { + index += int(pos) * int(product(n.shape[i+1:])) + } + return index, nil +} + +func validateShape(shape []uint) error { + if len(shape) == 0 { + return errors.New("shape cannot be empty") + } + + if len(shape) > MaxDimensions { + return fmt.Errorf("too many dimensions: %d exceeds maximum of %d", + len(shape), MaxDimensions) + } + + totalElements := product(shape) + if totalElements > MaxArrayElements { + return fmt.Errorf("array too large: %d elements exceeds maximum of %d", + totalElements, MaxArrayElements) + } + + return nil +} + +func product(s []uint) uint { + if len(s) == 0 { + return 1 + } + p := uint(1) + for _, v := range s { + if v != 0 && p > MaxArrayElements/v { + return MaxArrayElements + 1 + } + p *= v + } + return p +} diff --git a/ndarray_test.go b/ndarray_test.go new file mode 100644 index 0000000..acd0fde --- /dev/null +++ b/ndarray_test.go @@ -0,0 +1,320 @@ +/******************************************************************************* + * ___ _ ____ ____ + * / _ \ _ _ ___ ___| |_| _ \| __ ) + * | | | | | | |/ _ \/ __| __| | | | _ \ + * | |_| | |_| | __/\__ \ |_| |_| | |_) | + * \__\_\\__,_|\___||___/\__|____/|____/ + * + * Copyright (c) 2014-2019 Appsicle + * Copyright (c) 2019-2022 QuestDB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ******************************************************************************/ + +package questdb_test + +import ( + "testing" + + qdb "github.com/questdb/go-questdb-client/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew_ValidShapes(t *testing.T) { + testCases := []struct { + name string + shape []uint + expected []uint + }{ + {"1D array", []uint{5}, []uint{5}}, + {"2D array", []uint{3, 4}, []uint{3, 4}}, + {"3D array", []uint{2, 3, 4}, []uint{2, 3, 4}}, + {"single element", []uint{1}, []uint{1}}, + {"large array", []uint{100, 200}, []uint{100, 200}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + arr, err := qdb.NewNDArray[float64](tc.shape...) + require.NoError(t, err) + assert.Equal(t, tc.expected, arr.Shape()) + assert.Equal(t, len(tc.shape), arr.NDims()) + + expectedSize := uint(1) + for _, dim := range tc.shape { + expectedSize *= dim + } + assert.Equal(t, int(expectedSize), arr.Size()) + }) + } +} + +func TestNew_InvalidShapes(t *testing.T) { + testCases := []struct { + name string + shape []uint + }{ + {"empty shape", []uint{}}, + {"too many dimensions", make([]uint, qdb.MaxDimensions+1)}, + {"too many elements", []uint{qdb.MaxArrayElements + 1}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Initialize shape with 1 for "too many dimensions" test + if tc.name == "too many dimensions" { + for i := range tc.shape { + tc.shape[i] = 1 + } + } + + arr, err := qdb.NewNDArray[float64](tc.shape...) + assert.Error(t, err) + assert.Nil(t, arr) + }) + } +} + +func TestSetGet_ValidPositions(t *testing.T) { + arr, err := qdb.NewNDArray[float64](3, 4) + require.NoError(t, err) + + testCases := []struct { + positions []uint + value float64 + }{ + {[]uint{0, 0}, 1.5}, + {[]uint{2, 3}, 42.0}, + {[]uint{1, 2}, -7.5}, + } + + for _, tc := range testCases { + err := arr.Set(tc.value, tc.positions...) + require.NoError(t, err) + + retrieved, err := arr.Get(tc.positions...) + require.NoError(t, err) + assert.Equal(t, tc.value, retrieved) + } +} + +func TestSetGet_InvalidPositions(t *testing.T) { + arr, err := qdb.NewNDArray[float64](3, 4) + require.NoError(t, err) + + testCases := []struct { + name string + positions []uint + }{ + {"wrong dimensions", []uint{0}}, + {"out of bounds first dim", []uint{3, 0}}, + {"out of bounds second dim", []uint{0, 4}}, + {"too many dimensions", []uint{0, 0, 0}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := arr.Set(1.0, tc.positions...) + assert.Error(t, err) + + _, err = arr.Get(tc.positions...) + assert.Error(t, err) + }) + } +} + +func TestReshape_ValidShapes(t *testing.T) { + arr, err := qdb.NewNDArray[float64](2, 3) + require.NoError(t, err) + + value := 1.0 + for i := uint(0); i < 2; i++ { + for j := uint(0); j < 3; j++ { + err := arr.Set(value, i, j) + require.NoError(t, err) + value++ + } + } + + testCases := []struct { + name string + newShape []uint + }{ + {"1D reshape", []uint{6}}, + {"3D reshape", []uint{1, 2, 3}}, + {"different 2D", []uint{3, 2}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reshaped, err := arr.Reshape(tc.newShape...) + require.NoError(t, err) + assert.Equal(t, tc.newShape, reshaped.Shape()) + assert.Equal(t, arr.Size(), reshaped.Size()) + assert.Equal(t, arr.Data(), reshaped.Data()) + }) + } +} + +func TestReshape_InvalidShapes(t *testing.T) { + arr, err := qdb.NewNDArray[float64](2, 3) + require.NoError(t, err) + + testCases := []struct { + name string + newShape []uint + }{ + {"wrong size", []uint{5}}, + {"empty shape", []uint{}}, + {"too large1", []uint{qdb.MaxArrayElements + 1}}, + {"too large2", []uint{4294967296, 4294967296}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reshaped, err := arr.Reshape(tc.newShape...) + assert.Error(t, err) + assert.Nil(t, reshaped) + }) + } +} + +func TestAppend(t *testing.T) { + arr, err := qdb.NewNDArray[float64](2, 2) + require.NoError(t, err) + + values := []float64{1.0, 2.0, 3.0, 4.0} + for i, val := range values { + hasMore, err := arr.Append(val) + require.NoError(t, err) + + if i < len(values)-1 { + assert.True(t, hasMore, "should have more space") + } else { + assert.False(t, hasMore, "should be full") + } + } + + assert.Equal(t, values, arr.Data()) + _, err = arr.Append(5.0) + assert.Error(t, err) + assert.Contains(t, err.Error(), "array is full") +} + +func TestResetAppendIndex(t *testing.T) { + arr, err := qdb.NewNDArray[float64](2, 2) + require.NoError(t, err) + for i := 0; i < 4; i++ { + _, err := arr.Append(float64(i)) + require.NoError(t, err) + } + + // Reset + arr.ResetAppendIndex() + hasMore, err := arr.Append(10.0) + require.NoError(t, err) + assert.True(t, hasMore) + + // The first element was overwritten + val, err := arr.Get(0, 0) + require.NoError(t, err) + assert.Equal(t, 10.0, val) +} + +func TestFill(t *testing.T) { + arr, err := qdb.NewNDArray[float64](2, 3) + require.NoError(t, err) + + fillValue := 42.0 + arr.Fill(fillValue) + + for i := uint(0); i < 2; i++ { + for j := uint(0); j < 3; j++ { + val, err := arr.Get(i, j) + require.NoError(t, err) + assert.Equal(t, fillValue, val) + } + } + + _, err = arr.Append(1.0) + assert.Error(t, err) + assert.Contains(t, err.Error(), "array is full") +} + +func TestShape_ReturnsImmutableCopy(t *testing.T) { + originalShape := []uint{3, 4} + arr, err := qdb.NewNDArray[float64](originalShape...) + require.NoError(t, err) + + shape := arr.Shape() + shape[0] = 999 + actualShape := arr.Shape() + assert.Equal(t, originalShape, actualShape) + assert.NotEqual(t, uint(999), actualShape[0]) +} + +func TestGetData_SharedReference(t *testing.T) { + arr, err := qdb.NewNDArray[float64](2, 2) + require.NoError(t, err) + + data := arr.Data() + data[0] = 42.0 + val, err := arr.Get(0, 0) + require.NoError(t, err) + assert.Equal(t, 42.0, val) +} + +func TestMaxLimits(t *testing.T) { + t.Run("max dimensions", func(t *testing.T) { + shape := make([]uint, qdb.MaxDimensions) + for i := range shape { + shape[i] = 1 + } + + arr, err := qdb.NewNDArray[float64](shape...) + require.NoError(t, err) + assert.Equal(t, qdb.MaxDimensions, arr.NDims()) + }) + + t.Run("max elements", func(t *testing.T) { + arr, err := qdb.NewNDArray[float64](qdb.MaxArrayElements) + require.NoError(t, err) + assert.Equal(t, qdb.MaxArrayElements, arr.Size()) + }) +} + +func TestPositionsToIndex(t *testing.T) { + arr, err := qdb.NewNDArray[float64](3, 4) + require.NoError(t, err) + + testCases := []struct { + positions []uint + value float64 + }{ + {[]uint{0, 0}, 100.0}, + {[]uint{0, 1}, 101.0}, + {[]uint{1, 0}, 102.0}, + {[]uint{2, 3}, 103.0}, + } + + for _, tc := range testCases { + err := arr.Set(tc.value, tc.positions...) + require.NoError(t, err) + + retrieved, err := arr.Get(tc.positions...) + require.NoError(t, err) + assert.Equal(t, tc.value, retrieved) + } +} diff --git a/sender.go b/sender.go index 87b636f..b6f38f2 100644 --- a/sender.go +++ b/sender.go @@ -120,6 +120,64 @@ type LineSender interface { // '-', '*' '%%', '~', or a non-printable char. BoolColumn(name string, val bool) LineSender + // Float64Array1DColumn adds an array of 64-bit floats (double array) to the ILP message. + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', "', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + Float64Array1DColumn(name string, values []float64) LineSender + + // Float64Array2DColumn adds a 2D array of 64-bit floats (double 2D array) to the ILP message. + // + // The values parameter must have a regular (rectangular) shape - all rows must have + // exactly the same length. If the array has irregular shape, this method returns an error. + // + // Example of valid input: + // values := [][]float64{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} // 3x2 regular shape + // + // Example of invalid input: + // values := [][]float64{{1.0, 2.0}, {3.0}, {4.0, 5.0, 6.0}} // irregular shape - returns error + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', "', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + Float64Array2DColumn(name string, values [][]float64) LineSender + + // Float64Array3DColumn adds a 3D array of 64-bit floats (double 3D array) to the ILP message. + // + // The values parameter must have a regular (cuboid) shape - all dimensions must have + // consistent sizes throughout. If the array has irregular shape, this method returns an error. + // + // Example of valid input: + // values := [][][]float64{ + // {{1.0, 2.0}, {3.0, 4.0}}, // 2x2 matrix + // {{5.0, 6.0}, {7.0, 8.0}}, // 2x2 matrix (same shape) + // } // 2x2x2 regular shape + // + // Example of invalid input: + // values := [][][]float64{ + // {{1.0, 2.0}, {3.0, 4.0}}, // 2x2 matrix + // {{5.0}, {6.0, 7.0, 8.0}}, // irregular matrix - returns error + // } + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', "', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + Float64Array3DColumn(name string, values [][][]float64) LineSender + + // Float64ArrayNDColumn adds an n-dimensional array of 64-bit floats (double n-D array) to the ILP message. + // + // Example usage: + // // Create a 2x3x4 array + // arr, _ := questdb.NewNDArray[float64](2, 3, 4) + // arr.Fill(1.5) + // sender.Float64ArrayNDColumn("ndarray_col", arr) + // + // Column name cannot contain any of the following characters: + // '\n', '\r', '?', '.', ',', "', '"', '\', '/', ':', ')', '(', '+', + // '-', '*' '%%', '~', or a non-printable char. + Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender + // At sets the designated timestamp value and finalizes the ILP // message. // @@ -189,6 +247,14 @@ const ( tlsInsecureSkipVerify tlsMode = 2 ) +type protocolVersion int64 + +const ( + protocolVersionUnset protocolVersion = 0 + ProtocolVersion1 protocolVersion = 1 + ProtocolVersion2 protocolVersion = 2 +) + type lineSenderConfig struct { senderType senderType address string @@ -213,6 +279,8 @@ type lineSenderConfig struct { // Auto-flush fields autoFlushRows int autoFlushInterval time.Duration + + protocolVersion protocolVersion } // LineSenderOption defines line sender config option. @@ -404,6 +472,21 @@ func WithAutoFlushInterval(interval time.Duration) LineSenderOption { } } +// WithProtocolVersion sets the ingestion protocol version. +// +// - HTTP transport automatically negotiates the protocol version by default(unset, STRONGLY RECOMMENDED). +// You can explicitly configure the protocol version to avoid the slight latency cost at connection time. +// - TCP transport does not negotiate the protocol version and uses [ProtocolVersion1] by +// default. You must explicitly set [ProtocolVersion2] in order to ingest +// arrays. +// +// NOTE: QuestDB server version 9.0.0 or later is required for [ProtocolVersion2]. +func WithProtocolVersion(version protocolVersion) LineSenderOption { + return func(s *lineSenderConfig) { + s.protocolVersion = version + } +} + // LineSenderFromEnv creates a LineSender with a config string defined by the QDB_CLIENT_CONF // environment variable. See LineSenderFromConf for the config string format. // @@ -556,7 +639,7 @@ func newLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, err if err != nil { return nil, err } - return newHttpLineSender(conf) + return newHttpLineSender(ctx, conf) } return nil, errors.New("sender type is not specified: use WithHttp or WithTcp") } @@ -638,6 +721,10 @@ func validateConf(conf *lineSenderConfig) error { if conf.autoFlushInterval < 0 { return fmt.Errorf("auto flush interval is negative: %d", conf.autoFlushInterval) } + if conf.protocolVersion < protocolVersionUnset || conf.protocolVersion > ProtocolVersion2 { + return errors.New("current client only supports protocol version 1(text format for all datatypes), " + + "2(binary format for part datatypes) or explicitly unset") + } return nil } diff --git a/sender_pool.go b/sender_pool.go index a59a81a..7da8572 100644 --- a/sender_pool.go +++ b/sender_pool.go @@ -181,7 +181,7 @@ func (p *LineSenderPool) Sender(ctx context.Context) (LineSender, error) { return nil, errHttpOnlySender } } - s, err = newHttpLineSender(conf) + s, err = newHttpLineSender(ctx, conf) } if err != nil { @@ -324,6 +324,26 @@ func (ps *pooledSender) BoolColumn(name string, val bool) LineSender { return ps } +func (ps *pooledSender) Float64Array1DColumn(name string, values []float64) LineSender { + ps.wrapped.Float64Array1DColumn(name, values) + return ps +} + +func (ps *pooledSender) Float64Array2DColumn(name string, values [][]float64) LineSender { + ps.wrapped.Float64Array2DColumn(name, values) + return ps +} + +func (ps *pooledSender) Float64Array3DColumn(name string, values [][][]float64) LineSender { + ps.wrapped.Float64Array3DColumn(name, values) + return ps +} + +func (ps *pooledSender) Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender { + ps.wrapped.Float64ArrayNDColumn(name, values) + return ps +} + func (ps *pooledSender) AtNow(ctx context.Context) error { err := ps.wrapped.AtNow(ctx) if err != nil { diff --git a/tcp_integration_test.go b/tcp_integration_test.go index 2b0f3c3..9ac2cef 100644 --- a/tcp_integration_test.go +++ b/tcp_integration_test.go @@ -59,6 +59,7 @@ func (suite *integrationTestSuite) TestE2EWriteInBatches() { sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress)) assert.NoError(suite.T(), err) defer sender.Close(ctx) + dropTable(suite.T(), testTable, questdbC.httpAddress) for i := 0; i < n; i++ { for j := 0; j < nBatch; j++ { @@ -112,6 +113,7 @@ func (suite *integrationTestSuite) TestE2EImplicitFlush() { sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithInitBufferSize(bufCap)) assert.NoError(suite.T(), err) defer sender.Close(ctx) + dropTable(suite.T(), testTable, questdbC.httpAddress) for i := 0; i < 10*bufCap; i++ { err = sender. @@ -147,6 +149,7 @@ func (suite *integrationTestSuite) TestE2ESuccessfulAuth() { ) assert.NoError(suite.T(), err) + dropTable(suite.T(), testTable, questdbC.httpAddress) err = sender. Table(testTable). StringColumn("str_col", "foobar"). @@ -169,7 +172,7 @@ func (suite *integrationTestSuite) TestE2ESuccessfulAuth() { expected := tableData{ Columns: []column{ - {"str_col", "STRING"}, + {"str_col", "VARCHAR"}, {"timestamp", "TIMESTAMP"}, }, Dataset: [][]interface{}{ @@ -214,6 +217,7 @@ func (suite *integrationTestSuite) TestE2EFailedAuth() { return } + dropTable(suite.T(), testTable, questdbC.httpAddress) err = sender. Table(testTable). StringColumn("str_col", "barbaz"). @@ -252,6 +256,7 @@ func (suite *integrationTestSuite) TestE2EWritesWithTlsProxy() { ) assert.NoError(suite.T(), err) defer sender.Close(ctx) + dropTable(suite.T(), testTable, questdbC.httpAddress) err = sender. Table(testTable). @@ -270,7 +275,7 @@ func (suite *integrationTestSuite) TestE2EWritesWithTlsProxy() { expected := tableData{ Columns: []column{ - {"str_col", "STRING"}, + {"str_col", "VARCHAR"}, {"timestamp", "TIMESTAMP"}, }, Dataset: [][]interface{}{ @@ -305,6 +310,7 @@ func (suite *integrationTestSuite) TestE2ESuccessfulAuthWithTlsProxy() { qdb.WithTlsInsecureSkipVerify(), ) assert.NoError(suite.T(), err) + dropTable(suite.T(), testTable, questdbC.httpAddress) err = sender. Table(testTable). @@ -328,7 +334,7 @@ func (suite *integrationTestSuite) TestE2ESuccessfulAuthWithTlsProxy() { expected := tableData{ Columns: []column{ - {"str_col", "STRING"}, + {"str_col", "VARCHAR"}, {"timestamp", "TIMESTAMP"}, }, Dataset: [][]interface{}{ @@ -344,6 +350,66 @@ func (suite *integrationTestSuite) TestE2ESuccessfulAuthWithTlsProxy() { }, eventualDataTimeout, 100*time.Millisecond) } +func (suite *integrationTestSuite) TestDoubleArrayColumn() { + if testing.Short() { + suite.T().Skip("skipping integration test") + } + + ctx := context.Background() + questdbC, err := setupQuestDB(ctx, noAuth) + assert.NoError(suite.T(), err) + defer questdbC.Stop(ctx) + + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(questdbC.ilpAddress), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) + assert.NoError(suite.T(), err) + defer sender.Close(ctx) + dropTable(suite.T(), testTable, questdbC.httpAddress) + + values1D := []float64{1.0, 2.0, 3.0, 4.0, 5.0} + values2D := [][]float64{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} + values3D := [][][]float64{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}} + arrayND, err := qdb.NewNDArray[float64](2, 2, 1, 2) + assert.NoError(suite.T(), err) + arrayND.Fill(11.0) + + err = sender. + Table(testTable). + Float64Array1DColumn("array_1d", values1D). + Float64Array2DColumn("array_2d", values2D). + Float64Array3DColumn("array_3d", values3D). + Float64ArrayNDColumn("array_nd", arrayND). + At(ctx, time.UnixMicro(1)) + assert.NoError(suite.T(), err) + + err = sender.Flush(ctx) + assert.NoError(suite.T(), err) + + // Expected results + expected := tableData{ + Columns: []column{ + {"array_1d", "ARRAY"}, + {"array_2d", "ARRAY"}, + {"array_3d", "ARRAY"}, + {"array_nd", "ARRAY"}, + {"timestamp", "TIMESTAMP"}, + }, + Dataset: [][]interface{}{ + { + []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)}, + []interface{}{[]interface{}{float64(1), float64(2)}, []interface{}{float64(3), float64(4)}, []interface{}{float64(5), float64(6)}}, + []interface{}{[]interface{}{[]interface{}{float64(1), float64(2)}, []interface{}{float64(3), float64(4)}}, []interface{}{[]interface{}{float64(5), float64(6)}, []interface{}{float64(7), float64(8)}}}, + []interface{}{[]interface{}{[]interface{}{[]interface{}{float64(11), float64(11)}}, []interface{}{[]interface{}{float64(11), float64(11)}}}, []interface{}{[]interface{}{[]interface{}{float64(11), float64(11)}}, []interface{}{[]interface{}{float64(11), float64(11)}}}}, + "1970-01-01T00:00:00.000001Z"}, + }, + Count: 1, + } + + assert.Eventually(suite.T(), func() bool { + data := queryTableData(suite.T(), testTable, questdbC.httpAddress) + return reflect.DeepEqual(expected, data) + }, eventualDataTimeout, 100*time.Millisecond) +} + type tableData struct { Columns []column `json:"columns"` Dataset [][]interface{} `json:"dataset"` @@ -355,6 +421,23 @@ type column struct { Type string `json:"type"` } +func dropTable(t *testing.T, tableName, address string) { + // We always query data using the QuestDB container over http + address = "http://" + address + u, err := url.Parse(address) + assert.NoError(t, err) + + u.Path += "exec" + params := url.Values{} + params.Add("query", "drop table if exists '"+tableName+"'") + u.RawQuery = params.Encode() + url := fmt.Sprintf("%v", u) + + res, err := http.Get(url) + assert.NoError(t, err) + defer res.Body.Close() +} + func queryTableData(t *testing.T, tableName, address string) tableData { // We always query data using the QuestDB container over http address = "http://" + address diff --git a/tcp_sender.go b/tcp_sender.go index d4bf86d..f46c5de 100644 --- a/tcp_sender.go +++ b/tcp_sender.go @@ -33,6 +33,7 @@ import ( "crypto/rand" "crypto/tls" "encoding/base64" + "errors" "fmt" "math/big" "net" @@ -45,7 +46,11 @@ type tcpLineSender struct { conn net.Conn } -func newTcpLineSender(ctx context.Context, conf *lineSenderConfig) (*tcpLineSender, error) { +type tcpLineSenderV2 struct { + tcpLineSender +} + +func newTcpLineSender(ctx context.Context, conf *lineSenderConfig) (LineSender, error) { var ( d net.Dialer key *ecdsa.PrivateKey @@ -131,7 +136,13 @@ func newTcpLineSender(ctx context.Context, conf *lineSenderConfig) (*tcpLineSend s.conn = conn - return s, nil + if conf.protocolVersion == protocolVersionUnset || conf.protocolVersion == ProtocolVersion1 { + return s, nil + } else { + return &tcpLineSenderV2{ + *s, + }, nil + } } func (s *tcpLineSender) Close(_ context.Context) error { @@ -183,6 +194,26 @@ func (s *tcpLineSender) BoolColumn(name string, val bool) LineSender { return s } +func (s *tcpLineSender) Float64Array1DColumn(name string, values []float64) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support double-array")) + return s +} + +func (s *tcpLineSender) Float64Array2DColumn(name string, values [][]float64) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support double-array")) + return s +} + +func (s *tcpLineSender) Float64Array3DColumn(name string, values [][][]float64) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support double-array")) + return s +} + +func (s *tcpLineSender) Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender { + s.buf.SetLastErr(errors.New("current protocol version does not support double-array")) + return s +} + func (s *tcpLineSender) Flush(ctx context.Context) error { err := s.buf.LastErr() s.buf.ClearLastErr() @@ -240,9 +271,9 @@ func (s *tcpLineSender) At(ctx context.Context, ts time.Time) error { return nil } -// Messages returns a copy of accumulated ILP messages that are not +// Messages returns the accumulated ILP messages that are not // flushed to the TCP connection yet. Useful for debugging purposes. -func (s *tcpLineSender) Messages() string { +func (s *tcpLineSender) Messages() []byte { return s.buf.Messages() } @@ -255,3 +286,63 @@ func (s *tcpLineSender) MsgCount() int { func (s *tcpLineSender) BufLen() int { return s.buf.Len() } + +func (s *tcpLineSenderV2) Table(name string) LineSender { + s.buf.Table(name) + return s +} + +func (s *tcpLineSenderV2) Symbol(name, val string) LineSender { + s.buf.Symbol(name, val) + return s +} + +func (s *tcpLineSenderV2) Int64Column(name string, val int64) LineSender { + s.buf.Int64Column(name, val) + return s +} + +func (s *tcpLineSenderV2) Long256Column(name string, val *big.Int) LineSender { + s.buf.Long256Column(name, val) + return s +} + +func (s *tcpLineSenderV2) TimestampColumn(name string, ts time.Time) LineSender { + s.buf.TimestampColumn(name, ts) + return s +} + +func (s *tcpLineSenderV2) StringColumn(name, val string) LineSender { + s.buf.StringColumn(name, val) + return s +} + +func (s *tcpLineSenderV2) BoolColumn(name string, val bool) LineSender { + s.buf.BoolColumn(name, val) + return s +} + +func (s *tcpLineSenderV2) Float64Column(name string, val float64) LineSender { + s.buf.Float64ColumnBinary(name, val) + return s +} + +func (s *tcpLineSenderV2) Float64Array1DColumn(name string, values []float64) LineSender { + s.buf.Float64Array1DColumn(name, values) + return s +} + +func (s *tcpLineSenderV2) Float64Array2DColumn(name string, values [][]float64) LineSender { + s.buf.Float64Array2DColumn(name, values) + return s +} + +func (s *tcpLineSenderV2) Float64Array3DColumn(name string, values [][][]float64) LineSender { + s.buf.Float64Array3DColumn(name, values) + return s +} + +func (s *tcpLineSenderV2) Float64ArrayNDColumn(name string, values *NdArray[float64]) LineSender { + s.buf.Float64ArrayNDColumn(name, values) + return s +} diff --git a/tcp_sender_test.go b/tcp_sender_test.go index d738d93..688494e 100644 --- a/tcp_sender_test.go +++ b/tcp_sender_test.go @@ -67,6 +67,10 @@ func TestTcpHappyCasesFromConf(t *testing.T) { config: fmt.Sprintf("tcp::addr=%s;init_buf_size=%d;", addr, initBufSize), }, + { + name: "protocol_version", + config: fmt.Sprintf("tcp::addr=%s;protocol_version=2;", addr), + }, } for _, tc := range testCases { @@ -100,6 +104,10 @@ func TestTcpHappyCasesFromEnv(t *testing.T) { config: fmt.Sprintf("tcp::addr=%s;init_buf_size=%d;", addr, initBufSize), }, + { + name: "protocol_version", + config: fmt.Sprintf("tcp::addr=%s;protocol_version=2;", addr), + }, } for _, tc := range testCases { @@ -133,6 +141,11 @@ func TestTcpPathologicalCasesFromEnv(t *testing.T) { config: "tcp::auto_flush_rows=5;", expectedErr: "autoFlushRows setting is not available", }, + { + name: "protocol_version", + config: "tcp::protocol_version=3;", + expectedErr: "current client only supports protocol version 1 (text format for all datatypes), 2 (binary format for part datatypes) or explicitly unset", + }, } for _, tc := range testCases { @@ -197,6 +210,11 @@ func TestTcpPathologicalCasesFromConf(t *testing.T) { config: "tCp::addr=localhost:1234;", expectedErr: "invalid schema", }, + { + name: "protocol version", + config: "tcp::protocol_version=abc;", + expectedErr: "invalid protocol_version value", + }, } for _, tc := range testCases { @@ -301,6 +319,53 @@ func TestErrorOnContextDeadline(t *testing.T) { t.Fail() } +func TestArrayColumnUnsupportedInTCPProtocolV1(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond)) + defer cancel() + + srv, err := newTestTcpServer(readAndDiscard) + assert.NoError(t, err) + defer srv.Close() + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion1)) + assert.NoError(t, err) + defer sender.Close(ctx) + + values1D := []float64{1.0, 2.0, 3.0, 4.0, 5.0} + values2D := [][]float64{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} + values3D := [][][]float64{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}} + arrayND, err := qdb.NewNDArray[float64](2, 2, 1, 2) + assert.NoError(t, err) + arrayND.Fill(11.0) + + err = sender. + Table(testTable). + Float64Array1DColumn("array_1d", values1D). + At(ctx, time.UnixMicro(1)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support double-array") + + err = sender. + Table(testTable). + Float64Array2DColumn("array_2d", values2D). + At(ctx, time.UnixMicro(2)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support double-array") + + err = sender. + Table(testTable). + Float64Array3DColumn("array_3d", values3D). + At(ctx, time.UnixMicro(3)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support double-array") + + err = sender. + Table(testTable). + Float64ArrayNDColumn("array_nd", arrayND). + At(ctx, time.UnixMicro(4)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "current protocol version does not support double-array") +} + func BenchmarkLineSenderBatch1000(b *testing.B) { ctx := context.Background() @@ -308,10 +373,17 @@ func BenchmarkLineSenderBatch1000(b *testing.B) { assert.NoError(b, err) defer srv.Close() - sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr())) + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) assert.NoError(b, err) defer sender.Close(ctx) + // Prepare test array data + values1D := []float64{1.0, 2.0, 3.0, 4.0, 5.0} + values2D := [][]float64{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} + values3D := [][][]float64{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}} + arrayND, _ := qdb.NewNDArray[float64](2, 3) + arrayND.Fill(1.5) + b.ResetTimer() for i := 0; i < b.N; i++ { for j := 0; j < 1000; j++ { @@ -323,6 +395,10 @@ func BenchmarkLineSenderBatch1000(b *testing.B) { StringColumn("str_col", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua"). BoolColumn("bool_col", true). TimestampColumn("timestamp_col", time.UnixMicro(42)). + Float64Array1DColumn("array_1d", values1D). + Float64Array2DColumn("array_2d", values2D). + Float64Array3DColumn("array_3d", values3D). + Float64ArrayNDColumn("array_nd", arrayND). At(ctx, time.UnixMicro(int64(1000*i))) } sender.Flush(ctx) @@ -336,10 +412,16 @@ func BenchmarkLineSenderNoFlush(b *testing.B) { assert.NoError(b, err) defer srv.Close() - sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr())) + sender, err := qdb.NewLineSender(ctx, qdb.WithTcp(), qdb.WithAddress(srv.Addr()), qdb.WithProtocolVersion(qdb.ProtocolVersion2)) assert.NoError(b, err) defer sender.Close(ctx) + values1D := []float64{1.0, 2.0, 3.0, 4.0, 5.0} + values2D := [][]float64{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} + values3D := [][][]float64{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}} + arrayND, _ := qdb.NewNDArray[float64](2, 3) + arrayND.Fill(1.5) + b.ResetTimer() for i := 0; i < b.N; i++ { sender. @@ -350,6 +432,10 @@ func BenchmarkLineSenderNoFlush(b *testing.B) { StringColumn("str_col", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua"). BoolColumn("bool_col", true). TimestampColumn("timestamp_col", time.UnixMicro(42)). + Float64Array1DColumn("array_1d", values1D). + Float64Array2DColumn("array_2d", values2D). + Float64Array3DColumn("array_3d", values3D). + Float64ArrayNDColumn("array_nd", arrayND). At(ctx, time.UnixMicro(int64(1000*i))) } sender.Flush(ctx) diff --git a/utils_test.go b/utils_test.go index 8ecef81..d37cf5d 100644 --- a/utils_test.go +++ b/utils_test.go @@ -53,12 +53,14 @@ const ( ) type testServer struct { - addr string - tcpListener net.Listener - serverType serverType - BackCh chan string - closeCh chan struct{} - wg sync.WaitGroup + addr string + tcpListener net.Listener + serverType serverType + protocolVersions []int + BackCh chan string + closeCh chan struct{} + wg sync.WaitGroup + settingsReqErrMsg string } func (t *testServer) Addr() string { @@ -66,24 +68,42 @@ func (t *testServer) Addr() string { } func newTestTcpServer(serverType serverType) (*testServer, error) { - return newTestServerWithProtocol(serverType, "tcp") + return newTestServerWithProtocol(serverType, "tcp", []int{}) } func newTestHttpServer(serverType serverType) (*testServer, error) { - return newTestServerWithProtocol(serverType, "http") + return newTestServerWithProtocol(serverType, "http", []int{1, 2}) } -func newTestServerWithProtocol(serverType serverType, protocol string) (*testServer, error) { +func newTestHttpServerWithErrMsg(serverType serverType, errMsg string) (*testServer, error) { tcp, err := net.Listen("tcp", "127.0.0.1:") if err != nil { return nil, err } s := &testServer{ - addr: tcp.Addr().String(), - tcpListener: tcp, - serverType: serverType, - BackCh: make(chan string, 1000), - closeCh: make(chan struct{}), + addr: tcp.Addr().String(), + tcpListener: tcp, + serverType: serverType, + BackCh: make(chan string, 1000), + closeCh: make(chan struct{}), + settingsReqErrMsg: errMsg, + } + go s.serveHttp() + return s, nil +} + +func newTestServerWithProtocol(serverType serverType, protocol string, protocolVersions []int) (*testServer, error) { + tcp, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + return nil, err + } + s := &testServer{ + addr: tcp.Addr().String(), + tcpListener: tcp, + serverType: serverType, + BackCh: make(chan string, 1000), + closeCh: make(chan struct{}), + protocolVersions: protocolVersions, } switch protocol { @@ -193,6 +213,58 @@ func (s *testServer) serveHttp() { var ( err error ) + if r.Method == "GET" && r.URL.Path == "/settings" { + if len(s.settingsReqErrMsg) != 0 { + w.WriteHeader(http.StatusInternalServerError) + data, err := json.Marshal(map[string]interface{}{ + "code": "500", + "message": s.settingsReqErrMsg, + }) + if err != nil { + panic(err) + } + w.Write(data) + return + } else if s.protocolVersions == nil { + w.WriteHeader(http.StatusNotFound) + data, err := json.Marshal(map[string]interface{}{ + "code": "404", + "message": "Not Found", + }) + if err != nil { + panic(err) + } + w.Write(data) + return + } + w.Header().Set("Content-Type", "application/json") + var data []byte + if len(s.protocolVersions) == 0 { + data, err = json.Marshal(map[string]interface{}{ + "version": "8.1.2", + }) + } else { + data, err = json.Marshal(map[string]interface{}{ + "config": map[string]interface{}{ + "release.type": "OSS", + "release.version": "[DEVELOPMENT]", + "http.settings.readonly": false, + "line.proto.support.versions": s.protocolVersions, + "ilp.proto.transports": []string{"tcp", "http"}, + "posthog.enabled": false, + "posthog.api.key": nil, + "cairo.max.file.name.length": 256, + }, + "preferences.version": 0, + "preferences": map[string]interface{}{}, + }) + } + if err != nil { + panic(err) + } + w.Write(data) + return + } switch s.serverType { case failFirstThenSendToBackChannel: