diff --git a/examples/dcgan/dcgan_fuzzer.py b/examples/dcgan/dcgan_fuzzer.py index f31749c..3d1db98 100644 --- a/examples/dcgan/dcgan_fuzzer.py +++ b/examples/dcgan/dcgan_fuzzer.py @@ -19,9 +19,7 @@ import numpy as np import tensorflow as tf from lib.fuzz_utils import build_fetch_function -from lib.corpus import InputCorpus -from lib.corpus import seed_corpus_from_numpy_arrays -from lib.coverage_functions import raw_logit_coverage_function +from lib.coverage_functions import raw_coverage_function from lib.fuzzer import Fuzzer from lib.mutation_functions import do_basic_mutations from lib.sample_functions import uniform_sample_function @@ -40,7 +38,7 @@ def metadata_function(metadata_batches): - """Gets the metadata.""" + """Gets the metadata, for computing the objective function.""" loss_batch, grad_batch = metadata_batches metadata_list = [ {"loss": loss_batch[idx], "grad": grad_batch[idx]} @@ -64,6 +62,12 @@ def objective_function(corpus_element): return False +def mutation_function(elt): + """Mutates the element in question.""" + return do_basic_mutations( + elt, FLAGS.mutations_per_corpus_item, a_min=-1000, a_max=1000) + + # pylint: disable=too-many-locals def main(_): """Configures and runs the fuzzer.""" @@ -71,42 +75,32 @@ def main(_): # Log more tf.logging.set_verbosity(tf.logging.INFO) - coverage_function = raw_logit_coverage_function + # Set up initial seed inputs target_seed = np.random.uniform(low=0.0, high=1.0, size=(1,)) - numpy_arrays = [[target_seed]] + seed_inputs = [[target_seed]] + # Specify input, coverage, and metadata tensors targets_tensor = tf.placeholder(tf.float32, [64, 1]) coverage_tensor = tf.identity(targets_tensor) loss_batch_tensor, _ = binary_cross_entropy_with_logits( tf.zeros_like(targets_tensor), tf.nn.sigmoid(targets_tensor) ) grads_tensor = tf.gradients(loss_batch_tensor, targets_tensor)[0] - tensor_map = { - "input": [targets_tensor], - "coverage": [coverage_tensor], - "metadata": [loss_batch_tensor, grads_tensor], - } + # Construct and run fuzzer with tf.Session() as sess: - - fetch_function = build_fetch_function(sess, tensor_map) - size = FLAGS.mutations_per_corpus_item - mutation_function = lambda elt: do_basic_mutations( - elt, size, a_min=-1000, a_max=1000 - ) - seed_corpus = seed_corpus_from_numpy_arrays( - numpy_arrays, coverage_function, metadata_function, fetch_function - ) - corpus = InputCorpus( - seed_corpus, uniform_sample_function, FLAGS.ann_threshold, "kdtree" - ) fuzzer = Fuzzer( - corpus, - coverage_function, - metadata_function, - objective_function, - mutation_function, - fetch_function, + sess=sess, + seed_inputs=seed_inputs, + input_tensors=[targets_tensor], + coverage_tensors=[coverage_tensor], + metadata_tensors=[loss_batch_tensor, grads_tensor], + coverage_function=raw_coverage_function, + metadata_function=metadata_function, + objective_function=objective_function, + mutation_function=mutation_function, + sample_function=uniform_sample_function, + threshold=FLAGS.ann_threshold ) result = fuzzer.loop(FLAGS.total_inputs_to_fuzz) if result is not None: diff --git a/examples/nans/nan_fuzzer.py b/examples/nans/nan_fuzzer.py index a2c86fc..99265d5 100644 --- a/examples/nans/nan_fuzzer.py +++ b/examples/nans/nan_fuzzer.py @@ -20,9 +20,7 @@ import numpy as np import tensorflow as tf from lib import fuzz_utils -from lib.corpus import InputCorpus -from lib.corpus import seed_corpus_from_numpy_arrays -from lib.coverage_functions import all_logit_coverage_function +from lib.coverage_functions import sum_coverage_function from lib.fuzzer import Fuzzer from lib.mutation_functions import do_basic_mutations from lib.sample_functions import recent_sample_function @@ -49,6 +47,10 @@ FLAGS = tf.flags.FLAGS +if FLAGS.checkpoint_dir is None: + raise ValueError('checkpoint_dir flag must be specified') + + def metadata_function(metadata_batches): """Gets the metadata.""" metadata_list = [ @@ -68,6 +70,12 @@ def objective_function(corpus_element): return True +def mutation_function(elt): + """Mutates the element in question.""" + return do_basic_mutations( + elt, FLAGS.mutations_per_corpus_item) + + def main(_): """Constructs the fuzzer and performs fuzzing.""" @@ -78,36 +86,34 @@ def main(_): random.seed(FLAGS.seed) np.random.seed(FLAGS.seed) - coverage_function = all_logit_coverage_function + # Set up seed images image, label = fuzz_utils.basic_mnist_input_corpus( choose_randomly=FLAGS.random_seed_corpus ) - numpy_arrays = [[image, label]] + seed_inputs = [[image, label]] with tf.Session() as sess: + # Specify input, coverage, and metadata tensors + input_tensors, coverage_tensors, metadata_tensors = \ + fuzz_utils.get_tensors_from_checkpoint( + sess, FLAGS.checkpoint_dir + ) - tensor_map = fuzz_utils.get_tensors_from_checkpoint( - sess, FLAGS.checkpoint_dir - ) - - fetch_function = fuzz_utils.build_fetch_function(sess, tensor_map) - - size = FLAGS.mutations_per_corpus_item - mutation_function = lambda elt: do_basic_mutations(elt, size) - seed_corpus = seed_corpus_from_numpy_arrays( - numpy_arrays, coverage_function, metadata_function, fetch_function - ) - corpus = InputCorpus( - seed_corpus, recent_sample_function, FLAGS.ann_threshold, "kdtree" - ) + # Construct and run fuzzer fuzzer = Fuzzer( - corpus, - coverage_function, - metadata_function, - objective_function, - mutation_function, - fetch_function, + sess=sess, + seed_inputs=seed_inputs, + input_tensors=input_tensors, + coverage_tensors=coverage_tensors, + metadata_tensors=metadata_tensors, + coverage_function=sum_coverage_function, + metadata_function=metadata_function, + objective_function=objective_function, + mutation_function=mutation_function, + sample_function=recent_sample_function, + threshold=FLAGS.ann_threshold, ) + result = fuzzer.loop(FLAGS.total_inputs_to_fuzz) if result is not None: tf.logging.info("Fuzzing succeeded.") diff --git a/examples/quantize/README.md b/examples/quantize/README.md index 8e70228..96798c0 100644 --- a/examples/quantize/README.md +++ b/examples/quantize/README.md @@ -6,11 +6,11 @@ casts those variables to use 16 bits. To train this model, execute something like this: ``` -python examples/quantize/quantized_model.py --checkpoint_dir='/tmp/quantized_checkpoints_2' --training_steps=10000 +python examples/quantize/quantized_model.py --checkpoint_dir='/tmp/quantized_checkpoints' --training_steps=10000 ``` To fuzz the trained model, execute something like this: ``` -python examples/quantize/quantized_fuzzer.py --checkpoint_dir=/tmp/quantized_checkpoints_2 --total_inputs_to_fuzz=1000000 --mutations_per_corpus_item=100 --alsologtostderr --output_path=/cns/ok-d/home/augustusodena/fuzzer/plots/quantized_image.png --ann_threshold=1.0 --perturbation_constraint=1.0 --strategy=ann +python examples/quantize/quantized_fuzzer.py --checkpoint_dir=/tmp/quantized_checkpoints --total_inputs_to_fuzz=1000000 --mutations_per_corpus_item=100 --alsologtostderr --output_path=/cns/ok-d/home/augustusodena/fuzzer/plots/quantized_image.png --ann_threshold=1.0 --perturbation_constraint=1.0 --strategy=ann ``` diff --git a/examples/quantize/quantized_fuzzer.py b/examples/quantize/quantized_fuzzer.py index 9c8e9c4..6ed190b 100644 --- a/examples/quantize/quantized_fuzzer.py +++ b/examples/quantize/quantized_fuzzer.py @@ -19,9 +19,7 @@ import numpy as np import tensorflow as tf from lib import fuzz_utils -from lib.corpus import InputCorpus -from lib.corpus import seed_corpus_from_numpy_arrays -from lib.coverage_functions import raw_logit_coverage_function +from lib.coverage_functions import sum_coverage_function from lib.fuzzer import Fuzzer from lib.mutation_functions import do_basic_mutations from lib.sample_functions import recent_sample_function @@ -53,20 +51,27 @@ FLAGS = tf.flags.FLAGS +if FLAGS.checkpoint_dir is None: + raise ValueError('checkpoint_dir flag must be specified') + + + def metadata_function(metadata_batches): - """Gets the metadata.""" + """Gets the metadata, for computing the objective function.""" logit_32_batch = metadata_batches[0] logit_16_batch = metadata_batches[1] metadata_list = [] for idx in range(logit_16_batch.shape[0]): - metadata_list.append((logit_32_batch[idx], logit_16_batch[idx])) + metadata_list.append({ + "logits_32": logit_32_batch[idx], + "logits_16": logit_16_batch[idx]}) return metadata_list def objective_function(corpus_element): """Checks if the element is misclassified.""" - logits_32 = corpus_element.metadata[0] - logits_16 = corpus_element.metadata[1] + logits_32 = corpus_element.metadata["logits_32"] + logits_16 = corpus_element.metadata["logits_16"] prediction_16 = np.argmax(logits_16) prediction_32 = np.argmax(logits_32) if prediction_16 == prediction_32: @@ -80,6 +85,12 @@ def objective_function(corpus_element): return True +def mutation_function(elt): + """Mutates the element in question.""" + return do_basic_mutations( + elt, FLAGS.mutations_per_corpus_item, FLAGS.perturbation_constraint) + + # pylint: disable=too-many-locals def main(_): """Constructs the fuzzer and fuzzes.""" @@ -87,49 +98,43 @@ def main(_): # Log more tf.logging.set_verbosity(tf.logging.INFO) - coverage_function = raw_logit_coverage_function + # Set up initial seed inputs image, label = fuzz_utils.basic_mnist_input_corpus( choose_randomly=FLAGS.random_seed_corpus ) - numpy_arrays = [[image, label]] + seed_inputs = [[image, label]] image_copy = image[:] with tf.Session() as sess: + # Specify input, coverage, and metadata tensors + input_tensors, coverage_tensors, metadata_tensors = \ + fuzz_utils.get_tensors_from_checkpoint( + sess, FLAGS.checkpoint_dir + ) - tensor_map = fuzz_utils.get_tensors_from_checkpoint( - sess, FLAGS.checkpoint_dir - ) - - fetch_function = fuzz_utils.build_fetch_function(sess, tensor_map) - - size = FLAGS.mutations_per_corpus_item - - def mutation_function(elt): - """Mutates the element in question.""" - return do_basic_mutations(elt, size, FLAGS.perturbation_constraint) - - seed_corpus = seed_corpus_from_numpy_arrays( - numpy_arrays, coverage_function, metadata_function, fetch_function - ) - corpus = InputCorpus( - seed_corpus, recent_sample_function, FLAGS.ann_threshold, "kdtree" - ) + # Construct and run fuzzer fuzzer = Fuzzer( - corpus, - coverage_function, - metadata_function, - objective_function, - mutation_function, - fetch_function, + sess=sess, + seed_inputs=seed_inputs, + input_tensors=input_tensors, + coverage_tensors=coverage_tensors, + metadata_tensors=metadata_tensors, + coverage_function=sum_coverage_function, + metadata_function=metadata_function, + objective_function=objective_function, + mutation_function=mutation_function, + sample_function=recent_sample_function, + threshold=FLAGS.ann_threshold, ) result = fuzzer.loop(FLAGS.total_inputs_to_fuzz) + if result is not None: # Double check that there is persistent disagreement for idx in range(10): logits, quantized_logits = sess.run( - [tensor_map["coverage"][0], tensor_map["coverage"][1]], + [coverage_tensors[0], coverage_tensors[1]], feed_dict={ - tensor_map["input"][0]: np.expand_dims( + input_tensors[0]: np.expand_dims( result.data[0], 0 ) }, diff --git a/examples/quantize/quantized_model.py b/examples/quantize/quantized_model.py index 68b909b..e43d387 100644 --- a/examples/quantize/quantized_model.py +++ b/examples/quantize/quantized_model.py @@ -24,7 +24,7 @@ tf.flags.DEFINE_string( "checkpoint_dir", - "/tmp/nanfuzzer", + "/tmp/quantizefuzzer", "The overall dir in which we store experiments", ) tf.flags.DEFINE_string( diff --git a/examples/tfgrad/cumprod.py b/examples/tfgrad/cumprod.py new file mode 100644 index 0000000..853176f --- /dev/null +++ b/examples/tfgrad/cumprod.py @@ -0,0 +1,117 @@ +# Copyright 2018 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuzz a tf op to get a NaN.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import random +import numpy as np +import tensorflow as tf +from lib import fuzz_utils +from lib.coverage_functions import raw_coverage_function +from lib.fuzzer import Fuzzer +from lib.mutation_functions import do_basic_mutations +from lib.sample_functions import recent_sample_function + +tf.flags.DEFINE_integer( + "total_inputs_to_fuzz", 100, "Loops over the whole corpus." +) +tf.flags.DEFINE_integer( + "mutations_per_corpus_item", 100, "Number of times to mutate corpus item." +) +tf.flags.DEFINE_float( + "ann_threshold", + 1.0, + "Distance below which we consider something new coverage.", +) +tf.flags.DEFINE_integer("seed", None, "Random seed for both python and numpy.") +tf.flags.DEFINE_boolean( + "random_seed_corpus", False, "Whether to choose a random seed corpus." +) +FLAGS = tf.flags.FLAGS + + +def metadata_function(metadata_batches): + """Gets the metadata.""" + metadata_list = [metadata_batches[0][i] for i in range(len(metadata_batches[0]))] + return metadata_list + + +def objective_function(corpus_element): + """Checks if the metadata is inf or NaN.""" + metadata = corpus_element.metadata + if all([np.isfinite(d).all() for d in metadata]): + return False + + tf.logging.info("Objective function satisfied: non-finite element found.") + return True + + +def mutation_function(elt): + """Mutates the element in question.""" + return do_basic_mutations( + elt, FLAGS.mutations_per_corpus_item, a_min=None, a_max=None) + + +def main(_): + """Constructs the fuzzer and performs fuzzing.""" + + # Log more + tf.logging.set_verbosity(tf.logging.INFO) + # Set the seeds! + if FLAGS.seed: + random.seed(FLAGS.seed) + np.random.seed(FLAGS.seed) + + # Set up seed inputs + sz = 16 +# target_seed = np.random.uniform(low=0.0, high=1.0, size=(sz,)) + target_seed = np.ones(sz, dtype=np.uint32)*4 + seed_inputs = [[target_seed]] + + # Specify input, coverage, and metadata tensors + input_tensor = tf.placeholder(tf.int32, [None, sz]) + op_tensor = tf.cumprod(input_tensor) + grad_tensors = tf.gradients(op_tensor, input_tensor) + + with tf.Session() as sess: + # Construct and run fuzzer + fuzzer = Fuzzer( + sess=sess, + seed_inputs=seed_inputs, + input_tensors=[input_tensor], + coverage_tensors=grad_tensors, + metadata_tensors=grad_tensors, + coverage_function=raw_coverage_function, + metadata_function=metadata_function, + objective_function=objective_function, + mutation_function=mutation_function, + sample_function=recent_sample_function, + threshold=FLAGS.ann_threshold, + ) + + result = fuzzer.loop(FLAGS.total_inputs_to_fuzz) + if result is not None: + tf.logging.info("Fuzzing succeeded.") + tf.logging.info( + "Generations to make satisfying element: %s.", + result.oldest_ancestor()[1], + ) + else: + tf.logging.info("Fuzzing failed to satisfy objective function.") + + +if __name__ == "__main__": + tf.app.run() diff --git a/lib/corpus.py b/lib/corpus.py index b25a14c..1eed502 100644 --- a/lib/corpus.py +++ b/lib/corpus.py @@ -33,7 +33,7 @@ def __init__(self, data, metadata, coverage, parent): Args: data: a list of numpy arrays representing the mutated data. - metadata: arbitrary python object to be used by the fuzzer for e.g. + metadata: arbitrary python object to be used by the fuzzer for computing the objective function during the fuzzing loop. coverage: an arbitrary hashable python object that guides fuzzing process. parent: a reference to the CorpusElement this element is a mutation of. @@ -60,7 +60,9 @@ def oldest_ancestor(self): def seed_corpus_from_numpy_arrays( numpy_arrays, coverage_function, metadata_function, fetch_function ): - """Constructs a seed_corpus given numpy_arrays. + """Constructs a starting corpus, given numpy_arrays of inputs, + by passing them through the system to produce outputs (i.e. coverage + representations) We only use the first element of the batch that we fetch, because we're only trying to create one corpus element, and we may end up @@ -69,8 +71,8 @@ def seed_corpus_from_numpy_arrays( Args: numpy_arrays: multiple lists of input_arrays, each list with as many arrays as there are input tensors. - coverage_function: a function that does CorpusElement -> Coverage. - metadata_function: a function that does CorpusElement -> Metadata. + coverage_function: a function that does coverage batches -> coverage object. + metadata_function: a function that does metadata batches -> metadata object. fetch_function: grabs output from tensorflow runtime. Returns: List of CorpusElements. diff --git a/lib/coverage_functions.py b/lib/coverage_functions.py index 5e4d4b5..025b3cc 100644 --- a/lib/coverage_functions.py +++ b/lib/coverage_functions.py @@ -19,17 +19,16 @@ import numpy as np -def all_logit_coverage_function(coverage_batches): - """Computes coverage based on the sum of the absolute values of the logits. +def sum_coverage_function(coverage_batches): + """Computes coverage as the sum of the absolute values of coverage_batches. Args: coverage_batches: Numpy arrays containing coverage information pulled from - a call to sess.run. In this case, we assume that these correspond to a - batch of logits. + a call to sess.run. Returns: - A python integer corresponding to the sum of the absolute values of the - logits. + A list of python integers corresponding to the sum of the absolute + values of the entries in coverage_batches. """ coverage_batch = coverage_batches[0] coverage_list = [] @@ -40,18 +39,17 @@ def all_logit_coverage_function(coverage_batches): return coverage_list -def raw_logit_coverage_function(coverage_batches): - """The coverage in this case is just the actual logits. +def raw_coverage_function(coverage_batches): + """The coverage in this case is just the actual values of coverage_batches. This coverage function is intended for use with a nearest neighbor method. Args: coverage_batches: Numpy arrays containing coverage information pulled from - a call to sess.run. In this case, we assume that these correspond to a - batch of logits. + a call to sess.run. Returns: - A numpy array of logits. + A list of numpy arrays corresponding to the entries in coverage_batches. """ # For our purpose, we only need the first coverage element coverage_batch = coverage_batches[0] diff --git a/lib/fuzz_utils.py b/lib/fuzz_utils.py index eba425f..57eb147 100644 --- a/lib/fuzz_utils.py +++ b/lib/fuzz_utils.py @@ -150,12 +150,7 @@ def get_tensors_from_checkpoint(sess, checkpoint_dir): coverage_tensors = tf.get_collection("coverage_tensors") metadata_tensors = tf.get_collection("metadata_tensors") - tensor_map = { - "input": input_tensors, - "coverage": coverage_tensors, - "metadata": metadata_tensors, - } - return tensor_map + return input_tensors, coverage_tensors, metadata_tensors def fetch_function( @@ -183,16 +178,16 @@ def fetch_function( return coverage_batches, metadata_batches -def build_fetch_function(sess, tensor_map): +def build_fetch_function(sess, input_tensors, coverage_tensors, metadata_tensors): """Constructs fetch function given session and tensors.""" def func(input_batches): """The fetch function.""" return fetch_function( sess, - tensor_map["input"], - tensor_map["coverage"], - tensor_map["metadata"], + input_tensors, + coverage_tensors, + metadata_tensors, input_batches, ) diff --git a/lib/fuzzer.py b/lib/fuzzer.py index c05369c..7935a52 100644 --- a/lib/fuzzer.py +++ b/lib/fuzzer.py @@ -17,6 +17,9 @@ from __future__ import division from __future__ import print_function from lib.corpus import CorpusElement +from lib.corpus import InputCorpus +from lib.corpus import seed_corpus_from_numpy_arrays +from lib.fuzz_utils import build_fetch_function import tensorflow as tf @@ -27,33 +30,59 @@ class Fuzzer(object): def __init__( self, - corpus, + sess, + seed_inputs, + input_tensors, + coverage_tensors, + metadata_tensors, coverage_function, metadata_function, objective_function, mutation_function, - fetch_function, + sample_function, + threshold, + algorithm="kdtree", ): """Init the class. Args: - corpus: An InputCorpus object. - coverage_function: a function that does CorpusElement -> Coverage. - metadata_function: a function that does CorpusElement -> Metadata. + sess: a TF session + seed_inputs: np arrays of initial inputs, to seed the corpus with. + input_tensors: TF tensors to which we feed batches of input. + coverage_tensors: TF tensors we fetch to get coverage batches. + metadata_tensors: TF tensors we fetch to get metadata batches. + coverage_function: a function that does coverage batches -> coverage object. + metadata_function: a function that does metadata batches -> metadata object. objective_function: a function that checks if a CorpusElement satisifies the fuzzing objective (e.g. find a NaN, find a misclassification, etc). - mutation_function: a function that does CorpusElement -> Metadata. + mutation_function: a function that does CorpusElement -> mutated data. fetch_function: grabs numpy arrays from the TF runtime using the relevant - tensors. + tensors, to produce coverage_batches and metadata_batches Returns: Initialized object. """ - self.corpus = corpus self.coverage_function = coverage_function self.metadata_function = metadata_function self.objective_function = objective_function self.mutation_function = mutation_function - self.fetch_function = fetch_function + + # create a single fetch function (to sess.run the tensors) + self.fetch_function = build_fetch_function( + sess, + input_tensors, + coverage_tensors, + metadata_tensors + ) + + # set up seed corpus + seed_corpus = seed_corpus_from_numpy_arrays( + seed_inputs, + self.coverage_function, self.metadata_function, self.fetch_function + ) + self.corpus = InputCorpus( + seed_corpus, sample_function, threshold, algorithm + ) + def loop(self, iterations): """Fuzzes a machine learning model in a loop, making *iterations* steps.""" @@ -87,6 +116,11 @@ def loop(self, iterations): parent, ) if self.objective_function(new_element): + tf.logging.info( + "OBJECTIVE SATISFIED: coverage: %s, metadata: %s", + new_element.coverage, + new_element.metadata, + ) return new_element self.corpus.maybe_add_to_corpus(new_element) diff --git a/lib/mutation_functions.py b/lib/mutation_functions.py index ea47150..c02b985 100644 --- a/lib/mutation_functions.py +++ b/lib/mutation_functions.py @@ -32,6 +32,7 @@ def do_basic_mutations( mutations_count: Integer representing number of mutations to do in parallel. constraint: If not None, a constraint on the norm of the total mutation. + a_min, a_max: Constraints on the values of the mutated input Returns: A list of batches, the first of which is mutated images and the second of @@ -45,10 +46,16 @@ def do_basic_mutations( image_batch = np.tile(image, [mutations_count, 1, 1, 1]) else: image = corpus_element.data[0] - image_batch = np.tile(image, [mutations_count] + list(image.shape)) + image_batch = np.tile( + image, + [mutations_count] + list(np.ones_like(image.shape))) - sigma = 0.2 - noise = np.random.normal(size=image_batch.shape, scale=sigma) + if np.issubdtype(image_batch.dtype, np.floating): + sigma = 0.2 + noise = np.random.normal(size=image_batch.shape, scale=sigma) + elif np.issubdtype(image_batch.dtype, np.integer): + sigma = 10 + noise = np.round(np.random.normal(size=image_batch.shape, scale=sigma)) if constraint is not None: # (image - original_image) is a single image. it gets broadcast into a batch @@ -65,9 +72,10 @@ def do_basic_mutations( else: mutated_image_batch = noise + image_batch - mutated_image_batch = np.clip( - mutated_image_batch, a_min=a_min, a_max=a_max - ) + if a_min is not None or a_max is not None: + mutated_image_batch = np.clip( + mutated_image_batch, a_min=a_min, a_max=a_max + ) if len(corpus_element.data) > 1: label_batch = np.tile(label, [mutations_count])