Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,4 +308,4 @@ If you want to cite CSDID, you can use the following BibTeX entry:
year = {2024},
url = {https://github.com/d2cml-ai/csdid}
}
```
```
6 changes: 3 additions & 3 deletions csdid/aggte_fnc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def AGGTEobj(overall_att=None,
out1 = np.column_stack((out["overall_att"], out["overall_se"], overall_cband_lower, overall_cband_upper))
out1 = np.round(out1, 4)
overall_sig = (overall_cband_upper < 0) | (overall_cband_lower > 0)
overall_sig[np.isnan(overall_sig)] = False
overall_sig_text = np.where(overall_sig, "*", "")
overall_sig = np.asarray(overall_sig)
overall_sig = np.where(np.isnan(overall_sig), False, overall_sig)
overall_sig_text = np.atleast_1d(np.where(overall_sig, "*", ""))
out1 = np.column_stack((out1, overall_sig_text))

print("\n")
Expand Down Expand Up @@ -193,4 +194,3 @@ def AGGTEobj(overall_att=None,

return out


7 changes: 4 additions & 3 deletions csdid/att_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def __init__(self, yname, tname, idname, gname, data, control_group = ['nevertre
cband = False, biters = 1000, alp = 0.05
):
dp = pre_process_did(
yname=yname, tname = tname, idname=idname, gname = gname,
data = data, control_group=control_group, anticipation=anticipation,
xformla=xformla, panel=panel, allow_unbalanced_panel=allow_unbalanced_panel, cband=cband, clustervar=None, weights_name=None
yname=yname, tname=tname, idname=idname, gname=gname,
data=data, control_group=control_group, anticipation=anticipation,
xformla=xformla, panel=panel, allow_unbalanced_panel=allow_unbalanced_panel,
cband=cband, clustervar=clustervar, weights_name=weights_name
)

dp['biters'] = biters
Expand Down
70 changes: 35 additions & 35 deletions csdid/attgt_fnc/compute_att_gt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np, pandas as pd
import patsy
from drdid import drdid, reg_did, ipwd_did
from drdid import reg_did
from csdid.attgt_fnc import drdid_trim

from csdid.utils.bmisc import panel2cs2
import warnings
Expand Down Expand Up @@ -37,6 +38,26 @@ def compute_att_gt(dp, est_method = "dr", base_period = 'varying'):

att_est, group, year, post_array = [], [], [], []

def build_covariates(formula, frame):
try:
_, cov = fml(formula, data=frame, return_type='dataframe')
except Exception as e:
try:
cov = patsy.dmatrix(formula, data=frame, return_type='dataframe')
except Exception as e2:
print(f"Warning: Formula processing failed: {e2}")
y_str, x_str = formula.split("~")
xs1 = x_str.split('+')
xs1_col_names = [x.strip() for x in xs1 if x.strip() != '1']
n_dis = len(frame)
ones = np.ones((n_dis, 1))
try:
cov = frame[xs1_col_names].to_numpy()
cov = np.append(cov, ones, axis=1)
except Exception:
cov = ones
return np.array(cov)

def add_att_data(att = 0, pst = 0, inf_f = []):
inf_func.append(inf_f)
att_est.append(att)
Expand Down Expand Up @@ -75,12 +96,6 @@ def add_att_data(att = 0, pst = 0, inf_f = []):
f"There are no pre-treatment periods for the group first treated at {g}. Units from this group are dropped."
)

# If we are in the universal base period
if base_period == 'universal' and pret == tn:
# Normalize results to zero and skip computation
add_att_data(att=0, pst=0, inf_f=np.zeros(len(data)))
continue

# For non-never treated groups, set up the control group indicator 'C'
if not never_treated:
# Units that are either never treated (gname == 0) or treated in the future
Expand Down Expand Up @@ -113,13 +128,12 @@ def add_att_data(att = 0, pst = 0, inf_f = []):

# -----------------------------------------------------------------------------
# Debugging and validation
if base_period == 'universal' and pret == tn:
# Normalize results to zero and break the loop
add_att_data(att=0, pst=post_treat, inf_f=np.zeros(len(data)))
continue

# Post-treatment dummy variable
pret_year = tlist[pret]
post_treat = 1 * (g <= tn)
if base_period == 'universal' and pret_year == tn:
add_att_data(att=0, pst=post_treat, inf_f=np.zeros(n))
continue

# Subset the data for the current and pretreatment periods
disdat = data[(data[tname] == tn) | (data[tname] == tlist[pret])]
Expand All @@ -139,9 +153,9 @@ def add_att_data(att = 0, pst = 0, inf_f = []):
C = disdat.C
w = disdat.w

ypre = disdat.y0 if tn > pret else disdat.y1
ypost = disdat.y0 if tn < pret else disdat.y1
_, covariates = fml(xformla, data = disdat, return_type = 'dataframe')
ypre = disdat.y0 if tn > pret_year else disdat.y1
ypost = disdat.y0 if tn < pret_year else disdat.y1
covariates = build_covariates(xformla, disdat)

G, C, w, ypre = map(np.array, [G, C, w, ypre])
ypost, covariates = map(np.array, [ypost, covariates])
Expand All @@ -151,9 +165,9 @@ def add_att_data(att = 0, pst = 0, inf_f = []):
elif est_method == "reg":
est_att_f = reg_did.reg_did_panel
elif est_method == "ipw":
est_att_f = ipwd_did.std_ipw_did_panel
est_att_f = drdid_trim.std_ipw_did_panel
elif est_method == "dr":
est_att_f = drdid.drdid_panel
est_att_f = drdid_trim.drdid_panel

att_gt, att_inf_func = est_att_f(ypost, ypre, G, i_weights=w, covariates=covariates)

Expand Down Expand Up @@ -224,21 +238,7 @@ def add_att_data(att = 0, pst = 0, inf_f = []):
continue

# return (inf_func)
try:
_, covariates = fml(xformla, data = disdat, return_type = 'dataframe')
covariates = np.array(covariates)
except Exception as e:
print(f"Warning: Formula processing failed: {e}")
y_str, x_str = xformla.split("~")
xs1 = x_str.split('+')
xs1_col_names = [x.strip() for x in xs1 if x.strip() != '1']
n_dis = len(disdat)
ones = np.ones((n_dis, 1))
try:
covariates = disdat[xs1_col_names].to_numpy()
covariates = np.append(covariates, ones, axis=1)
except:
covariates = ones
covariates = build_covariates(xformla, disdat)

#-----------------------------------------------------------------------------
# code for actually computing att(g,t)
Expand All @@ -250,9 +250,9 @@ def add_att_data(att = 0, pst = 0, inf_f = []):
elif est_method == "reg":
est_att_f = reg_did.reg_did_rc
elif est_method == "ipw":
est_att_f = ipwd_did.std_ipw_did_rc
est_att_f = drdid_trim.std_ipw_did_rc
elif est_method == "dr":
est_att_f = drdid.drdid_rc
est_att_f = drdid_trim.drdid_rc

att_gt, att_inf_func = est_att_f(y=Y, post=post, D = G, i_weights=w, covariates=covariates)
# print(att_inf_func)
Expand Down Expand Up @@ -504,4 +504,4 @@ def add_att_data(att = 0, pst = 0, inf_f = []):
# "att" : att_est,
# 'post ': post_array
# }
# return (output, np.array(inf_func))
# return (output, np.array(inf_func))
Loading
Loading