Skip to content

Commit 1c9474b

Browse files
Gabe Lyonsmistercrunch
Gabe Lyons
authored andcommitted
treating floats like doubles for druid versions lower than 11.0.0 (apache#5030)
1 parent 9f66dae commit 1c9474b

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

.pylintrc

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ ignore-mixin-members=yes
277277
# (useful for modules/projects where namespaces are manipulated during runtime
278278
# and thus existing member attributes cannot be deduced by static analysis. It
279279
# supports qualified module names, as well as Unix pattern matching.
280-
ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuilder.security.sqla.PermissionView.role,flask_appbuilder.Model.metadata,flask_appbuilder.Base.metadata
280+
ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuilder.security.sqla.PermissionView.role,flask_appbuilder.Model.metadata,flask_appbuilder.Base.metadata,distutils
281281

282282
# List of class names for which member attributes should not be checked (useful
283283
# for classes with dynamically set attributes). This supports the use of

superset/connectors/druid/models.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections import OrderedDict
1010
from copy import deepcopy
1111
from datetime import datetime, timedelta
12+
from distutils.version import LooseVersion
1213
import json
1314
import logging
1415
from multiprocessing.pool import ThreadPool
@@ -899,8 +900,8 @@ def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dic
899900
missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
900901
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)
901902

902-
@classmethod
903-
def metrics_and_post_aggs(cls, metrics, metrics_dict):
903+
@staticmethod
904+
def metrics_and_post_aggs(metrics, metrics_dict, druid_version=None):
904905
# Separate metrics into those that are aggregations
905906
# and those that are post aggregations
906907
saved_agg_names = set()
@@ -920,9 +921,13 @@ def metrics_and_post_aggs(cls, metrics, metrics_dict):
920921
for postagg_name in postagg_names:
921922
postagg = metrics_dict[postagg_name]
922923
visited_postaggs.add(postagg_name)
923-
cls.resolve_postagg(
924+
DruidDatasource.resolve_postagg(
924925
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict)
925-
aggs = cls.get_aggregations(metrics_dict, saved_agg_names, adhoc_agg_configs)
926+
aggs = DruidDatasource.get_aggregations(
927+
metrics_dict,
928+
saved_agg_names,
929+
adhoc_agg_configs,
930+
)
926931
return aggs, post_aggs
927932

928933
def values_for_column(self,
@@ -997,11 +1002,12 @@ def _add_filter_from_pre_query_data(self, df, dimensions, dim_filter):
9971002

9981003
@staticmethod
9991004
def druid_type_from_adhoc_metric(adhoc_metric):
1000-
column_type = adhoc_metric.get('column').get('type').lower()
1001-
aggregate = adhoc_metric.get('aggregate').lower()
1002-
if (aggregate == 'count'):
1005+
column_type = adhoc_metric['column']['type'].lower()
1006+
aggregate = adhoc_metric['aggregate'].lower()
1007+
1008+
if aggregate == 'count':
10031009
return 'count'
1004-
if (aggregate == 'count_distinct'):
1010+
if aggregate == 'count_distinct':
10051011
return 'cardinality'
10061012
else:
10071013
return column_type + aggregate.capitalize()
@@ -1132,6 +1138,17 @@ def run_query( # noqa / druid
11321138
metrics_dict = {m.metric_name: m for m in self.metrics}
11331139
columns_dict = {c.column_name: c for c in self.columns}
11341140

1141+
if (
1142+
self.cluster and
1143+
LooseVersion(self.cluster.get_druid_version()) < LooseVersion('0.11.0')
1144+
):
1145+
for metric in metrics:
1146+
if (
1147+
utils.is_adhoc_metric(metric) and
1148+
metric['column']['type'].upper() == 'FLOAT'
1149+
):
1150+
metric['column']['type'] = 'DOUBLE'
1151+
11351152
aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs(
11361153
metrics,
11371154
metrics_dict)
@@ -1187,7 +1204,7 @@ def run_query( # noqa / druid
11871204
pre_qry = deepcopy(qry)
11881205
if timeseries_limit_metric:
11891206
order_by = timeseries_limit_metric
1190-
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
1207+
aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs(
11911208
[timeseries_limit_metric],
11921209
metrics_dict)
11931210
if phase == 1:
@@ -1256,7 +1273,7 @@ def run_query( # noqa / druid
12561273

12571274
if timeseries_limit_metric:
12581275
order_by = timeseries_limit_metric
1259-
aggs_dict, post_aggs_dict = self.metrics_and_post_aggs(
1276+
aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs(
12601277
[timeseries_limit_metric],
12611278
metrics_dict)
12621279
if phase == 1:

tests/druid_tests.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def __reduce__(self):
7676
},
7777
]
7878

79+
DruidCluster.get_druid_version = lambda _: '0.9.1'
80+
7981

8082
class DruidTests(SupersetTestCase):
8183

@@ -114,7 +116,6 @@ def get_cluster(self, PyDruid):
114116

115117
db.session.add(cluster)
116118
cluster.get_datasources = PickableMock(return_value=['test_datasource'])
117-
cluster.get_druid_version = PickableMock(return_value='0.9.1')
118119

119120
return cluster
120121

@@ -324,7 +325,6 @@ def test_sync_druid_perm(self, PyDruid):
324325
cluster.get_datasources = PickableMock(
325326
return_value=['test_datasource'],
326327
)
327-
cluster.get_druid_version = PickableMock(return_value='0.9.1')
328328

329329
cluster.refresh_datasources()
330330
cluster.datasources[0].merge_flag = True

0 commit comments

Comments
 (0)