-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathcontinuous_did.py
More file actions
1626 lines (1456 loc) · 66.9 KB
/
continuous_did.py
File metadata and controls
1626 lines (1456 loc) · 66.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Continuous Difference-in-Differences estimator.
Implements Callaway, Goodman-Bacon & Sant'Anna (2024),
"Difference-in-Differences with a Continuous Treatment" (NBER WP 32117).
Estimates dose-response curves ATT(d) and ACRT(d), as well as summary
parameters ATT^{glob} and ACRT^{glob}, with optional multiplier bootstrap
inference.
"""
import warnings
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from diff_diff.bootstrap_utils import (
compute_effect_bootstrap_stats,
generate_bootstrap_weights_batch,
)
from diff_diff.continuous_did_bspline import (
bspline_derivative_design_matrix,
bspline_design_matrix,
build_bspline_basis,
default_dose_grid,
)
from diff_diff.continuous_did_results import (
ContinuousDiDResults,
DoseResponseCurve,
)
from diff_diff.linalg import solve_ols
from diff_diff.survey import (
ResolvedSurveyDesign,
_resolve_survey_for_fit,
_validate_unit_constant_survey,
compute_survey_vcov,
)
from diff_diff.utils import safe_inference
__all__ = ["ContinuousDiD", "ContinuousDiDResults", "DoseResponseCurve"]
class ContinuousDiD:
"""
Continuous Difference-in-Differences estimator.
Implements the methodology from Callaway, Goodman-Bacon & Sant'Anna (2024)
for estimating dose-response curves when treatment has a continuous intensity.
Parameters
----------
degree : int, default=3
B-spline degree (3 = cubic).
num_knots : int, default=0
Number of interior knots for the B-spline basis.
dvals : array-like, optional
Custom dose evaluation grid. If None, uses quantile-based default.
control_group : str, default="never_treated"
``"never_treated"`` or ``"not_yet_treated"``.
anticipation : int, default=0
Number of periods of treatment anticipation.
base_period : str, default="varying"
``"varying"`` or ``"universal"``.
alpha : float, default=0.05
Significance level for confidence intervals.
n_bootstrap : int, default=0
Number of multiplier bootstrap iterations. 0 for analytical SEs only.
bootstrap_weights : str, default="rademacher"
Bootstrap weight type: ``"rademacher"``, ``"mammen"``, or ``"webb"``.
seed : int, optional
Random seed for reproducibility.
rank_deficient_action : str, default="warn"
Action for rank-deficient B-spline OLS: ``"warn"``, ``"error"``, or ``"silent"``.
Examples
--------
>>> from diff_diff import ContinuousDiD, generate_continuous_did_data
>>> data = generate_continuous_did_data(n_units=200, seed=42)
>>> est = ContinuousDiD(n_bootstrap=199, seed=42)
>>> results = est.fit(data, outcome="outcome", unit="unit",
... time="period", first_treat="first_treat",
... dose="dose", aggregate="dose")
>>> results.overall_att # doctest: +SKIP
"""
_VALID_CONTROL_GROUPS = {"never_treated", "not_yet_treated"}
_VALID_BASE_PERIODS = {"varying", "universal"}
def __init__(
self,
degree: int = 3,
num_knots: int = 0,
dvals: Optional[np.ndarray] = None,
control_group: str = "never_treated",
anticipation: int = 0,
base_period: str = "varying",
alpha: float = 0.05,
n_bootstrap: int = 0,
bootstrap_weights: str = "rademacher",
seed: Optional[int] = None,
rank_deficient_action: str = "warn",
):
self.degree = degree
self.num_knots = num_knots
self.dvals = np.asarray(dvals, dtype=float) if dvals is not None else None
self.control_group = control_group
self.anticipation = anticipation
self.base_period = base_period
self.alpha = alpha
self.n_bootstrap = n_bootstrap
self.bootstrap_weights = bootstrap_weights
self.seed = seed
self.rank_deficient_action = rank_deficient_action
self._validate_constrained_params()
def _validate_constrained_params(self) -> None:
"""Validate control_group and base_period values."""
if self.control_group not in self._VALID_CONTROL_GROUPS:
raise ValueError(
f"Invalid control_group: '{self.control_group}'. "
f"Must be one of {self._VALID_CONTROL_GROUPS}."
)
if self.base_period not in self._VALID_BASE_PERIODS:
raise ValueError(
f"Invalid base_period: '{self.base_period}'. "
f"Must be one of {self._VALID_BASE_PERIODS}."
)
def get_params(self) -> Dict[str, Any]:
"""Return estimator parameters as a dictionary."""
return {
"degree": self.degree,
"num_knots": self.num_knots,
"dvals": self.dvals,
"control_group": self.control_group,
"anticipation": self.anticipation,
"base_period": self.base_period,
"alpha": self.alpha,
"n_bootstrap": self.n_bootstrap,
"bootstrap_weights": self.bootstrap_weights,
"seed": self.seed,
"rank_deficient_action": self.rank_deficient_action,
}
def set_params(self, **params) -> "ContinuousDiD":
"""Set estimator parameters and return self."""
for key, value in params.items():
if not hasattr(self, key):
raise ValueError(f"Invalid parameter: {key}")
setattr(self, key, value)
self._validate_constrained_params()
return self
# ------------------------------------------------------------------
# Main fit
# ------------------------------------------------------------------
def fit(
self,
data: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
dose: str,
aggregate: Optional[str] = None,
survey_design: object = None,
) -> ContinuousDiDResults:
"""
Fit the continuous DiD estimator.
Parameters
----------
data : pd.DataFrame
Panel data.
outcome : str
Outcome column name.
unit : str
Unit identifier column.
time : str
Time period column.
first_treat : str
First treatment period column (0 or inf for never-treated).
dose : str
Continuous dose column.
aggregate : str, optional
``"dose"`` for dose-response aggregation, ``"eventstudy"`` for
binarized event study.
survey_design : SurveyDesign, optional
Survey design specification for design-based inference.
Supports weighted estimation and Taylor series linearization
variance with strata, PSU, and FPC.
Returns
-------
ContinuousDiDResults
"""
# 1. Validate & prepare
_VALID_AGGREGATES = (None, "dose", "eventstudy")
if aggregate not in _VALID_AGGREGATES:
raise ValueError(
f"Invalid aggregate: '{aggregate}'. " f"Must be one of {_VALID_AGGREGATES}."
)
# Resolve survey design if provided
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
_resolve_survey_for_fit(survey_design, data, "analytical")
)
# Validate within-unit constancy for panel survey designs
if resolved_survey is not None:
_validate_unit_constant_survey(data, unit, survey_design)
# Bootstrap + survey supported via PSU-level multiplier bootstrap.
df = data.copy()
for col in [outcome, unit, time, first_treat, dose]:
if col not in df.columns:
raise ValueError(f"Column '{col}' not found in data.")
# Verify dose is time-invariant
dose_nunique = df.groupby(unit)[dose].nunique()
if dose_nunique.max() > 1:
bad_units = dose_nunique[dose_nunique > 1].index.tolist()
raise ValueError(
f"Dose must be time-invariant. Units with varying dose: {bad_units[:5]}"
)
# Normalize first_treat: inf → 0
df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0)
# Drop units with positive first_treat but zero dose (R convention)
unit_info = df.groupby(unit).first()[[first_treat, dose]]
drop_units = unit_info[(unit_info[first_treat] > 0) & (unit_info[dose] == 0)].index
if len(drop_units) > 0:
warnings.warn(
f"Dropping {len(drop_units)} units with positive first_treat but zero dose.",
UserWarning,
stacklevel=2,
)
df = df[~df[unit].isin(drop_units)]
# Validate no negative doses among treated units
treated_doses = df.loc[df[first_treat] > 0, dose]
if (treated_doses < 0).any():
n_neg = int((treated_doses < 0).sum())
raise ValueError(
f"Found {n_neg} treated unit(s) with negative dose. "
f"Dose must be strictly positive for treated units (D > 0)."
)
# Detect discrete (integer-valued) dose among treated units
unit_doses = df.loc[df[first_treat] > 0].groupby(unit)[dose].first()
unique_pos_doses = unit_doses[unit_doses > 0].unique()
is_integer = len(unique_pos_doses) > 0 and np.allclose(
unique_pos_doses, np.round(unique_pos_doses)
)
if is_integer:
warnings.warn(
f"Dose appears discrete ({len(unique_pos_doses)} unique integer values). "
"B-spline smoothing may be inappropriate for discrete treatments. "
"Consider a saturated regression approach (not yet implemented).",
UserWarning,
stacklevel=2,
)
# Force dose=0 for never-treated units with nonzero dose
never_treated_mask = df[first_treat] == 0
if (df.loc[never_treated_mask, dose] != 0).any():
df.loc[never_treated_mask, dose] = 0.0
# Verify balanced panel
all_periods = set(df[time].unique())
unit_periods = df.groupby(unit)[time].apply(set)
is_unbalanced = unit_periods.apply(lambda s: s != all_periods)
if is_unbalanced.any():
n_bad = int(is_unbalanced.sum())
raise ValueError(
"Unbalanced panel detected. ContinuousDiD requires a balanced panel. "
f"{n_bad} unit(s) have missing periods."
)
# Identify groups and time periods
unit_cohort = df.groupby(unit)[first_treat].first()
treatment_groups = sorted([g for g in unit_cohort.unique() if g > 0])
time_periods = sorted(df[time].unique())
if len(treatment_groups) == 0:
raise ValueError("No treated units found (all first_treat == 0).")
n_control = int((unit_cohort == 0).sum())
if self.control_group == "never_treated" and n_control == 0:
raise ValueError(
"No never-treated units found. Use control_group='not_yet_treated' "
"or add never-treated units."
)
if self.control_group == "not_yet_treated" and n_control == 0:
raise ValueError(
"No never-treated (D=0) units found. With control_group='not_yet_treated', "
"dose-response curve identification requires P(D=0) > 0 "
"(Remark 3.1 in Callaway et al. is not yet implemented). "
"Add never-treated units or use a dataset with D=0 observations."
)
# Re-resolve survey design on filtered df if rows were dropped
# (survey arrays must align with df, not the original data)
if resolved_survey is not None and len(df) < len(data):
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
_resolve_survey_for_fit(survey_design, df, "analytical")
)
# 2. Precompute structures
precomp = self._precompute_structures(
df,
outcome,
unit,
time,
first_treat,
dose,
time_periods,
survey_weights=survey_weights,
)
# Compute dvals (evaluation grid)
all_treated_doses = precomp["dose_vector"][precomp["dose_vector"] > 0]
if self.dvals is not None:
dvals = self.dvals
else:
dvals = default_dose_grid(all_treated_doses)
# Build B-spline knots from all treated doses
knots, degree = build_bspline_basis(
all_treated_doses, degree=self.degree, num_knots=self.num_knots
)
# 3. Iterate over (g,t) cells
gt_results = {}
gt_bootstrap_info = {}
for g in treatment_groups:
for t in time_periods:
result = self._compute_dose_response_gt(
precomp,
g,
t,
knots,
degree,
dvals,
survey_weights=precomp.get("unit_survey_weights"),
resolved_survey=resolved_survey,
)
if result is not None:
gt_results[(g, t)] = result
gt_bootstrap_info[(g, t)] = result.get("_bootstrap_info", {})
# Filter out NaN cells (e.g., from zero effective survey mass)
gt_results = {
gt: r for gt, r in gt_results.items()
if np.isfinite(r.get("att_glob", np.nan))
}
if len(gt_results) == 0:
raise ValueError("No valid (g,t) cells computed.")
# 4. Aggregate
post_gt = {(g, t): r for (g, t), r in gt_results.items() if t >= g - self.anticipation}
# Dose-response aggregation
n_grid = len(dvals)
# NaN-initialized SE/CI fields (used when post_gt is empty or as defaults)
att_d_se = np.full(n_grid, np.nan)
att_d_ci_lower = np.full(n_grid, np.nan)
att_d_ci_upper = np.full(n_grid, np.nan)
acrt_d_se = np.full(n_grid, np.nan)
acrt_d_ci_lower = np.full(n_grid, np.nan)
acrt_d_ci_upper = np.full(n_grid, np.nan)
overall_att_se = np.nan
overall_att_t = np.nan
overall_att_p = np.nan
overall_att_ci = (np.nan, np.nan)
overall_acrt_se = np.nan
overall_acrt_t = np.nan
overall_acrt_p = np.nan
overall_acrt_ci = (np.nan, np.nan)
att_d_p = None
acrt_d_p = None
# Event study aggregation (binarized) — runs on ALL (g,t) cells
event_study_effects = None
if aggregate == "eventstudy":
event_study_effects = self._aggregate_event_study(
gt_results,
gt_bootstrap_info=gt_bootstrap_info,
unit_survey_weights=precomp.get("unit_survey_weights"),
unit_cohorts=precomp["unit_cohorts"],
anticipation=self.anticipation,
)
_survey_df = None # Set by analytical branch when survey is active
if len(post_gt) == 0:
warnings.warn(
"No post-treatment (g,t) cells available for aggregation. "
"This can occur when all treatments start after the last observed "
"period or all cells were skipped due to insufficient data.",
UserWarning,
stacklevel=2,
)
overall_att = np.nan
overall_acrt = np.nan
agg_att_d = np.full(n_grid, np.nan)
agg_acrt_d = np.full(n_grid, np.nan)
else:
# Compute cell weights: group-proportional (matching R's contdid convention).
# Each group g gets weight proportional to its number of treated units.
# When survey weights present, use sum(w_g) / sum(w) instead of n_g / N.
# Within each group, weight is divided equally among post-treatment cells.
group_n_treated = {}
group_n_post_cells = {}
unit_sw = precomp.get("unit_survey_weights")
for (g, t), r in post_gt.items():
if g not in group_n_treated:
if unit_sw is not None:
# Survey-weighted group size: sum of weights for treated units in g
g_mask = precomp["unit_cohorts"] == g
group_n_treated[g] = float(np.sum(unit_sw[g_mask]))
else:
group_n_treated[g] = float(r["n_treated"])
group_n_post_cells[g] = 0
group_n_post_cells[g] += 1
total_treated = sum(group_n_treated.values())
cell_weights = {}
if total_treated > 0:
for (g, t), r in post_gt.items():
pg = group_n_treated[g] / total_treated
cell_weights[(g, t)] = pg / group_n_post_cells[g]
agg_att_d = np.zeros(n_grid)
agg_acrt_d = np.zeros(n_grid)
overall_att = 0.0
overall_acrt = 0.0
for gt, w in cell_weights.items():
r = post_gt[gt]
agg_att_d += w * r["att_d"]
agg_acrt_d += w * r["acrt_d"]
overall_att += w * r["att_glob"]
overall_acrt += w * r["acrt_glob"]
# 5. Bootstrap / Analytical SE
if self.n_bootstrap > 0:
boot_result = self._run_bootstrap(
precomp,
gt_results,
gt_bootstrap_info,
post_gt,
cell_weights,
knots,
degree,
dvals,
overall_att,
overall_acrt,
agg_att_d,
agg_acrt_d,
event_study_effects,
resolved_survey=resolved_survey,
)
att_d_se = boot_result["att_d_se"]
att_d_ci_lower = boot_result["att_d_ci_lower"]
att_d_ci_upper = boot_result["att_d_ci_upper"]
acrt_d_se = boot_result["acrt_d_se"]
acrt_d_ci_lower = boot_result["acrt_d_ci_lower"]
acrt_d_ci_upper = boot_result["acrt_d_ci_upper"]
att_d_p = boot_result["att_d_p"]
acrt_d_p = boot_result["acrt_d_p"]
overall_att_se = boot_result["overall_att_se"]
overall_att_t = safe_inference(overall_att, overall_att_se, self.alpha)[0]
overall_att_p = boot_result["overall_att_p"]
overall_att_ci = boot_result["overall_att_ci"]
overall_acrt_se = boot_result["overall_acrt_se"]
overall_acrt_t = safe_inference(overall_acrt, overall_acrt_se, self.alpha)[0]
overall_acrt_p = boot_result["overall_acrt_p"]
overall_acrt_ci = boot_result["overall_acrt_ci"]
if event_study_effects is not None:
for e, info in event_study_effects.items():
if e in boot_result.get("es_se", {}):
info["se"] = boot_result["es_se"][e]
info["t_stat"] = safe_inference(info["effect"], info["se"], self.alpha)[
0
]
info["p_value"] = boot_result["es_p"][e]
info["conf_int"] = boot_result["es_ci"][e]
else:
# Analytical SEs via influence functions
analytic = self._compute_analytical_se(
precomp,
gt_results,
gt_bootstrap_info,
post_gt,
cell_weights,
knots,
degree,
dvals,
agg_att_d,
agg_acrt_d,
resolved_survey=resolved_survey,
)
att_d_se = analytic["att_d_se"]
acrt_d_se = analytic["acrt_d_se"]
overall_att_se = analytic["overall_att_se"]
overall_acrt_se = analytic["overall_acrt_se"]
# Survey df for t-distribution inference (unit-level, not panel-level)
_survey_df = analytic.get("df_survey")
# Guard: replicate design with undefined df → NaN inference
if (_survey_df is None and resolved_survey is not None
and hasattr(resolved_survey, 'uses_replicate_variance')
and resolved_survey.uses_replicate_variance):
_survey_df = 0
# Recompute survey_metadata from unit-level design so reported
# effective_n/n_psu/df_survey match the inference actually run
_unit_resolved = analytic.get("unit_resolved")
if _unit_resolved is not None:
from diff_diff.survey import compute_survey_metadata
raw_w_unit = _unit_resolved.weights
survey_metadata = compute_survey_metadata(_unit_resolved, raw_w_unit)
# Propagate replicate df override to survey_metadata for display
# (but not the df=0 sentinel — keep metadata as None for undefined df)
if (_survey_df is not None and _survey_df != 0
and survey_metadata is not None):
if survey_metadata.df_survey != _survey_df:
survey_metadata.df_survey = _survey_df
overall_att_t, overall_att_p, overall_att_ci = safe_inference(
overall_att, overall_att_se, self.alpha, df=_survey_df
)
overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference(
overall_acrt, overall_acrt_se, self.alpha, df=_survey_df
)
# Per-grid-point inference for dose-response
for idx in range(n_grid):
_, _, ci = safe_inference(
agg_att_d[idx], att_d_se[idx], self.alpha, df=_survey_df
)
att_d_ci_lower[idx] = ci[0]
att_d_ci_upper[idx] = ci[1]
_, _, ci = safe_inference(
agg_acrt_d[idx], acrt_d_se[idx], self.alpha, df=_survey_df
)
acrt_d_ci_lower[idx] = ci[0]
acrt_d_ci_upper[idx] = ci[1]
# Event study analytical SEs
if event_study_effects is not None:
n_units = precomp["n_units"]
unit_sw = precomp.get("unit_survey_weights")
# Build unit-level ResolvedSurveyDesign once (reused per bin)
unit_resolved_es = None
if resolved_survey is not None:
row_idx = precomp["unit_first_panel_row"]
uw = (
precomp.get("unit_survey_weights")
if precomp.get("unit_survey_weights") is not None
else np.ones(n_units)
)
us = (
resolved_survey.strata[row_idx]
if resolved_survey.strata is not None
else None
)
up = (
resolved_survey.psu[row_idx]
if resolved_survey.psu is not None
else None
)
uf = (
resolved_survey.fpc[row_idx]
if resolved_survey.fpc is not None
else None
)
n_strata_u = len(np.unique(us)) if us is not None else 0
n_psu_u = len(np.unique(up)) if up is not None else 0
unit_resolved_es = resolved_survey.subset_to_units(
row_idx, uw, us, up, uf, n_strata_u, n_psu_u,
)
for e_val, info_e in event_study_effects.items():
# Collect (g,t) cells for this event-time bin
e_gts = [gt for gt in gt_results if gt[1] - gt[0] == e_val]
if not e_gts:
continue
# Weights within this bin: survey-weighted mass or n_treated
if unit_sw is not None:
unit_cohorts = precomp["unit_cohorts"]
ns = np.array(
[float(np.sum(unit_sw[unit_cohorts == gt[0]])) for gt in e_gts],
dtype=float,
)
else:
ns = np.array(
[gt_results[gt]["n_treated"] for gt in e_gts],
dtype=float,
)
total_n = ns.sum()
if total_n == 0:
continue
ws = ns / total_n
# Build per-unit IF for this event-time bin
if_es = np.zeros(n_units)
for idx_cell, gt in enumerate(e_gts):
b_info = gt_bootstrap_info.get(gt, {})
if not b_info:
continue
w = ws[idx_cell]
treated_idx = b_info["treated_indices"]
control_idx = b_info["control_indices"]
n_t = b_info["n_treated"]
n_c = b_info["n_control"]
# Use survey-weighted masses when available
if "w_treated" in b_info:
n_t = b_info["w_treated"]
n_c = b_info["w_control"]
n_total_gt = n_t + n_c
p_1 = n_t / n_total_gt
p_0 = n_c / n_total_gt
att_glob_gt = b_info["att_glob"]
mu_0 = b_info["mu_0"]
delta_y_treated = b_info["delta_y_treated"]
ee_control = b_info["ee_control"]
sw_treated = b_info.get("w_treated_arr")
for k, uid in enumerate(treated_idx):
score_k = delta_y_treated[k] - att_glob_gt - mu_0
if sw_treated is not None:
score_k = sw_treated[k] * score_k
if_es[uid] += w * score_k / p_1 / n_total_gt
for k, uid in enumerate(control_idx):
if_es[uid] -= w * ee_control[k] / p_0 / n_total_gt
# Compute SE: survey-aware TSL or standard sqrt(sum(IF^2))
if unit_resolved_es is not None:
if unit_resolved_es.uses_replicate_variance:
from diff_diff.survey import compute_replicate_if_variance
# Score-scale: psi = w * if_es (matches TSL bread)
psi_es = unit_resolved_es.weights * if_es
variance, _nv = compute_replicate_if_variance(psi_es, unit_resolved_es)
es_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
else:
X_ones_es = np.ones((n_units, 1))
tsl_scale_es = float(unit_resolved_es.weights.sum())
if_es_tsl = if_es * tsl_scale_es
vcov_es = compute_survey_vcov(X_ones_es, if_es_tsl, unit_resolved_es)
es_se = float(np.sqrt(np.abs(vcov_es[0, 0])))
else:
es_se = float(np.sqrt(np.sum(if_es**2)))
t_stat, p_val, ci_es = safe_inference(
info_e["effect"], es_se, self.alpha, df=_survey_df
)
info_e["se"] = es_se
info_e["t_stat"] = t_stat
info_e["p_value"] = p_val
info_e["conf_int"] = ci_es
# 6. Assemble results
dose_response_att = DoseResponseCurve(
dose_grid=dvals,
effects=agg_att_d,
se=att_d_se,
conf_int_lower=att_d_ci_lower,
conf_int_upper=att_d_ci_upper,
target="att",
p_value=att_d_p,
n_bootstrap=self.n_bootstrap,
df_survey=_survey_df,
)
dose_response_acrt = DoseResponseCurve(
dose_grid=dvals,
effects=agg_acrt_d,
se=acrt_d_se,
conf_int_lower=acrt_d_ci_lower,
conf_int_upper=acrt_d_ci_upper,
target="acrt",
p_value=acrt_d_p,
n_bootstrap=self.n_bootstrap,
df_survey=_survey_df,
)
# Strip bootstrap internals from gt_results
clean_gt = {}
for gt, r in gt_results.items():
clean_gt[gt] = {k: v for k, v in r.items() if not k.startswith("_")}
return ContinuousDiDResults(
dose_response_att=dose_response_att,
dose_response_acrt=dose_response_acrt,
overall_att=overall_att,
overall_att_se=overall_att_se,
overall_att_t_stat=overall_att_t,
overall_att_p_value=overall_att_p,
overall_att_conf_int=overall_att_ci,
overall_acrt=overall_acrt,
overall_acrt_se=overall_acrt_se,
overall_acrt_t_stat=overall_acrt_t,
overall_acrt_p_value=overall_acrt_p,
overall_acrt_conf_int=overall_acrt_ci,
group_time_effects=clean_gt,
dose_grid=dvals,
groups=treatment_groups,
time_periods=time_periods,
n_obs=len(df),
n_treated_units=int((unit_cohort > 0).sum()),
n_control_units=n_control,
alpha=self.alpha,
control_group=self.control_group,
degree=self.degree,
num_knots=self.num_knots,
base_period=self.base_period,
anticipation=self.anticipation,
n_bootstrap=self.n_bootstrap,
bootstrap_weights=self.bootstrap_weights,
seed=self.seed,
rank_deficient_action=self.rank_deficient_action,
event_study_effects=event_study_effects,
survey_metadata=survey_metadata,
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _precompute_structures(
self,
df: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
dose: str,
time_periods: List[Any],
survey_weights: Optional[np.ndarray] = None,
) -> Dict[str, Any]:
"""Pivot to wide format and build lookup structures."""
all_units = sorted(df[unit].unique())
unit_to_idx = {u: i for i, u in enumerate(all_units)}
n_units = len(all_units)
n_periods = len(time_periods)
period_to_col = {t: j for j, t in enumerate(time_periods)}
# Outcome matrix: (n_units, n_periods)
outcome_matrix = np.full((n_units, n_periods), np.nan)
for _, row in df.iterrows():
i = unit_to_idx[row[unit]]
j = period_to_col[row[time]]
outcome_matrix[i, j] = row[outcome]
# Per-unit cohort and dose
unit_cohorts = np.zeros(n_units, dtype=float)
dose_vector = np.zeros(n_units, dtype=float)
unit_first = df.groupby(unit).first()
for u in all_units:
i = unit_to_idx[u]
unit_cohorts[i] = unit_first.loc[u, first_treat]
dose_vector[i] = unit_first.loc[u, dose]
# Build unit-to-first-panel-row mapping (for subsetting panel-level arrays)
# This maps each unit index to the positional index of its first row in df.
unit_first_panel_row = np.zeros(n_units, dtype=int)
seen_units: set = set()
for pos_idx, (_, row) in enumerate(df.iterrows()):
u = row[unit]
if u not in seen_units:
seen_units.add(u)
unit_first_panel_row[unit_to_idx[u]] = pos_idx
# Per-unit survey weights (take first obs per unit from panel data)
unit_survey_weights = None
if survey_weights is not None:
unit_survey_weights = survey_weights[unit_first_panel_row]
# Cohort masks
cohort_masks = {}
unique_cohorts = np.unique(unit_cohorts)
for c in unique_cohorts:
cohort_masks[c] = unit_cohorts == c
never_treated_mask = unit_cohorts == 0
return {
"all_units": all_units,
"unit_to_idx": unit_to_idx,
"outcome_matrix": outcome_matrix,
"period_to_col": period_to_col,
"unit_cohorts": unit_cohorts,
"dose_vector": dose_vector,
"cohort_masks": cohort_masks,
"never_treated_mask": never_treated_mask,
"time_periods": time_periods,
"n_units": n_units,
"unit_survey_weights": unit_survey_weights,
"unit_first_panel_row": unit_first_panel_row,
}
def _compute_dose_response_gt(
self,
precomp: Dict[str, Any],
g: Any,
t: Any,
knots: np.ndarray,
degree: int,
dvals: np.ndarray,
survey_weights: Optional[np.ndarray] = None,
resolved_survey: object = None,
) -> Optional[Dict[str, Any]]:
"""Compute dose-response for a single (g,t) cell."""
period_to_col = precomp["period_to_col"]
outcome_matrix = precomp["outcome_matrix"]
unit_cohorts = precomp["unit_cohorts"]
dose_vector = precomp["dose_vector"]
never_treated_mask = precomp["never_treated_mask"]
time_periods = precomp["time_periods"]
# Base period selection
is_post = t >= g - self.anticipation
if self.base_period == "varying":
if is_post:
base_t = g - 1 - self.anticipation
else:
# Pre-treatment: use t-1
t_idx = time_periods.index(t)
if t_idx == 0:
return None # No prior period
base_t = time_periods[t_idx - 1]
else:
# Universal base period
base_t = g - 1 - self.anticipation
if base_t not in period_to_col or t not in period_to_col:
return None
col_t = period_to_col[t]
col_base = period_to_col[base_t]
# Treated units: first_treat == g and dose > 0
treated_mask = (unit_cohorts == g) & (dose_vector > 0)
n_treated = int(np.sum(treated_mask))
if n_treated == 0:
return None
# Control units
if self.control_group == "never_treated":
control_mask = never_treated_mask
else:
# Not-yet-treated: never-treated + first_treat > t
control_mask = never_treated_mask | (
(unit_cohorts > t + self.anticipation) & (unit_cohorts != g)
)
n_control = int(np.sum(control_mask))
if n_control == 0:
warnings.warn(
f"No control units for (g={g}, t={t}). Skipping.",
UserWarning,
stacklevel=3,
)
return None
# Outcome changes
delta_y_treated = (
outcome_matrix[treated_mask, col_t] - outcome_matrix[treated_mask, col_base]
)
delta_y_control = (
outcome_matrix[control_mask, col_t] - outcome_matrix[control_mask, col_base]
)
# Subset survey weights to the cell
w_treated = None
w_control = None
if survey_weights is not None:
w_treated = survey_weights[treated_mask]
w_control = survey_weights[control_mask]
# Guard against zero effective mass (e.g., after subpopulation)
if np.sum(w_treated) <= 0 or np.sum(w_control) <= 0:
return {
"att_glob": np.nan, "acrt_glob": np.nan,
"n_treated": 0, "n_control": 0,
"att_d": np.full(len(dvals), np.nan),
"acrt_d": np.full(len(dvals), np.nan),
}
# Control counterfactual (weighted mean when survey weights present)
if w_control is not None:
mu_0 = float(np.average(delta_y_control, weights=w_control))
else:
mu_0 = float(np.mean(delta_y_control))
# Demean
delta_tilde_y = delta_y_treated - mu_0
# Treated doses
treated_doses = dose_vector[treated_mask]
# B-spline OLS
Psi = bspline_design_matrix(treated_doses, knots, degree, include_intercept=True)
n_basis = Psi.shape[1]
# Check for all-same dose
if np.all(treated_doses == treated_doses[0]):
warnings.warn(
f"All treated doses identical in (g={g}, t={t}). " "ACRT(d) will be 0 everywhere.",
UserWarning,
stacklevel=3,
)
# Skip if not enough treated units for OLS (need n > K for residual df)
# When survey weights are present, use positive-weight count as
# the effective sample size — subpopulation() can zero weights
# leaving rows present but the weighted regression underidentified.
n_eff = int(np.count_nonzero(w_treated > 0)) if w_treated is not None else n_treated
if n_eff <= n_basis:
label = "positive-weight treated units" if w_treated is not None else "treated units"
warnings.warn(
f"Not enough {label} ({n_eff}) for {n_basis} basis functions "
f"in (g={g}, t={t}). Skipping cell.",
UserWarning,
stacklevel=3,
)
return None
# OLS or WLS regression
if w_treated is not None:
# WLS: apply sqrt(w) transformation
sqrt_w = np.sqrt(w_treated)
Psi_w = Psi * sqrt_w[:, np.newaxis]
delta_tilde_y_w = delta_tilde_y * sqrt_w
beta_hat, _, _ = solve_ols(
Psi_w,
delta_tilde_y_w,
return_vcov=False,
rank_deficient_action=self.rank_deficient_action,
)
# Residuals on original scale (for influence functions)
beta_pred_tmp = np.where(np.isnan(beta_hat), 0.0, beta_hat)
residuals = delta_tilde_y - Psi @ beta_pred_tmp
else:
beta_hat, residuals, _ = solve_ols(
Psi,
delta_tilde_y,
return_vcov=False,
rank_deficient_action=self.rank_deficient_action,
)
# For prediction: zero out NaN (dropped rank-deficient columns).
# solve_ols sets dropped-column coefficients to NaN (R convention);
# zeroing them produces correct predictions: ATT(d) = intercept
# (constant), ACRT(d) = 0 (derivative of intercept is 0).
beta_pred = np.where(np.isnan(beta_hat), 0.0, beta_hat)
# Evaluate ATT(d) and ACRT(d) at dvals
Psi_eval = bspline_design_matrix(dvals, knots, degree, include_intercept=True)
dPsi_eval = bspline_derivative_design_matrix(dvals, knots, degree, include_intercept=True)
att_d = Psi_eval @ beta_pred
acrt_d = dPsi_eval @ beta_pred
# Summary parameters
if w_treated is not None:
att_glob = float(np.average(delta_y_treated, weights=w_treated) - mu_0)
else:
att_glob = float(np.mean(delta_y_treated) - mu_0)
# ACRT^{glob}: plug-in average of ACRT(D_i) for treated
dPsi_treated = bspline_derivative_design_matrix(
treated_doses, knots, degree, include_intercept=True
)
if w_treated is not None:
acrt_glob = float(np.average(dPsi_treated @ beta_pred, weights=w_treated))
else:
acrt_glob = float(np.mean(dPsi_treated @ beta_pred))
# Store bootstrap info for influence function computation
# bread = (Psi'WPsi / n_treated)^{-1} when survey, (Psi'Psi / n_treated)^{-1} otherwise
if w_treated is not None:
w_treated_sum = float(np.sum(w_treated))
PtWP = Psi.T @ (Psi * w_treated[:, np.newaxis])
# Normalize bread by weighted mass (not raw count) for consistency
# with downstream IF score denominators that also use weighted mass
try:
bread = np.linalg.inv(PtWP / w_treated_sum)