Skip to content

Commit 6e7df24

Browse files
committed
Weight features based on depth
1 parent 6582e5e commit 6e7df24

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

create_parameter_weights.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,17 @@ def main():
4949
0
5050
) # 1, 1, N_grid, d_features
5151

52+
# Create parameter weights based on depth
5253
w_list = np.ones(len(constants.EXP_PARAM_NAMES_SHORT))
54+
depth_weights = [(200 - depth) for depth in constants.DEPTHS]
55+
depth_weights = [w / sum(depth_weights) for w in depth_weights]
56+
w_dict = dict(zip([round(d) for d in constants.DEPTHS], depth_weights))
57+
for i, par in enumerate(constants.EXP_PARAM_NAMES_SHORT):
58+
if "_" in par:
59+
weight = w_dict[int(par.split("_")[-1])]
60+
else:
61+
weight = 1
62+
w_list[i] = weight
5363
print("Saving parameter weights...")
5464
np.save(
5565
os.path.join(static_dir_path, "parameter_weights.npy"),

0 commit comments

Comments
 (0)