Skip to content

Commit

Permalink
branch-3.0: [Fix](bug) Percentile* func core when percent args is neg…
Browse files Browse the repository at this point in the history
…ative number #47068 (#47219)

Cherry-picked from #47068

Co-authored-by: HappenLee <[email protected]>
  • Loading branch information
github-actions[bot] and HappenLee authored Feb 7, 2025
1 parent 814e4d7 commit 2c155a4
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 69 deletions.
56 changes: 36 additions & 20 deletions be/src/vec/aggregate_functions/aggregate_function_percentile.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,20 @@ namespace doris::vectorized {
class Arena;
class BufferReadable;

inline void check_quantile(double quantile) {
if (quantile < 0 || quantile > 1) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"quantile in func percentile should in [0, 1], but real data is:" +
std::to_string(quantile));
}
}

struct PercentileApproxState {
static constexpr double INIT_QUANTILE = -1.0;
PercentileApproxState() = default;
~PercentileApproxState() = default;

void init(double compression = 10000) {
void init(double quantile, double compression = 10000) {
if (!init_flag) {
//https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description
//The compression parameter setting range is [2048, 10000].
Expand All @@ -66,6 +74,8 @@ struct PercentileApproxState {
compression = 10000;
}
digest = TDigest::create_unique(compression);
check_quantile(quantile);
target_quantile = quantile;
compressions = compression;
init_flag = true;
}
Expand Down Expand Up @@ -126,18 +136,14 @@ struct PercentileApproxState {
}
}

void add(double source, double quantile) {
digest->add(source);
target_quantile = quantile;
}
void add(double source) { digest->add(source); }

void add_with_weight(double source, double weight, double quantile) {
void add_with_weight(double source, double weight) {
// the weight should be positive num, as have check the value valid use DCHECK_GT(c._weight, 0);
if (weight <= 0) {
return;
}
digest->add(source, weight);
target_quantile = quantile;
}

void reset() {
Expand Down Expand Up @@ -192,8 +198,8 @@ class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPerce
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
this->data(place).init();
this->data(place).add(sources.get_element(row_num), quantile.get_element(row_num));
this->data(place).init(quantile.get_element(0));
this->data(place).add(sources.get_element(row_num));
}

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
Expand Down Expand Up @@ -223,8 +229,8 @@ class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPer
const auto& compression =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);

this->data(place).init(compression.get_element(row_num));
this->data(place).add(sources.get_element(row_num), quantile.get_element(row_num));
this->data(place).init(quantile.get_element(0), compression.get_element(0));
this->data(place).add(sources.get_element(row_num));
}

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
Expand Down Expand Up @@ -256,9 +262,9 @@ class AggregateFunctionPercentileApproxWeightedThreeParams
const auto& quantile =
assert_cast<const ColumnVector<Float64>&, TypeCheckOnRelease::DISABLE>(*columns[2]);

this->data(place).init();
this->data(place).add_with_weight(sources.get_element(row_num), weight.get_element(row_num),
quantile.get_element(row_num));
this->data(place).init(quantile.get_element(0));
this->data(place).add_with_weight(sources.get_element(row_num),
weight.get_element(row_num));
}

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
Expand Down Expand Up @@ -291,9 +297,9 @@ class AggregateFunctionPercentileApproxWeightedFourParams
const auto& compression =
assert_cast<const ColumnVector<Float64>&, TypeCheckOnRelease::DISABLE>(*columns[3]);

this->data(place).init(compression.get_element(row_num));
this->data(place).add_with_weight(sources.get_element(row_num), weight.get_element(row_num),
quantile.get_element(row_num));
this->data(place).init(quantile.get_element(0), compression.get_element(0));
this->data(place).add_with_weight(sources.get_element(row_num),
weight.get_element(row_num));
}

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
Expand Down Expand Up @@ -351,12 +357,19 @@ struct PercentileState {
}
}

