From a5ca89c05a352a9441cf842118b15b4afbc3b345 Mon Sep 17 00:00:00 2001 From: Junwang Zhao Date: Sat, 28 Jun 2025 21:08:37 +0800 Subject: [PATCH] feat: implement transform ResultType --- src/iceberg/transform_function.cc | 99 ++++++++++++++++++++++++++++--- test/transform_test.cc | 77 ++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 7 deletions(-) diff --git a/src/iceberg/transform_function.cc b/src/iceberg/transform_function.cc index 6aa49bff..32b1a1f8 100644 --- a/src/iceberg/transform_function.cc +++ b/src/iceberg/transform_function.cc @@ -48,7 +48,27 @@ Result BucketTransform::Transform(const ArrowArray& input) { } Result> BucketTransform::ResultType() const { - return NotImplemented("BucketTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for bucket transform"); + } + switch (src_type->type_id()) { + case TypeId::kInt: + case TypeId::kLong: + case TypeId::kDecimal: + case TypeId::kDate: + case TypeId::kTime: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + case TypeId::kString: + case TypeId::kUuid: + case TypeId::kFixed: + case TypeId::kBinary: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for bucket transform", + src_type->ToString()); + } } TruncateTransform::TruncateTransform(std::shared_ptr const& source_type, @@ -60,7 +80,21 @@ Result TruncateTransform::Transform(const ArrowArray& input) { } Result> TruncateTransform::ResultType() const { - return NotImplemented("TruncateTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for truncate transform"); + } + switch (src_type->type_id()) { + case TypeId::kInt: + case TypeId::kLong: + case TypeId::kDecimal: + case TypeId::kString: + case TypeId::kBinary: + return src_type; + default: + return NotSupported("{} is not a valid input type for truncate transform", + src_type->ToString()); + } } YearTransform::YearTransform(std::shared_ptr const& source_type) @@ -71,7 +105,19 @@ Result YearTransform::Transform(const ArrowArray& input) { } Result> YearTransform::ResultType() const { - return NotImplemented("YearTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for year transform"); + } + switch (src_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for year transform", + src_type->ToString()); + } } MonthTransform::MonthTransform(std::shared_ptr const& source_type) @@ -82,7 +128,19 @@ Result MonthTransform::Transform(const ArrowArray& input) { } Result> MonthTransform::ResultType() const { - return NotImplemented("MonthTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for month transform"); + } + switch (src_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for month transform", + src_type->ToString()); + } } DayTransform::DayTransform(std::shared_ptr const& source_type) @@ -93,7 +151,19 @@ Result DayTransform::Transform(const ArrowArray& input) { } Result> DayTransform::ResultType() const { - return NotImplemented("DayTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for day transform"); + } + switch (src_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for day transform", + src_type->ToString()); + } } HourTransform::HourTransform(std::shared_ptr const& source_type) @@ -104,7 +174,18 @@ Result HourTransform::Transform(const ArrowArray& input) { } Result> HourTransform::ResultType() const { - return NotImplemented("HourTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for hour transform"); + } + switch (src_type->type_id()) { + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for hour transform", + src_type->ToString()); + } } VoidTransform::VoidTransform(std::shared_ptr const& source_type) @@ -115,7 +196,11 @@ Result VoidTransform::Transform(const ArrowArray& input) { } Result> VoidTransform::ResultType() const { - return NotImplemented("VoidTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for void transform"); + } + return src_type; } } // namespace iceberg diff --git a/test/transform_test.cc b/test/transform_test.cc index f4c68a69..4c7e4da8 100644 --- a/test/transform_test.cc +++ b/test/transform_test.cc @@ -117,4 +117,81 @@ TEST(TransformFromStringTest, NegativeCases) { } } +TEST(TransformResultTypeTest, PositiveCases) { + struct Case { + std::string str; + std::shared_ptr source_type; + std::shared_ptr expected_result_type; + }; + + const std::vector cases = { + {.str = "identity", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "year", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "month", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "day", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "hour", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "void", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "bucket[16]", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "truncate[32]", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + }; + + for (const auto& c : cases) { + auto result = TransformFromString(c.str); + ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; + + const auto& transform = result.value(); + const auto transformPtr = transform->Bind(c.source_type); + ASSERT_TRUE(transformPtr.has_value()) << "Failed to bind: " << c.str; + + auto result_type = transformPtr.value()->ResultType(); + ASSERT_TRUE(result_type.has_value()) << "Failed to get result type for: " << c.str; + EXPECT_EQ(result_type.value()->type_id(), c.expected_result_type->type_id()) + << "Unexpected result type for: " << c.str; + } +} + +TEST(TransformResultTypeTest, NegativeCases) { + struct Case { + std::string str; + std::shared_ptr source_type; + }; + + const std::vector cases = { + {.str = "identity", .source_type = nullptr}, + {.str = "year", .source_type = std::make_shared()}, + {.str = "month", .source_type = std::make_shared()}, + {.str = "day", .source_type = std::make_shared()}, + {.str = "hour", .source_type = std::make_shared()}, + {.str = "void", .source_type = nullptr}, + {.str = "bucket[16]", .source_type = std::make_shared()}, + {.str = "truncate[32]", .source_type = std::make_shared()}}; + + for (const auto& c : cases) { + auto result = TransformFromString(c.str); + ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; + + const auto& transform = result.value(); + auto transformPtr = transform->Bind(c.source_type); + + auto result_type = transformPtr.value()->ResultType(); + ASSERT_THAT(result_type, IsError(ErrorKind::kNotSupported)); + } +} + } // namespace iceberg