-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathstaggered.py
More file actions
3963 lines (3477 loc) · 165 KB
/
staggered.py
File metadata and controls
3963 lines (3477 loc) · 165 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Staggered Difference-in-Differences estimators.
Implements modern methods for DiD with variation in treatment timing,
including the Callaway-Sant'Anna (2021) estimator.
"""
import warnings
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from scipy import linalg as scipy_linalg
from diff_diff.linalg import (
_check_propensity_diagnostics,
_detect_rank_deficiency,
_format_dropped_columns,
solve_logit,
solve_ols,
)
from diff_diff.staggered_aggregation import (
CallawaySantAnnaAggregationMixin,
)
from diff_diff.staggered_bootstrap import (
CallawaySantAnnaBootstrapMixin,
CSBootstrapResults,
)
# Import from split modules
from diff_diff.staggered_results import (
CallawaySantAnnaResults,
GroupTimeEffect,
)
from diff_diff.utils import safe_inference, safe_inference_batch
# Re-export for backward compatibility
__all__ = [
"CallawaySantAnna",
"CallawaySantAnnaResults",
"CSBootstrapResults",
"GroupTimeEffect",
]
# Type alias for pre-computed structures
PrecomputedData = Dict[str, Any]
def _linear_regression(
X: np.ndarray,
y: np.ndarray,
rank_deficient_action: str = "warn",
weights: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Fit OLS regression.
Parameters
----------
X : np.ndarray
Feature matrix (n_samples, n_features). Intercept added automatically.
y : np.ndarray
Outcome variable.
rank_deficient_action : str, default "warn"
Action when design matrix is rank-deficient:
- "warn": Issue warning and drop linearly dependent columns (default)
- "error": Raise ValueError
- "silent": Drop columns silently without warning
weights : np.ndarray, optional
Observation weights for WLS. When None, OLS is used.
Returns
-------
beta : np.ndarray
Fitted coefficients (including intercept).
residuals : np.ndarray
Residuals from the fit.
"""
n = X.shape[0]
# Add intercept
X_with_intercept = np.column_stack([np.ones(n), X])
# Use unified OLS backend (no vcov needed)
beta, residuals, _ = solve_ols(
X_with_intercept,
y,
return_vcov=False,
rank_deficient_action=rank_deficient_action,
weights=weights,
)
return beta, residuals
def _safe_inv(
A: np.ndarray,
tracker: Optional[list] = None,
) -> np.ndarray:
"""Invert a square matrix with lstsq fallback for near-singular cases.
Parameters
----------
A : np.ndarray
Square matrix to invert.
tracker : list, optional
When provided, one condition-number sample of ``A`` is appended on
every LinAlgError fallback. ``CallawaySantAnna.fit()`` initializes
a list and emits a single aggregate `UserWarning` after the fit
finishes, rather than surfacing a separate warning per fallback.
Sibling of finding #17 in the Phase 2 silent-failures audit.
"""
try:
return np.linalg.solve(A, np.eye(A.shape[0]))
except np.linalg.LinAlgError:
if tracker is not None:
with np.errstate(invalid="ignore", over="ignore"):
tracker.append(float(np.linalg.cond(A)))
return np.linalg.lstsq(A, np.eye(A.shape[0]), rcond=None)[0]
class CallawaySantAnna(
CallawaySantAnnaBootstrapMixin,
CallawaySantAnnaAggregationMixin,
):
"""
Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences.
This estimator handles DiD designs with variation in treatment timing
(staggered adoption) and heterogeneous treatment effects. It avoids the
bias of traditional two-way fixed effects (TWFE) estimators by:
1. Computing group-time average treatment effects ATT(g,t) for each
cohort g (units first treated in period g) and time t.
2. Aggregating these to summary measures (overall ATT, event study, etc.)
using appropriate weights.
Parameters
----------
control_group : str, default="never_treated"
Which units to use as controls:
- "never_treated": Use only never-treated units (recommended)
- "not_yet_treated": Use never-treated and not-yet-treated units
anticipation : int, default=0
Number of periods before treatment where effects may occur.
Set to > 0 if treatment effects can begin before the official
treatment date.
estimation_method : str, default="dr"
Estimation method:
- "dr": Doubly robust (recommended)
- "ipw": Inverse probability weighting
- "reg": Outcome regression
alpha : float, default=0.05
Significance level for confidence intervals.
cluster : str, optional
Column name for cluster-robust standard errors.
Defaults to unit-level clustering.
n_bootstrap : int, default=0
Number of bootstrap iterations for inference.
If 0, uses analytical standard errors.
Recommended: 999 or more for reliable inference.
.. note:: Memory Usage
The bootstrap stores all weights in memory as a (n_bootstrap, n_units)
float64 array. For large datasets, this can be significant:
- 1K bootstrap × 10K units = ~80 MB
- 10K bootstrap × 100K units = ~8 GB
Consider reducing n_bootstrap if memory is constrained.
bootstrap_weights : str, default="rademacher"
Type of weights for multiplier bootstrap:
- "rademacher": +1/-1 with equal probability (standard choice)
- "mammen": Two-point distribution (asymptotically valid, matches skewness)
- "webb": Six-point distribution (recommended when n_clusters < 20)
seed : int, optional
Random seed for reproducibility.
rank_deficient_action : str, default="warn"
Action when design matrix is rank-deficient (linearly dependent columns):
- "warn": Issue warning and drop linearly dependent columns (default)
- "error": Raise ValueError
- "silent": Drop columns silently without warning
base_period : str, default="varying"
Method for selecting the base (reference) period for computing
ATT(g,t). Options:
- "varying": For pre-treatment periods (t < g - anticipation), use
t-1 as base (consecutive comparisons). For post-treatment, use
g-1-anticipation. Requires t-1 to exist in data.
- "universal": Always use g-1-anticipation as base period.
Both produce identical post-treatment effects. Matches R's
did::att_gt() base_period parameter.
cband : bool, default=True
Whether to compute simultaneous confidence bands (sup-t) for
event study aggregation. Requires ``n_bootstrap > 0``.
When True, results include ``cband_crit_value`` and per-event-time
``cband_conf_int`` entries controlling family-wise error rate.
pscore_trim : float, default=0.01
Trimming bound for propensity scores. Scores are clipped to
``[pscore_trim, 1 - pscore_trim]`` before weight computation
in IPW and DR estimation. Must be in ``(0, 0.5)``.
panel : bool, default=True
Whether the data is a balanced/unbalanced panel (units observed
across multiple time periods). Set to ``False`` for stationary
repeated cross-sections where each observation has a unique unit
ID and units do not repeat across periods. Requires that the
cross-sectional samples are drawn from the same population in
each period (stationarity). Uses cross-sectional DRDID
(Sant'Anna & Zhao 2020, Section 4) with per-observation influence
functions.
epv_threshold : float, default=10
Events Per Variable threshold for propensity score logit.
When the ratio of minority-class observations to predictor
variables (excluding intercept) falls below this value, a
warning is emitted (or ``ValueError`` raised if
``rank_deficient_action="error"``). Based on Peduzzi et al.
(1996). Only applies to IPW and DR estimation methods.
Use ``diagnose_propensity()`` for a pre-estimation check across
all cohorts.
pscore_fallback : str, default="error"
Action when propensity score estimation fails entirely
(``LinAlgError`` or ``ValueError`` from IRLS):
- "error": Raise the exception (default). Ensures the user is
aware of estimation failures.
- "unconditional": Fall back to unconditional propensity
with a warning. For IPW, this drops all covariates. For DR,
the propensity model becomes unconditional but outcome
regression still uses covariates.
When ``rank_deficient_action="error"``, errors are always
re-raised regardless of this setting.
Attributes
----------
results_ : CallawaySantAnnaResults
Estimation results after calling fit().
is_fitted_ : bool
Whether the model has been fitted.
Examples
--------
Basic usage:
>>> import pandas as pd
>>> from diff_diff import CallawaySantAnna
>>>
>>> # Panel data with staggered treatment
>>> # 'first_treat' = period when unit was first treated (0 if never treated)
>>> data = pd.DataFrame({
... 'unit': [...],
... 'time': [...],
... 'outcome': [...],
... 'first_treat': [...] # 0 for never-treated, else first treatment period
... })
>>>
>>> cs = CallawaySantAnna()
>>> results = cs.fit(data, outcome='outcome', unit='unit',
... time='time', first_treat='first_treat')
>>>
>>> results.print_summary()
With event study aggregation:
>>> cs = CallawaySantAnna()
>>> results = cs.fit(data, outcome='outcome', unit='unit',
... time='time', first_treat='first_treat',
... aggregate='event_study')
>>>
>>> # Plot event study
>>> from diff_diff import plot_event_study
>>> plot_event_study(results)
With covariate adjustment (conditional parallel trends):
>>> # When parallel trends only holds conditional on covariates
>>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust
>>> results = cs.fit(data, outcome='outcome', unit='unit',
... time='time', first_treat='first_treat',
... covariates=['age', 'income'])
>>>
>>> # DR is recommended: consistent if either outcome model
>>> # or propensity model is correctly specified
Notes
-----
The key innovation of Callaway & Sant'Anna (2021) is the disaggregated
approach: instead of estimating a single treatment effect, they estimate
ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison"
problem where already-treated units act as controls.
The ATT(g,t) is identified under parallel trends conditional on covariates:
E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1]
where G=g indicates treatment cohort g and C=1 indicates control units.
This uses g-1 as the base period, which applies to post-treatment (t >= g).
With base_period="varying" (default), pre-treatment uses t-1 as base for
consecutive comparisons useful in parallel trends diagnostics.
References
----------
Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with
multiple time periods. Journal of Econometrics, 225(2), 200-230.
"""
def __init__(
self,
control_group: str = "never_treated",
anticipation: int = 0,
estimation_method: str = "dr",
alpha: float = 0.05,
cluster: Optional[str] = None,
n_bootstrap: int = 0,
bootstrap_weights: Optional[str] = None,
seed: Optional[int] = None,
rank_deficient_action: str = "warn",
base_period: str = "varying",
cband: bool = True,
pscore_trim: float = 0.01,
panel: bool = True,
epv_threshold: float = 10,
pscore_fallback: str = "error",
):
import warnings
if control_group not in ["never_treated", "not_yet_treated"]:
raise ValueError(
f"control_group must be 'never_treated' or 'not_yet_treated', "
f"got '{control_group}'"
)
if estimation_method not in ["dr", "ipw", "reg"]:
raise ValueError(
f"estimation_method must be 'dr', 'ipw', or 'reg', " f"got '{estimation_method}'"
)
if not (0 < pscore_trim < 0.5):
raise ValueError(f"pscore_trim must be in (0, 0.5), got {pscore_trim}")
if epv_threshold <= 0:
raise ValueError(f"epv_threshold must be > 0, got {epv_threshold}")
if pscore_fallback not in ["error", "unconditional"]:
raise ValueError(
f"pscore_fallback must be 'error' or 'unconditional', " f"got '{pscore_fallback}'"
)
# Default to rademacher if not specified
if bootstrap_weights is None:
bootstrap_weights = "rademacher"
if bootstrap_weights not in ["rademacher", "mammen", "webb"]:
raise ValueError(
f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
f"got '{bootstrap_weights}'"
)
if rank_deficient_action not in ["warn", "error", "silent"]:
raise ValueError(
f"rank_deficient_action must be 'warn', 'error', or 'silent', "
f"got '{rank_deficient_action}'"
)
if base_period not in ["varying", "universal"]:
raise ValueError(
f"base_period must be 'varying' or 'universal', " f"got '{base_period}'"
)
self.control_group = control_group
self.anticipation = anticipation
self.estimation_method = estimation_method
self.alpha = alpha
self.cluster = cluster
self.n_bootstrap = n_bootstrap
self.bootstrap_weights = bootstrap_weights
self.seed = seed
self.rank_deficient_action = rank_deficient_action
self.base_period = base_period
self.cband = cband
self.pscore_trim = pscore_trim
self.panel = panel
self.epv_threshold = epv_threshold
self.pscore_fallback = pscore_fallback
self.is_fitted_ = False
self.results_: Optional[CallawaySantAnnaResults] = None
def diagnose_propensity(
self,
df: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
covariates: Optional[List[str]] = None,
) -> pd.DataFrame:
"""
Check Events Per Variable (EPV) across all cohorts without estimation.
Examines the data to identify cohorts where propensity score logit may
be unreliable due to too few events per covariate. Based on Peduzzi
et al. (1996).
This is a raw-count heuristic: it uses total cohort/control unit
counts without filtering for missing outcomes, zero survey weights,
or period-specific validity. The actual fit-time EPV (stored in
``results.epv_diagnostics``) may be lower because ``fit()`` operates
on the valid base/post outcome pair and the positive-weight effective
sample. Use this method as a quick pre-check; rely on
``results.epv_diagnostics`` for authoritative per-cell EPV.
Parameters
----------
df, outcome, unit, time, first_treat, covariates
Same arguments as ``fit()``.
Returns
-------
pd.DataFrame
Per-cohort EPV diagnostics with columns: group, n_treated,
n_control, n_covariates, n_params, epv, status.
"""
if not self.panel:
raise NotImplementedError(
"diagnose_propensity() is not yet supported for repeated "
"cross-section data (panel=False). Use fit() with covariates "
"and check results.epv_diagnostics instead."
)
if self.control_group == "not_yet_treated":
raise NotImplementedError(
"diagnose_propensity() is not yet supported for "
"control_group='not_yet_treated' because the control set "
"varies per (g, t) cell. Use fit() with covariates and "
"check results.epv_diagnostics instead."
)
if self.estimation_method == "reg":
return pd.DataFrame(
columns=[
"group",
"n_treated",
"n_control",
"n_covariates",
"n_params",
"epv",
"status",
]
)
if not covariates:
return pd.DataFrame(
columns=[
"group",
"n_treated",
"n_control",
"n_covariates",
"n_params",
"epv",
"status",
]
)
# Normalize np.inf → 0 for never-treated encoding (same as fit())
df = df.copy()
_inf_mask_diag = df[first_treat].isin([np.inf, float("inf")])
if _inf_mask_diag.any():
n_inf_units = df.loc[_inf_mask_diag, unit].nunique()
warnings.warn(
f"{n_inf_units} unit(s) have first_treat=inf; recoding to 0 "
f"(never-treated). Use first_treat=0 to suppress this warning.",
UserWarning,
stacklevel=2,
)
df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0)
# Compute time_periods and treatment_groups (same logic as fit())
time_periods = sorted(df[time].unique())
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
precomputed = self._precompute_structures(
df,
outcome,
unit,
time,
first_treat,
covariates,
time_periods=time_periods,
treatment_groups=treatment_groups,
)
cohort_masks = precomputed["cohort_masks"]
never_treated_mask = precomputed["never_treated_mask"]
unit_cohorts = precomputed["unit_cohorts"]
n_covariates = len(covariates)
n_params = n_covariates # predictor count, excluding intercept (Peduzzi convention)
rows = []
for g in sorted(cohort_masks.keys()):
treated_mask = cohort_masks[g]
if self.control_group == "never_treated":
control_mask = never_treated_mask
else:
base_period_val = g - 1 - self.anticipation
nyt_threshold = base_period_val + self.anticipation
control_mask = never_treated_mask | (
(unit_cohorts > nyt_threshold) & (unit_cohorts != g)
)
n_treated = int(np.sum(treated_mask))
n_control = int(np.sum(control_mask))
n_events = min(n_treated, n_control)
epv = n_events / n_params if n_params > 0 else float("inf")
if epv >= self.epv_threshold:
status = "ok"
elif epv >= 2:
status = "low"
else:
status = "critical"
rows.append(
{
"group": g,
"n_treated": n_treated,
"n_control": n_control,
"n_covariates": n_covariates,
"n_params": n_params,
"epv": round(epv, 1),
"status": status,
}
)
return pd.DataFrame(rows)
@staticmethod
def _collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units):
"""Create unit-level ResolvedSurveyDesign for panel IF-based variance.
Survey design columns are constant within units (validated upstream).
This extracts one row per unit, aligned to ``all_units`` ordering.
"""
from diff_diff.survey import collapse_survey_to_unit_level
return collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units)
def _precompute_structures(
self,
df: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
covariates: Optional[List[str]],
time_periods: List[Any],
treatment_groups: List[Any],
resolved_survey=None,
) -> PrecomputedData:
"""
Pre-compute data structures for efficient ATT(g,t) computation.
This pivots data to wide format and pre-computes:
- Outcome matrix (units x time periods)
- Covariate matrix (units x covariates) from base period
- Unit cohort membership masks
- Control unit masks
Returns
-------
PrecomputedData
Dictionary with pre-computed structures.
"""
# Get unique units and their cohort assignments
unit_info = df.groupby(unit)[first_treat].first()
all_units = unit_info.index.values
unit_cohorts = unit_info.values
# Create unit index mapping for fast lookups
unit_to_idx = {u: i for i, u in enumerate(all_units)}
# Pivot outcome to wide format: rows = units, columns = time periods
outcome_wide = df.pivot(index=unit, columns=time, values=outcome)
# Reindex to ensure all units are present (handles unbalanced panels)
outcome_wide = outcome_wide.reindex(all_units)
outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods)
period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)}
# Pre-compute cohort masks (boolean arrays)
cohort_masks = {}
for g in treatment_groups:
cohort_masks[g] = unit_cohorts == g
# Never-treated mask
# np.inf was normalized to 0 in fit(), so the np.inf check is defensive only
never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
# Pre-compute covariate matrices by time period if needed
# (covariates are retrieved from the base period of each comparison)
covariate_by_period = None
if covariates:
covariate_by_period = {}
for t in time_periods:
period_data = df[df[time] == t].set_index(unit)
period_cov = period_data.reindex(all_units)[covariates]
covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates)
is_balanced = not np.any(np.isnan(outcome_matrix))
# Extract per-unit survey weights (one weight per unit)
if resolved_survey is not None:
sw_by_unit = (
pd.Series(resolved_survey.weights, index=df.index).groupby(df[unit]).first()
)
survey_weights_arr = sw_by_unit.reindex(all_units).values
else:
survey_weights_arr = None
resolved_survey_unit = (
self._collapse_survey_to_unit_level(resolved_survey, df, unit, all_units)
if resolved_survey is not None
else None
)
return {
"all_units": all_units,
"unit_to_idx": unit_to_idx,
"unit_cohorts": unit_cohorts,
"outcome_matrix": outcome_matrix,
"period_to_col": period_to_col,
"cohort_masks": cohort_masks,
"never_treated_mask": never_treated_mask,
"covariate_by_period": covariate_by_period,
"time_periods": time_periods,
"is_balanced": is_balanced,
"is_panel": True,
"canonical_size": len(all_units),
"survey_weights": survey_weights_arr,
"resolved_survey": resolved_survey,
"resolved_survey_unit": resolved_survey_unit,
"df_survey": (
resolved_survey_unit.df_survey if resolved_survey_unit is not None else None
),
}
def _compute_att_gt_fast(
self,
precomputed: PrecomputedData,
g: Any,
t: Any,
covariates: Optional[List[str]],
pscore_cache: Optional[Dict] = None,
cho_cache: Optional[Dict] = None,
epv_diagnostics: Optional[Dict] = None,
) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]:
"""
Compute ATT(g,t) using pre-computed data structures (fast version).
Uses vectorized numpy operations on pre-pivoted outcome matrix
instead of repeated pandas filtering.
Returns
-------
att_gt : float or None
se_gt : float
n_treated : int
n_control : int
inf_func_info : dict or None
survey_weight_sum : float or None
Sum of survey weights for treated units (for aggregation weighting).
"""
period_to_col = precomputed["period_to_col"]
outcome_matrix = precomputed["outcome_matrix"]
cohort_masks = precomputed["cohort_masks"]
never_treated_mask = precomputed["never_treated_mask"]
unit_cohorts = precomputed["unit_cohorts"]
covariate_by_period = precomputed["covariate_by_period"]
# Base period selection based on mode
if self.base_period == "universal":
# Universal: always use g - 1 - anticipation
base_period_val = g - 1 - self.anticipation
else: # varying
if t < g - self.anticipation:
# Pre-treatment: use t - 1 (consecutive comparison)
base_period_val = t - 1
else:
# Post-treatment: use g - 1 - anticipation
base_period_val = g - 1 - self.anticipation
if base_period_val not in period_to_col:
# Base period must exist; no fallback to maintain methodological consistency
return None, 0.0, 0, 0, None, None
# Check if periods exist in the data
if base_period_val not in period_to_col or t not in period_to_col:
return None, 0.0, 0, 0, None, None
base_col = period_to_col[base_period_val]
post_col = period_to_col[t]
# Get treated units mask (cohort g)
treated_mask = cohort_masks[g]
# Get control units mask
if self.control_group == "never_treated":
control_mask = never_treated_mask
else: # not_yet_treated
# Not yet treated at BOTH time t and the base period:
# Controls must be untreated at whichever is later, otherwise
# their outcome at the base period is contaminated by treatment.
nyt_threshold = max(t, base_period_val) + self.anticipation
control_mask = never_treated_mask | (
(unit_cohorts > nyt_threshold) & (unit_cohorts != g)
)
# Extract outcomes for base and post periods
y_base = outcome_matrix[:, base_col]
y_post = outcome_matrix[:, post_col]
# Compute outcome changes (vectorized)
outcome_change = y_post - y_base
# Filter to units with valid data (no NaN in either period)
valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
# Get treated and control with valid data
treated_valid = treated_mask & valid_mask
control_valid = control_mask & valid_mask
n_treated = np.sum(treated_valid)
n_control = np.sum(control_valid)
if n_treated == 0 or n_control == 0:
return None, 0.0, 0, 0, None, None
# Extract outcome changes for treated and control
treated_change = outcome_change[treated_valid]
control_change = outcome_change[control_valid]
# Extract survey weights for treated and control
survey_w = precomputed.get("survey_weights")
sw_treated = survey_w[treated_valid] if survey_w is not None else None
sw_control = survey_w[control_valid] if survey_w is not None else None
# Guard against zero effective mass after subpopulation filtering
if sw_treated is not None and np.sum(sw_treated) <= 0:
return None, 0.0, 0, 0, None, None
if sw_control is not None and np.sum(sw_control) <= 0:
return None, 0.0, 0, 0, None, None
# Get covariates if specified (from the base period)
X_treated = None
X_control = None
if covariates and covariate_by_period is not None:
cov_matrix = covariate_by_period[base_period_val]
X_treated = cov_matrix[treated_valid]
X_control = cov_matrix[control_valid]
# Check for missing values
if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)):
warnings.warn(
f"Missing values in covariates for group {g}, time {t}. "
"Falling back to unconditional estimation.",
UserWarning,
stacklevel=3,
)
X_treated = None
X_control = None
# Compute cache key for propensity score reuse
pscore_key = None
if pscore_cache is not None and X_treated is not None:
is_balanced = precomputed.get("is_balanced", False)
if is_balanced and self.control_group == "never_treated":
pscore_key = (g, base_period_val)
else:
pscore_key = (g, base_period_val, t)
# Compute cache key for Cholesky reuse (DR outcome regression)
cho_key = None
if cho_cache is not None and X_control is not None:
is_balanced = precomputed.get("is_balanced", False)
if is_balanced and self.control_group == "never_treated":
cho_key = base_period_val
else:
cho_key = (g, base_period_val, t)
# Estimation method
if self.estimation_method == "reg":
att_gt, se_gt, inf_func = self._outcome_regression(
treated_change,
control_change,
X_treated,
X_control,
sw_treated=sw_treated,
sw_control=sw_control,
)
elif self.estimation_method == "ipw":
sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None
epv_diag: dict = {}
att_gt, se_gt, inf_func = self._ipw_estimation(
treated_change,
control_change,
int(n_treated),
int(n_control),
X_treated,
X_control,
pscore_cache=pscore_cache,
pscore_key=pscore_key,
sw_treated=sw_treated,
sw_control=sw_control,
sw_all=sw_all,
context_label=f"cohort g={g}",
epv_diagnostics_out=epv_diag,
)
if epv_diagnostics is not None and epv_diag:
epv_diagnostics[(g, t)] = epv_diag
else: # doubly robust
sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None
epv_diag = {}
att_gt, se_gt, inf_func = self._doubly_robust(
treated_change,
control_change,
X_treated,
X_control,
pscore_cache=pscore_cache,
pscore_key=pscore_key,
cho_cache=cho_cache,
cho_key=cho_key,
sw_treated=sw_treated,
sw_control=sw_control,
sw_all=sw_all,
context_label=f"cohort g={g}",
epv_diagnostics_out=epv_diag,
)
if epv_diagnostics is not None and epv_diag:
epv_diagnostics[(g, t)] = epv_diag
# Package influence function info with index arrays (positions into
# precomputed['all_units']) for O(1) downstream lookups instead of
# O(n) Python dict lookups.
n_t = int(n_treated)
all_units = precomputed["all_units"]
treated_positions = np.where(treated_valid)[0]
control_positions = np.where(control_valid)[0]
inf_func_info = {
"treated_idx": treated_positions,
"control_idx": control_positions,
"treated_units": all_units[treated_positions],
"control_units": all_units[control_positions],
"treated_inf": inf_func[:n_t],
"control_inf": inf_func[n_t:],
}
sw_sum = float(np.sum(sw_treated)) if sw_treated is not None else None
return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info, sw_sum
def _compute_all_att_gt_vectorized(
self,
precomputed: PrecomputedData,
treatment_groups: List[Any],
time_periods: List[Any],
min_period: Any,
) -> Tuple[Dict, Dict, Dict]:
"""
Vectorized computation of all ATT(g,t) for the no-covariates regression case.
This inlines the simple difference-in-means path from _outcome_regression()
and eliminates per-(g,t) Python function call overhead.
Returns
-------
group_time_effects : dict
Mapping (g, t) -> effect dict.
influence_func_info : dict
Mapping (g, t) -> influence function info dict.
"""
period_to_col = precomputed["period_to_col"]
outcome_matrix = precomputed["outcome_matrix"]
cohort_masks = precomputed["cohort_masks"]
never_treated_mask = precomputed["never_treated_mask"]
unit_cohorts = precomputed["unit_cohorts"]
survey_w = precomputed.get("survey_weights")
group_time_effects = {}
influence_func_info = {}
skipped_missing_period: List[Tuple] = []
skipped_empty_cell: List[Tuple] = []
# Collect all valid (g, t, base_col, post_col) tuples
tasks = []
for g in treatment_groups:
if self.base_period == "universal":
universal_base = g - 1 - self.anticipation
valid_periods = [t for t in time_periods if t != universal_base]
else:
valid_periods = [
t for t in time_periods if t >= g - self.anticipation or t > min_period
]
for t in valid_periods:
# Base period selection
if self.base_period == "universal":
base_period_val = g - 1 - self.anticipation
else:
if t < g - self.anticipation:
base_period_val = t - 1
else:
base_period_val = g - 1 - self.anticipation
if base_period_val not in period_to_col or t not in period_to_col:
skipped_missing_period.append((g, t))
continue
tasks.append(
(g, t, period_to_col[base_period_val], period_to_col[t], base_period_val)
)
# Process all tasks
atts = []
ses = []
task_keys = []
for g, t, base_col, post_col, base_period_val in tasks:
treated_mask = cohort_masks[g]
if self.control_group == "never_treated":
control_mask = never_treated_mask
else:
# Controls must be untreated at both t and base_period_val
nyt_threshold = max(t, base_period_val) + self.anticipation
control_mask = never_treated_mask | (
(unit_cohorts > nyt_threshold) & (unit_cohorts != g)
)
y_base = outcome_matrix[:, base_col]
y_post = outcome_matrix[:, post_col]
outcome_change = y_post - y_base
valid_mask = ~(np.isnan(y_base) | np.isnan(y_post))
treated_valid = treated_mask & valid_mask
control_valid = control_mask & valid_mask
n_treated = np.sum(treated_valid)
n_control = np.sum(control_valid)
if n_treated == 0 or n_control == 0:
skipped_empty_cell.append((g, t))
continue
treated_change = outcome_change[treated_valid]
control_change = outcome_change[control_valid]
n_t = int(n_treated)
n_c = int(n_control)
# Inline no-covariates regression (difference in means)
if survey_w is not None:
sw_t = survey_w[treated_valid]
sw_c = survey_w[control_valid]
# Guard against zero effective mass
if np.sum(sw_t) <= 0 or np.sum(sw_c) <= 0:
skipped_empty_cell.append((g, t))
continue
sw_t_norm = sw_t / np.sum(sw_t)
sw_c_norm = sw_c / np.sum(sw_c)
mu_t = float(np.sum(sw_t_norm * treated_change))
mu_c = float(np.sum(sw_c_norm * control_change))
att = mu_t - mu_c
# Influence function (survey-weighted)
inf_treated = sw_t_norm * (treated_change - mu_t)
inf_control = -sw_c_norm * (control_change - mu_c)
# SE derived from IF: sum(IF_i^2)
se = (
float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2)))
if (n_t > 0 and n_c > 0)
else 0.0
)
sw_sum = float(np.sum(sw_t))
else:
att = float(np.mean(treated_change) - np.mean(control_change))
var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0
var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0
se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0
# Influence function
inf_treated = (treated_change - np.mean(treated_change)) / n_t
inf_control = -(control_change - np.mean(control_change)) / n_c
sw_sum = None
gte_entry = {
"effect": att,
"se": se,
# t_stat, p_value, conf_int filled by batch inference below
"t_stat": np.nan,
"p_value": np.nan,
"conf_int": (np.nan, np.nan),
"n_treated": n_t,
"n_control": n_c,
}
if sw_sum is not None:
gte_entry["survey_weight_sum"] = sw_sum
group_time_effects[(g, t)] = gte_entry
all_units = precomputed["all_units"]