Skip to content

Commit

Permalink
google lint on tdc/utils
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Mar 5, 2024
1 parent 750face commit b2ebf3b
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 163 deletions.
66 changes: 36 additions & 30 deletions tdc/utils/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def convert_y_unit(y, from_, to_):
if from_ == "nM":
y = y
elif from_ == "p":
y = (10 ** (-y) - 1e-10) / 1e-9
y = (10**(-y) - 1e-10) / 1e-9

if to_ == "p":
y = -np.log10(y * 1e-9 + 1e-10)
Expand All @@ -31,9 +31,12 @@ def convert_y_unit(y, from_, to_):
return y


def label_transform(
y, binary, threshold, convert_to_log, verbose=True, order="descending"
):
def label_transform(y,
binary,
threshold,
convert_to_log,
verbose=True,
order="descending"):
"""label transformation helper function
Args:
Expand Down Expand Up @@ -62,7 +65,8 @@ def label_transform(
elif order == "ascending":
y = np.array([1 if i else 0 for i in np.array(y) > threshold])
else:
raise ValueError("Please select order from 'descending or ascending!")
raise ValueError(
"Please select order from 'descending or ascending!")
else:
if (len(np.unique(y)) > 2) and convert_to_log:
if verbose:
Expand Down Expand Up @@ -144,16 +148,16 @@ def label_dist(y, name=None):
median = np.median(y)
mean = np.mean(y)

f, (ax_box, ax_hist) = plt.subplots(
2, sharex=True, gridspec_kw={"height_ratios": (0.15, 1)}
)
f, (ax_box,
ax_hist) = plt.subplots(2,
sharex=True,
gridspec_kw={"height_ratios": (0.15, 1)})

if name is None:
sns.boxplot(y, ax=ax_box).set_title("Label Distribution")
else:
sns.boxplot(y, ax=ax_box).set_title(
"Label Distribution of " + str(name) + " Dataset"
)
sns.boxplot(y, ax=ax_box).set_title("Label Distribution of " +
str(name) + " Dataset")
ax_box.axvline(median, color="b", linestyle="--")
ax_box.axvline(mean, color="g", linestyle="--")

Expand Down Expand Up @@ -191,7 +195,8 @@ def NegSample(df, column_names, frac, two_types):
pos_set = set([tuple([i[0], i[1]]) for i in pos])
np.random.seed(1234)
samples = np.random.choice(df_unique, size=(x, 2), replace=True)
neg_set = set([tuple([i[0], i[1]]) for i in samples if i[0] != i[1]]) - pos_set
neg_set = set([tuple([i[0], i[1]]) for i in samples if i[0] != i[1]
]) - pos_set

while len(neg_set) < x:
sample = np.random.choice(df_unique, 2, replace=False)
Expand All @@ -208,10 +213,13 @@ def NegSample(df, column_names, frac, two_types):
neg_list_val.append([i[0], id2seq[i[0]], i[1], id2seq[i[1]], 0])

df = df.append(
pd.DataFrame(neg_list_val).rename(
columns={0: id1, 1: x1, 2: id2, 3: x2, 4: "Y"}
)
).reset_index(drop=True)
pd.DataFrame(neg_list_val).rename(columns={
0: id1,
1: x1,
2: id2,
3: x2,
4: "Y"
})).reset_index(drop=True)
return df
else:
df_unique_id1 = np.unique(df[id1].values.reshape(-1))
Expand All @@ -224,16 +232,11 @@ def NegSample(df, column_names, frac, two_types):
sample_id1 = np.random.choice(df_unique_id1, size=len(df), replace=True)
sample_id2 = np.random.choice(df_unique_id2, size=len(df), replace=True)

neg_set = (
set(
[
tuple([sample_id1[i], sample_id2[i]])
for i in range(len(df))
if sample_id1[i] != sample_id2[i]
]
)
- pos_set
)
neg_set = (set([
tuple([sample_id1[i], sample_id2[i]])
for i in range(len(df))
if sample_id1[i] != sample_id2[i]
]) - pos_set)

while len(neg_set) < len(df):
sample_id1 = np.random.choice(df_unique_id1, size=1, replace=True)
Expand All @@ -252,8 +255,11 @@ def NegSample(df, column_names, frac, two_types):
neg_list_val.append([i[0], id2seq1[i[0]], i[1], id2seq2[i[1]], 0])

df = df.append(
pd.DataFrame(neg_list_val).rename(
columns={0: id1, 1: x1, 2: id2, 3: x2, 4: "Y"}
)
).reset_index(drop=True)
pd.DataFrame(neg_list_val).rename(columns={
0: id1,
1: x1,
2: id2,
3: x2,
4: "Y"
})).reset_index(drop=True)
return df
3 changes: 0 additions & 3 deletions tdc/utils/label_name_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,9 @@
"Tanguay_ZF_120hpf_YSE_up",
]


QM7_targets = ["Y"]
# QM7_targets = ["E_PBE0", "E_max_EINDO", "I_max_ZINDO", "HOMO_ZINDO", "LUMO_ZINDO", "E_1st_ZINDO", "IP_ZINDO", "EA_ZINDO", "HOMO_PBE0", "LUMO_PBE0", "HOMO_GW", "LUMO_GW", "alpha_PBE0", "alpha_SCS"]


#### qm7b: 14 labels
QM7b_targets = [
"AE_PBE0",
Expand Down Expand Up @@ -683,7 +681,6 @@
"f1-CAM",
]


# QM9_targets = [
# "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "cv", "u0", "u298",
# "h298", "g298"
Expand Down
Loading

0 comments on commit b2ebf3b

Please sign in to comment.