Skip to content

Commit 68a89b3

Browse files
committed
Datum based arithmetic
1 parent 07a721f commit 68a89b3

File tree

25 files changed

+270
-3761
lines changed

25 files changed

+270
-3761
lines changed

Cargo.toml

+10-2
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ rust-version = "1.64"
4747

4848
[workspace.dependencies]
4949
arrow = { version = "43.0.0", features = ["prettyprint", "dyn_cmp_dict"] }
50-
arrow-flight = { version = "43.0.0", features = ["flight-sql-experimental"] }
50+
arrow-array = { version = "43.0.0", default-features = false, features = ["chrono-tz"] }
5151
arrow-buffer = { version = "43.0.0", default-features = false }
52+
arrow-flight = { version = "43.0.0", features = ["flight-sql-experimental"] }
5253
arrow-schema = { version = "43.0.0", default-features = false }
53-
arrow-array = { version = "43.0.0", default-features = false, features = ["chrono-tz"] }
5454
parquet = { version = "43.0.0", features = ["arrow", "async", "object_store"] }
5555
sqlparser = { version = "0.35", features = ["visitor"] }
5656

@@ -71,3 +71,11 @@ opt-level = 3
7171
overflow-checks = false
7272
panic = 'unwind'
7373
rpath = false
74+
75+
[patch.crates-io]
76+
arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
77+
arrow-array = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
78+
arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
79+
arrow-flight = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
80+
arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
81+
parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }

datafusion-cli/Cargo.lock

+15-30
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-cli/Cargo.toml

+7
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,10 @@ assert_cmd = "2.0"
4949
ctor = "0.2.0"
5050
predicates = "3.0"
5151
rstest = "0.17"
52+
53+
[patch.crates-io]
54+
arrow = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
55+
arrow-array = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
56+
arrow-buffer = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
57+
arrow-schema = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }
58+
parquet = { git = "https://github.com/tustvold/arrow-rs.git", rev = "9c461f7027871b3d1a1b30de7fd26b3ac01cb096" }

datafusion/common/src/scalar.rs

+59-92
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ use arrow::{
4646
DECIMAL128_MAX_PRECISION,
4747
},
4848
};
49-
use arrow_array::timezone::Tz;
49+
use arrow_array::{timezone::Tz, ArrowNativeTypeOp};
5050
use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};
5151

5252
// Constants we use throughout this file:
@@ -743,55 +743,21 @@ macro_rules! impl_op {
743743
($LHS:expr, $RHS:expr, -) => {
744744
match ($LHS, $RHS) {
745745
(
746-
ScalarValue::TimestampSecond(Some(ts_lhs), tz_lhs),
747-
ScalarValue::TimestampSecond(Some(ts_rhs), tz_rhs),
748-
) => {
749-
let err = || {
750-
DataFusionError::Execution(
751-
"Overflow while converting seconds to milliseconds".to_string(),
752-
)
753-
};
754-
ts_sub_to_interval::<MILLISECOND_MODE>(
755-
ts_lhs.checked_mul(1_000).ok_or_else(err)?,
756-
ts_rhs.checked_mul(1_000).ok_or_else(err)?,
757-
tz_lhs.as_deref(),
758-
tz_rhs.as_deref(),
759-
)
760-
},
746+
ScalarValue::TimestampSecond(Some(ts_lhs), _),
747+
ScalarValue::TimestampSecond(Some(ts_rhs), _),
748+
) => Ok(ScalarValue::DurationSecond(Some(ts_lhs.sub_checked(*ts_rhs)?))),
761749
(
762-
ScalarValue::TimestampMillisecond(Some(ts_lhs), tz_lhs),
763-
ScalarValue::TimestampMillisecond(Some(ts_rhs), tz_rhs),
764-
) => ts_sub_to_interval::<MILLISECOND_MODE>(
765-
*ts_lhs,
766-
*ts_rhs,
767-
tz_lhs.as_deref(),
768-
tz_rhs.as_deref(),
769-
),
750+
ScalarValue::TimestampMillisecond(Some(ts_lhs), _),
751+
ScalarValue::TimestampMillisecond(Some(ts_rhs), _),
752+
) => Ok(ScalarValue::DurationMillisecond(Some(ts_lhs.sub_checked(*ts_rhs)?))),
770753
(
771-
ScalarValue::TimestampMicrosecond(Some(ts_lhs), tz_lhs),
772-
ScalarValue::TimestampMicrosecond(Some(ts_rhs), tz_rhs),
773-
) => {
774-
let err = || {
775-
DataFusionError::Execution(
776-
"Overflow while converting microseconds to nanoseconds".to_string(),
777-
)
778-
};
779-
ts_sub_to_interval::<NANOSECOND_MODE>(
780-
ts_lhs.checked_mul(1_000).ok_or_else(err)?,
781-
ts_rhs.checked_mul(1_000).ok_or_else(err)?,
782-
tz_lhs.as_deref(),
783-
tz_rhs.as_deref(),
784-
)
785-
},
754+
ScalarValue::TimestampMicrosecond(Some(ts_lhs), _),
755+
ScalarValue::TimestampMicrosecond(Some(ts_rhs), _),
756+
) => Ok(ScalarValue::DurationMicrosecond(Some(ts_lhs.sub_checked(*ts_rhs)?))),
786757
(
787-
ScalarValue::TimestampNanosecond(Some(ts_lhs), tz_lhs),
788-
ScalarValue::TimestampNanosecond(Some(ts_rhs), tz_rhs),
789-
) => ts_sub_to_interval::<NANOSECOND_MODE>(
790-
*ts_lhs,
791-
*ts_rhs,
792-
tz_lhs.as_deref(),
793-
tz_rhs.as_deref(),
794-
),
758+
ScalarValue::TimestampNanosecond(Some(ts_lhs), _),
759+
ScalarValue::TimestampNanosecond(Some(ts_rhs), _),
760+
) => Ok(ScalarValue::DurationNanosecond(Some(ts_lhs.sub_checked(*ts_rhs)?))),
795761
_ => impl_op_arithmetic!($LHS, $RHS, -)
796762
}
797763
};
@@ -1147,49 +1113,6 @@ pub const MDN_MODE: i8 = 2;
11471113

