-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathhad_pretests.py
More file actions
4951 lines (4560 loc) · 220 KB
/
had_pretests.py
File metadata and controls
4951 lines (4560 loc) · 220 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Pre-test diagnostics for the HeterogeneousAdoptionDiD estimator.
Paper Section 4 (de Chaisemartin, Ciccia, D'Haultfoeuille, Knau 2026,
arXiv:2405.04465v6) prescribes a four-step pre-testing workflow for TWFE
validity in HADs. This module ships the tests and the composite workflow:
Single-horizon tests:
1. :func:`qug_test` - order-statistic ratio test of the support infimum
``H_0: d_lower = 0`` (paper Theorem 4). Closed-form, tuning-free.
2. :func:`stute_test` - Cramer-von Mises cusum test of linearity of
``E[ΔY | D_2]`` with Mammen (1993) wild bootstrap p-value (paper
Appendix D).
3. :func:`yatchew_hr_test` - heteroskedasticity-robust variance-ratio
specification test (paper Theorem 7 / Equation 29). Feasible at
``G >= 100k``. Two nulls via the keyword-only ``null=`` argument:
``"linearity"`` (default; paper Theorem 7, fits ``Y ~ 1 + D``) and
``"mean_independence"`` (R-parity extension mirroring R
``YatchewTest::yatchew_test(order=0)``; fits ``Y ~ 1``). The
downstream variance-ratio machinery is shared between the two
modes — only the residual definition differs.
Joint / multi-period tests (Phase 3 follow-up):
4. :func:`stute_joint_pretest` - residuals-in core that generalizes the
single-horizon Stute CvM to K horizons with shared-η wild bootstrap
and sum-of-CvMs aggregation (Delgado 1993; Escanciano 2006).
5. :func:`joint_pretrends_test` - data-in wrapper for the mean-
independence null (paper step 2 pre-trends across pre-period
placebos, Section 4.2 footnote 6 + Section 4.3 paragraph 1).
6. :func:`joint_homogeneity_test` - data-in wrapper for the linearity
null across post-periods (paper Section 4.3 joint extension,
page 32).
Composite workflow:
:func:`did_had_pretest_workflow` has two dispatch modes:
- ``aggregate="overall"`` (default, two-period panel): runs steps 1 + 3
via :func:`qug_test` + :func:`stute_test` + :func:`yatchew_hr_test`.
Paper step 2 is NOT run on this path (a two-period panel has no pre-
period placebo); the verdict explicitly flags the Assumption 7 gap
via the ``"paper step 2 deferred"`` caveat so callers do not get an
unconditional "TWFE safe" signal.
- ``aggregate="event_study"`` (multi-period panel, >= 3 periods): runs
QUG at ``F`` + joint pre-trends Stute across earlier pre-periods +
joint homogeneity-linearity Stute across post-periods. Closes the
paper step-2 gap and does NOT emit the step-2-deferred caveat in the
verdict when at least one earlier pre-period is available. The
step-3 alternative (Yatchew-HR linearity) is subsumed by joint Stute
on this path; the paper does not derive a joint Yatchew variant, so
users who need Yatchew robustness under multi-period data can call
:func:`yatchew_hr_test` on each ``(base, post)`` pair manually.
(Step 4 in the paper's workflow is the decision itself - "use TWFE
if none of the tests rejects" - not a separate test.)
Eq. 17 / Eq. 18 linear-trend detrending (paper Section 5.2 Pierce-Schott
application) shipped in PR #389 (Phase 4 R-parity) as the
``trends_lin: bool = False`` keyword-only kwarg on
:func:`joint_pretrends_test`, :func:`joint_homogeneity_test`, AND
:meth:`HeterogeneousAdoptionDiD.fit` (event-study path). Mirrors R
``DIDHAD::did_had(..., trends_lin=TRUE)``. Survey-weighted variant is
not yet derived from the paper and raises ``NotImplementedError``;
tracked in ``TODO.md`` if user demand emerges. See
``docs/methodology/REGISTRY.md`` for the full algorithm narrative,
invariants, and deviation notes.
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional
import numpy as np
import pandas as pd
from scipy import stats
from diff_diff.bootstrap_utils import (
apply_stratum_centering,
generate_survey_multiplier_weights_batch,
)
from diff_diff.had import (
_aggregate_first_difference,
_aggregate_unit_resolved_survey,
_aggregate_unit_weights,
_json_safe_scalar,
_validate_had_panel,
_validate_had_panel_event_study,
)
from diff_diff.survey import (
HAD_DEPRECATION_MSG_SURVEY_KWARG,
HAD_DEPRECATION_MSG_WEIGHTS_KWARG_ARRAY_IN,
HAD_DEPRECATION_MSG_WEIGHTS_KWARG_DATA_IN,
HAD_DUAL_KNOB_MUTEX_MSG_ARRAY_IN,
HAD_DUAL_KNOB_MUTEX_MSG_DATA_IN,
SurveyDesign,
make_pweight_design,
)
from diff_diff.utils import _generate_mammen_weights
__all__ = [
"QUGTestResults",
"StuteTestResults",
"YatchewTestResults",
"StuteJointResult",
"HADPretestReport",
"qug_test",
"stute_test",
"yatchew_hr_test",
"stute_joint_pretest",
"joint_pretrends_test",
"joint_homogeneity_test",
"did_had_pretest_workflow",
]
_MIN_G_QUG = 2
_MIN_G_STUTE = 10
_MIN_G_YATCHEW = 3
_MIN_N_BOOTSTRAP = 99
_STUTE_LARGE_G_THRESHOLD = 100_000
# Scale-invariant tolerance for detecting a numerically exact linear OLS fit.
# The ratio SSR / TSS = sum(eps^2) / sum((dy - dybar)^2) equals 1 - R^2
# and is BOTH TRANSLATION-INVARIANT (centering absorbs additive shifts)
# and SCALE-INVARIANT (the ratio is dimensionless under multiplicative
# rescaling of dy). Under exact Assumption 8, residuals are mathematically
# zero; in practice FP round-off leaves eps on the order of machine-epsilon
# (~1e-16). Squared that is ~1e-32. The threshold ~1e-24 leaves ~10^8
# accumulated FP operations of margin so genuinely-noisy data is never
# mis-classified.
#
# IMPORTANT: the comparison is purely ``eps^2 <= tol * dy_centered^2`` with
# NO additive floor (e.g. ``max(dy_centered^2, 1.0)`` would break scale
# invariance - scaling dy by 1e-12 would make dy_centered^2 ~ 1e-24 but
# the floor would hold the threshold at 1.0, firing the shortcut on
# noisy data that should not trigger it). The ``dy_centered^2 == 0``
# edge case (constant dy) is handled by a separate branch above the
# relative comparison, so the relative form is only applied when the
# denominator is genuinely positive.
_EXACT_LINEAR_RELATIVE_TOL = 1e-24
# =============================================================================
# Result dataclasses
# =============================================================================
@dataclass
class QUGTestResults:
"""Result of :func:`qug_test` (paper Theorem 4).
The QUG test rejects ``H_0: d_lower = 0`` when the order-statistic
ratio ``T = D_{(1)} / (D_{(2)} - D_{(1)})`` exceeds ``1/alpha - 1``.
Under the null, the asymptotic limit law of ``T`` is the ratio of two
independent Exp(1) random variables, with CDF ``F(t) = t / (1 + t)``,
so ``p_value = 1 / (1 + T)``.
Attributes
----------
t_stat : float
``D_{(1)} / (D_{(2)} - D_{(1)})``. NaN when fewer than 2 non-zero
observations remain or when the two smallest doses tie.
p_value : float
``1 / (1 + t_stat)`` under the null. NaN when ``t_stat`` is NaN.
reject : bool
``True`` iff ``t_stat > critical_value``. ``False`` on NaN statistic.
alpha : float
Significance level used.
critical_value : float
``1 / alpha - 1``. Populated even when the statistic is NaN so
downstream readers can inspect the decision threshold.
n_obs : int
Number of observations after filtering to ``d > 0``.
n_excluded_zero : int
Number of zero-dose observations excluded from the sample.
d_order_1 : float
Smallest positive dose ``D_{(1)}``. NaN when ``n_obs < 2``.
d_order_2 : float
Second-smallest positive dose ``D_{(2)}``. NaN when ``n_obs < 2``.
"""
t_stat: float
p_value: float
reject: bool
alpha: float
critical_value: float
n_obs: int
n_excluded_zero: int
d_order_1: float
d_order_2: float
def __repr__(self) -> str:
return (
f"QUGTestResults(t_stat={self.t_stat:.4f}, p_value={self.p_value:.4f}, "
f"reject={self.reject}, alpha={self.alpha}, n_obs={self.n_obs})"
)
def summary(self) -> str:
"""Formatted summary table."""
width = 64
lines = [
"=" * width,
"QUG null test (H_0: d_lower = 0)".center(width),
"=" * width,
f"{'Statistic T:':<30} {self.t_stat:>20.4f}",
f"{'p-value:':<30} {self.p_value:>20.4f}",
f"{'Critical value (1/alpha-1):':<30} {self.critical_value:>20.4f}",
f"{'Reject H_0:':<30} {str(self.reject):>20}",
f"{'alpha:':<30} {self.alpha:>20.4f}",
f"{'Observations:':<30} {self.n_obs:>20}",
f"{'Excluded (d == 0):':<30} {self.n_excluded_zero:>20}",
f"{'D_(1):':<30} {self.d_order_1:>20.4f}",
f"{'D_(2):':<30} {self.d_order_2:>20.4f}",
"=" * width,
]
return "\n".join(lines)
def print_summary(self) -> None:
"""Print the summary to stdout."""
print(self.summary())
def to_dict(self) -> Dict[str, Any]:
"""Return results as a JSON-safe dict."""
return {
"test": "qug",
"t_stat": _json_safe_scalar(self.t_stat),
"p_value": _json_safe_scalar(self.p_value),
"reject": bool(self.reject),
"alpha": float(self.alpha),
"critical_value": _json_safe_scalar(self.critical_value),
"n_obs": int(self.n_obs),
"n_excluded_zero": int(self.n_excluded_zero),
"d_order_1": _json_safe_scalar(self.d_order_1),
"d_order_2": _json_safe_scalar(self.d_order_2),
}
def to_dataframe(self) -> pd.DataFrame:
"""Return a one-row DataFrame of the result dict."""
return pd.DataFrame([self.to_dict()])
@dataclass
class StuteTestResults:
"""Result of :func:`stute_test` (paper Appendix D).
The Stute test rejects the null that ``E[ΔY | D_2]`` is linear in
``D_2`` (paper Assumption 8) when the sorted-residual CvM statistic
``S = (1/G^2) Σ (Σ_{h=1}^g eps_{(h)})^2`` exceeds the Mammen wild
bootstrap ``1 - alpha`` quantile.
Attributes
----------
cvm_stat : float
CvM statistic. NaN when ``G < 10`` (below the threshold the
statistic is not well-calibrated).
p_value : float
Bootstrap p-value ``(1 + sum(S_b >= S)) / (B + 1)``. NaN when
the statistic is NaN.
reject : bool
``True`` iff ``p_value <= alpha``. ``False`` on NaN.
alpha : float
Significance level used.
n_bootstrap : int
Number of Mammen wild bootstrap replications.
n_obs : int
Number of observations.
seed : int or None
Seed passed to ``np.random.default_rng``. ``None`` when unseeded.
"""
cvm_stat: float
p_value: float
reject: bool
alpha: float
n_bootstrap: int
n_obs: int
seed: Optional[int]
def __repr__(self) -> str:
return (
f"StuteTestResults(cvm_stat={self.cvm_stat:.4f}, "
f"p_value={self.p_value:.4f}, reject={self.reject}, "
f"alpha={self.alpha}, n_bootstrap={self.n_bootstrap}, "
f"n_obs={self.n_obs})"
)
def summary(self) -> str:
"""Formatted summary table."""
width = 64
lines = [
"=" * width,
"Stute CvM linearity test (H_0: linear E[dY|D])".center(width),
"=" * width,
f"{'CvM statistic:':<30} {self.cvm_stat:>20.4f}",
f"{'Bootstrap p-value:':<30} {self.p_value:>20.4f}",
f"{'Reject H_0:':<30} {str(self.reject):>20}",
f"{'alpha:':<30} {self.alpha:>20.4f}",
f"{'Bootstrap replications:':<30} {self.n_bootstrap:>20}",
f"{'Observations:':<30} {self.n_obs:>20}",
f"{'Seed:':<30} {str(self.seed):>20}",
"=" * width,
]
return "\n".join(lines)
def print_summary(self) -> None:
"""Print the summary to stdout."""
print(self.summary())
def to_dict(self) -> Dict[str, Any]:
"""Return results as a JSON-safe dict."""
return {
"test": "stute",
"cvm_stat": _json_safe_scalar(self.cvm_stat),
"p_value": _json_safe_scalar(self.p_value),
"reject": bool(self.reject),
"alpha": float(self.alpha),
"n_bootstrap": int(self.n_bootstrap),
"n_obs": int(self.n_obs),
"seed": None if self.seed is None else int(self.seed),
}
def to_dataframe(self) -> pd.DataFrame:
"""Return a one-row DataFrame of the result dict."""
return pd.DataFrame([self.to_dict()])
@dataclass
class YatchewTestResults:
"""Result of :func:`yatchew_hr_test` (paper Theorem 7 / Equation 29).
Heteroskedasticity-robust specification test using Yatchew's
difference-based variance estimator. Two nulls are supported via
the ``null=`` argument on :func:`yatchew_hr_test` and reflected on
the ``null_form`` attribute below: ``"linearity"`` (default; paper
Theorem 7, the same null as :func:`stute_test`, residuals from OLS
``dy ~ 1 + d``) and ``"mean_independence"`` (R-parity extension
mirroring R ``YatchewTest::yatchew_test(order=0)``, residuals from
intercept-only OLS ``dy ~ 1``). The test statistic
``T_hr = sqrt(G) * (sigma2_lin - sigma2_diff) / sigma2_W`` is
asymptotically N(0, 1) under H_0 in both modes; rejection uses the
one-sided standard-normal critical value. Only the residual
definition (and therefore ``sigma2_lin``) differs between modes —
the ``sigma2_diff`` / ``sigma2_W`` / sort-by-``d`` machinery is
shared.
Attributes
----------
t_stat_hr : float
Test statistic ``T_hr`` from paper Equation 29. NaN when
``G < 3``.
p_value : float
``1 - Phi(T_hr)``. NaN when the statistic is NaN.
reject : bool
``True`` iff ``T_hr >= critical_value``. ``False`` on NaN.
alpha : float
Significance level used.
critical_value : float
One-sided standard-normal critical value ``z_{1 - alpha}``.
sigma2_lin : float
Residual variance under the chosen null. Under
``null_form="linearity"``: residual variance from OLS of ``dy``
on ``d``. Under ``null_form="mean_independence"``: ``(1/G) *
sum((dy - mean(dy))^2)``, the population variance of ``dy``.
sigma2_diff : float
Yatchew differencing variance
``(1 / (2G)) * sum((dy_{(g)} - dy_{(g-1)})^2)`` - divisor is ``2G``
(paper-literal), NOT ``2(G-1)``.
sigma2_W : float
Heteroskedasticity-robust scale
``sqrt((1 / (G-1)) * sum(eps_{(g)}^2 * eps_{(g-1)}^2))``.
n_obs : int
Number of observations.
null_form : str
``"linearity"`` (default; H_0: ``E[dY|D]`` is linear in ``D``,
residuals from OLS ``dy ~ 1 + d``) or ``"mean_independence"``
(H_0: ``E[dY|D] = E[dY]``, residuals from intercept-only OLS
``dy ~ 1``). Mirrors R ``YatchewTest::yatchew_test``'s
``order`` argument (``order=1`` ↔ ``"linearity"``; ``order=0``
↔ ``"mean_independence"``).
"""
t_stat_hr: float
p_value: float
reject: bool
alpha: float
critical_value: float
sigma2_lin: float
sigma2_diff: float
sigma2_W: float
n_obs: int
null_form: str = "linearity"
def __repr__(self) -> str:
return (
f"YatchewTestResults(t_stat_hr={self.t_stat_hr:.4f}, "
f"p_value={self.p_value:.4f}, reject={self.reject}, "
f"alpha={self.alpha}, null_form={self.null_form!r}, "
f"n_obs={self.n_obs})"
)
def summary(self) -> str:
"""Formatted summary table."""
width = 64
title = {
"linearity": "Yatchew-HR linearity test (H_0: linear E[dY|D])",
"mean_independence": ("Yatchew-HR mean-independence test (H_0: E[dY|D] = E[dY])"),
}.get(self.null_form, f"Yatchew-HR test (null_form={self.null_form!r})")
lines = [
"=" * width,
title.center(width),
"=" * width,
f"{'T_hr statistic:':<30} {self.t_stat_hr:>20.4f}",
f"{'p-value:':<30} {self.p_value:>20.4f}",
f"{'Critical value (1-sided z):':<30} {self.critical_value:>20.4f}",
f"{'Reject H_0:':<30} {str(self.reject):>20}",
f"{'alpha:':<30} {self.alpha:>20.4f}",
f"{'sigma^2_lin (OLS):':<30} {self.sigma2_lin:>20.4f}",
f"{'sigma^2_diff (Yatchew):':<30} {self.sigma2_diff:>20.4f}",
f"{'sigma^2_W (HR scale):':<30} {self.sigma2_W:>20.4f}",
f"{'Observations:':<30} {self.n_obs:>20}",
"=" * width,
]
return "\n".join(lines)
def print_summary(self) -> None:
"""Print the summary to stdout."""
print(self.summary())
def to_dict(self) -> Dict[str, Any]:
"""Return results as a JSON-safe dict."""
return {
"test": "yatchew_hr",
"t_stat_hr": _json_safe_scalar(self.t_stat_hr),
"p_value": _json_safe_scalar(self.p_value),
"reject": bool(self.reject),
"alpha": float(self.alpha),
"critical_value": _json_safe_scalar(self.critical_value),
"sigma2_lin": _json_safe_scalar(self.sigma2_lin),
"sigma2_diff": _json_safe_scalar(self.sigma2_diff),
"sigma2_W": _json_safe_scalar(self.sigma2_W),
"n_obs": int(self.n_obs),
"null_form": str(self.null_form),
}
def to_dataframe(self) -> pd.DataFrame:
"""Return a one-row DataFrame of the result dict."""
return pd.DataFrame([self.to_dict()])
@dataclass
class StuteJointResult:
"""Result of :func:`stute_joint_pretest` (joint Cramer-von Mises across horizons).
Aggregates the per-horizon Stute (1997) CvM statistic into a joint
specification test: ``S_joint = sum_k S_k``, where ``S_k`` is the
single-horizon CvM on residuals ``eps_{g,k}``. Inference is via
Mammen (1993) wild bootstrap with a **shared** multiplier ``eta_g``
across horizons per unit (Delgado-Manteiga 2001; Hlavka-Huskova 2020)
to preserve the unit-level dependence structure of the vector-valued
empirical process.
Two nulls are supported via the thin wrappers
:func:`joint_pretrends_test` (mean-independence: ``E[Y_t - Y_base | D]
= mu_t``, design matrix ``[1]``) and :func:`joint_homogeneity_test`
(linearity: ``E[Y_t - Y_base | D_t] = beta_{0,t} + beta_{fe,t} * D``,
design matrix ``[1, D]``). Both wrappers accept a ``trends_lin:
bool = False`` keyword-only flag (PR #392): when ``True``, applies
paper Eq 17 / Eq 18 linear-trend detrending before the joint CvM
using per-group slope ``Y[g, F-1] - Y[g, F-2]``.
Attributes
----------
cvm_stat_joint : float
Joint statistic ``S_joint = sum_k S_k``. NaN on NaN-propagation.
p_value : float
Bootstrap p-value ``(1 + sum(S*_b >= S_joint)) / (B + 1)``. NaN
when the statistic is NaN. ``1.0`` when the per-horizon exact-
linear short-circuit fires (all horizons machine-exact linear).
reject : bool
``True`` iff ``p_value <= alpha``. Always ``False`` on NaN.
alpha : float
Significance level.
horizon_labels : list of str
Horizon identifiers as ``str(t)`` for each period. **String
identity only** - NOT a chronological ordering key. Callers who
need chronological order should preserve the original period
values alongside (a downstream plotter sorting labels
lexicographically will misorder e.g.
``["2003-Q10", "2003-Q2", ...]``).
per_horizon_stats : dict[str, float]
``{label: S_k}`` diagnostic. Per-horizon p-values are NOT
exposed (decomposing the joint bootstrap into K independent
loops is a K-fold memory/time cost; deferred). Callers who need
per-horizon p-values can call :func:`stute_test` separately on
each (period, residual) pair.
On NaN-propagation (any horizon has NaN input), this dict is
preserved with ``{label: np.nan for label in horizon_labels}``,
NOT an empty dict, NOT a partial dict: the keys carry diagnostic
value (which horizons were attempted), the NaN values signal
non-propagation.
n_bootstrap : int
n_obs : int
Number of units ``G``.
n_horizons : int
seed : int or None
null_form : str
``"mean_independence"`` (from :func:`joint_pretrends_test`) or
``"linearity"`` (from :func:`joint_homogeneity_test`).
``"custom"`` when called directly via :func:`stute_joint_pretest`
without a wrapper.
exact_linear_short_circuited : bool
``True`` when every horizon's residual SSR to centered TSS ratio
is below :data:`_EXACT_LINEAR_RELATIVE_TOL`; bootstrap is
skipped and ``p_value = 1.0``. The per-horizon check ensures a
single degenerate horizon does not collapse the joint test when
other horizons have nontrivial residuals.
"""
cvm_stat_joint: float
p_value: float
reject: bool
alpha: float
horizon_labels: list
per_horizon_stats: Dict[str, float]
n_bootstrap: int
n_obs: int
n_horizons: int
seed: Optional[int]
null_form: str
exact_linear_short_circuited: bool
def __repr__(self) -> str:
return (
f"StuteJointResult(cvm_stat_joint={self.cvm_stat_joint:.4f}, "
f"p_value={self.p_value:.4f}, reject={self.reject}, "
f"n_horizons={self.n_horizons}, null_form={self.null_form!r}, "
f"n_obs={self.n_obs})"
)
def summary(self) -> str:
"""Formatted summary table."""
width = 64
per_horizon_lines = [
f" {label:<20} {stat:>20.4f}" for label, stat in self.per_horizon_stats.items()
]
null_label = {
"mean_independence": "mean-independence (pre-trends)",
"linearity": "linearity (post-homogeneity)",
}.get(self.null_form, self.null_form)
lines = [
"=" * width,
f"Joint Stute CvM test ({null_label})".center(width),
"=" * width,
f"{'Joint CvM statistic:':<30} {self.cvm_stat_joint:>20.4f}",
f"{'Bootstrap p-value:':<30} {self.p_value:>20.4f}",
f"{'Reject H_0:':<30} {str(self.reject):>20}",
f"{'alpha:':<30} {self.alpha:>20.4f}",
f"{'Bootstrap replications:':<30} {self.n_bootstrap:>20}",
f"{'Horizons:':<30} {self.n_horizons:>20}",
f"{'Observations:':<30} {self.n_obs:>20}",
f"{'Seed:':<30} {str(self.seed):>20}",
f"{'Exact-linear short-circuit:':<30} " f"{str(self.exact_linear_short_circuited):>20}",
"-" * width,
"Per-horizon statistics:",
*per_horizon_lines,
"=" * width,
]
return "\n".join(lines)
def print_summary(self) -> None:
"""Print the summary to stdout."""
print(self.summary())
def to_dict(self) -> Dict[str, Any]:
"""Return results as a JSON-safe dict."""
return {
"test": "stute_joint",
"cvm_stat_joint": _json_safe_scalar(self.cvm_stat_joint),
"p_value": _json_safe_scalar(self.p_value),
"reject": bool(self.reject),
"alpha": float(self.alpha),
"horizon_labels": [str(label) for label in self.horizon_labels],
"per_horizon_stats": {
str(k): _json_safe_scalar(v) for k, v in self.per_horizon_stats.items()
},
"n_bootstrap": int(self.n_bootstrap),
"n_obs": int(self.n_obs),
"n_horizons": int(self.n_horizons),
"seed": None if self.seed is None else int(self.seed),
"null_form": str(self.null_form),
"exact_linear_short_circuited": bool(self.exact_linear_short_circuited),
}
def to_dataframe(self) -> pd.DataFrame:
"""Return a one-row DataFrame of the top-level result fields."""
return pd.DataFrame(
[
{
"test": "stute_joint",
"cvm_stat_joint": _json_safe_scalar(self.cvm_stat_joint),
"p_value": _json_safe_scalar(self.p_value),
"reject": bool(self.reject),
"alpha": float(self.alpha),
"n_bootstrap": int(self.n_bootstrap),
"n_obs": int(self.n_obs),
"n_horizons": int(self.n_horizons),
"null_form": str(self.null_form),
}
]
)
@dataclass
class HADPretestReport:
"""Composite output of :func:`did_had_pretest_workflow`.
Two dispatch shapes, distinguished by :attr:`aggregate`:
``aggregate="overall"`` (default, two-period panel): bundles paper
steps 1 (QUG) and 3 (linearity via Stute + Yatchew-HR) on a
two-period first-differenced sample. Step 2 (Assumption 7 pre-trends)
is NOT implemented on this path and is explicitly flagged in the
verdict; callers must run pre-trends separately.
``aggregate="event_study"`` (multi-period panel, >= 3 periods):
bundles QUG + joint pre-trends Stute + joint homogeneity-linearity
Stute. The joint Stute variants close the paper step-2 gap; the
event-study verdict does NOT emit the "paper step 2 deferred"
caveat. Step 3 adjudication uses joint Stute only - no joint Yatchew
variant exists because the paper does not derive one; users who need
Yatchew robustness under multi-period data can run
:func:`yatchew_hr_test` on each (base, post) pair manually.
Attributes
----------
qug : QUGTestResults or None
Populated by default; ``None`` only when the workflow runs under
``survey=`` / ``weights=`` (Phase 4.5 C path), where the QUG step
is permanently skipped per Phase 4.5 C0 (extreme-value theory under
complex sampling not a settled toolkit; see :func:`qug_test`).
stute : StuteTestResults or None
Populated when ``aggregate == "overall"``; ``None`` when
``aggregate == "event_study"``.
yatchew : YatchewTestResults or None
Populated when ``aggregate == "overall"``; ``None`` when
``aggregate == "event_study"``.
pretrends_joint : StuteJointResult or None
Populated when ``aggregate == "event_study"`` and at least one
earlier pre-period exists; ``None`` on the overall path or when
only the immediate base pre-period is available.
homogeneity_joint : StuteJointResult or None
Populated when ``aggregate == "event_study"``; ``None`` on the
overall path.
all_pass : bool
On the **unweighted overall path**: same Phase 3 semantics - True
iff QUG is conclusive AND at least one of Stute/Yatchew is
conclusive AND no conclusive test rejects. On the **unweighted
event-study path**: True iff ``np.isfinite(qug.p_value)``,
``pretrends_joint is not None and
np.isfinite(pretrends_joint.p_value)``,
``np.isfinite(homogeneity_joint.p_value)``, AND none of the
three rejects. On the **survey/weights path** (Phase 4.5 C) the
QUG-conclusiveness gate is dropped (``qug=None`` per C0
deferral); the linearity-conditional rule splits by aggregate:
- ``aggregate="overall"`` survey: True iff at least one of
Stute/Yatchew is conclusive AND no conclusive test rejects.
- ``aggregate="event_study"`` survey: True iff
``pretrends_joint`` is non-None and conclusive,
``homogeneity_joint`` is conclusive, AND neither rejects.
(Both joint variants must be conclusive on the event-study
path - same step-2 + step-3 closure as the unweighted
aggregate, just without the QUG step.)
Mirrors Phase 3's ``bool(np.isfinite(p_value))`` convention - no
``.conclusive()`` helper on any result dataclass.
verdict : str
Human-readable classification. Paper rule applies symmetrically:
TWFE is admissible only if NONE of the implemented tests
rejects. Conclusive rejections are the primary verdict;
unresolved steps append as ``"; additional steps unresolved:
..."`` rather than replacing the rejection.
alpha : float
n_obs : int
Unit count. For overall: units after two-period first-difference
aggregation. For event_study: units after balanced-panel
validation and (if applicable) last-cohort auto-filter.
aggregate : str
``"overall"`` or ``"event_study"``. Determines which component
fields are populated and which branch of serialization methods
to render.
"""
qug: Optional[QUGTestResults]
stute: Optional[StuteTestResults]
yatchew: Optional[YatchewTestResults]
all_pass: bool
verdict: str
alpha: float
n_obs: int
pretrends_joint: Optional[StuteJointResult] = None
homogeneity_joint: Optional[StuteJointResult] = None
aggregate: str = "overall"
def __repr__(self) -> str:
# Preserve Phase 3 repr bit-exactly on the overall path. The
# aggregate kwarg is only surfaced on the event-study path so
# downstream consumers comparing repr strings on two-period
# reports see identical output.
if self.aggregate == "event_study":
return (
f"HADPretestReport(aggregate={self.aggregate!r}, "
f"all_pass={self.all_pass}, "
f"verdict={self.verdict!r}, n_obs={self.n_obs})"
)
return (
f"HADPretestReport(all_pass={self.all_pass}, "
f"verdict={self.verdict!r}, n_obs={self.n_obs})"
)
def summary(self) -> str:
"""Formatted summary of all tests and the verdict."""
width = 72
# Preserve Phase 3 summary bit-exactly on the overall path. The
# `aggregate: ...` header line is only rendered on the event-
# study path; two-period reports produce the Phase 3 layout.
# QUG block: rendered when self.qug is populated, else a skip note
# (Phase 4.5 C survey/weights path leaves qug=None; see C0 deferral).
qug_block = (
self.qug.summary()
if self.qug is not None
else "(QUG step skipped - permanently deferred under survey/weights per Phase 4.5 C0)"
)
if self.aggregate == "event_study":
header = [
"=" * width,
"HAD pre-test workflow".center(width),
f"aggregate: {self.aggregate}".center(width),
"=" * width,
qug_block,
"",
]
if self.pretrends_joint is not None:
body = [self.pretrends_joint.summary(), ""]
else:
body = [
"(joint pre-trends skipped - no earlier pre-period)",
"",
]
if self.homogeneity_joint is not None:
body += [self.homogeneity_joint.summary(), ""]
else:
# aggregate == "overall" - Phase 3 layout preserved when qug is
# not None (unweighted path); QUG-skip block on the survey path.
header = [
"=" * width,
"HAD pre-test workflow".center(width),
"=" * width,
qug_block,
"",
]
body = []
if self.stute is not None:
body += [self.stute.summary(), ""]
if self.yatchew is not None:
body += [self.yatchew.summary(), ""]
footer = [
"=" * width,
f"{'All pass:':<30} {str(self.all_pass):>40}",
f"Verdict: {self.verdict}",
"=" * width,
]
return "\n".join(header + body + footer)
def print_summary(self) -> None:
"""Print the summary to stdout."""
print(self.summary())
def to_dict(self) -> Dict[str, Any]:
"""Return a JSON-safe nested dict of the full report.
On ``aggregate="overall"``, the output schema is bit-exact with
Phase 3 (``{qug, stute, yatchew, all_pass, verdict, alpha,
n_obs}``) - no new keys, no aggregate field. On
``aggregate="event_study"``, the output carries ``aggregate``,
``pretrends_joint``, ``homogeneity_joint`` and omits the
``None``-valued ``stute`` / ``yatchew`` keys entirely.
"""
# qug serializes as None on the survey/weights path (Phase 4.5 C
# QUG-skip per C0 deferral); rendered as the existing dict on the
# default unweighted path.
qug_dict = None if self.qug is None else self.qug.to_dict()
if self.aggregate == "event_study":
return {
"aggregate": str(self.aggregate),
"qug": qug_dict,
"pretrends_joint": (
None if self.pretrends_joint is None else self.pretrends_joint.to_dict()
),
"homogeneity_joint": (
None if self.homogeneity_joint is None else self.homogeneity_joint.to_dict()
),
"all_pass": bool(self.all_pass),
"verdict": str(self.verdict),
"alpha": float(self.alpha),
"n_obs": int(self.n_obs),
}
# aggregate == "overall" - Phase 3 schema preserved bit-exactly on
# the unweighted path (qug populated); the qug=None survey path
# surfaces qug: null.
return {
"qug": qug_dict,
"stute": None if self.stute is None else self.stute.to_dict(),
"yatchew": None if self.yatchew is None else self.yatchew.to_dict(),
"all_pass": bool(self.all_pass),
"verdict": str(self.verdict),
"alpha": float(self.alpha),
"n_obs": int(self.n_obs),
}
def to_dataframe(self) -> pd.DataFrame:
"""Return a tidy 3-row DataFrame (one row per implemented test).
Columns (stable across aggregates):
``[test, statistic_name, statistic_value, p_value, reject, alpha,
n_obs]``. Row identifiers vary by aggregate:
- ``aggregate="overall"``: rows are ``qug``, ``stute``,
``yatchew_hr`` (Phase 3 schema, unchanged).
- ``aggregate="event_study"``: rows are ``qug``,
``pretrends_joint``, ``homogeneity_joint``.
Rows for ``None``-valued components (e.g. ``pretrends_joint`` when
no earlier pre-period exists) are emitted with NaN statistic
values and ``reject=False`` to preserve the 3-row shape.
"""
# qug row: NaN-skip when self.qug is None (Phase 4.5 C survey/weights
# path leaves qug=None per C0 deferral). Mirrors the joint NaN-row
# shape from `_joint_row_or_nan` so the 3-row contract is preserved.
if self.qug is None:
qug_row = {
"test": "qug",
"statistic_name": "t_stat",
"statistic_value": float("nan"),
"p_value": float("nan"),
"reject": False,
"alpha": float(self.alpha),
"n_obs": int(self.n_obs),
}
else:
qug_row = {
"test": "qug",
"statistic_name": "t_stat",
"statistic_value": _json_safe_scalar(self.qug.t_stat),
"p_value": _json_safe_scalar(self.qug.p_value),
"reject": bool(self.qug.reject),
"alpha": float(self.qug.alpha),
"n_obs": int(self.qug.n_obs),
}
if self.aggregate == "event_study":
pre_row = self._joint_row_or_nan("pretrends_joint", self.pretrends_joint)
hom_row = self._joint_row_or_nan("homogeneity_joint", self.homogeneity_joint)
rows = [qug_row, pre_row, hom_row]
else:
stute_row = (
{
"test": "stute",
"statistic_name": "cvm_stat",
"statistic_value": _json_safe_scalar(self.stute.cvm_stat),
"p_value": _json_safe_scalar(self.stute.p_value),
"reject": bool(self.stute.reject),
"alpha": float(self.stute.alpha),
"n_obs": int(self.stute.n_obs),
}
if self.stute is not None
else {
"test": "stute",
"statistic_name": "cvm_stat",
"statistic_value": float("nan"),
"p_value": float("nan"),
"reject": False,
"alpha": float(self.alpha),
"n_obs": int(self.n_obs),
}
)
yatchew_row = (
{
"test": "yatchew_hr",
"statistic_name": "t_stat_hr",
"statistic_value": _json_safe_scalar(self.yatchew.t_stat_hr),
"p_value": _json_safe_scalar(self.yatchew.p_value),
"reject": bool(self.yatchew.reject),
"alpha": float(self.yatchew.alpha),
"n_obs": int(self.yatchew.n_obs),
}
if self.yatchew is not None
else {
"test": "yatchew_hr",
"statistic_name": "t_stat_hr",
"statistic_value": float("nan"),
"p_value": float("nan"),
"reject": False,
"alpha": float(self.alpha),
"n_obs": int(self.n_obs),
}
)
rows = [qug_row, stute_row, yatchew_row]
cols = [
"test",
"statistic_name",
"statistic_value",
"p_value",
"reject",
"alpha",
"n_obs",
]
return pd.DataFrame(rows).reindex(columns=cols)
def _joint_row_or_nan(
self, test_label: str, joint: Optional[StuteJointResult]
) -> Dict[str, Any]:
"""Build a to_dataframe row for a joint-Stute component.
When ``joint`` is ``None`` (e.g. pretrends_joint skipped because
no earlier pre-period), emit a NaN row preserving the 3-row
shape for downstream plotting.
"""
if joint is None:
return {
"test": test_label,
"statistic_name": "cvm_stat_joint",
"statistic_value": float("nan"),
"p_value": float("nan"),
"reject": False,
"alpha": float(self.alpha),
"n_obs": int(self.n_obs),
}
return {
"test": test_label,
"statistic_name": "cvm_stat_joint",
"statistic_value": _json_safe_scalar(joint.cvm_stat_joint),
"p_value": _json_safe_scalar(joint.p_value),
"reject": bool(joint.reject),
"alpha": float(joint.alpha),
"n_obs": int(joint.n_obs),
}
# =============================================================================
# Private helpers
# =============================================================================
def _validate_1d_numeric(arr: np.ndarray, name: str) -> np.ndarray:
"""Return ``arr`` as a 1D float ndarray or raise ``ValueError``."""
a = np.asarray(arr)
if a.ndim != 1:
raise ValueError(f"{name} must be 1-dimensional, got shape {a.shape}.")
a = a.astype(np.float64, copy=False)
if np.isnan(a).any():
raise ValueError(f"{name} contains NaN values.")
if not np.isfinite(a).all():
raise ValueError(f"{name} contains non-finite values (inf).")
return a
def _fit_ols_intercept_slope(d: np.ndarray, dy: np.ndarray) -> "tuple[float, float, np.ndarray]":
"""Fit ``dy = a + b*d + eps`` via closed-form OLS.
Returns ``(a_hat, b_hat, residuals)`` where ``residuals`` has the
same length as ``d`` in the ORIGINAL input order (not sorted).
"""
d_mean = d.mean()
dy_mean = dy.mean()
d_dev = d - d_mean
var_d = np.dot(d_dev, d_dev)
if var_d <= 0.0:
# Degenerate case: all dose values equal. Slope undefined.
# Caller is responsible for gating before we reach here; if we
# do reach here, return (mean(dy), 0, dy - mean(dy)).
return float(dy_mean), 0.0, dy - dy_mean
b_hat = float(np.dot(d_dev, dy - dy_mean) / var_d)
a_hat = float(dy_mean - b_hat * d_mean)
residuals = dy - a_hat - b_hat * d
return a_hat, b_hat, residuals
def _fit_weighted_ols_intercept_slope(
d: np.ndarray, dy: np.ndarray, w: np.ndarray
) -> "tuple[float, float, np.ndarray]":
"""Weighted OLS analog of :func:`_fit_ols_intercept_slope`.
Solves the weighted normal equations for ``dy = a + b*d + eps`` where
each observation has weight ``w_g``. Returns ``(a_hat, b_hat,
residuals)`` with ``residuals`` in the ORIGINAL input order (not
sorted) and on the un-weighted scale (``residuals = dy - a_hat - b_hat * d``,