-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathimputation_results.py
More file actions
470 lines (417 loc) · 16.3 KB
/
imputation_results.py
File metadata and controls
470 lines (417 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
"""
Result containers for the Imputation DiD estimator.
This module contains ImputationBootstrapResults and ImputationDiDResults
dataclasses. Extracted from imputation.py for module size management.
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from diff_diff.results import _format_survey_block, _get_significance_stars
__all__ = [
"ImputationBootstrapResults",
"ImputationDiDResults",
]
@dataclass
class ImputationBootstrapResults:
"""
Results from ImputationDiD bootstrap inference.
Bootstrap is a library extension beyond Borusyak et al. (2024), which
proposes only analytical inference via the conservative variance estimator.
Provided for consistency with CallawaySantAnna and SunAbraham.
Attributes
----------
n_bootstrap : int
Number of bootstrap iterations.
weight_type : str
Type of bootstrap weights: "rademacher", "mammen", or "webb".
alpha : float
Significance level used for confidence intervals.
overall_att_se : float
Bootstrap standard error for overall ATT.
overall_att_ci : tuple
Bootstrap confidence interval for overall ATT.
overall_att_p_value : float
Bootstrap p-value for overall ATT.
event_study_ses : dict, optional
Bootstrap SEs for event study effects.
event_study_cis : dict, optional
Bootstrap CIs for event study effects.
event_study_p_values : dict, optional
Bootstrap p-values for event study effects.
group_ses : dict, optional
Bootstrap SEs for group effects.
group_cis : dict, optional
Bootstrap CIs for group effects.
group_p_values : dict, optional
Bootstrap p-values for group effects.
bootstrap_distribution : np.ndarray, optional
Full bootstrap distribution of overall ATT.
"""
n_bootstrap: int
weight_type: str
alpha: float
overall_att_se: float
overall_att_ci: Tuple[float, float]
overall_att_p_value: float
event_study_ses: Optional[Dict[int, float]] = None
event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None
event_study_p_values: Optional[Dict[int, float]] = None
group_ses: Optional[Dict[Any, float]] = None
group_cis: Optional[Dict[Any, Tuple[float, float]]] = None
group_p_values: Optional[Dict[Any, float]] = None
bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
@dataclass
class ImputationDiDResults:
"""
Results from Borusyak-Jaravel-Spiess (2024) imputation DiD estimation.
Attributes
----------
treatment_effects : pd.DataFrame
Unit-level treatment effects with columns: unit, time, tau_hat, weight.
overall_att : float
Overall average treatment effect on the treated.
overall_se : float
Standard error of overall ATT.
overall_t_stat : float
T-statistic for overall ATT.
overall_p_value : float
P-value for overall ATT.
overall_conf_int : tuple
Confidence interval for overall ATT.
event_study_effects : dict, optional
Dictionary mapping relative time h to effect dict with keys:
'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'.
group_effects : dict, optional
Dictionary mapping cohort g to effect dict.
groups : list
List of treatment cohorts.
time_periods : list
List of all time periods.
n_obs : int
Total number of observations.
n_treated_obs : int
Number of treated observations (:math:`|\\Omega_1|`).
n_untreated_obs : int
Number of untreated observations (:math:`|\\Omega_0|`).
n_treated_units : int
Number of ever-treated units.
n_control_units : int
Number of units contributing to Omega_0.
alpha : float
Significance level used.
pretrend_results : dict, optional
Populated by pretrend_test().
bootstrap_results : ImputationBootstrapResults, optional
Bootstrap inference results.
"""
treatment_effects: pd.DataFrame
overall_att: float
overall_se: float
overall_t_stat: float
overall_p_value: float
overall_conf_int: Tuple[float, float]
event_study_effects: Optional[Dict[int, Dict[str, Any]]]
group_effects: Optional[Dict[Any, Dict[str, Any]]]
groups: List[Any]
time_periods: List[Any]
n_obs: int
n_treated_obs: int
n_untreated_obs: int
n_treated_units: int
n_control_units: int
alpha: float = 0.05
anticipation: int = 0
pretrend_results: Optional[Dict[str, Any]] = field(default=None, repr=False)
bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False)
# Internal: stores data needed for pretrend_test()
_estimator_ref: Optional[Any] = field(default=None, repr=False)
# Survey design metadata (SurveyMetadata instance from diff_diff.survey)
survey_metadata: Optional[Any] = field(default=None, repr=False)
# --- Inference-field aliases (balance/external-adapter compatibility) ---
@property
def att(self) -> float:
return self.overall_att
@property
def se(self) -> float:
return self.overall_se
@property
def conf_int(self) -> Tuple[float, float]:
return self.overall_conf_int
@property
def p_value(self) -> float:
return self.overall_p_value
@property
def t_stat(self) -> float:
return self.overall_t_stat
def __repr__(self) -> str:
"""Concise string representation."""
sig = _get_significance_stars(self.overall_p_value)
return (
f"ImputationDiDResults(ATT={self.overall_att:.4f}{sig}, "
f"SE={self.overall_se:.4f}, "
f"n_groups={len(self.groups)}, "
f"n_treated_obs={self.n_treated_obs})"
)
@property
def coef_var(self) -> float:
"""Coefficient of variation: SE / abs(overall ATT). NaN when ATT is 0 or SE non-finite."""
if not (np.isfinite(self.overall_se) and self.overall_se >= 0):
return np.nan
if not np.isfinite(self.overall_att) or self.overall_att == 0:
return np.nan
return self.overall_se / abs(self.overall_att)
def summary(self, alpha: Optional[float] = None) -> str:
"""
Generate formatted summary of estimation results.
Parameters
----------
alpha : float, optional
Significance level. Defaults to alpha used in estimation.
Returns
-------
str
Formatted summary.
"""
alpha = alpha or self.alpha
conf_level = int((1 - alpha) * 100)
lines = [
"=" * 85,
"Imputation DiD Estimator Results (Borusyak et al. 2024)".center(85),
"=" * 85,
"",
f"{'Total observations:':<30} {self.n_obs:>10}",
f"{'Treated observations:':<30} {self.n_treated_obs:>10}",
f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}",
f"{'Treated units:':<30} {self.n_treated_units:>10}",
f"{'Control units:':<30} {self.n_control_units:>10}",
f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
f"{'Time periods:':<30} {len(self.time_periods):>10}",
"",
]
# Survey design info
if self.survey_metadata is not None:
sm = self.survey_metadata
lines.extend(_format_survey_block(sm, 85))
# Overall ATT
lines.extend(
[
"-" * 85,
"Overall Average Treatment Effect on the Treated".center(85),
"-" * 85,
f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} "
f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
"-" * 85,
]
)
t_str = (
f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}"
)
p_str = (
f"{self.overall_p_value:>10.4f}"
if np.isfinite(self.overall_p_value)
else f"{'NaN':>10}"
)
sig = _get_significance_stars(self.overall_p_value)
lines.extend(
[
f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
f"{t_str} {p_str} {sig:>6}",
"-" * 85,
"",
f"{conf_level}% Confidence Interval: "
f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
]
)
cv = self.coef_var
if np.isfinite(cv):
lines.append(f"{'CV (SE/abs(ATT)):':<25} {cv:>10.4f}")
lines.append("")
# Event study effects
if self.event_study_effects:
lines.extend(
[
"-" * 85,
"Event Study (Dynamic) Effects".center(85),
"-" * 85,
f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} "
f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
"-" * 85,
]
)
for h in sorted(self.event_study_effects.keys()):
eff = self.event_study_effects[h]
if eff.get("n_obs", 1) == 0:
# Reference period marker
lines.append(
f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}"
)
elif np.isnan(eff["effect"]):
lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}")
else:
e_sig = _get_significance_stars(eff["p_value"])
e_t = (
f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}"
)
e_p = (
f"{eff['p_value']:>10.4f}"
if np.isfinite(eff["p_value"])
else f"{'NaN':>10}"
)
lines.append(
f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
f"{e_t} {e_p} {e_sig:>6}"
)
lines.extend(["-" * 85, ""])
# Group effects
if self.group_effects:
lines.extend(
[
"-" * 85,
"Group (Cohort) Effects".center(85),
"-" * 85,
f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} "
f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
"-" * 85,
]
)
for g in sorted(self.group_effects.keys()):
eff = self.group_effects[g]
if np.isnan(eff["effect"]):
lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}")
else:
g_sig = _get_significance_stars(eff["p_value"])
g_t = (
f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}"
)
g_p = (
f"{eff['p_value']:>10.4f}"
if np.isfinite(eff["p_value"])
else f"{'NaN':>10}"
)
lines.append(
f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
f"{g_t} {g_p} {g_sig:>6}"
)
lines.extend(["-" * 85, ""])
# Pre-trend test
if self.pretrend_results is not None:
pt = self.pretrend_results
lines.extend(
[
"-" * 85,
"Pre-Trend Test (Equation 9)".center(85),
"-" * 85,
f"{'F-statistic:':<30} {pt['f_stat']:>10.3f}",
f"{'P-value:':<30} {pt['p_value']:>10.4f}",
f"{'Degrees of freedom:':<30} {pt['df']:>10}",
f"{'Number of leads:':<30} {pt['n_leads']:>10}",
"-" * 85,
"",
]
)
lines.extend(
[
"Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
"=" * 85,
]
)
return "\n".join(lines)
def print_summary(self, alpha: Optional[float] = None) -> None:
"""Print summary to stdout."""
print(self.summary(alpha))
def to_dataframe(self, level: str = "observation") -> pd.DataFrame:
"""
Convert results to DataFrame.
Parameters
----------
level : str, default="observation"
Level of aggregation:
- "observation": Unit-level treatment effects
- "event_study": Event study effects by relative time
- "group": Group (cohort) effects
Returns
-------
pd.DataFrame
Results as DataFrame.
"""
if level == "observation":
return self.treatment_effects.copy()
elif level == "event_study":
if self.event_study_effects is None:
raise ValueError(
"Event study effects not computed. "
"Use aggregate='event_study' or aggregate='all'."
)
rows = []
for h, data in sorted(self.event_study_effects.items()):
rows.append(
{
"relative_period": h,
"effect": data["effect"],
"se": data["se"],
"t_stat": data["t_stat"],
"p_value": data["p_value"],
"conf_int_lower": data["conf_int"][0],
"conf_int_upper": data["conf_int"][1],
"n_obs": data.get("n_obs", np.nan),
}
)
return pd.DataFrame(rows)
elif level == "group":
if self.group_effects is None:
raise ValueError(
"Group effects not computed. " "Use aggregate='group' or aggregate='all'."
)
rows = []
for g, data in sorted(self.group_effects.items()):
rows.append(
{
"group": g,
"effect": data["effect"],
"se": data["se"],
"t_stat": data["t_stat"],
"p_value": data["p_value"],
"conf_int_lower": data["conf_int"][0],
"conf_int_upper": data["conf_int"][1],
"n_obs": data.get("n_obs", np.nan),
}
)
return pd.DataFrame(rows)
else:
raise ValueError(
f"Unknown level: {level}. Use 'observation', 'event_study', or 'group'."
)
def pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]:
"""
Run a pre-trend test (Equation 9 of Borusyak et al. 2024).
Adds pre-treatment lead indicators to the Step 1 OLS and tests
their joint significance via a Wald F-test (cluster-robust, or
design-based survey VCV when survey_design was provided at fit).
Parameters
----------
n_leads : int, optional
Number of pre-treatment leads to include. If None, uses all
available pre-treatment periods minus one (for the reference period).
Returns
-------
dict
Dictionary with keys: 'f_stat', 'p_value', 'df', 'n_leads',
'lead_coefficients'.
"""
if self._estimator_ref is None:
raise RuntimeError(
"Pre-trend test requires internal estimator reference. "
"Re-fit the model to use this method."
)
result = self._estimator_ref._pretrend_test(n_leads=n_leads)
self.pretrend_results = result
return result
@property
def is_significant(self) -> bool:
"""Check if overall ATT is significant."""
return bool(self.overall_p_value < self.alpha)
@property
def significance_stars(self) -> str:
"""Significance stars for overall ATT."""
return _get_significance_stars(self.overall_p_value)