void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size) {
void add(T source, const PaddedPODArray<Float64>& quantiles, const NullMap& null_maps,
int arg_size) {
if (!inited_flag) {
vec_counts.resize(arg_size);
vec_quantile.resize(arg_size, -1);
inited_flag = true;
for (int i = 0; i < arg_size; ++i) {
// throw Exception func call percentile_array(id, [1,0,null])
if (null_maps[i]) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"quantiles in func percentile_array should not have null");
}
check_quantile(quantiles[i]);
vec_quantile[i] = quantiles[i];
}
}
Expand Down Expand Up @@ -429,7 +442,7 @@ class AggregateFunctionPercentile final
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num],
quantile.get_data(), 1);
quantile.get_data(), NullMap(1, 0), 1);
}

void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Expand Down Expand Up @@ -490,14 +503,17 @@ class AggregateFunctionPercentileArray final
const auto& quantile_array =
assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& null_maps = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_null_map_data();
const auto& nested_column = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_nested_column();
const auto& nested_column_data =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(nested_column);

AggregateFunctionPercentileArray::data(place).add(
sources.get_int(row_num), nested_column_data.get_data(),
sources.get_int(row_num), nested_column_data.get_data(), null_maps,
offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,6 @@ class AggregateFunctionSimpleFactory {
std::unordered_map<std::string, std::string> function_alias;

public:
void register_nullable_function_combinator(const Creator& creator) {
for (const auto& entity : aggregate_functions) {
if (nullable_aggregate_functions.find(entity.first) ==
nullable_aggregate_functions.end()) {
nullable_aggregate_functions[entity.first] = creator;
}
}
}

static bool is_foreach(const std::string& name) {
constexpr std::string_view suffix = "_foreach";
if (name.length() < suffix.length()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
Expand Down Expand Up @@ -69,6 +70,14 @@ public PercentileArray(boolean distinct, Expression arg0, Expression arg1) {
super("percentile_array", distinct, arg0, arg1);
}

@Override
public void checkLegalityBeforeTypeCoercion() {
if (!getArgument(1).isConstant()) {
throw new AnalysisException(
"percentile_array requires second parameter must be a constant : " + this.toSql());
}
}

/**
* withDistinctAndChildren.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,27 @@

import com.google.common.collect.ImmutableList;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

/**
* combinator foreach
*/
public class ForEachCombinator extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, Combinator {

public static final Set<String> UNSUPPORTED_AGGREGATE_FUNCTION = Collections.unmodifiableSet(new HashSet<String>() {
{
add("percentile");
add("percentile_array");
add("percentile_approx");
add("percentile_approx_weighted");
}
});

private final AggregateFunction nested;

/**
Expand All @@ -48,10 +60,27 @@ public ForEachCombinator(List<Expression> arguments, AggregateFunction nested) {
this(arguments, false, nested);
}

/**
* Constructs a new instance of {@code ForEachCombinator}.
*
* <p>This constructor initializes a combinator that will iterate over each item in the input list
* and apply the nested aggregate function.
* If the provided aggregate function name is within the list of unsupported functions,
* an {@link UnsupportedOperationException} will be thrown.
*
* @param arguments A list of {@code Expression} objects that serve as parameters to the aggregate function.
* @param alwaysNullable A boolean flag indicating whether this combinator should always return a nullable result.
* @param nested The nested aggregate function to apply to each element. It must not be {@code null}.
* @throws NullPointerException If the provided nested aggregate function is {@code null}.
* @throws UnsupportedOperationException If nested aggregate function is one of the unsupported aggregate functions
*/
public ForEachCombinator(List<Expression> arguments, boolean alwaysNullable, AggregateFunction nested) {
super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, false, alwaysNullable, arguments);

this.nested = Objects.requireNonNull(nested, "nested can not be null");
if (UNSUPPORTED_AGGREGATE_FUNCTION.contains(nested.getName().toLowerCase())) {
throw new UnsupportedOperationException("Unsupport the func:" + nested.getName() + " use in foreach");
}
}

public static ForEachCombinator create(AggregateFunction nested) {
Expand Down
9 changes: 0 additions & 9 deletions regression-test/data/function_p0/test_agg_foreach.out
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@
-- !sql --
["{"num_buckets":3,"buckets":[{"lower":"1","upper":"1","ndv":1,"count":1,"pre_sum":0},{"lower":"20","upper":"20","ndv":1,"count":1,"pre_sum":1},{"lower":"100","upper":"100","ndv":1,"count":1,"pre_sum":2}]}", "{"num_buckets":1,"buckets":[{"lower":"2","upper":"2","ndv":1,"count":2,"pre_sum":0}]}", "{"num_buckets":1,"buckets":[{"lower":"3","upper":"3","ndv":1,"count":1,"pre_sum":0}]}"]

-- !sql --
[100, 2, 3]

-- !sql --
[[1], [2, 2, 2], [3]]

-- !sql --
[0, 0, 0]

-- !sql --
[0, 2, 3] [117, 2, 3] [113, 0, 3]

Expand Down
9 changes: 0 additions & 9 deletions regression-test/data/function_p0/test_agg_foreach_notnull.out
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@
-- !sql --
["{"num_buckets":3,"buckets":[{"lower":"1","upper":"1","ndv":1,"count":1,"pre_sum":0},{"lower":"20","upper":"20","ndv":1,"count":1,"pre_sum":1},{"lower":"100","upper":"100","ndv":1,"count":1,"pre_sum":2}]}", "{"num_buckets":1,"buckets":[{"lower":"2","upper":"2","ndv":1,"count":2,"pre_sum":0}]}", "{"num_buckets":1,"buckets":[{"lower":"3","upper":"3","ndv":1,"count":1,"pre_sum":0}]}"]

-- !sql --
[100, 2, 3]

-- !sql --
[[1], [2, 2, 2], [3]]

-- !sql --
[0, 0, 0]

-- !sql --
[0, 2, 3] [117, 2, 3] [113, 0, 3]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ beijing chengdu shanghai
5 29.1
6 101.1

-- !select22_1_1 --
1 \N
2 \N
3 \N
5 \N
6 \N

-- !select23 --
1 10.0
2 224.5
Expand Down Expand Up @@ -192,6 +199,13 @@ beijing chengdu shanghai
5 29.0
6 101.0

-- !select28_1 --
1 \N
2 \N
3 \N
5 \N
6 \N

-- !select29 --
1 0.0
2 216.5
Expand Down
26 changes: 16 additions & 10 deletions regression-test/suites/function_p0/test_agg_foreach.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,24 @@ suite("test_agg_foreach") {
select histogram_foreach(a) from foreach_table;
"""

qt_sql """
select PERCENTILE_foreach(a,a) from foreach_table;
"""
try {
sql "select PERCENTILE_foreach(a,a) from foreach_table;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

qt_sql """
select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1;
"""

qt_sql """
select PERCENTILE_APPROX_foreach(a,a) from foreach_table;
"""
try {
sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

try {
sql "select PERCENTILE_APPROX_foreach(a,a) from foreach_table;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

qt_sql """
select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table;
Expand Down
30 changes: 18 additions & 12 deletions regression-test/suites/function_p0/test_agg_foreach_notnull.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,26 @@ suite("test_agg_foreach_not_null") {
qt_sql """
select histogram_foreach(a) from foreach_table_not_null;
"""

qt_sql """
select PERCENTILE_foreach(a,a) from foreach_table_not_null;
"""

qt_sql """
select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null where id = 1;
"""

qt_sql """

select PERCENTILE_APPROX_foreach(a,a) from foreach_table_not_null;
"""
try {
sql "select PERCENTILE_foreach(a,a) from foreach_table_not_null;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}


try {
sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null where id = 1;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

try {
sql "select PERCENTILE_APPROX_foreach(a,a) from foreach_table_not_null;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

qt_sql """
select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table_not_null;
"""
Expand Down
Loading

0 comments on commit 2c155a4

Please sign in to comment.