-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathutils.py
More file actions
1902 lines (1593 loc) · 61.4 KB
/
utils.py
File metadata and controls
1902 lines (1593 loc) · 61.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Utility functions for difference-in-differences estimation.
"""
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from scipy import stats
from diff_diff.linalg import compute_robust_vcov as _compute_robust_vcov_linalg
from diff_diff.linalg import solve_ols as _solve_ols_linalg
# Import Rust backend if available (from _backend to avoid circular imports)
from diff_diff._backend import (
HAS_RUST_BACKEND,
_rust_project_simplex,
_rust_synthetic_weights,
_rust_sdid_unit_weights,
_rust_compute_time_weights,
_rust_compute_noise_level,
_rust_sc_weight_fw,
)
# Numerical constants for optimization algorithms
_OPTIMIZATION_MAX_ITER = 1000 # Maximum iterations for weight optimization
_OPTIMIZATION_TOL = 1e-8 # Convergence tolerance for optimization
_NUMERICAL_EPS = 1e-10 # Small constant to prevent division by zero
# Cache for critical values to avoid repeated scipy calls
_critical_value_cache: Dict[Tuple[float, Optional[int]], float] = {}
def _get_critical_value(alpha: float, df: Optional[int] = None) -> float:
"""Return cached critical value for (alpha, df) pair."""
key = (alpha, df)
if key not in _critical_value_cache:
if df is not None:
_critical_value_cache[key] = float(stats.t.ppf(1 - alpha / 2, df))
else:
_critical_value_cache[key] = float(stats.norm.ppf(1 - alpha / 2))
return _critical_value_cache[key]
def validate_binary(arr: np.ndarray, name: str) -> None:
"""
Validate that an array contains only binary values (0 or 1).
Parameters
----------
arr : np.ndarray
Array to validate.
name : str
Name of the variable (for error messages).
Raises
------
ValueError
If array contains non-binary values.
"""
unique_values = np.unique(arr[~np.isnan(arr)])
if not np.all(np.isin(unique_values, [0, 1])):
raise ValueError(f"{name} must be binary (0 or 1). " f"Found values: {unique_values}")
def compute_robust_se(
X: np.ndarray, residuals: np.ndarray, cluster_ids: Optional[np.ndarray] = None
) -> np.ndarray:
"""
Compute heteroskedasticity-robust (HC1) or cluster-robust standard errors.
This function is a thin wrapper around the optimized implementation in
diff_diff.linalg for backwards compatibility.
Parameters
----------
X : np.ndarray
Design matrix of shape (n, k).
residuals : np.ndarray
Residuals from regression of shape (n,).
cluster_ids : np.ndarray, optional
Cluster identifiers for cluster-robust SEs.
Returns
-------
np.ndarray
Variance-covariance matrix of shape (k, k).
"""
return _compute_robust_vcov_linalg(X, residuals, cluster_ids)
def compute_confidence_interval(
estimate: float, se: float, alpha: float = 0.05, df: Optional[int] = None
) -> Tuple[float, float]:
"""
Compute confidence interval for an estimate.
Parameters
----------
estimate : float
Point estimate.
se : float
Standard error.
alpha : float
Significance level (default 0.05 for 95% CI).
df : int, optional
Degrees of freedom. If None, uses normal distribution.
Returns
-------
tuple
(lower_bound, upper_bound) of confidence interval.
"""
critical_value = _get_critical_value(alpha, df)
lower = estimate - critical_value * se
upper = estimate + critical_value * se
return (lower, upper)
def compute_p_value(t_stat: float, df: Optional[int] = None, two_sided: bool = True) -> float:
"""
Compute p-value for a t-statistic.
Parameters
----------
t_stat : float
T-statistic.
df : int, optional
Degrees of freedom. If None, uses normal distribution.
two_sided : bool
Whether to compute two-sided p-value (default True).
Returns
-------
float
P-value.
"""
if df is not None:
p_value = stats.t.sf(np.abs(t_stat), df)
else:
p_value = stats.norm.sf(np.abs(t_stat))
if two_sided:
p_value *= 2
return float(p_value)
def safe_inference(effect, se, alpha=0.05, df=None):
"""Compute t_stat, p_value, conf_int with NaN-safe gating.
When SE is non-finite, zero, or negative, ALL inference fields
are set to NaN to prevent misleading statistical output.
Accepts scalar inputs only (not numpy arrays). All existing inference
call sites operate on scalars within loops.
Parameters
----------
effect : float
Point estimate (treatment effect or coefficient).
se : float
Standard error of the estimate.
alpha : float, optional
Significance level for confidence interval (default 0.05).
df : int, optional
Degrees of freedom. If None, uses normal distribution.
Returns
-------
tuple
(t_stat, p_value, (ci_lower, ci_upper)). All NaN when SE is
non-finite, zero, or negative.
"""
if not (np.isfinite(se) and se > 0):
return np.nan, np.nan, (np.nan, np.nan)
if df is not None and df <= 0:
# Undefined degrees of freedom (e.g., rank-deficient replicate design)
return np.nan, np.nan, (np.nan, np.nan)
t_stat = effect / se
p_value = compute_p_value(t_stat, df=df)
conf_int = compute_confidence_interval(effect, se, alpha, df=df)
return t_stat, p_value, conf_int
def safe_inference_batch(effects, ses, alpha=0.05, df=None):
"""Vectorized batch inference for arrays of effects and SEs.
Parameters
----------
effects : np.ndarray
Array of point estimates.
ses : np.ndarray
Array of standard errors.
alpha : float, optional
Significance level (default 0.05).
df : int, optional
Degrees of freedom. If None, uses normal distribution.
Returns
-------
t_stats : np.ndarray
p_values : np.ndarray
ci_lowers : np.ndarray
ci_uppers : np.ndarray
"""
effects = np.asarray(effects, dtype=float)
ses = np.asarray(ses, dtype=float)
n = len(effects)
t_stats = np.full(n, np.nan)
p_values = np.full(n, np.nan)
ci_lowers = np.full(n, np.nan)
ci_uppers = np.full(n, np.nan)
# Undefined df (e.g., rank-deficient replicate design) → all NaN
if df is not None and df <= 0:
return t_stats, p_values, ci_lowers, ci_uppers
valid = np.isfinite(ses) & (ses > 0)
if not np.any(valid):
return t_stats, p_values, ci_lowers, ci_uppers
t_stats[valid] = effects[valid] / ses[valid]
if df is not None:
p_values[valid] = 2.0 * stats.t.sf(np.abs(t_stats[valid]), df)
else:
p_values[valid] = 2.0 * stats.norm.sf(np.abs(t_stats[valid]))
crit = _get_critical_value(alpha, df)
ci_lowers[valid] = effects[valid] - crit * ses[valid]
ci_uppers[valid] = effects[valid] + crit * ses[valid]
return t_stats, p_values, ci_lowers, ci_uppers
# =============================================================================
# Wild Cluster Bootstrap
# =============================================================================
@dataclass
class WildBootstrapResults:
"""
Results from wild cluster bootstrap inference.
Attributes
----------
se : float
Bootstrap standard error of the coefficient.
p_value : float
Bootstrap p-value (two-sided).
t_stat_original : float
Original t-statistic from the data.
ci_lower : float
Lower bound of the confidence interval.
ci_upper : float
Upper bound of the confidence interval.
n_clusters : int
Number of clusters in the data.
n_bootstrap : int
Number of bootstrap replications.
weight_type : str
Type of bootstrap weights used ("rademacher", "webb", or "mammen").
alpha : float
Significance level used for confidence interval.
bootstrap_distribution : np.ndarray, optional
Full bootstrap distribution of coefficients (if requested).
References
----------
Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
Bootstrap-Based Improvements for Inference with Clustered Errors.
The Review of Economics and Statistics, 90(3), 414-427.
"""
se: float
p_value: float
t_stat_original: float
ci_lower: float
ci_upper: float
n_clusters: int
n_bootstrap: int
weight_type: str
alpha: float = 0.05
bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
def summary(self) -> str:
"""Generate formatted summary of bootstrap results."""
lines = [
"Wild Cluster Bootstrap Results",
"=" * 40,
f"Bootstrap SE: {self.se:.6f}",
f"Bootstrap p-value: {self.p_value:.4f}",
f"Original t-stat: {self.t_stat_original:.4f}",
f"CI ({int((1-self.alpha)*100)}%): [{self.ci_lower:.6f}, {self.ci_upper:.6f}]",
f"Number of clusters: {self.n_clusters}",
f"Bootstrap reps: {self.n_bootstrap}",
f"Weight type: {self.weight_type}",
]
return "\n".join(lines)
def print_summary(self) -> None:
"""Print formatted summary to stdout."""
print(self.summary())
def _generate_rademacher_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
"""
Generate Rademacher weights: +1 or -1 with probability 0.5.
Parameters
----------
n_clusters : int
Number of clusters.
rng : np.random.Generator
Random number generator.
Returns
-------
np.ndarray
Array of Rademacher weights.
"""
return np.asarray(rng.choice([-1.0, 1.0], size=n_clusters))
def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
"""
Generate Webb's 6-point distribution weights.
Values: {-sqrt(3/2), -sqrt(2/2), -sqrt(1/2), sqrt(1/2), sqrt(2/2), sqrt(3/2)}
with equal probabilities (1/6 each), giving E[w]=0 and Var(w)=1.0.
This distribution is recommended for very few clusters (G < 10) as it
provides better finite-sample properties than Rademacher weights.
Parameters
----------
n_clusters : int
Number of clusters.
rng : np.random.Generator
Random number generator.
Returns
-------
np.ndarray
Array of Webb weights.
References
----------
Webb, M. D. (2014). Reworking wild bootstrap based inference for
clustered errors. Queen's Economics Department Working Paper No. 1315.
Note: Uses equal probabilities (1/6 each) matching R's `did` package,
which gives unit variance for consistency with other weight distributions.
"""
values = np.array(
[
-np.sqrt(3 / 2),
-np.sqrt(2 / 2),
-np.sqrt(1 / 2),
np.sqrt(1 / 2),
np.sqrt(2 / 2),
np.sqrt(3 / 2),
]
)
# Equal probabilities (1/6 each) matching R's did package, giving Var(w) = 1.0
return np.asarray(rng.choice(values, size=n_clusters))
def _generate_mammen_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray:
"""
Generate Mammen's two-point distribution weights.
Values: {-(sqrt(5)-1)/2, (sqrt(5)+1)/2}
with probabilities {(sqrt(5)+1)/(2*sqrt(5)), (sqrt(5)-1)/(2*sqrt(5))}.
This distribution satisfies E[v]=0, E[v^2]=1, E[v^3]=1, which provides
asymptotic refinement for skewed error distributions.
Parameters
----------
n_clusters : int
Number of clusters.
rng : np.random.Generator
Random number generator.
Returns
-------
np.ndarray
Array of Mammen weights.
References
----------
Mammen, E. (1993). Bootstrap and Wild Bootstrap for High Dimensional
Linear Models. The Annals of Statistics, 21(1), 255-285.
"""
sqrt5 = np.sqrt(5)
# Values from Mammen (1993)
val1 = -(sqrt5 - 1) / 2 # approximately -0.618
val2 = (sqrt5 + 1) / 2 # approximately 1.618 (golden ratio)
# Probability of val1
p1 = (sqrt5 + 1) / (2 * sqrt5) # approximately 0.724
return np.asarray(rng.choice([val1, val2], size=n_clusters, p=[p1, 1 - p1]))
def wild_bootstrap_se(
X: np.ndarray,
y: np.ndarray,
residuals: np.ndarray,
cluster_ids: np.ndarray,
coefficient_index: int,
n_bootstrap: int = 999,
weight_type: str = "rademacher",
null_hypothesis: float = 0.0,
alpha: float = 0.05,
seed: Optional[int] = None,
return_distribution: bool = False,
) -> WildBootstrapResults:
"""
Compute wild cluster bootstrap standard errors and p-values.
Implements the Wild Cluster Residual (WCR) bootstrap procedure from
Cameron, Gelbach, and Miller (2008). Uses the restricted residuals
approach (imposing H0: coefficient = null_hypothesis) for more accurate
p-value computation.
Parameters
----------
X : np.ndarray
Design matrix of shape (n, k).
y : np.ndarray
Outcome vector of shape (n,).
residuals : np.ndarray
OLS residuals from unrestricted regression, shape (n,).
cluster_ids : np.ndarray
Cluster identifiers of shape (n,).
coefficient_index : int
Index of the coefficient for which to compute bootstrap inference.
For DiD, this is typically 3 (the treatment*post interaction term).
n_bootstrap : int, default=999
Number of bootstrap replications. Odd numbers are recommended for
exact p-value computation.
weight_type : str, default="rademacher"
Type of bootstrap weights:
- "rademacher": +1 or -1 with equal probability (standard choice)
- "webb": 6-point distribution (recommended for <10 clusters)
- "mammen": Two-point distribution with skewness correction
null_hypothesis : float, default=0.0
Value of the null hypothesis for p-value computation.
alpha : float, default=0.05
Significance level for confidence interval.
seed : int, optional
Random seed for reproducibility. If None (default), results
will vary between runs.
return_distribution : bool, default=False
If True, include full bootstrap distribution in results.
Returns
-------
WildBootstrapResults
Dataclass containing bootstrap SE, p-value, confidence interval,
and other inference results.
Raises
------
ValueError
If weight_type is not recognized or if there are fewer than 2 clusters.
Warns
-----
UserWarning
If the number of clusters is less than 5, as bootstrap inference
may be unreliable.
Examples
--------
>>> from diff_diff.utils import wild_bootstrap_se
>>> results = wild_bootstrap_se(
... X, y, residuals, cluster_ids,
... coefficient_index=3, # ATT coefficient
... n_bootstrap=999,
... weight_type="rademacher",
... seed=42
... )
>>> print(f"Bootstrap SE: {results.se:.4f}")
>>> print(f"Bootstrap p-value: {results.p_value:.4f}")
References
----------
Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008).
Bootstrap-Based Improvements for Inference with Clustered Errors.
The Review of Economics and Statistics, 90(3), 414-427.
MacKinnon, J. G., & Webb, M. D. (2018). The wild bootstrap for
few (treated) clusters. The Econometrics Journal, 21(2), 114-135.
"""
# Validate inputs
valid_weight_types = ["rademacher", "webb", "mammen"]
if weight_type not in valid_weight_types:
raise ValueError(f"weight_type must be one of {valid_weight_types}, got '{weight_type}'")
unique_clusters = np.unique(cluster_ids)
n_clusters = len(unique_clusters)
if n_clusters < 2:
raise ValueError(f"Wild cluster bootstrap requires at least 2 clusters, got {n_clusters}")
if n_clusters < 5:
warnings.warn(
f"Only {n_clusters} clusters detected. Wild bootstrap inference may be "
"unreliable with fewer than 5 clusters. Consider using Webb weights "
"(weight_type='webb') for improved finite-sample properties.",
UserWarning,
)
# Initialize RNG
rng = np.random.default_rng(seed)
# Select weight generator
weight_generators = {
"rademacher": _generate_rademacher_weights,
"webb": _generate_webb_weights,
"mammen": _generate_mammen_weights,
}
generate_weights = weight_generators[weight_type]
n = X.shape[0]
# Step 1: Compute original coefficient and cluster-robust SE
beta_hat, _, vcov_original = _solve_ols_linalg(X, y, cluster_ids=cluster_ids, return_vcov=True)
original_coef = beta_hat[coefficient_index]
assert vcov_original is not None
se_original = np.sqrt(vcov_original[coefficient_index, coefficient_index])
t_stat_original = (original_coef - null_hypothesis) / se_original
# Step 2: Impose null hypothesis (restricted estimation)
# Create restricted y: y_restricted = y - X[:, coef_index] * null_hypothesis
# This imposes the null that the coefficient equals null_hypothesis
y_restricted = y - X[:, coefficient_index] * null_hypothesis
# Fit restricted model (but we need to drop the column for the restricted coef)
# Actually, for WCR bootstrap we keep all columns but impose the null via residuals
# Re-estimate with the restricted dependent variable
beta_restricted, residuals_restricted, _ = _solve_ols_linalg(X, y_restricted, return_vcov=False)
# Create cluster-to-observation mapping for efficiency
cluster_map = {c: np.where(cluster_ids == c)[0] for c in unique_clusters}
cluster_indices = [cluster_map[c] for c in unique_clusters]
# Step 3: Bootstrap loop
bootstrap_t_stats = np.zeros(n_bootstrap)
bootstrap_coefs = np.zeros(n_bootstrap)
for b in range(n_bootstrap):
# Generate cluster-level weights
cluster_weights = generate_weights(n_clusters, rng)
# Map cluster weights to observations
obs_weights = np.zeros(n)
for g, indices in enumerate(cluster_indices):
obs_weights[indices] = cluster_weights[g]
# Construct bootstrap sample: y* = X @ beta_restricted + e_restricted * weights
y_star = np.dot(X, beta_restricted) + residuals_restricted * obs_weights
# Estimate bootstrap coefficients with cluster-robust SE
beta_star, residuals_star, vcov_star = _solve_ols_linalg(
X, y_star, cluster_ids=cluster_ids, return_vcov=True
)
bootstrap_coefs[b] = beta_star[coefficient_index]
assert vcov_star is not None
se_star = np.sqrt(vcov_star[coefficient_index, coefficient_index])
# Compute bootstrap t-statistic (under null hypothesis)
if se_star > 0:
bootstrap_t_stats[b] = (beta_star[coefficient_index] - null_hypothesis) / se_star
else:
bootstrap_t_stats[b] = 0.0
# Step 4: Compute bootstrap p-value
# P-value is proportion of |t*| >= |t_original|
p_value = np.mean(np.abs(bootstrap_t_stats) >= np.abs(t_stat_original))
# Ensure p-value is at least 1/(n_bootstrap+1) to avoid exact zero
p_value = float(max(float(p_value), 1 / (n_bootstrap + 1)))
# Step 5: Compute bootstrap SE and confidence interval
# SE from standard deviation of bootstrap coefficient distribution
se_bootstrap = float(np.std(bootstrap_coefs, ddof=1))
# Percentile confidence interval from bootstrap distribution
lower_percentile = alpha / 2 * 100
upper_percentile = (1 - alpha / 2) * 100
ci_lower = float(np.percentile(bootstrap_coefs, lower_percentile))
ci_upper = float(np.percentile(bootstrap_coefs, upper_percentile))
return WildBootstrapResults(
se=se_bootstrap,
p_value=p_value,
t_stat_original=t_stat_original,
ci_lower=ci_lower,
ci_upper=ci_upper,
n_clusters=n_clusters,
n_bootstrap=n_bootstrap,
weight_type=weight_type,
alpha=alpha,
bootstrap_distribution=bootstrap_coefs if return_distribution else None,
)
def check_parallel_trends(
data: pd.DataFrame,
outcome: str,
time: str,
treatment_group: str,
pre_periods: Optional[List[Any]] = None,
) -> Dict[str, Any]:
"""
Perform a simple check for parallel trends assumption.
This computes the trend (slope) in the outcome variable for both
treatment and control groups during pre-treatment periods.
Parameters
----------
data : pd.DataFrame
Panel data.
outcome : str
Name of outcome variable column.
time : str
Name of time period column.
treatment_group : str
Name of treatment group indicator column.
pre_periods : list, optional
List of pre-treatment time periods. If None, infers from data.
Returns
-------
dict
Dictionary with trend statistics and test results.
"""
if pre_periods is None:
# Assume treatment happens at median time period
all_periods = sorted(data[time].unique())
mid_point = len(all_periods) // 2
pre_periods = all_periods[:mid_point]
pre_data = data[data[time].isin(pre_periods)]
# Compute trends for each group
treated_data = pre_data[pre_data[treatment_group] == 1]
control_data = pre_data[pre_data[treatment_group] == 0]
# Simple linear regression for trends
def compute_trend(group_data: pd.DataFrame) -> Tuple[float, float]:
time_values = group_data[time].values
outcome_values = group_data[outcome].values
# Normalize time to start at 0
time_norm = time_values - time_values.min()
# Compute slope using least squares
n = len(time_norm)
if n < 2:
return np.nan, np.nan
mean_t = np.mean(time_norm)
mean_y = np.mean(outcome_values)
# Check for zero variance in time (all same time period)
time_var = np.sum((time_norm - mean_t) ** 2)
if time_var == 0:
return np.nan, np.nan
slope = np.sum((time_norm - mean_t) * (outcome_values - mean_y)) / time_var
# Compute standard error of slope
y_hat = mean_y + slope * (time_norm - mean_t)
residuals = outcome_values - y_hat
mse = np.sum(residuals**2) / (n - 2)
se_slope = np.sqrt(mse / time_var)
return slope, se_slope
treated_slope, treated_se = compute_trend(treated_data)
control_slope, control_se = compute_trend(control_data)
# Test for difference in trends
slope_diff = treated_slope - control_slope
se_diff = np.sqrt(treated_se**2 + control_se**2)
t_stat, p_value, _ = safe_inference(slope_diff, se_diff)
return {
"treated_trend": treated_slope,
"treated_trend_se": treated_se,
"control_trend": control_slope,
"control_trend_se": control_se,
"trend_difference": slope_diff,
"trend_difference_se": se_diff,
"t_statistic": t_stat,
"p_value": p_value,
"parallel_trends_plausible": p_value > 0.05 if not np.isnan(p_value) else None,
}
def check_parallel_trends_robust(
data: pd.DataFrame,
outcome: str,
time: str,
treatment_group: str,
unit: Optional[str] = None,
pre_periods: Optional[List[Any]] = None,
n_permutations: int = 1000,
seed: Optional[int] = None,
wasserstein_threshold: float = 0.2,
) -> Dict[str, Any]:
"""
Perform robust parallel trends testing using distributional comparisons.
Uses the Wasserstein (Earth Mover's) distance to compare the full
distribution of outcome changes between treated and control groups,
with permutation-based inference.
Parameters
----------
data : pd.DataFrame
Panel data with repeated observations over time.
outcome : str
Name of outcome variable column.
time : str
Name of time period column.
treatment_group : str
Name of treatment group indicator column (0/1).
unit : str, optional
Name of unit identifier column. If provided, computes unit-level
changes. Otherwise uses observation-level data.
pre_periods : list, optional
List of pre-treatment time periods. If None, uses first half of periods.
n_permutations : int, default=1000
Number of permutations for computing p-value.
seed : int, optional
Random seed for reproducibility.
wasserstein_threshold : float, default=0.2
Threshold for normalized Wasserstein distance. Values below this
threshold (combined with p > 0.05) suggest parallel trends are plausible.
Returns
-------
dict
Dictionary containing:
- wasserstein_distance: Wasserstein distance between group distributions
- wasserstein_p_value: Permutation-based p-value
- ks_statistic: Kolmogorov-Smirnov test statistic
- ks_p_value: KS test p-value
- mean_difference: Difference in mean changes
- variance_ratio: Ratio of variances in changes
- treated_changes: Array of outcome changes for treated
- control_changes: Array of outcome changes for control
- parallel_trends_plausible: Boolean assessment
Examples
--------
>>> results = check_parallel_trends_robust(
... data, outcome='sales', time='year',
... treatment_group='treated', unit='firm_id'
... )
>>> print(f"Wasserstein distance: {results['wasserstein_distance']:.4f}")
>>> print(f"P-value: {results['wasserstein_p_value']:.4f}")
Notes
-----
The Wasserstein distance (Earth Mover's Distance) measures the minimum
"cost" of transforming one distribution into another. Unlike simple
mean comparisons, it captures differences in the entire distribution
shape, making it more robust to non-normal data and heterogeneous effects.
A small Wasserstein distance and high p-value suggest the distributions
of pre-treatment changes are similar, supporting the parallel trends
assumption.
"""
# Use local RNG to avoid affecting global random state
rng = np.random.default_rng(seed)
# Identify pre-treatment periods
if pre_periods is None:
all_periods = sorted(data[time].unique())
mid_point = len(all_periods) // 2
pre_periods = all_periods[:mid_point]
pre_data = data[data[time].isin(pre_periods)].copy()
# Compute outcome changes
treated_changes, control_changes = _compute_outcome_changes(
pre_data, outcome, time, treatment_group, unit
)
if len(treated_changes) < 2 or len(control_changes) < 2:
return {
"wasserstein_distance": np.nan,
"wasserstein_p_value": np.nan,
"ks_statistic": np.nan,
"ks_p_value": np.nan,
"mean_difference": np.nan,
"variance_ratio": np.nan,
"treated_changes": treated_changes,
"control_changes": control_changes,
"parallel_trends_plausible": None,
"error": "Insufficient data for comparison",
}
# Compute Wasserstein distance
wasserstein_dist = stats.wasserstein_distance(treated_changes, control_changes)
# Permutation test for Wasserstein distance
all_changes = np.concatenate([treated_changes, control_changes])
n_treated = len(treated_changes)
n_total = len(all_changes)
permuted_distances = np.zeros(n_permutations)
for i in range(n_permutations):
perm_idx = rng.permutation(n_total)
perm_treated = all_changes[perm_idx[:n_treated]]
perm_control = all_changes[perm_idx[n_treated:]]
permuted_distances[i] = stats.wasserstein_distance(perm_treated, perm_control)
# P-value: proportion of permuted distances >= observed
wasserstein_p = np.mean(permuted_distances >= wasserstein_dist)
# Kolmogorov-Smirnov test
ks_stat, ks_p = stats.ks_2samp(treated_changes, control_changes)
# Additional summary statistics
mean_diff = np.mean(treated_changes) - np.mean(control_changes)
var_treated = np.var(treated_changes, ddof=1)
var_control = np.var(control_changes, ddof=1)
var_ratio = var_treated / var_control if var_control > 0 else np.nan
# Normalized Wasserstein (relative to pooled std)
pooled_std = np.std(all_changes, ddof=1)
wasserstein_normalized = wasserstein_dist / pooled_std if pooled_std > 0 else np.nan
# Assessment: parallel trends plausible if p-value > 0.05
# and normalized Wasserstein is small (below threshold)
plausible = bool(
wasserstein_p > 0.05
and (
wasserstein_normalized < wasserstein_threshold
if not np.isnan(wasserstein_normalized)
else True
)
)
return {
"wasserstein_distance": wasserstein_dist,
"wasserstein_normalized": wasserstein_normalized,
"wasserstein_p_value": wasserstein_p,
"ks_statistic": ks_stat,
"ks_p_value": ks_p,
"mean_difference": mean_diff,
"variance_ratio": var_ratio,
"n_treated": len(treated_changes),
"n_control": len(control_changes),
"treated_changes": treated_changes,
"control_changes": control_changes,
"parallel_trends_plausible": plausible,
}
def _compute_outcome_changes(
data: pd.DataFrame, outcome: str, time: str, treatment_group: str, unit: Optional[str] = None
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute period-to-period outcome changes for treated and control groups.
Parameters
----------
data : pd.DataFrame
Panel data.
outcome : str
Outcome variable column.
time : str
Time period column.
treatment_group : str
Treatment group indicator column.
unit : str, optional
Unit identifier column.
Returns
-------
tuple
(treated_changes, control_changes) as numpy arrays.
"""
if unit is not None:
# Unit-level changes: compute change for each unit across periods
data_sorted = data.sort_values([unit, time])
data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff()
# Remove NaN from first period of each unit
changes_data = data_sorted.dropna(subset=["_outcome_change"])
treated_changes = changes_data[changes_data[treatment_group] == 1]["_outcome_change"].values
control_changes = changes_data[changes_data[treatment_group] == 0]["_outcome_change"].values
else:
# Aggregate changes: compute mean change per period per group
treated_data = data[data[treatment_group] == 1]
control_data = data[data[treatment_group] == 0]
# Compute period means
treated_means = treated_data.groupby(time)[outcome].mean()
control_means = control_data.groupby(time)[outcome].mean()
# Compute changes between consecutive periods
treated_changes = np.diff(treated_means.values)
control_changes = np.diff(control_means.values)
return treated_changes.astype(float), control_changes.astype(float)
def equivalence_test_trends(
data: pd.DataFrame,
outcome: str,
time: str,
treatment_group: str,
unit: Optional[str] = None,
pre_periods: Optional[List[Any]] = None,
equivalence_margin: Optional[float] = None,
) -> Dict[str, Any]:
"""
Perform equivalence testing (TOST) for parallel trends.
Tests whether the difference in trends is practically equivalent to zero
using Two One-Sided Tests (TOST) procedure.
Parameters
----------
data : pd.DataFrame
Panel data.
outcome : str
Name of outcome variable column.
time : str
Name of time period column.
treatment_group : str
Name of treatment group indicator column.
unit : str, optional
Name of unit identifier column.
pre_periods : list, optional
List of pre-treatment time periods.
equivalence_margin : float, optional
The margin for equivalence (delta). If None, uses 0.5 * pooled SD
of outcome changes as a default.
Returns
-------
dict
Dictionary containing:
- mean_difference: Difference in mean changes
- equivalence_margin: The margin used
- lower_p_value: P-value for lower bound test
- upper_p_value: P-value for upper bound test
- tost_p_value: Maximum of the two p-values
- equivalent: Boolean indicating equivalence at alpha=0.05
"""
# Get pre-treatment periods
if pre_periods is None:
all_periods = sorted(data[time].unique())
mid_point = len(all_periods) // 2
pre_periods = all_periods[:mid_point]
pre_data = data[data[time].isin(pre_periods)].copy()
# Compute outcome changes
treated_changes, control_changes = _compute_outcome_changes(
pre_data, outcome, time, treatment_group, unit
)
# Need at least 2 observations per group to compute variance
# and at least 3 total for meaningful df calculation
if len(treated_changes) < 2 or len(control_changes) < 2:
return {
"mean_difference": np.nan,
"se_difference": np.nan,
"equivalence_margin": np.nan,
"lower_t_stat": np.nan,
"upper_t_stat": np.nan,
"lower_p_value": np.nan,
"upper_p_value": np.nan,
"tost_p_value": np.nan,
"degrees_of_freedom": np.nan,
"equivalent": None,
"error": "Insufficient data (need at least 2 observations per group)",
}