diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index 9c7112f1c3..22d0ebce5c 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -102,7 +102,7 @@ public ExecuteResult execute( } new AvaticaParameterBinder( - preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()) + preparedStatement, typedValues, ((ArrowFlightConnection) connection).getBufferAllocator()) .bind(typedValues); if (statementHandle.signature == null) { @@ -144,11 +144,13 @@ public ExecuteBatchResult executeBatch( throw new IllegalStateException("Prepared statement not found: " + statementHandle); } - final AvaticaParameterBinder binder = - new AvaticaParameterBinder( - preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator()); - for (int i = 0; i < parameterValuesList.size(); i++) { - binder.bind(parameterValuesList.get(i), i); + if (parameterValuesList.size() > 0) { + final AvaticaParameterBinder binder = + new AvaticaParameterBinder( + preparedStatement, parameterValuesList.get(0), ((ArrowFlightConnection) connection).getBufferAllocator()); + for (int i = 0; i < parameterValuesList.size(); i++) { + binder.bind(parameterValuesList.get(i), i); + } } // Update query diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java index 4c2a9b865f..4e03c0535f 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java @@ -16,7 +16,10 @@ */ package org.apache.arrow.driver.jdbc.utils; +import java.sql.Types; +import java.util.ArrayList; import java.util.List; + import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; import org.apache.arrow.driver.jdbc.converter.impl.BinaryAvaticaParameterConverter; import org.apache.arrow.driver.jdbc.converter.impl.BoolAvaticaParameterConverter; @@ -42,9 +45,18 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.remote.TypedValue; +import static org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE; +import static org.apache.calcite.avatica.ColumnMetaData.Rep.*; + /** * Convert Avatica PreparedStatement parameters from a list of TypedValue to Arrow and bind them to * the VectorSchemaRoot representing the PreparedStatement parameters. @@ -59,13 +71,15 @@ public class AvaticaParameterBinder { * Instantiate a new AvaticaParameterBinder. * * @param preparedStatement The PreparedStatement to bind parameters to. - * @param bufferAllocator The BufferAllocator to use for allocating memory. + * @param bufferAllocator The BufferAllocator to use for allocating memory. */ public AvaticaParameterBinder( - PreparedStatement preparedStatement, BufferAllocator bufferAllocator) { + PreparedStatement preparedStatement, List typedValues, BufferAllocator bufferAllocator) { this.parameters = - VectorSchemaRoot.create(preparedStatement.getParameterSchema(), bufferAllocator); + VectorSchemaRoot.create(makeSchema(typedValues), bufferAllocator); this.preparedStatement = preparedStatement; + + } /** @@ -77,19 +91,90 @@ public void bind(List typedValues) { bind(typedValues, 0); } + private ArrowType getArrowTypeFromTypedValue(TypedValue typedValue) { + switch (typedValue.type) { + case PRIMITIVE_BOOLEAN: + case BOOLEAN: + return new ArrowType.Bool(); + + case PRIMITIVE_BYTE: + case BYTE: + return new ArrowType.Int(8, true); + + case PRIMITIVE_CHAR: + case CHARACTER: + + case STRING: + return new ArrowType.Utf8(); + + case PRIMITIVE_SHORT: + case SHORT: + return new ArrowType.Int(16, true); + + case PRIMITIVE_INT: + case INTEGER: + return new ArrowType.Int(32, true); + + case PRIMITIVE_LONG: + case LONG: + return new ArrowType.Int(64, true); + + case PRIMITIVE_FLOAT: + case FLOAT: + return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + + case PRIMITIVE_DOUBLE: + case DOUBLE: + return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + + case JAVA_SQL_TIME: + return new ArrowType.Time(TimeUnit.MILLISECOND, 32); + + case JAVA_SQL_TIMESTAMP: + // TODO: figure out TZ + return new ArrowType.Timestamp(TimeUnit.MILLISECOND, null); + + case JAVA_SQL_DATE: + case JAVA_UTIL_DATE: + return new ArrowType.Date(DateUnit.DAY); + + case BYTE_STRING: + return new ArrowType.Binary(); + + case NUMBER: + return new ArrowType.Decimal(38, 0, 128); + + case ARRAY: + return new ArrowType.List(); + + case MULTISET: + case STRUCT: + return new ArrowType.Struct(); + + case OBJECT: + // TODO: figure out how to handle Object. I imagine java.time objects end up here + default: + throw new UnsupportedOperationException("Unsupported TypedValue type: " + typedValue.type); + } + } + + public Schema makeSchema(List typedValues) { + final List parameterFields = new ArrayList<>(typedValues.size()); + for (int i = 0; i < typedValues.size(); i++) { + ArrowType arrowType = getArrowTypeFromTypedValue(typedValues.get(i)); + FieldType fieldType = new FieldType(false, arrowType, null); + parameterFields.add(new Field(null, fieldType, null)); + } + return new Schema(parameterFields); + } + /** * Bind the given Avatica values to the prepared statement at the given index. * * @param typedValues The parameter values. - * @param index index for parameter. + * @param index index for parameter. */ public void bind(List typedValues, int index) { - if (preparedStatement.getParameterSchema().getFields().size() != typedValues.size()) { - throw new IllegalStateException( - String.format( - "Prepared statement has %s parameters, but only received %s", - preparedStatement.getParameterSchema().getFields().size(), typedValues.size())); - } for (int i = 0; i < typedValues.size(); i++) { bind(parameters.getVector(i), typedValues.get(i), index); @@ -104,9 +189,9 @@ public void bind(List typedValues, int index) { /** * Bind a TypedValue to the given index on the FieldVector. * - * @param vector FieldVector to bind to. + * @param vector FieldVector to bind to. * @param typedValue TypedValue to bind to the vector. - * @param index Vector index to bind the value at. + * @param index Vector index to bind the value at. */ private void bind(FieldVector vector, TypedValue typedValue, int index) { try { @@ -144,8 +229,8 @@ public static class BinderVisitor implements ArrowType.ArrowTypeVisitor * Instantiate a new BinderVisitor. * * @param vector FieldVector to bind values to. - * @param value TypedValue to bind. - * @param index Vector index (0-based) to bind the value to. + * @param value TypedValue to bind. + * @param index Vector index (0-based) to bind the value to. */ public BinderVisitor(FieldVector vector, TypedValue value, int index) { this.vector = vector;