diff --git a/arrow-pg/src/datatypes/df.rs b/arrow-pg/src/datatypes/df.rs index 2959128..256b1c4 100644 --- a/arrow-pg/src/datatypes/df.rs +++ b/arrow-pg/src/datatypes/df.rs @@ -2,7 +2,7 @@ use std::iter; use std::sync::Arc; use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; -use datafusion::arrow::datatypes::{DataType, Date32Type}; +use datafusion::arrow::datatypes::{DataType, Date32Type, TimeUnit}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::ParamValues; use datafusion::prelude::*; @@ -79,10 +79,8 @@ where let param_len = portal.parameter_len(); let mut deserialized_params = Vec::with_capacity(param_len); for i in 0..param_len { - let pg_type = get_pg_type( - portal.statement.parameter_types.get(i), - inferenced_types.get(i).and_then(|v| v.to_owned()), - )?; + let inferenced_type = inferenced_types.get(i).and_then(|v| v.to_owned()); + let pg_type = get_pg_type(portal.statement.parameter_types.get(i), inferenced_type)?; match pg_type { // enumerate all supported parameter types and deserialize the // type to ScalarValue @@ -158,9 +156,36 @@ where } Type::TIME => { let value = portal.parameter::(i, &pg_type)?; - deserialized_params.push(ScalarValue::Time64Microsecond(value.map(|t| { - t.num_seconds_from_midnight() as i64 * 1_000_000 + t.nanosecond() as i64 / 1_000 - }))); + + let ns = value.map(|t| { + t.num_seconds_from_midnight() as i64 * 1_000_000_000 + t.nanosecond() as i64 + }); + + let scalar_value = match inferenced_type { + Some(DataType::Time64(TimeUnit::Nanosecond)) => { + ScalarValue::Time64Nanosecond(ns) + } + Some(DataType::Time64(TimeUnit::Microsecond)) => { + ScalarValue::Time64Microsecond(ns.map(|ns| (ns / 1_000) as _)) + } + Some(DataType::Time32(TimeUnit::Millisecond)) => { + ScalarValue::Time32Millisecond(ns.map(|ns| (ns / 1_000_000) as _)) + } + Some(DataType::Time32(TimeUnit::Second)) => { + ScalarValue::Time32Second(ns.map(|ns| (ns / 1_000_000_000) as _)) + } + _ => { + return Err(PgWireError::ApiError( + format!( + "Unable to deserialise time parameter type {:?} to type {:?}", + value, inferenced_type + ) + .into(), + )) + } + }; + + deserialized_params.push(scalar_value); } Type::UUID => { let value = portal.parameter::(i, &pg_type)?; @@ -294,3 +319,64 @@ where Ok(ParamValues::List(deserialized_params)) } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::DataType; + use bytes::Bytes; + use datafusion::{common::ParamValues, scalar::ScalarValue}; + use pgwire::{ + api::{portal::Portal, stmt::StoredStatement}, + messages::extendedquery::Bind, + }; + use postgres_types::Type; + + use crate::datatypes::df::deserialize_parameters; + + #[test] + fn test_deserialise_time_params() { + let postgres_types = vec![Type::TIME]; + + let us: i64 = 1_000_000; // 1 second + + let bind = Bind::new( + None, + None, + vec![], + vec![Some(Bytes::from(i64::to_be_bytes(us).to_vec()))], + vec![], + ); + + let stmt = StoredStatement::new("statement_id".into(), "statement", postgres_types); + let portal = Portal::try_new(&bind, Arc::new(stmt)).unwrap(); + + for (arrow_type, expected) in [ + ( + DataType::Time32(arrow::datatypes::TimeUnit::Second), + ScalarValue::Time32Second(Some(1)), + ), + ( + DataType::Time32(arrow::datatypes::TimeUnit::Millisecond), + ScalarValue::Time32Millisecond(Some(1000)), + ), + ( + DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), + ScalarValue::Time64Microsecond(Some(1000000)), + ), + ( + DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond), + ScalarValue::Time64Nanosecond(Some(1000000000)), + ), + ] { + let result = deserialize_parameters(&portal, &[Some(&arrow_type)]).unwrap(); + let ParamValues::List(list) = result else { + panic!("expected list"); + }; + + assert_eq!(list.len(), 1); + assert_eq!(list[0], expected) + } + } +} diff --git a/arrow-pg/src/encoder.rs b/arrow-pg/src/encoder.rs index e35fadd..61bdf96 100644 --- a/arrow-pg/src/encoder.rs +++ b/arrow-pg/src/encoder.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use arrow::{array::*, datatypes::*}; use bytes::BufMut; use bytes::BytesMut; +use chrono::NaiveTime; use chrono::{NaiveDate, NaiveDateTime}; #[cfg(feature = "datafusion")] use datafusion::arrow::{array::*, datatypes::*}; @@ -203,43 +204,43 @@ fn get_date64_value(arr: &Arc, idx: usize) -> Option { .value_as_date(idx) } -fn get_time32_second_value(arr: &Arc, idx: usize) -> Option { +fn get_time32_second_value(arr: &Arc, idx: usize) -> Option { if arr.is_null(idx) { return None; } arr.as_any() .downcast_ref::() .unwrap() - .value_as_datetime(idx) + .value_as_time(idx) } -fn get_time32_millisecond_value(arr: &Arc, idx: usize) -> Option { +fn get_time32_millisecond_value(arr: &Arc, idx: usize) -> Option { if arr.is_null(idx) { return None; } arr.as_any() .downcast_ref::() .unwrap() - .value_as_datetime(idx) + .value_as_time(idx) } -fn get_time64_microsecond_value(arr: &Arc, idx: usize) -> Option { +fn get_time64_microsecond_value(arr: &Arc, idx: usize) -> Option { if arr.is_null(idx) { return None; } arr.as_any() .downcast_ref::() .unwrap() - .value_as_datetime(idx) + .value_as_time(idx) } -fn get_time64_nanosecond_value(arr: &Arc, idx: usize) -> Option { +fn get_time64_nanosecond_value(arr: &Arc, idx: usize) -> Option { if arr.is_null(idx) { return None; } arr.as_any() .downcast_ref::() .unwrap() - .value_as_datetime(idx) + .value_as_time(idx) } fn get_numeric_128_value( @@ -518,4 +519,45 @@ mod tests { assert!(encoder.encoded_value == val); } + + #[test] + fn test_get_time32_second_value() { + let array = Time32SecondArray::from_iter_values([3723_i32]); + let array: Arc = Arc::new(array); + let value = get_time32_second_value(&array, 0); + assert_eq!(value, Some(NaiveTime::from_hms_opt(1, 2, 3)).unwrap()); + } + + #[test] + fn test_get_time32_millisecond_value() { + let array = Time32MillisecondArray::from_iter_values([3723001_i32]); + let array: Arc = Arc::new(array); + let value = get_time32_millisecond_value(&array, 0); + assert_eq!( + value, + Some(NaiveTime::from_hms_milli_opt(1, 2, 3, 1)).unwrap() + ); + } + + #[test] + fn test_get_time64_microsecond_value() { + let array = Time64MicrosecondArray::from_iter_values([3723001001_i64]); + let array: Arc = Arc::new(array); + let value = get_time64_microsecond_value(&array, 0); + assert_eq!( + value, + Some(NaiveTime::from_hms_micro_opt(1, 2, 3, 1001)).unwrap() + ); + } + + #[test] + fn test_get_time64_nanosecond_value() { + let array = Time64NanosecondArray::from_iter_values([3723001001001_i64]); + let array: Arc = Arc::new(array); + let value = get_time64_nanosecond_value(&array, 0); + assert_eq!( + value, + Some(NaiveTime::from_hms_nano_opt(1, 2, 3, 1001001)).unwrap() + ); + } }