Skip to content

Commit

Permalink
Explore the search space when initializing the Tuner based on conditi…
Browse files Browse the repository at this point in the history
…ons (#584)

* test case added

* code added

* adding docs

* refactor to move the function to base tuner

Co-authored-by: Haifeng Jin <[email protected]>
  • Loading branch information
haifeng-jin and haifeng-jin authored Aug 15, 2021
1 parent 130015e commit 98d5927
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 9 deletions.
40 changes: 36 additions & 4 deletions keras_tuner/engine/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function

import copy
import os

import tensorflow as tf
Expand Down Expand Up @@ -110,12 +111,43 @@ def __init__(
def _populate_initial_space(self):
"""Populate initial search space for oracle.
Keep this function as a subroutine for AutoKeras to override. The space may
not be ready at the initialization of the tuner, but after seeing the
training data.
Keep this function as a subroutine for AutoKeras to override. The space
may not be ready at the initialization of the tuner, but after seeing
the training data.
Build hypermodel multiple times to find all conditional hps. It
generates hp values based on the not activated `conditional_scopes`
found in the builds.
"""
hp = self.oracle.get_space()
self.hypermodel.build(hp)

# Lists of stacks of conditions used during `explore_space()`.
scopes_never_active = []
scopes_once_active = []

while True:
self.hypermodel.build(hp)

# Update the recored scopes.
for conditions in hp.active_scopes:
if conditions not in scopes_once_active:
scopes_once_active.append(copy.deepcopy(conditions))
if conditions in scopes_never_active:
scopes_never_active.remove(conditions)

for conditions in hp.inactive_scopes:
if conditions not in scopes_once_active:
scopes_never_active.append(copy.deepcopy(conditions))

# All conditional scopes are activated.
if len(scopes_never_active) == 0:
break

# Generate new values to activate new conditions.
conditions = scopes_never_active[0]
for condition in conditions:
hp.values[condition.name] = condition.values[0]

self.oracle.update_space(hp)

def search(self, *fit_args, **fit_kwargs):
Expand Down
29 changes: 24 additions & 5 deletions keras_tuner/engine/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import collections
import contextlib
import copy
import math
import random

Expand Down Expand Up @@ -70,6 +71,11 @@ def _check_int(val, arg):
class HyperParameter(object):
"""Hyperparameter base class.
A `HyperParameter` instance is uniquely identified by its `name` and
`conditions` attributes. `HyperParameter`s with the same `name` but with
different `conditions` are considered as different `HyperParameter`s by
the `HyperParameters` instance.
Args:
name: A string. the name of parameter. Must be unique for each
`HyperParameter` instance in the search space.
Expand Down Expand Up @@ -557,6 +563,12 @@ def __init__(self):
# Active values for this `Trial`.
self.values = {}

# A list of active `conditional_scope`s in a build,
# each of which is a list of condtions.
self.active_scopes = []
# Similar for inactive `conditional_scope`s.
self.inactive_scopes = []

@contextlib.contextmanager
def name_scope(self, name):
self._name_scopes.append(name)
Expand Down Expand Up @@ -594,12 +606,12 @@ def build(self, hp):
model = Sequential()
model.add(Input(shape=(32, 32, 3)))
model_type = hp.Choice("model_type", ["mlp", "cnn"])
if model_type == "mlp":
with hp.conditional_scope("model_type", ["mlp"]):
with hp.conditional_scope("model_type", ["mlp"]):
if model_type == "mlp":
model.add(Flatten())
model.add(Dense(32, activation='relu'))
if model_type == "cnn":
with hp.conditional_scope("model_type", ["cnn"]):
with hp.conditional_scope("model_type", ["cnn"]):
if model_type == "cnn":
model.add(Conv2D(64, 3, activation='relu'))
model.add(GlobalAveragePooling2D())
model.add(Dense(10, activation='softmax'))
Expand All @@ -620,7 +632,14 @@ def build(self, hp):
"`HyperParameter` named: " + parent_name + " " "not defined."
)

self._conditions.append(conditions_mod.Parent(parent_name, parent_values))
condition = conditions_mod.Parent(parent_name, parent_values)
self._conditions.append(condition)

if condition.is_active(self.values):
self.active_scopes.append(copy.deepcopy(self._conditions))
else:
self.inactive_scopes.append(copy.deepcopy(self._conditions))

try:
yield
finally:
Expand Down
54 changes: 54 additions & 0 deletions tests/keras_tuner/engine/tuner_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,57 @@ def test_save_model_delete_called(tmp_dir):
tuner = save_model_setup_tuner(tmp_dir)
tuner.save_model("a", None, step=16)
assert tuner.was_called


def test_init_build_all_hps_in_all_conditions(tmp_dir):
class ConditionalHyperModel(MockHyperModel):
def build(self, hp):
model_type = hp.Choice("model_type", ["cnn", "mlp"])
with hp.conditional_scope("model_type", ["cnn"]):
if model_type == "cnn":
sub_cnn = hp.Choice("sub_cnn", ["a", "b"])
with hp.conditional_scope("sub_cnn", ["a"]):
if sub_cnn == "a":
hp.Int("n_filters_a", 2, 4)
with hp.conditional_scope("sub_cnn", ["b"]):
if sub_cnn == "b":
hp.Int("n_filters_b", 6, 8)
with hp.conditional_scope("model_type", ["mlp"]):
if model_type == "mlp":
sub_mlp = hp.Choice("sub_mlp", ["a", "b"])
with hp.conditional_scope("sub_mlp", ["a"]):
if sub_mlp == "a":
hp.Int("n_units_a", 2, 4)
with hp.conditional_scope("sub_mlp", ["b"]):
if sub_mlp == "b":
hp.Int("n_units_b", 6, 8)
more_block = hp.Boolean("more_block", default=False)
with hp.conditional_scope("more_block", [True]):
if more_block:
hp.Int("new_block_hp", 1, 3)
return super().build(hp)

def name_in_hp(name, hp):
return any([name == single_hp.name for single_hp in hp.space])

class MyTuner(tuner_module.Tuner):
def _populate_initial_space(self):
super()._populate_initial_space()
hp = self.oracle.hyperparameters
assert name_in_hp("model_type", hp)
assert name_in_hp("sub_cnn", hp)
assert name_in_hp("n_filters_a", hp)
assert name_in_hp("n_filters_b", hp)
assert name_in_hp("sub_mlp", hp)
assert name_in_hp("n_units_a", hp)
assert name_in_hp("n_units_b", hp)
assert name_in_hp("more_block", hp)
assert name_in_hp("new_block_hp", hp)

MyTuner(
oracle=keras_tuner.tuners.randomsearch.RandomSearchOracle(
objective="loss", max_trials=2, seed=1337
),
hypermodel=ConditionalHyperModel(),
directory=tmp_dir,
)

0 comments on commit 98d5927

Please sign in to comment.