Skip to content

Commit

Permalink
black format
Browse files Browse the repository at this point in the history
  • Loading branch information
gao-jianming committed Apr 1, 2024
1 parent 6654f30 commit 50684f6
Show file tree
Hide file tree
Showing 22 changed files with 1,322 additions and 1,083 deletions.
24 changes: 19 additions & 5 deletions curriculum_generation/curriculum_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
##############################################################################
# Create training tasks using custom evaluation functions


def PracticeEating(gs, subject):
"""The progress, the max of which is 1, should
* increase small for each eating
Expand All @@ -55,12 +56,15 @@ def PracticeEating(gs, subject):
progress += 0.3
return norm(progress) # norm is a helper function to normalize the value to [0, 1]


curriculum.append(TaskSpec(eval_fn=PracticeEating, eval_fn_kwargs={}))


# You can also use pre-built eval functions to define your own eval functions
def PracticeInventoryManagement(gs, subject, space, num_tick):
return norm(InventorySpaceGE(gs, subject, space) * TickGE(gs, subject, num_tick))


for space in [2, 4, 8]:
curriculum.append(
TaskSpec(
Expand All @@ -74,6 +78,7 @@ def PracticeInventoryManagement(gs, subject, space, num_tick):
# Import the custom curriculum
print("------------------------------------------------------------")
import curriculum_tutorial # which is this file

CURRICULUM = curriculum_tutorial.curriculum
print("The number of training tasks in the curriculum:", len(CURRICULUM))

Expand All @@ -96,6 +101,7 @@ def PracticeInventoryManagement(gs, subject, space, num_tick):
CURRICULUM_FILE_PATH = "custom_curriculum_with_embedding.pkl"
with open(CURRICULUM_FILE_PATH, "wb") as f:
import dill

dill.dump(CURRICULUM, f)
print("All training tasks are picklable.")

Expand All @@ -105,6 +111,7 @@ def PracticeInventoryManagement(gs, subject, space, num_tick):
print("------------------------------------------------------------")
print("Generating the task spec with embedding file ...")
from task_encoder import TaskEncoder

LLM_CHECKPOINT = "Salesforce/codegen25-7b-instruct"

# Get the task embeddings for the training tasks and save to file
Expand All @@ -117,8 +124,11 @@ def PracticeInventoryManagement(gs, subject, space, num_tick):
# These lines are the same as the RL track. If these don't run, please see train.py
from reinforcement_learning import config
from train import setup_env

args = config.create_config(config.Config)
args.tasks_path = CURRICULUM_FILE_PATH # This is the curriculum file saved by the task encoder
args.tasks_path = (
CURRICULUM_FILE_PATH # This is the curriculum file saved by the task encoder
)

# Remove below lines if you want to use the default training config
local_mode = True
Expand Down Expand Up @@ -149,12 +159,16 @@ def PracticeInventoryManagement(gs, subject, space, num_tick):
reward_signal_count = []
for sub_list in infos[key]:
for prog, rcnt in sub_list:
completed.append(int(prog>=1)) # progress >= 1 is considered task complete
completed.append(
int(prog >= 1)
) # progress >= 1 is considered task complete
max_progress.append(prog)
reward_signal_count.append(rcnt)
print(f"{key} -- task tried: {len(completed)}, completed: {sum(completed)}, " +
f"avg max progress: {sum(max_progress)/len(max_progress):.3f}, " +
f"avg reward signal count: {sum(reward_signal_count)/len(reward_signal_count):.3f}")
print(
f"{key} -- task tried: {len(completed)}, completed: {sum(completed)}, "
+ f"avg max progress: {sum(max_progress)/len(max_progress):.3f}, "
+ f"avg reward signal count: {sum(reward_signal_count)/len(reward_signal_count):.3f}"
)

print("------------------------------------------------------------")
print("The tutorial is done.")
Expand Down
117 changes: 59 additions & 58 deletions curriculum_generation/elm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@ def extract_task_fn(result_str, fn_name):
split = result_str.split("\n")
fn_str = []
for line in split[::-1]:
if line.startswith(f"def {fn_name}("):
if line.startswith(f"def {fn_name}("):
fn_str.append(line)
break
fn_str.append(line)
break
fn_str.append(line)
return "\n".join(fn_str[::-1])


def sample_parameter(key, type_hint):
"""
Generates sample parameter values based on the key and type_hint provided.
Expand Down Expand Up @@ -225,57 +226,58 @@ def run_env():
# Return True if at least one task ran successfully.
return num_success > 0