11481114
pub const MILLISECOND_MODE: bool = false;
11491115
pub const NANOSECOND_MODE: bool = true;
1150-
/// This function computes subtracts `rhs_ts` from `lhs_ts`, taking timezones
1151-
/// into account when given. Units of the resulting interval is specified by
1152-
/// the constant `TIME_MODE`.
1153-
/// The default behavior of Datafusion is the following:
1154-
/// - When subtracting timestamps at seconds/milliseconds precision, the output
1155-
/// interval will have the type [`IntervalDayTimeType`].
1156-
/// - When subtracting timestamps at microseconds/nanoseconds precision, the
1157-
/// output interval will have the type [`IntervalMonthDayNanoType`].
1158-
fn ts_sub_to_interval<const TIME_MODE: bool>(
1159-
lhs_ts: i64,
1160-
rhs_ts: i64,
1161-
lhs_tz: Option<&str>,
1162-
rhs_tz: Option<&str>,
1163-
) -> Result<ScalarValue> {
1164-
let parsed_lhs_tz = parse_timezones(lhs_tz)?;
1165-
let parsed_rhs_tz = parse_timezones(rhs_tz)?;
1166-
1167-
let (naive_lhs, naive_rhs) =
1168-
calculate_naives::<TIME_MODE>(lhs_ts, parsed_lhs_tz, rhs_ts, parsed_rhs_tz)?;
1169-
let delta_secs = naive_lhs.signed_duration_since(naive_rhs);
1170-
1171-
match TIME_MODE {
1172-
MILLISECOND_MODE => {
1173-
let as_millisecs = delta_secs.num_milliseconds();
1174-
Ok(ScalarValue::new_interval_dt(
1175-
(as_millisecs / MILLISECS_IN_ONE_DAY) as i32,
1176-
(as_millisecs % MILLISECS_IN_ONE_DAY) as i32,
1177-
))
1178-
}
1179-
NANOSECOND_MODE => {
1180-
let as_nanosecs = delta_secs.num_nanoseconds().ok_or_else(|| {
1181-
DataFusionError::Execution(String::from(
1182-
"Can not compute timestamp differences with nanosecond precision",
1183-
))
1184-
})?;
1185-
Ok(ScalarValue::new_interval_mdn(
1186-
0,
1187-
(as_nanosecs / NANOSECS_IN_ONE_DAY) as i32,
1188-
as_nanosecs % NANOSECS_IN_ONE_DAY,
1189-
))
1190-
}
1191-
}
1192-
}
11931116

