Skip to content
This repository was archived by the owner on May 15, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 23 additions & 29 deletions examples/dcgan/dcgan_fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import numpy as np
import tensorflow as tf
from lib.fuzz_utils import build_fetch_function
from lib.corpus import InputCorpus
from lib.corpus import seed_corpus_from_numpy_arrays
from lib.coverage_functions import raw_logit_coverage_function
from lib.coverage_functions import raw_coverage_function
from lib.fuzzer import Fuzzer
from lib.mutation_functions import do_basic_mutations
from lib.sample_functions import uniform_sample_function
Expand All @@ -40,7 +38,7 @@


def metadata_function(metadata_batches):
"""Gets the metadata."""
"""Gets the metadata, for computing the objective function."""
loss_batch, grad_batch = metadata_batches
metadata_list = [
{"loss": loss_batch[idx], "grad": grad_batch[idx]}
Expand All @@ -64,49 +62,45 @@ def objective_function(corpus_element):
return False


def mutation_function(elt):
"""Mutates the element in question."""
return do_basic_mutations(
elt, FLAGS.mutations_per_corpus_item, a_min=-1000, a_max=1000)


# pylint: disable=too-many-locals
def main(_):
"""Configures and runs the fuzzer."""

# Log more
tf.logging.set_verbosity(tf.logging.INFO)

coverage_function = raw_logit_coverage_function
# Set up initial seed inputs
target_seed = np.random.uniform(low=0.0, high=1.0, size=(1,))
numpy_arrays = [[target_seed]]
seed_inputs = [[target_seed]]

# Specify input, coverage, and metadata tensors
targets_tensor = tf.placeholder(tf.float32, [64, 1])
coverage_tensor = tf.identity(targets_tensor)
loss_batch_tensor, _ = binary_cross_entropy_with_logits(
tf.zeros_like(targets_tensor), tf.nn.sigmoid(targets_tensor)
)
grads_tensor = tf.gradients(loss_batch_tensor, targets_tensor)[0]
tensor_map = {
"input": [targets_tensor],
"coverage": [coverage_tensor],
"metadata": [loss_batch_tensor, grads_tensor],
}

# Construct and run fuzzer
with tf.Session() as sess:

fetch_function = build_fetch_function(sess, tensor_map)
size = FLAGS.mutations_per_corpus_item
mutation_function = lambda elt: do_basic_mutations(
elt, size, a_min=-1000, a_max=1000
)
seed_corpus = seed_corpus_from_numpy_arrays(
numpy_arrays, coverage_function, metadata_function, fetch_function
)
corpus = InputCorpus(
seed_corpus, uniform_sample_function, FLAGS.ann_threshold, "kdtree"
)
fuzzer = Fuzzer(
corpus,
coverage_function,
metadata_function,
objective_function,
mutation_function,
fetch_function,
sess=sess,
seed_inputs=seed_inputs,
input_tensors=[targets_tensor],
coverage_tensors=[coverage_tensor],
metadata_tensors=[loss_batch_tensor, grads_tensor],
coverage_function=raw_coverage_function,
metadata_function=metadata_function,
objective_function=objective_function,
mutation_function=mutation_function,
sample_function=uniform_sample_function,
threshold=FLAGS.ann_threshold
)
result = fuzzer.loop(FLAGS.total_inputs_to_fuzz)
if result is not None:
Expand Down
56 changes: 31 additions & 25 deletions examples/nans/nan_fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import numpy as np
import tensorflow as tf
from lib import fuzz_utils
from lib.corpus import InputCorpus
from lib.corpus import seed_corpus_from_numpy_arrays
from lib.coverage_functions import all_logit_coverage_function
from lib.coverage_functions import sum_coverage_function
from lib.fuzzer import Fuzzer
from lib.mutation_functions import do_basic_mutations
from lib.sample_functions import recent_sample_function
Expand All @@ -49,6 +47,10 @@
FLAGS = tf.flags.FLAGS


if FLAGS.checkpoint_dir is None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something weird happens to me when this line runs.
I get an error from absl telling me that alsologtostderr is not a recognized FLAG.
It seems totally unrelated but when I comment out this check everything works fine?
Would you just remove this check for now?

raise ValueError('checkpoint_dir flag must be specified')


def metadata_function(metadata_batches):
"""Gets the metadata."""
metadata_list = [
Expand All @@ -68,6 +70,12 @@ def objective_function(corpus_element):
return True


def mutation_function(elt):
"""Mutates the element in question."""
return do_basic_mutations(
elt, FLAGS.mutations_per_corpus_item)


def main(_):
"""Constructs the fuzzer and performs fuzzing."""

Expand All @@ -78,36 +86,34 @@ def main(_):
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)

coverage_function = all_logit_coverage_function
# Set up seed images
image, label = fuzz_utils.basic_mnist_input_corpus(
choose_randomly=FLAGS.random_seed_corpus
)
numpy_arrays = [[image, label]]
seed_inputs = [[image, label]]

with tf.Session() as sess:
# Specify input, coverage, and metadata tensors
input_tensors, coverage_tensors, metadata_tensors = \
fuzz_utils.get_tensors_from_checkpoint(
sess, FLAGS.checkpoint_dir
)

tensor_map = fuzz_utils.get_tensors_from_checkpoint(
sess, FLAGS.checkpoint_dir
)

fetch_function = fuzz_utils.build_fetch_function(sess, tensor_map)

size = FLAGS.mutations_per_corpus_item
mutation_function = lambda elt: do_basic_mutations(elt, size)
seed_corpus = seed_corpus_from_numpy_arrays(
numpy_arrays, coverage_function, metadata_function, fetch_function
)
corpus = InputCorpus(
seed_corpus, recent_sample_function, FLAGS.ann_threshold, "kdtree"
)
# Construct and run fuzzer
fuzzer = Fuzzer(
corpus,
coverage_function,
metadata_function,
objective_function,
mutation_function,
fetch_function,
sess=sess,
seed_inputs=seed_inputs,
input_tensors=input_tensors,
coverage_tensors=coverage_tensors,
metadata_tensors=metadata_tensors,
coverage_function=sum_coverage_function,
metadata_function=metadata_function,
objective_function=objective_function,
mutation_function=mutation_function,
sample_function=recent_sample_function,
threshold=FLAGS.ann_threshold,
)

result = fuzzer.loop(FLAGS.total_inputs_to_fuzz)
if result is not None:
tf.logging.info("Fuzzing succeeded.")
Expand Down
4 changes: 2 additions & 2 deletions examples/quantize/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ casts those variables to use 16 bits.
To train this model, execute something like this:

```
python examples/quantize/quantized_model.py --checkpoint_dir='/tmp/quantized_checkpoints_2' --training_steps=10000
python examples/quantize/quantized_model.py --checkpoint_dir='/tmp/quantized_checkpoints' --training_steps=10000
```

To fuzz the trained model, execute something like this:

```
python examples/quantize/quantized_fuzzer.py --checkpoint_dir=/tmp/quantized_checkpoints_2 --total_inputs_to_fuzz=1000000 --mutations_per_corpus_item=100 --alsologtostderr --output_path=/cns/ok-d/home/augustusodena/fuzzer/plots/quantized_image.png --ann_threshold=1.0 --perturbation_constraint=1.0 --strategy=ann
python examples/quantize/quantized_fuzzer.py --checkpoint_dir=/tmp/quantized_checkpoints --total_inputs_to_fuzz=1000000 --mutations_per_corpus_item=100 --alsologtostderr --output_path=/cns/ok-d/home/augustusodena/fuzzer/plots/quantized_image.png --ann_threshold=1.0 --perturbation_constraint=1.0 --strategy=ann
```
75 changes: 40 additions & 35 deletions examples/quantize/quantized_fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import numpy as np
import tensorflow as tf
from lib import fuzz_utils
from lib.corpus import InputCorpus
from lib.corpus import seed_corpus_from_numpy_arrays
from lib.coverage_functions import raw_logit_coverage_function
from lib.coverage_functions import sum_coverage_function
from lib.fuzzer import Fuzzer
from lib.mutation_functions import do_basic_mutations
from lib.sample_functions import recent_sample_function
Expand Down Expand Up @@ -53,20 +51,27 @@
FLAGS = tf.flags.FLAGS


if FLAGS.checkpoint_dir is None:
raise ValueError('checkpoint_dir flag must be specified')



def metadata_function(metadata_batches):
"""Gets the metadata."""
"""Gets the metadata, for computing the objective function."""
logit_32_batch = metadata_batches[0]
logit_16_batch = metadata_batches[1]
metadata_list = []
for idx in range(logit_16_batch.shape[0]):
metadata_list.append((logit_32_batch[idx], logit_16_batch[idx]))
metadata_list.append({
"logits_32": logit_32_batch[idx],
"logits_16": logit_16_batch[idx]})
return metadata_list


def objective_function(corpus_element):
"""Checks if the element is misclassified."""
logits_32 = corpus_element.metadata[0]
logits_16 = corpus_element.metadata[1]
logits_32 = corpus_element.metadata["logits_32"]
logits_16 = corpus_element.metadata["logits_16"]
prediction_16 = np.argmax(logits_16)
prediction_32 = np.argmax(logits_32)
if prediction_16 == prediction_32:
Expand All @@ -80,56 +85,56 @@ def objective_function(corpus_element):
return True


def mutation_function(elt):
"""Mutates the element in question."""
return do_basic_mutations(
elt, FLAGS.mutations_per_corpus_item, FLAGS.perturbation_constraint)


# pylint: disable=too-many-locals
def main(_):
"""Constructs the fuzzer and fuzzes."""

# Log more
tf.logging.set_verbosity(tf.logging.INFO)

coverage_function = raw_logit_coverage_function
# Set up initial seed inputs
image, label = fuzz_utils.basic_mnist_input_corpus(
choose_randomly=FLAGS.random_seed_corpus
)
numpy_arrays = [[image, label]]
seed_inputs = [[image, label]]
image_copy = image[:]

with tf.Session() as sess:
# Specify input, coverage, and metadata tensors
input_tensors, coverage_tensors, metadata_tensors = \
fuzz_utils.get_tensors_from_checkpoint(
sess, FLAGS.checkpoint_dir
)

tensor_map = fuzz_utils.get_tensors_from_checkpoint(
sess, FLAGS.checkpoint_dir
)

fetch_function = fuzz_utils.build_fetch_function(sess, tensor_map)

size = FLAGS.mutations_per_corpus_item

def mutation_function(elt):
"""Mutates the element in question."""
return do_basic_mutations(elt, size, FLAGS.perturbation_constraint)

seed_corpus = seed_corpus_from_numpy_arrays(
numpy_arrays, coverage_function, metadata_function, fetch_function
)
corpus = InputCorpus(
seed_corpus, recent_sample_function, FLAGS.ann_threshold, "kdtree"
)
# Construct and run fuzzer
fuzzer = Fuzzer(
corpus,
coverage_function,
metadata_function,
objective_function,
mutation_function,
fetch_function,
sess=sess,
seed_inputs=seed_inputs,
input_tensors=input_tensors,
coverage_tensors=coverage_tensors,
metadata_tensors=metadata_tensors,
coverage_function=sum_coverage_function,
metadata_function=metadata_function,
objective_function=objective_function,
mutation_function=mutation_function,
sample_function=recent_sample_function,
threshold=FLAGS.ann_threshold,
)
result = fuzzer.loop(FLAGS.total_inputs_to_fuzz)

if result is not None:
# Double check that there is persistent disagreement
for idx in range(10):
logits, quantized_logits = sess.run(
[tensor_map["coverage"][0], tensor_map["coverage"][1]],
[coverage_tensors[0], coverage_tensors[1]],
feed_dict={
tensor_map["input"][0]: np.expand_dims(
input_tensors[0]: np.expand_dims(
result.data[0], 0
)
},
Expand Down
2 changes: 1 addition & 1 deletion examples/quantize/quantized_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

tf.flags.DEFINE_string(
"checkpoint_dir",
"/tmp/nanfuzzer",
"/tmp/quantizefuzzer",
"The overall dir in which we store experiments",
)
tf.flags.DEFINE_string(
Expand Down
Loading