def generate_task_spec(result_str, fn_name, num_sample=3):
"""
Generates a list of TaskSpec objects from the task function string provided during the class instantiation.
Each TaskSpec is an instantiation of the task function with sampled parameters.
Args:
program_str: The string representation of the task function.
fn_name: The name of the task function.
num_sample: The number of TaskSpecs to generate. Defaults to None, which will generate a TaskSpec for each valid
function parameter set.
Returns:
A list of valid TaskSpec objects. If the task function string is invalid or no valid TaskSpecs can be generated,
an empty list is returned.
"""
task_spec = []
task_fn_str = extract_task_fn(result_str, fn_name)
import_str = (
"from nmmo.task.game_state import GameState\n"
+ "from nmmo.task.group import Group\n"
+ "from nmmo.task.base_predicates import *\n\n"
)

locals_dict = {}
try:
# NOTE: this is a security vulenerability
# TODO: make this secure
exec(import_str + task_fn_str, globals(), locals_dict)
except:
# return empty task spec for invalid function
print("Invalid python function generated ...")
"""
Generates a list of TaskSpec objects from the task function string provided during the class instantiation.
Each TaskSpec is an instantiation of the task function with sampled parameters.
Args:
program_str: The string representation of the task function.
fn_name: The name of the task function.
num_sample: The number of TaskSpecs to generate. Defaults to None, which will generate a TaskSpec for each valid
function parameter set.
Returns:
A list of valid TaskSpec objects. If the task function string is invalid or no valid TaskSpecs can be generated,
an empty list is returned.
"""
task_spec = []
task_fn_str = extract_task_fn(result_str, fn_name)
import_str = (
"from nmmo.task.game_state import GameState\n"
+ "from nmmo.task.group import Group\n"
+ "from nmmo.task.base_predicates import *\n\n"
)

locals_dict = {}
try:
# NOTE: this is a security vulenerability
# TODO: make this secure
exec(import_str + task_fn_str, globals(), locals_dict)
except:
# return empty task spec for invalid function
print("Invalid python function generated ...")
return task_spec
task_fn = locals_dict[fn_name]
fn_params = inspect.signature(task_fn).parameters

included_kwargs = set()
for _ in range(num_sample):
task_fn_kwargs = {}
for key, param in fn_params.items():
if key in ["gs", "subject"]:
continue
type_hint = param.annotation.__name__
task_fn_kwargs[key] = sample_parameter(key, type_hint)
args_vals = tuple(task_fn_kwargs.values())
if args_vals not in included_kwargs:
task_spec.append(
ts.TaskSpec(eval_fn=task_fn, eval_fn_kwargs=task_fn_kwargs)
)
included_kwargs.add(args_vals)

return task_spec
task_fn = locals_dict[fn_name]
fn_params = inspect.signature(task_fn).parameters

included_kwargs = set()
for _ in range(num_sample):
task_fn_kwargs = {}
for key, param in fn_params.items():
if key in ["gs", "subject"]:
continue
type_hint = param.annotation.__name__
task_fn_kwargs[key] = sample_parameter(key, type_hint)
args_vals = tuple(task_fn_kwargs.values())
if args_vals not in included_kwargs:
task_spec.append(
ts.TaskSpec(eval_fn=task_fn, eval_fn_kwargs=task_fn_kwargs)
)
included_kwargs.add(args_vals)

return task_spec


class NMMOTaskFn(Genotype):
Expand All @@ -293,9 +295,7 @@ def __init__(self, program_str: str, fn_name: str, module: ModuleType):
self._fitness = -np.inf
self._fn_name = fn_name
self.program_str = extract_task_fn(program_str, self._fn_name)
self.valid = is_task_spec_valid(
generate_task_spec(program_str, self._fn_name)
)
self.valid = is_task_spec_valid(generate_task_spec(program_str, self._fn_name))

self.PREBUILT_TASK_FN = {
name: fn
Expand Down Expand Up @@ -556,13 +556,14 @@ def fitness(self, x: NMMOTaskFn) -> float:
return -np.inf

def get_rng_state(self) -> Optional[np.random._generator.Generator]:
#warnings.warn("WARNING: rng state not used in this environment")
# warnings.warn("WARNING: rng state not used in this environment")
return None

def set_rng_state(self, rng_state: Optional[np.random._generator.Generator]):
#warnings.warn("WARNING: rng state not used in this environment")
# warnings.warn("WARNING: rng state not used in this environment")
pass


def entropy(task):
"""
Computes the entropy of the task string by counting the frequency of each word.
Expand Down Expand Up @@ -595,4 +596,4 @@ def calculate_length(task):
* (dmax - dmin)
+ dmin
)
return math.ceil(scale(len(task), 100, 9000, 0, 10))
return math.ceil(scale(len(task), 100, 9000, 0, 10))
Loading

0 comments on commit 50684f6

Please sign in to comment.