@@ -96,8 +96,8 @@ impl ScalarUDFImpl for PowUdf {
96
96
// function, but we check again to make sure
97
97
assert_eq ! ( args. len( ) , 2 ) ;
98
98
let ( base, exp) = ( & args[ 0 ] , & args[ 1 ] ) ;
99
- assert_eq ! ( base. data_type( ) , DataType :: Float64 ) ;
100
- assert_eq ! ( exp. data_type( ) , DataType :: Float64 ) ;
99
+ assert_eq ! ( base. data_type( ) , & DataType :: Float64 ) ;
100
+ assert_eq ! ( exp. data_type( ) , & DataType :: Float64 ) ;
101
101
102
102
match ( base, exp) {
103
103
// For demonstration purposes we also implement the scalar / scalar
@@ -108,28 +108,31 @@ impl ScalarUDFImpl for PowUdf {
108
108
// the DataFusion expression simplification logic will often invoke
109
109
// this path once during planning, and simply use the result during
110
110
// execution.
111
- (
112
- ColumnarValue :: Scalar ( ScalarValue :: Float64 ( base) ) ,
113
- ColumnarValue :: Scalar ( ScalarValue :: Float64 ( exp) ) ,
114
- ) => {
115
- // compute the output. Note DataFusion treats `None` as NULL.
116
- let res = match ( base, exp) {
117
- ( Some ( base) , Some ( exp) ) => Some ( base. powf ( * exp) ) ,
118
- // one or both arguments were NULL
119
- _ => None ,
120
- } ;
121
- Ok ( ColumnarValue :: Scalar ( ScalarValue :: from ( res) ) )
111
+ ( ColumnarValue :: Scalar ( base) , ColumnarValue :: Scalar ( exp) ) => {
112
+ match ( base. value ( ) , exp. value ( ) ) {
113
+ ( ScalarValue :: Float64 ( base) , ScalarValue :: Float64 ( exp) ) => {
114
+ // compute the output. Note DataFusion treats `None` as NULL.
115
+ let res = match ( base, exp) {
116
+ ( Some ( base) , Some ( exp) ) => Some ( base. powf ( * exp) ) ,
117
+ // one or both arguments were NULL
118
+ _ => None ,
119
+ } ;
120
+ Ok ( ColumnarValue :: from ( ScalarValue :: from ( res) ) )
121
+ }
122
+ _ => {
123
+ internal_err ! ( "Invalid argument types to pow function" )
124
+ }
125
+ }
122
126
}
123
127
// special case if the exponent is a constant
124
- (
125
- ColumnarValue :: Array ( base_array) ,
126
- ColumnarValue :: Scalar ( ScalarValue :: Float64 ( exp) ) ,
127
- ) => {
128
- let result_array = match exp {
128
+ ( ColumnarValue :: Array ( base_array) , ColumnarValue :: Scalar ( exp) ) => {
129
+ let result_array = match exp. value ( ) {
129
130
// a ^ null = null
130
- None => new_null_array ( base_array. data_type ( ) , base_array. len ( ) ) ,
131
+ ScalarValue :: Float64 ( None ) => {
132
+ new_null_array ( base_array. data_type ( ) , base_array. len ( ) )
133
+ }
131
134
// a ^ exp
132
- Some ( exp) => {
135
+ ScalarValue :: Float64 ( Some ( exp) ) => {
133
136
// DataFusion has ensured both arguments are Float64:
134
137
let base_array = base_array. as_primitive :: < Float64Type > ( ) ;
135
138
// calculate the result for every row. The `unary`
@@ -139,24 +142,25 @@ impl ScalarUDFImpl for PowUdf {
139
142
compute:: unary ( base_array, |base| base. powf ( * exp) ) ;
140
143
Arc :: new ( res)
141
144
}
145
+ _ => return internal_err ! ( "Invalid argument types to pow function" ) ,
142
146
} ;
143
147
Ok ( ColumnarValue :: Array ( result_array) )
144
148
}
145
149
146
150
// special case if the base is a constant (note this code is quite
147
151
// similar to the previous case, so we omit comments)
148
- (
149
- ColumnarValue :: Scalar ( ScalarValue :: Float64 ( base) ) ,
150
- ColumnarValue :: Array ( exp_array) ,
151
- ) => {
152
- let res = match base {
153
- None => new_null_array ( exp_array. data_type ( ) , exp_array. len ( ) ) ,
154
- Some ( base) => {
152
+ ( ColumnarValue :: Scalar ( base) , ColumnarValue :: Array ( exp_array) ) => {
153
+ let res = match base. value ( ) {
154
+ ScalarValue :: Float64 ( None ) => {
155
+ new_null_array ( exp_array. data_type ( ) , exp_array. len ( ) )
156
+ }
157
+ ScalarValue :: Float64 ( Some ( base) ) => {
155
158
let exp_array = exp_array. as_primitive :: < Float64Type > ( ) ;
156
159
let res: Float64Array =
157
160
compute:: unary ( exp_array, |exp| base. powf ( exp) ) ;
158
161
Arc :: new ( res)
159
162
}
163
+ _ => return internal_err ! ( "Invalid argument types to pow function" ) ,
160
164
} ;
161
165
Ok ( ColumnarValue :: Array ( res) )
162
166
}
@@ -169,10 +173,6 @@ impl ScalarUDFImpl for PowUdf {
169
173
) ?;
170
174
Ok ( ColumnarValue :: Array ( Arc :: new ( res) ) )
171
175
}
172
- // if the types were not float, it is a bug in DataFusion
173
- _ => {
174
- internal_err ! ( "Invalid argument types to pow function" )
175
- }
176
176
}
177
177
}
178
178
0 commit comments