11941117
/// This function parses the timezone from string to Tz.
11951118
/// If it cannot parse or timezone field is [`None`], it returns [`None`].
@@ -1424,6 +1347,14 @@ where
14241347
ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign),
14251348
ScalarValue::IntervalYearMonth(Some(i)) => shift_months(prior, *i, sign),
14261349
ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign),
1350+
ScalarValue::DurationSecond(Some(v)) => prior.add(Duration::seconds(*v)),
1351+
ScalarValue::DurationMillisecond(Some(v)) => {
1352+
prior.add(Duration::milliseconds(*v))
1353+
}
1354+
ScalarValue::DurationMicrosecond(Some(v)) => {
1355+
prior.add(Duration::microseconds(*v))
1356+
}
1357+
ScalarValue::DurationNanosecond(Some(v)) => prior.add(Duration::nanoseconds(*v)),
14271358
other => Err(DataFusionError::Execution(format!(
14281359
"DateIntervalExpr does not support non-interval type {other:?}"
14291360
)))?,
@@ -1891,6 +1822,16 @@ impl ScalarValue {
18911822
DataType::Interval(IntervalUnit::MonthDayNano) => {
18921823
ScalarValue::IntervalMonthDayNano(Some(0))
18931824
}
1825+
DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None),
1826+
DataType::Duration(TimeUnit::Millisecond) => {
1827+
ScalarValue::DurationMillisecond(None)
1828+
}
1829+
DataType::Duration(TimeUnit::Microsecond) => {
1830+
ScalarValue::DurationMicrosecond(None)
1831+
}
1832+
DataType::Duration(TimeUnit::Nanosecond) => {
1833+
ScalarValue::DurationNanosecond(None)
1834+
}
18941835
_ => {
18951836
return Err(DataFusionError::NotImplemented(format!(
18961837
"Can't create a zero scalar from data_type \"{datatype:?}\""
@@ -3191,6 +3132,20 @@ impl ScalarValue {
31913132
IntervalMonthDayNano
31923133
)
31933134
}
3135+
3136+
DataType::Duration(TimeUnit::Second) => {
3137+
typed_cast!(array, index, DurationSecondArray, DurationSecond)
3138+
}
3139+
DataType::Duration(TimeUnit::Millisecond) => {
3140+
typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond)
3141+
}
3142+
DataType::Duration(TimeUnit::Microsecond) => {
3143+
typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond)
3144+
}
3145+
DataType::Duration(TimeUnit::Nanosecond) => {
3146+
typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)
3147+
}
3148+
31943149
other => {
31953150
return Err(DataFusionError::NotImplemented(format!(
31963151
"Can't create a scalar from array of type \"{other:?}\""
@@ -3682,6 +3637,18 @@ impl TryFrom<&DataType> for ScalarValue {
36823637
DataType::Interval(IntervalUnit::MonthDayNano) => {
36833638
ScalarValue::IntervalMonthDayNano(None)
36843639
}
3640+
3641+
DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None),
3642+
DataType::Duration(TimeUnit::Millisecond) => {
3643+
ScalarValue::DurationMillisecond(None)
3644+
}
3645+
DataType::Duration(TimeUnit::Microsecond) => {
3646+
ScalarValue::DurationMicrosecond(None)
3647+
}
3648+
DataType::Duration(TimeUnit::Nanosecond) => {
3649+
ScalarValue::DurationNanosecond(None)
3650+
}
3651+
36853652
DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary(
36863653
index_type.clone(),
36873654
Box::new(value_type.as_ref().try_into()?),
@@ -3944,7 +3911,7 @@ mod tests {
39443911
use std::sync::Arc;
39453912

39463913
use arrow::compute::kernels;
3947-
use arrow::compute::{self, concat, is_null};
3914+
use arrow::compute::{concat, is_null};
39483915
use arrow::datatypes::ArrowPrimitiveType;
39493916
use arrow::util::pretty::pretty_format_columns;
39503917
use arrow_array::ArrowNumericType;
@@ -4073,7 +4040,7 @@ mod tests {
40734040
let right_array = right.to_array();
40744041
let arrow_left_array = left_array.as_primitive::<T>();
40754042
let arrow_right_array = right_array.as_primitive::<T>();
4076-
let arrow_result = compute::add_checked(arrow_left_array, arrow_right_array);
4043+
let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array);
40774044

40784045
assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
40794046
}

0 commit comments

Comments
 (0)