From 67f354d4ebc5532810417a494a408959827ece64 Mon Sep 17 00:00:00 2001 From: Tural Neymanov Date: Wed, 27 Mar 2019 17:49:18 -0400 Subject: [PATCH] Add threshold logic to adjust flags for certain input sizes. - Extracts manually supplied flags from argv. - Uses supplied args and input size estimations to check whether flags should be enabled or not. --- gcp_variant_transforms/libs/optimize_flags.py | 157 ++++++++++++ .../libs/optimize_flags_test.py | 234 ++++++++++++++++++ gcp_variant_transforms/pipeline_common.py | 16 ++ .../pipeline_common_test.py | 21 ++ gcp_variant_transforms/vcf_to_bq.py | 3 + 5 files changed, 431 insertions(+) create mode 100644 gcp_variant_transforms/libs/optimize_flags.py create mode 100644 gcp_variant_transforms/libs/optimize_flags_test.py diff --git a/gcp_variant_transforms/libs/optimize_flags.py b/gcp_variant_transforms/libs/optimize_flags.py new file mode 100644 index 000000000..9c7f35a10 --- /dev/null +++ b/gcp_variant_transforms/libs/optimize_flags.py @@ -0,0 +1,157 @@ +# Copyright 2019 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Util class used to optimize default values for flags, based on provided +input size. + +If any of the flags were manually supplied during the command's invocation, +they will not be overriden. + +The class uses 5 signals extracted from input, for flag adjustment: + - estimated total number of variants. + - estimated total number of samples. + - estimated number of records (variant data for sample). + - total size of the input. + - amount of supplied files. +""" + +import operator + +from apache_beam.runners import runner # pylint: disable=unused-import + + +class Dimensions(object): + """Contains dimensions of the input data and the manually supplied args.""" + def __init__(self, + line_count=None, # type: int + sample_count=None, # type: int + record_count=None, # type: int + files_size=None, # type: int + file_count=None, # type: int + supplied_args=None # type: List[str] + ): + # type(...) -> None + self.line_count = line_count + self.sample_count = sample_count + self.record_count = record_count + self.files_size = files_size + self.file_count = file_count + self.supplied_args = supplied_args + + +class Threshold(Dimensions): + """Describes the limits the input needs to pass to enable a certain flag. + + Unlike Dimensions object, should not have supplied_args set and not all + dimensions need to be defined. + """ + def __init__(self, + flag_name, # type: str + line_count=None, # type: int + sample_count=None, # type: int + record_count=None, # type: int + files_size=None, # type: int + file_count=None # type: int + ): + super(Threshold, self).__init__(line_count, + sample_count, + record_count, + files_size, + file_count) + self.flag_name = flag_name + + def not_supplied(self, state): + # type(Dimensions) -> bool + """Verify that flag was not manually supplied.""" + return self.flag_name not in state.supplied_args + + def hard_pass(self, state, cond=operator.gt): + # type(Dimensions, Callable) -> bool + """Verifies that all of set dimensions of the threshold are satisfied.""" + return self.not_supplied(state) and ( + (not self.line_count or cond(state.line_count, self.line_count)) and + (not self.sample_count or + cond(state.sample_count, self.sample_count)) and + (not self.record_count or + cond(state.record_count, self.record_count)) and + (not self.files_size or cond(state.files_size, self.files_size)) and + (not self.file_count or cond(state.file_count, self.file_count))) + + def soft_pass(self, state, cond=operator.gt): + # type(Dimensions, Callable) -> bool + """Verifies that at least one of the set dimensions is satisfied.""" + return self.not_supplied(state) and ( + (self.line_count and cond(state.line_count, self.line_count)) or + (self.sample_count and cond(state.sample_count, self.sample_count)) or + (self.record_count and cond(state.record_count, self.record_count)) or + (self.files_size and cond(state.files_size, self.files_size)) or + (self.file_count and cond(state.file_count, self.file_count))) + + +OPTIMIZE_FOR_LARGE_INPUTS_TS = Threshold( + 'optimize_for_large_inputs', + record_count=3000000000, + file_count=50000) +INFER_HEADERS_TS = Threshold( + 'infer_headers', + record_count=5000000000 +) +INFER_ANNOTATION_TYPES_TS = Threshold( + 'infer_annotation_types', + record_count=5000000000 +) +NUM_BIGQUERY_WRITE_SHARDS_TS = Threshold( + 'num_bigquery_write_shards', + record_count=1000000000, + files_size=500000000000 +) +NUM_WORKERS_TS = Threshold( + 'num_workers', + record_count=1000000000 +) +SHARD_VARIANTS_TS = Threshold( + 'shard_variants', + record_count=1000000000, +) + +def _optimize_known_args(known_args, input_dimensions): + if OPTIMIZE_FOR_LARGE_INPUTS_TS.soft_pass(input_dimensions): + known_args.optimize_for_large_inputs = True + if INFER_HEADERS_TS.soft_pass(input_dimensions, operator.le): + known_args.infer_headers = True + if NUM_BIGQUERY_WRITE_SHARDS_TS.soft_pass(input_dimensions): + known_args.num_bigquery_write_shards = 20 + if INFER_ANNOTATION_TYPES_TS.soft_pass(input_dimensions, operator.le): + known_args.infer_annotation_types = True + if SHARD_VARIANTS_TS.soft_pass(input_dimensions, operator.le): + known_args.shard_variants = False + +def _optimize_pipeline_args(pipeline_args, known_args, input_dimensions): + if NUM_WORKERS_TS.hard_pass(input_dimensions): + pipeline_args.num_workers = 100 + if (known_args.run_annotation_pipeline and + NUM_WORKERS_TS.not_supplied(input_dimensions)): + pipeline_args.num_workers = 400 + +def optimize_flags(supplied_args, known_args, pipeline_args): + # type(Namespace, List[str]) -> None + input_dimensions = Dimensions(line_count=known_args.estimated_line_count, + sample_count=known_args.estimated_sample_count, + record_count=known_args.estimated_record_count, + files_size=known_args.files_size, + file_count=known_args.file_count, + supplied_args=supplied_args) + + _optimize_known_args(known_args, input_dimensions) + _optimize_pipeline_args(pipeline_args, known_args, input_dimensions) diff --git a/gcp_variant_transforms/libs/optimize_flags_test.py b/gcp_variant_transforms/libs/optimize_flags_test.py new file mode 100644 index 000000000..8937976e8 --- /dev/null +++ b/gcp_variant_transforms/libs/optimize_flags_test.py @@ -0,0 +1,234 @@ +# Copyright 2019 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for optimize_flags module.""" + +import unittest +import argparse + +from apache_beam.options import pipeline_options + +from gcp_variant_transforms.libs import optimize_flags +from gcp_variant_transforms.options import variant_transform_options + +TOOL_OPTIONS = [ + variant_transform_options.VcfReadOptions, + variant_transform_options.AvroWriteOptions, + variant_transform_options.BigQueryWriteOptions, + variant_transform_options.AnnotationOptions, + variant_transform_options.FilterOptions, + variant_transform_options.MergeOptions, + variant_transform_options.PartitionOptions, + variant_transform_options.ExperimentalOptions] + +PIPELINE_OPTIONS = [ + pipeline_options.WorkerOptions +] + +def add_defaults(known_args): + known_args.run_annotation_pipeline = False + +def make_known_args_with_default_values(options): + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + _ = [option().add_arguments(parser) for option in options] + known_args, unknown_known_args = parser.parse_known_args([]) + + parser = argparse.ArgumentParser() + for cls in pipeline_options.PipelineOptions.__subclasses__(): + if '_add_argparse_args' in cls.__dict__: + cls._add_argparse_args(parser) + pipeline_args, unknown_pipeline_args = parser.parse_known_args([]) + assert not unknown_known_args + assert not unknown_pipeline_args + return known_args, pipeline_args + +class OptimizeFlagsTest(unittest.TestCase): + + known_args = pipeline_args = supplied_args = None + + def setUp(self): + self.known_args, self.pipeline_args = ( + make_known_args_with_default_values(TOOL_OPTIONS)) + self.supplied_args = [] + + def _set_up_dimensions( + self, + line_count, + sample_count, + record_count, + files_size, + file_count): + self.known_args.estimated_line_count = line_count + self.known_args.estimated_sample_count = sample_count + self.known_args.estimated_record_count = record_count + self.known_args.files_size = files_size + self.known_args.file_count = file_count + + def _run_tests(self): + optimize_flags.optimize_flags( + self.supplied_args, self.known_args, self.pipeline_args) + + def test_optimize_for_large_inputs_passes_records(self): + self._set_up_dimensions(1, 1, 3000000001, 1, 1) + self.known_args.optimize_for_large_inputs = False + + self._run_tests() + + self.assertEqual(self.known_args.optimize_for_large_inputs, True) + + def test_optimize_for_large_inputs_passes_files(self): + self._set_up_dimensions(1, 1, 3000000000, 1, 50001) + self.known_args.optimize_for_large_inputs = False + + self._run_tests() + + self.assertEqual(self.known_args.optimize_for_large_inputs, True) + + def test_optimize_for_large_inputs_fails(self): + self._set_up_dimensions(1, 1, 3000000000, 1, 50000) + self.known_args.optimize_for_large_inputs = False + + self._run_tests() + + self.assertEqual(self.known_args.optimize_for_large_inputs, False) + + def test_optimize_for_large_inputs_supplied(self): + self._set_up_dimensions(1, 1, 3000000001, 1, 50001) + self.supplied_args = ['optimize_for_large_inputs'] + self.known_args.optimize_for_large_inputs = False + + self._run_tests() + + self.assertEqual(self.known_args.optimize_for_large_inputs, False) + + def test_infer_headers_passes(self): + self._set_up_dimensions(1, 1, 5000000000, 1, 1) + self.known_args.infer_headers = False + + self._run_tests() + + self.assertEqual(self.known_args.infer_headers, True) + + def test_infer_headers_fails(self): + self._set_up_dimensions(1, 1, 5000000001, 1, 1) + self.known_args.infer_headers = False + + self._run_tests() + + self.assertEqual(self.known_args.infer_headers, False) + + def test_infer_headers_supplied(self): + self._set_up_dimensions(1, 1, 5000000000, 1, 1) + self.supplied_args = ['infer_headers'] + self.known_args.infer_headers = False + + self._run_tests() + + self.assertEqual(self.known_args.infer_headers, False) + + def test_num_bigquery_write_shards_passes_records(self): + self._set_up_dimensions(1, 1, 1000000001, 500000000000, 1) + self.known_args.num_bigquery_write_shards = 1 + + self._run_tests() + + self.assertEqual(self.known_args.num_bigquery_write_shards, 20) + + def test_num_bigquery_write_shards_passes_size(self): + self._set_up_dimensions(1, 1, 1000000000, 500000000001, 1) + self.known_args.num_bigquery_write_shards = 1 + + self._run_tests() + + self.assertEqual(self.known_args.num_bigquery_write_shards, 20) + + def test_num_bigquery_write_shards_fails(self): + self._set_up_dimensions(1, 1, 1000000000, 500000000000, 1) + self.known_args.num_bigquery_write_shards = 1 + + self._run_tests() + + self.assertEqual(self.known_args.num_bigquery_write_shards, 1) + + def test_num_bigquery_write_shards_supplied(self): + self._set_up_dimensions(1, 1, 1000000001, 500000000000, 1) + self.supplied_args = ['num_bigquery_write_shards'] + self.known_args.num_bigquery_write_shards = 1 + + self._run_tests() + + self.assertEqual(self.known_args.num_bigquery_write_shards, 1) + + def test_num_workers_passes_records(self): + self._set_up_dimensions(1, 1, 1000000001, 1, 1) + self.known_args.run_annotation_pipeline = False + self.pipeline_args.num_workers = 1 + + self._run_tests() + + self.assertEqual(self.pipeline_args.num_workers, 100) + + def test_num_workers_passes_size(self): + self._set_up_dimensions(1, 1, 1000000001, 1, 1) + self.known_args.run_annotation_pipeline = True + self.pipeline_args.num_workers = 1 + + self._run_tests() + + self.assertEqual(self.pipeline_args.num_workers, 400) + + def test_num_workers_fails(self): + self._set_up_dimensions(1, 1, 1000000000, 1, 1) + self.known_args.run_annotation_pipeline = False + self.pipeline_args.num_workers = 1 + + self._run_tests() + + self.assertEqual(self.pipeline_args.num_workers, 1) + + def test_num_workers_supplied(self): + self._set_up_dimensions(1, 1, 1000000001, 1, 1) + self.supplied_args = ['num_workers'] + self.known_args.run_annotation_pipeline = True + self.pipeline_args.num_workers = 1 + + self._run_tests() + + self.assertEqual(self.pipeline_args.num_workers, 1) + + def test_shard_variants_passes(self): + self._set_up_dimensions(1, 1, 1000000000, 1, 1) + self.known_args.shard_variants = True + + self._run_tests() + + self.assertEqual(self.known_args.shard_variants, False) + + def test_shard_variants_fails(self): + self._set_up_dimensions(1, 1, 1000000001, 1, 1) + self.known_args.shard_variants = True + + self._run_tests() + + self.assertEqual(self.known_args.shard_variants, True) + + def test_shard_variants_supplied(self): + self._set_up_dimensions(1, 1, 1000000000, 1, 1) + self.supplied_args = ['shard_variants'] + self.known_args.shard_variants = True + + self._run_tests() + + self.assertEqual(self.known_args.shard_variants, True) diff --git a/gcp_variant_transforms/pipeline_common.py b/gcp_variant_transforms/pipeline_common.py index e4850502c..38174b0c6 100644 --- a/gcp_variant_transforms/pipeline_common.py +++ b/gcp_variant_transforms/pipeline_common.py @@ -22,6 +22,7 @@ import argparse import enum import os +import sys import uuid from datetime import datetime @@ -76,6 +77,21 @@ def parse_args(argv, command_line_options): return known_args, pipeline_args +def extract_supplied_args(args): + # type (List[str]) -> List[str] + """Filters out all manually supplied arguments from argv by finding args that + start with '--', then droping the prefix and potentially '=*' suffix. + """ + supplied_args = [] + for arg in args or sys.argv: + if '--' in arg: + flag = arg[2:(arg + '=').find('=')] + supplied_args.append('infer_headers' if + flag == 'infer_undefined_headers' else flag) + + return supplied_args + + def _get_all_patterns(input_pattern, input_file): # type: (str, str) -> List[str] patterns = [input_pattern] if input_pattern else _get_file_names(input_file) diff --git a/gcp_variant_transforms/pipeline_common_test.py b/gcp_variant_transforms/pipeline_common_test.py index 0256b21be..c2df82bcf 100644 --- a/gcp_variant_transforms/pipeline_common_test.py +++ b/gcp_variant_transforms/pipeline_common_test.py @@ -15,6 +15,7 @@ """Tests for pipeline_common script.""" import collections +import sys import unittest from apache_beam.io.filesystems import FileSystems @@ -158,6 +159,26 @@ def test_get_splittable_bgzf(self): pipeline_common.get_splittable_bgzf(['no index file']), []) + def test_extract_supplied_args_from_sys(self): + sysargs = ['vcf_to_bq', '--input_file', 'file', '--output_table=output', + '--append', '--infer_headers', 'True'] + known_args = [] + + with mock.patch.object(sys, 'argv', sysargs): + self.assertItemsEqual( + pipeline_common.extract_supplied_args(known_args), + ['input_file', 'output_table', 'append', 'infer_headers']) + + def test_extract_supplied_args_from_supplied(self): + sysargs = ['vcf_to_bq', '--input_file'] + known_args = ['vcf_to_bq', '--input_file', 'file', '--output_table=output', + '--append', '--infer_undefined_headers', 'True'] + + with mock.patch.object(sys, 'argv', sysargs): + self.assertItemsEqual( + pipeline_common.extract_supplied_args(known_args), + ['input_file', 'output_table', 'append', 'infer_headers']) + class PipelineCommonWithFileTest(unittest.TestCase): """Tests cases for the `pipeline_common` script with file input.""" diff --git a/gcp_variant_transforms/vcf_to_bq.py b/gcp_variant_transforms/vcf_to_bq.py index e95ab3c70..065ec131e 100644 --- a/gcp_variant_transforms/vcf_to_bq.py +++ b/gcp_variant_transforms/vcf_to_bq.py @@ -48,6 +48,7 @@ from gcp_variant_transforms import pipeline_common from gcp_variant_transforms.beam_io import vcfio from gcp_variant_transforms.libs import metrics_util +from gcp_variant_transforms.libs import optimize_flags from gcp_variant_transforms.libs import processed_variant from gcp_variant_transforms.libs import vcf_header_parser from gcp_variant_transforms.libs import variant_partition @@ -410,7 +411,9 @@ def run(argv=None): _COMMAND_LINE_OPTIONS) if known_args.auto_flags_experiment: + supplied_args = pipeline_common.extract_supplied_args(argv) _get_input_dimensions(known_args, pipeline_args) + optimize_flags.optimize_flags(supplied_args, known_args, pipeline_args) annotated_vcf_pattern = _run_annotation_pipeline(known_args, pipeline_args)