-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathran_env_wrapper.py
More file actions
executable file
·133 lines (96 loc) · 4.21 KB
/
Copy pathran_env_wrapper.py
File metadata and controls
executable file
·133 lines (96 loc) · 4.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import importlib
import multiprocessing
import os
from typing import Any, Dict, Optional
import numpy as np
import pandas as pd
import tensorflow as tf
from tf_agents.environments import gym_wrapper
from tf_agents.environments import batched_py_environment
from tf_agents.environments import parallel_py_environment
from tf_agents.environments import tf_py_environment
from ran_env import RanEnv
global_bundle = None
_bundle_cache: Dict[str, dict] = {}
def _load_default_config():
module_name = os.environ.get("ADVORAN_CONFIG_MODULE", "config")
return importlib.import_module(module_name)
def _resolve_config(config_obj=None):
return config_obj if config_obj is not None else _load_default_config()
def _bundle_key(cfg) -> str:
dataset = str(getattr(cfg, "dataset_path", ""))
metrics = ",".join(getattr(cfg, "metric_list_autoencoder", []))
return f"{dataset}|{metrics}"
def prepare_data_bundle(config_obj=None):
"""Load CSV once and cache grouped features for fast env lookup."""
global global_bundle
cfg = _resolve_config(config_obj)
key = _bundle_key(cfg)
if key in _bundle_cache:
global_bundle = _bundle_cache[key]
return global_bundle
csv_path = cfg.dataset_path
print(f"Wrapper: Loading and Bundling CSV from {csv_path}...")
if not os.path.exists(csv_path):
raise FileNotFoundError(f"Dataset not found at {csv_path}")
df = pd.read_csv(csv_path)
feature_cache = {0: {}, 1: {}, 2: {}}
metric_cols = cfg.metric_list_autoencoder
context_cols = ["slice_prb_norm", "scheduling_policy_norm"]
if "reward" not in df.columns:
df["reward"] = 0.0
missing_context = [c for c in context_cols if c not in df.columns]
if missing_context:
context_cols = ["slice_prb", "scheduling_policy"]
final_feature_order = metric_cols + context_cols + ["reward"]
for s_id in range(3):
slice_df = df[df["slice_id"] == s_id]
grouped = slice_df.groupby(["slice_prb", "scheduling_policy"])
for (prb, sched), group in grouped:
feats = group[final_feature_order].values.astype(np.float32)
feature_cache[s_id][(prb, sched)] = feats
bundle = {
"data": feature_cache,
"num_metrics": len(metric_cols),
}
_bundle_cache[key] = bundle
global_bundle = bundle
print("Wrapper: Data Bundle prepared successfully.")
return bundle
def create_gym_env(config_obj=None, data_bundle=None):
"""Create one RanEnv instance."""
cfg = _resolve_config(config_obj)
bundle = data_bundle if data_bundle is not None else prepare_data_bundle(cfg)
return RanEnv(
data_bundle=bundle,
config_obj=cfg,
encoder_path=cfg.encoder_path if os.path.exists(cfg.encoder_path) else None,
max_steps=cfg.num_steps_per_episode,
n_samples_per_slice=10,
du_prb=cfg.du_prb,
)
def create_wrapped_env(config_obj=None, data_bundle=None):
return gym_wrapper.GymWrapper(create_gym_env(config_obj=config_obj, data_bundle=data_bundle))
def get_training_env(config_obj=None, num_parallel_override: Optional[int] = None):
cfg = _resolve_config(config_obj)
bundle = prepare_data_bundle(cfg)
max_workers = max(1, multiprocessing.cpu_count() - 2)
cfg_parallel = int(getattr(cfg, "num_parallel_environments", 1))
requested = cfg_parallel if num_parallel_override is None else int(num_parallel_override)
num_parallel = max(1, min(requested, max_workers))
print(f"Wrapper: Spawning {num_parallel} parallel environments...")
env_constructors = [
(lambda cfg=cfg, bundle=bundle: create_wrapped_env(config_obj=cfg, data_bundle=bundle))
for _ in range(num_parallel)
]
py_env = parallel_py_environment.ParallelPyEnvironment(env_constructors)
return tf_py_environment.TFPyEnvironment(py_env)
def get_eval_env(config_obj=None):
cfg = _resolve_config(config_obj)
bundle = prepare_data_bundle(cfg)
print("Wrapper: Creating Single Evaluation Environment...")
base_env = create_wrapped_env(config_obj=cfg, data_bundle=bundle)
batched_env = batched_py_environment.BatchedPyEnvironment(
[base_env], multithreading=False
)
return tf_py_environment.TFPyEnvironment(batched_env, isolation=False)