"""
causers - High-performance statistical operations for Polars DataFrames
A Python package with Rust backend for fast statistical computations
on Polars DataFrames.
"""
# ============================================================
# IMPORTS
# ============================================================
from typing import List as _List, Optional as _Optional, Union as _Union
import math as _math
import warnings as _warnings
import polars as _pl
from polars.exceptions import ColumnNotFoundError as _ColumnNotFoundError
# ============================================================
# CONSTANTS
# ============================================================
_CLUSTER_BALANCE_THRESHOLD = 0.5 # Warn if single cluster has >50% of observations
_SMALL_CLUSTER_THRESHOLD = 42 # Recommend bootstrap when clusters < 42
_WEIGHT_CONCENTRATION_THRESHOLD = 0.5 # Warn if single unit/period weight > 50%
_MIN_RELIABLE_BOOTSTRAP_ITERATIONS = 100 # Minimum for reliable SE estimates
_POOR_FIT_RMSE_RATIO = 0.1 # Warn if RMSE > 10% of outcome std
# ============================================================
# EXPORTS
# ============================================================
__version__ = "0.8.0"
# Import the Rust extension module
from causers._causers import (
LinearRegressionResult,
LogisticRegressionResult,
SyntheticDIDResult,
SyntheticControlResult,
DMLResult,
TwoStageLSResult,
linear_regression as _linear_regression_rust,
logistic_regression as _logistic_regression_rust,
synthetic_did_impl as _synthetic_did_impl,
synthetic_control_impl as _synthetic_control_impl,
dml_impl as _dml_impl,
two_stage_least_squares as _two_stage_least_squares_rust,
balance_check_impl as _balance_check_impl,
BalanceResult as _BalanceResult,
)
# Re-export main functions
__all__ = [
# Version
"__version__",
# Result classes
"LinearRegressionResult",
"LogisticRegressionResult",
"SyntheticDIDResult",
"SyntheticControlResult",
"DMLResult",
"TwoStageLSResult",
# Estimators
"linear_regression",
"logistic_regression",
"synthetic_did",
"synthetic_control",
"dml",
"two_stage_least_squares",
# Balance checking
"balance_check",
"BalanceCheckResult",
# Utilities
"about",
]
# ============================================================
# PRIVATE HELPER FUNCTIONS
# ============================================================
def _convert_dataframe_if_pandas(
df: _Union[_pl.DataFrame, "pd.DataFrame"],
required_cols: _List[str],
int_columns: _Optional[_List[str]] = None,
) -> _pl.DataFrame:
"""Convert pandas DataFrame to Polars if needed.
Parameters
----------
df : Union[pl.DataFrame, pd.DataFrame]
Input DataFrame.
required_cols : List[str]
Columns that must be present.
int_columns : Optional[List[str]]
Columns to treat as integer type during conversion.
Returns
-------
pl.DataFrame
Polars DataFrame (converted if input was pandas).
"""
from causers._pandas_compat import (
detect_dataframe_type,
validate_pandas_dataframe,
convert_pandas_to_polars,
)
df_type = detect_dataframe_type(df)
if df_type == "pandas":
validate_pandas_dataframe(df, required_cols)
return convert_pandas_to_polars(df, required_cols, int_columns=int_columns)
return df
def _warn_if_float_cluster(df: _pl.DataFrame, cluster: _Optional[str]) -> None:
"""Emit warning if cluster column is float type."""
if cluster is None:
return
try:
cluster_dtype = df[cluster].dtype
if cluster_dtype in (_pl.Float32, _pl.Float64):
_warnings.warn(
f"Cluster column '{cluster}' is float; will be cast to string for grouping.",
UserWarning,
stacklevel=3,
)
except (KeyError, AttributeError, _ColumnNotFoundError):
pass # Let Rust layer handle column not found errors
def _check_cluster_balance(
df: _pl.DataFrame, cluster_col: str
) -> _Optional[str]:
"""
Check cluster balance and return warning message if any cluster has >50% observations.
Returns None if balanced, otherwise returns the warning message.
"""
try:
value_counts = df[cluster_col].value_counts()
total = len(df)
threshold = int(total * _CLUSTER_BALANCE_THRESHOLD)
for row in value_counts.iter_rows():
cluster_val, count = row[0], row[1]
if count > threshold:
pct = (count * 100) // total
return (
f"Cluster '{cluster_val}' contains {pct}% of observations ({count}/{total}). "
f"Clustered standard errors may be unreliable with such imbalanced clusters."
)
return None
except (KeyError, AttributeError, _ColumnNotFoundError):
# If we can't check balance, don't warn
return None
# ============================================================
# UTILITIES & METADATA
# ============================================================
[docs]
def about() -> None:
"""Print information about the causers package."""
print(f"causers version {__version__}")
print("High-performance statistical operations for Polars DataFrames")
print("Powered by Rust via PyO3/maturin")
print("")
print("Features:")
print(" - Linear regression with HC3 robust standard errors")
print(" - Fixed effects estimation (one-way and two-way)")
print(" - Logistic regression with Newton-Raphson MLE")
print(" - Cluster-robust standard errors (analytical and bootstrap)")
print(" - Wild cluster bootstrap for small cluster counts (linear)")
print(" - Score bootstrap for small cluster counts (logistic)")
print(" - Synthetic Difference-in-Differences (SDID)")
print(" - Synthetic Control (SC) with multiple method variants")
print(" - Double Machine Learning (DML) for causal inference")
# ============================================================
# CORE ESTIMATORS
# ============================================================
[docs]
def linear_regression(
df: _Union[_pl.DataFrame, "pd.DataFrame"],
x_cols: _Union[str, _List[str]],
y_col: str,
include_intercept: bool = True,
cluster: _Optional[str] = None,
bootstrap: bool = False,
bootstrap_iterations: int = 1000,
seed: _Optional[int] = None,
bootstrap_method: str = "rademacher",
fixed_effects: _Optional[_Union[str, _List[str]]] = None,
) -> LinearRegressionResult:
"""
Perform linear regression on Polars or pandas DataFrame columns.
Supports both single and multiple covariate regression using ordinary
least squares (OLS). For multiple covariates, uses matrix operations:
β = (X'X)^-1 X'y
Optionally absorbs fixed effects (entity/time) via within-transformation
(demeaning) before regression. When fixed effects are absorbed, the intercept
is implicitly absorbed and not returned.
Parameters
----------
df : pl.DataFrame or pd.DataFrame
The DataFrame containing the data. Accepts both Polars and pandas.
For pandas DataFrames with Arrow-backed columns (pd.ArrowDtype),
data is extracted via zero-copy where possible.
x_cols : str or List[str]
Name(s) of the independent variable column(s). Can be:
- A single column name as a string
- A list of column names for multiple covariates
y_col : str
Name of the dependent variable column
include_intercept : bool, default=True
Whether to include an intercept term in the regression.
If False, forces the regression line through the origin.
Note: When fixed_effects are specified, intercept is implicitly
absorbed regardless of this setting.
cluster : str, optional
Column name for cluster identifiers. When specified, computes
cluster-robust standard errors instead of HC3. Supports integer,
string, or categorical columns.
bootstrap : bool, default=False
If True and cluster is specified, use wild cluster bootstrap
for standard error computation. Requires cluster to be specified.
Recommended when number of clusters is less than 42.
bootstrap_iterations : int, default=1000
Number of bootstrap replications when bootstrap=True.
seed : int, optional
Random seed for reproducibility when using bootstrap. When None,
uses a random seed which may produce different results each call.
bootstrap_method : str, default "rademacher"
Weight distribution for wild bootstrap. Options:
- "rademacher": Standard Rademacher weights (±1 with equal probability)
- "webb": Webb's 6-point distribution for improved small-sample performance
Only used when bootstrap=True and cluster is specified.
fixed_effects : str or List[str], optional
Column name(s) for fixed effects to absorb. Supports 1 or 2 columns.
When specified:
- One column: One-way fixed effects (e.g., entity or time)
- Two columns: Two-way fixed effects (e.g., entity + time)
Fixed effect columns must not overlap with x_cols or y_col.
Columns can be integer, string, or categorical type.
Returns
-------
LinearRegressionResult
Result object with the following attributes:
- coefficients : List[float]
Regression coefficients for each x variable
- intercept : float or None
Intercept term (None if include_intercept=False or fixed_effects used)
- r_squared : float
Coefficient of determination (R²) using original y
- n_samples : int
Number of samples used in the regression
- slope : float or None
For single covariate only, same as coefficients[0]
- standard_errors : List[float]
Robust standard errors for each coefficient. Uses HC3 by
default, or cluster-robust SE if cluster is specified.
- intercept_se : float or None
Robust standard error for intercept (None if include_intercept=False
or fixed_effects used)
- n_clusters : int or None
Number of unique clusters (None if cluster not specified)
- cluster_se_type : str or None
Type of clustered SE: "analytical", "bootstrap_rademacher", or
"bootstrap_webb" (None if not clustered)
- bootstrap_iterations_used : int or None
Number of bootstrap iterations (None if not bootstrap)
- fixed_effects_absorbed : List[int] or None
Number of groups absorbed for each fixed effect (None if no FE)
- fixed_effects_names : List[str] or None
Names of the fixed effect columns absorbed (None if no FE)
- within_r_squared : float or None
Within R² computed on demeaned data (None if no FE)
Raises
------
ValueError
- If x_cols is empty or columns don't exist
- If cluster column contains null values
- If bootstrap=True without cluster specified
- If fewer than 2 clusters detected
- If single-observation clusters exist (analytical mode only)
- If numerical instability detected (condition number > 1e10)
- If bootstrap_iterations < 1
- If fixed_effects has more than 2 columns
- If fixed_effects column overlaps with x_cols or y_col
- If fixed_effects column contains null values
- If fixed_effects column has only one unique value
- If covariate becomes collinear after FE demeaning
Warns
-----
UserWarning
- When fewer than 42 clusters with bootstrap=False: recommends using
wild cluster bootstrap for more accurate inference.
- When cluster column has float dtype: implicit cast to string.
- When fixed_effects column has float dtype: implicit cast warning.
Examples
--------
Single covariate regression:
>>> import polars as pl
>>> import causers
>>> df = pl.DataFrame({"x": [1, 2, 3, 4, 5], "y": [2, 4, 6, 8, 10]})
>>> result = causers.linear_regression(df, "x", "y")
>>> print(f"y = {result.slope:.2f}x + {result.intercept:.2f}")
y = 2.00x + 0.00
Accessing standard errors:
>>> df = pl.DataFrame({"x": [1, 2, 3, 4, 5], "y": [2.1, 3.9, 6.2, 7.8, 10.1]})
>>> result = causers.linear_regression(df, "x", "y")
>>> print(f"Coefficient: {result.coefficients[0]:.4f} ± {result.standard_errors[0]:.4f}")
Coefficient: 1.9900 ± 0.0682
>>> print(f"Intercept: {result.intercept:.4f} ± {result.intercept_se:.4f}")
Intercept: 0.0500 ± 0.1896
Multiple covariate regression:
>>> df = pl.DataFrame({
... "x1": [1, 2, 3, 4, 5],
... "x2": [1, 1, 2, 2, 3],
... "y": [6, 8, 13, 15, 20]
... })
>>> result = causers.linear_regression(df, ["x1", "x2"], "y")
>>> print(f"Coefficients: {result.coefficients}")
Coefficients: [2.0, 3.0]
Clustered standard errors (analytical):
>>> df = pl.DataFrame({
... "x": [1, 2, 3, 4, 5, 6],
... "y": [2, 4, 5, 8, 9, 12],
... "firm_id": [1, 1, 2, 2, 3, 3]
... })
>>> result = causers.linear_regression(df, "x", "y", cluster="firm_id")
>>> print(f"Clustered SE: {result.standard_errors[0]:.4f} (G={result.n_clusters})")
Clustered SE: ... (G=3)
Wild cluster bootstrap (recommended for <42 clusters):
>>> result = causers.linear_regression(
... df, "x", "y",
... cluster="firm_id", bootstrap=True, seed=42
... )
>>> print(f"Bootstrap SE: {result.standard_errors[0]:.4f}")
Bootstrap SE: ...
Notes
-----
Standard errors are computed using:
- **HC3 (default)**: Heteroskedasticity-consistent standard errors
when no cluster is specified. Provides robust inference when error
variance may not be constant (MacKinnon & White, 1985).
- **Analytical clustered SE**: When cluster is specified and bootstrap=False.
Uses the sandwich estimator with small-sample adjustment (G/(G-1) × (N-1)/(N-k)).
Accounts for within-cluster correlation.
- **Wild cluster bootstrap SE**: When cluster and bootstrap=True.
Uses Rademacher weights (±1 with equal probability) and is recommended
when the number of clusters is small (G < 42).
The 42-cluster threshold is based on asymptotic theory and simulation
evidence that analytical clustered SE can be unreliable with few clusters.
References
----------
Cameron, A. C., & Miller, D. L. (2015). A Practitioner's Guide to
Cluster-Robust Inference. Journal of Human Resources, 50(2), 317-372.
MacKinnon, J. G., & Webb, M. D. (2018). The wild bootstrap for few
(treated) clusters. The Econometrics Journal, 21(2), 114-135.
See Also
--------
LinearRegressionResult : Result class with coefficient estimates and diagnostics.
"""
# Normalize x_cols to always be a list
if isinstance(x_cols, str):
x_cols_list = [x_cols]
else:
x_cols_list = list(x_cols)
# Normalize fixed_effects to a list or None
fe_cols_list: _Optional[_List[str]] = None
if fixed_effects is not None:
if isinstance(fixed_effects, str):
fe_cols_list = [fixed_effects]
else:
fe_cols_list = list(fixed_effects)
# Detect and convert pandas DataFrames if needed
all_cols = x_cols_list + [y_col]
if cluster:
all_cols.append(cluster)
if fe_cols_list:
all_cols.extend(fe_cols_list)
df = _convert_dataframe_if_pandas(df, all_cols)
# Check for float cluster column and emit warning
_warn_if_float_cluster(df, cluster)
# Check for float fixed_effects columns and emit warning
if fe_cols_list is not None:
for fe_col in fe_cols_list:
try:
fe_dtype = df[fe_col].dtype
if fe_dtype in (_pl.Float32, _pl.Float64):
_warnings.warn(
f"Fixed effect column '{fe_col}' is float; will be cast to integer for grouping.",
UserWarning,
stacklevel=2
)
except (KeyError, AttributeError, _ColumnNotFoundError):
pass # Let the Rust layer handle column not found errors
# Normalize bootstrap_method to lowercase for case-insensitive matching
bootstrap_method_normalized = bootstrap_method.lower()
# Call the Rust implementation
result = _linear_regression_rust(
df,
x_cols_list,
y_col,
include_intercept,
cluster,
bootstrap,
bootstrap_iterations,
seed,
bootstrap_method_normalized,
fe_cols_list,
)
# Check for small cluster count and emit warning
if result.n_clusters is not None and not bootstrap:
if result.n_clusters < _SMALL_CLUSTER_THRESHOLD:
_warnings.warn(
f"Only {result.n_clusters} clusters detected. Wild cluster bootstrap "
f"(bootstrap=True) is recommended when clusters < {_SMALL_CLUSTER_THRESHOLD}.",
UserWarning,
stacklevel=2
)
# Check for cluster imbalance and emit warning
if cluster is not None:
balance_warning = _check_cluster_balance(df, cluster)
if balance_warning is not None:
_warnings.warn(balance_warning, UserWarning, stacklevel=2)
return result
[docs]
def logistic_regression(
df: _Union[_pl.DataFrame, "pd.DataFrame"],
x_cols: _Union[str, _List[str]],
y_col: str,
include_intercept: bool = True,
cluster: _Optional[str] = None,
bootstrap: bool = False,
bootstrap_iterations: int = 1000,
seed: _Optional[int] = None,
bootstrap_method: str = "rademacher",
fixed_effects: _Optional[_Union[str, _List[str]]] = None,
) -> LogisticRegressionResult:
"""
Perform logistic regression on binary outcome with robust standard errors.
Fits a logistic regression model using Maximum Likelihood Estimation (MLE)
with Newton-Raphson optimization. Returns coefficient estimates (log-odds),
robust standard errors, and diagnostic information.
Optionally absorbs fixed effects using the Mundlak (1978) approach: adds
group means of covariates as additional regressors. This is suitable for
nonlinear models where standard within-transformation is not valid.
Parameters
----------
df : pl.DataFrame or pd.DataFrame
The DataFrame containing the data. Accepts both Polars and pandas.
For pandas DataFrames with Arrow-backed columns (pd.ArrowDtype),
data is extracted via zero-copy where possible.
x_cols : str or List[str]
Name(s) of the independent variable column(s). Can be:
- A single column name as a string
- A list of column names for multiple covariates
y_col : str
Name of the binary outcome column (must contain only 0 and 1)
include_intercept : bool, default=True
Whether to include an intercept term in the regression.
cluster : str, optional
Column name for cluster identifiers. When specified, computes
cluster-robust standard errors using the score-based approach.
Supports integer, string, or categorical columns.
bootstrap : bool, default=False
If True and cluster is specified, use score bootstrap
for standard error computation. Requires cluster to be specified.
Recommended when number of clusters is less than 42.
bootstrap_iterations : int, default=1000
Number of bootstrap replications when bootstrap=True.
seed : int, optional
Random seed for reproducibility when using bootstrap. When None,
uses a random seed which may produce different results each call.
bootstrap_method : str, default "rademacher"
Weight distribution for score bootstrap. Options:
- "rademacher": Standard Rademacher weights (±1 with equal probability)
- "webb": Webb's 6-point distribution for improved small-sample performance
Only used when bootstrap=True and cluster is specified.
fixed_effects : str or List[str], optional
Column name(s) for fixed effects to absorb using the Mundlak approach.
Supports 1 or 2 columns. When specified:
- One column: One-way fixed effects (e.g., entity or time)
- Two columns: Two-way fixed effects (e.g., entity + time)
The Mundlak approach adds group means of covariates as additional regressors,
which is appropriate for nonlinear models like logistic regression.
Fixed effect columns must not overlap with x_cols or y_col.
Columns can be integer, string, or categorical type.
Returns
-------
LogisticRegressionResult
Result object with the following attributes:
- coefficients : List[float]
Coefficient estimates for x variables (log-odds scale).
When fixed_effects is used, only returns coefficients for
the original K covariates (Mundlak terms are filtered out).
- intercept : float or None
Intercept term (None if include_intercept=False)
- standard_errors : List[float]
Robust standard errors for each coefficient. Uses HC3 by
default, or clustered SE if cluster is specified.
- intercept_se : float or None
Robust standard error for intercept (None if include_intercept=False)
- n_samples : int
Number of observations used
- n_clusters : int or None
Number of unique clusters (None if cluster not specified)
- cluster_se_type : str or None
Type of clustered SE: "analytical", "bootstrap_rademacher", or
"bootstrap_webb" (None if not clustered)
- bootstrap_iterations_used : int or None
Number of bootstrap iterations (None if not bootstrap)
- converged : bool
Whether the MLE optimizer converged
- iterations : int
Number of Newton-Raphson iterations used
- log_likelihood : float
Log-likelihood at the MLE solution
- pseudo_r_squared : float
McFadden's pseudo R² = 1 - (LL_model / LL_null)
- fixed_effects_absorbed : List[int] or None
Number of groups absorbed for each fixed effect (None if no FE)
- fixed_effects_names : List[str] or None
Names of the fixed effect columns absorbed (None if no FE)
- within_pseudo_r_squared : float or None
Within pseudo R² for FE model (None if no FE)
Raises
------
ValueError
- If y_col contains values other than 0 and 1
- If y_col contains only 0s or only 1s
- If x_cols is empty or columns don't exist
- If cluster column contains null values
- If bootstrap=True without cluster specified
- If fewer than 2 clusters detected
- If perfect separation is detected
- If Hessian is singular (collinearity)
- If convergence fails after max iterations
- If numerical instability detected (condition number > 1e10)
- If bootstrap_iterations < 1
- If fixed_effects has more than 2 columns
- If fixed_effects column overlaps with x_cols or y_col
- If fixed_effects column contains null values
- If fixed_effects column has only one unique value
- If augmented design matrix is collinear after adding Mundlak terms
Warns
-----
UserWarning
- When fewer than 42 clusters with bootstrap=False: recommends using
score bootstrap for more accurate inference.
- When cluster column has float dtype: implicit cast to string.
- When fixed_effects column has float dtype: implicit cast warning.
Examples
--------
Simple logistic regression:
>>> import polars as pl
>>> import causers
>>> df = pl.DataFrame({
... "x": [0.5, 1.0, 1.5, 2.0, 2.5, 3.0],
... "y": [0, 0, 0, 1, 1, 1]
... })
>>> result = causers.logistic_regression(df, "x", "y")
>>> print(f"Coefficient: {result.coefficients[0]:.4f}")
Coefficient: ...
Accessing convergence information:
>>> if result.converged:
... print(f"Converged in {result.iterations} iterations")
... print(f"Log-likelihood: {result.log_likelihood:.2f}")
... print(f"McFadden R²: {result.pseudo_r_squared:.3f}")
Converged in ... iterations
Log-likelihood: ...
McFadden R²: ...
Multiple covariates:
>>> df = pl.DataFrame({
... "x1": [1, 2, 3, 4, 5, 6],
... "x2": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
... "y": [0, 0, 0, 1, 1, 1]
... })
>>> result = causers.logistic_regression(df, ["x1", "x2"], "y")
>>> print(f"Coefficients: {result.coefficients}")
Coefficients: [...]
Clustered standard errors:
>>> df = pl.DataFrame({
... "x": [1, 2, 3, 4, 5, 6],
... "y": [0, 0, 1, 0, 1, 1],
... "firm_id": [1, 1, 2, 2, 3, 3]
... })
>>> result = causers.logistic_regression(df, "x", "y", cluster="firm_id")
>>> print(f"Clustered SE: {result.standard_errors[0]:.4f} (G={result.n_clusters})")
Clustered SE: ... (G=3)
Score bootstrap (recommended for <42 clusters):
>>> result = causers.logistic_regression(
... df, "x", "y",
... cluster="firm_id", bootstrap=True, seed=42
... )
>>> print(f"Bootstrap SE: {result.standard_errors[0]:.4f}")
Bootstrap SE: ...
Fixed effects (entity FE using Mundlak approach):
>>> df = pl.DataFrame({
... "x": [1.0, 2.0, 3.0, 4.0, 1.5, 2.5, 3.5, 4.5],
... "y": [0, 0, 1, 1, 0, 1, 0, 1],
... "entity_id": [1, 1, 1, 1, 2, 2, 2, 2]
... })
>>> result = causers.logistic_regression(df, "x", "y", fixed_effects="entity_id")
>>> print(f"Coefficient (FE): {result.coefficients[0]:.4f}")
Coefficient (FE): ...
>>> print(f"FE absorbed: {result.fixed_effects_absorbed}")
FE absorbed: [2]
Notes
-----
The logistic regression model is:
P(y=1|x) = 1 / (1 + exp(-x'β))
Coefficients are on the log-odds scale. To convert to odds ratios,
use exp(coefficient).
Standard errors are computed using:
- **HC3 (default)**: Heteroskedasticity-consistent standard errors
adapted for logistic regression, using weighted leverages.
- **Analytical clustered SE**: When cluster is specified and bootstrap=False.
Uses the sandwich estimator with cluster-level scores.
- **Score bootstrap SE**: When cluster and bootstrap=True.
Uses Rademacher weights (±1 with equal probability) following
Kline & Santos (2012). Recommended for small cluster counts (G < 42).
The optimizer uses Newton-Raphson with step halving for stability,
converging when gradient norm < 1e-8 or after 35 iterations.
**Fixed Effects (Mundlak Approach)**
For nonlinear models like logistic regression, standard within-transformation
(demeaning) is not valid. The Mundlak (1978) approach provides an alternative:
1. Compute group means of all covariates: X̄_g for each FE group g
2. Augment the design matrix: [X | X̄_g1 | X̄_g2 | ...]
3. Run standard logistic regression on the augmented model
4. Return only the coefficients for the original X variables
This approach is equivalent to correlated random effects (CRE) and produces
consistent estimates under the same assumptions as fixed effects.
References
----------
Kline, P., & Santos, A. (2012). "A Score Based Approach to Wild Bootstrap
Inference." Journal of Econometric Methods, 1(1), 23-41.
https://doi.org/10.1515/2156-6674.1006
MacKinnon, J. G., & White, H. (1985). "Some heteroskedasticity-consistent
covariance matrix estimators with improved finite sample properties."
Journal of Econometrics, 29(3), 305-325.
Mundlak, Y. (1978). "On the Pooling of Time Series and Cross Section Data."
Econometrica, 46(1), 69-85.
See Also
--------
LogisticRegressionResult : Result class with coefficient estimates and diagnostics.
linear_regression : For continuous outcome regression with FE support.
"""
# Normalize x_cols to always be a list
if isinstance(x_cols, str):
x_cols_list = [x_cols]
else:
x_cols_list = list(x_cols)
# Normalize fixed_effects to a list or None
fe_cols_list: _Optional[_List[str]] = None
if fixed_effects is not None:
if isinstance(fixed_effects, str):
fe_cols_list = [fixed_effects]
else:
fe_cols_list = list(fixed_effects)
# Detect and convert pandas DataFrames if needed
all_cols = x_cols_list + [y_col]
if cluster:
all_cols.append(cluster)
if fe_cols_list:
all_cols.extend(fe_cols_list)
df = _convert_dataframe_if_pandas(df, all_cols)
# Check for float cluster column and emit warning
_warn_if_float_cluster(df, cluster)
# Check for float fixed_effects columns and emit warning
if fe_cols_list is not None:
for fe_col in fe_cols_list:
try:
fe_dtype = df[fe_col].dtype
if fe_dtype in (_pl.Float32, _pl.Float64):
_warnings.warn(
f"Fixed effect column '{fe_col}' is float; will be cast to integer for grouping.",
UserWarning,
stacklevel=2
)
except (KeyError, AttributeError, _ColumnNotFoundError):
pass # Let the Rust layer handle column not found errors
# Normalize bootstrap_method to lowercase for case-insensitive matching
bootstrap_method_normalized = bootstrap_method.lower()
# Call the Rust implementation
result = _logistic_regression_rust(
df,
x_cols_list,
y_col,
include_intercept,
cluster,
bootstrap,
bootstrap_iterations,
seed,
bootstrap_method_normalized,
fe_cols_list,
)
# Check for small cluster count and emit warning
if result.n_clusters is not None and not bootstrap:
if result.n_clusters < _SMALL_CLUSTER_THRESHOLD:
_warnings.warn(
f"Only {result.n_clusters} clusters detected. Score bootstrap "
f"(bootstrap=True) is recommended when clusters < {_SMALL_CLUSTER_THRESHOLD}.",
UserWarning,
stacklevel=2
)
return result
[docs]
def synthetic_did(
df: _Union[_pl.DataFrame, "pd.DataFrame"],
unit_col: str,
time_col: str,
outcome_col: str,
treatment_col: str,
bootstrap_iterations: int = 200,
seed: _Optional[int] = None,
) -> SyntheticDIDResult:
"""
Compute Synthetic Difference-in-Differences (SDID) estimator.
Implements the SDID estimator from Arkhangelsky et al. (2021), which combines
synthetic control weighting with difference-in-differences to estimate the
Average Treatment Effect on the Treated (ATT).
The estimator uses two-stage optimization:
1. **Unit weights**: Find control unit weights that match pre-treatment trends
2. **Time weights**: Find pre-period weights that predict post-period outcomes
Standard errors are computed via placebo bootstrap.
Parameters
----------
df : pl.DataFrame or pd.DataFrame
Panel data in long format with one row per unit-time observation.
Must be a balanced panel (all units observed in all time periods).
Accepts both Polars and pandas DataFrames.
unit_col : str
Column name identifying unique units (e.g., "state", "firm_id").
Must be integer or string type.
time_col : str
Column name identifying time periods (e.g., "year", "quarter").
Must be integer or string type.
outcome_col : str
Column name for the outcome variable. Must be numeric.
treatment_col : str
Column name for treatment indicator. Must contain only 0 and 1 values.
Value of 1 indicates the unit is treated in that period.
bootstrap_iterations : int, default=200
Number of placebo bootstrap iterations for standard error estimation.
Must be at least 1. Values < 100 will emit a warning.
seed : int, optional
Random seed for reproducibility. If None, uses system time.
Returns
-------
SyntheticDIDResult
Result object with the following attributes:
- att : float
Average Treatment Effect on the Treated
- standard_error : float
Bootstrap standard error of the ATT
- unit_weights : List[float]
Weights assigned to each control unit (sums to 1)
- time_weights : List[float]
Weights assigned to each pre-treatment period (sums to 1)
- n_units_control : int
Number of control units
- n_units_treated : int
Number of treated units
- n_periods_pre : int
Number of pre-treatment periods
- n_periods_post : int
Number of post-treatment periods
- solver_iterations : Tuple[int, int]
Number of iterations for (unit_weights, time_weights) optimization
- solver_converged : bool
Whether the Frank-Wolfe solver converged
- pre_treatment_fit : float
RMSE of pre-treatment fit (lower is better)
- bootstrap_iterations_used : int
Number of successful bootstrap iterations
Raises
------
ValueError
- If DataFrame is empty
- If any specified column doesn't exist
- If unit_col or time_col is float type
- If outcome_col is not numeric
- If outcome_col contains null values
- If treatment_col contains values other than 0 and 1
- If bootstrap_iterations < 1
- If fewer than 2 control units found
- If fewer than 2 pre-treatment periods found
- If no treated units found
- If no post-treatment periods found
- If panel is not balanced
Warns
-----
UserWarning
- If any unit weight > 0.5 (weight concentration on single unit)
- If any time weight > 0.5 (weight concentration on single period)
- If bootstrap_iterations < 100 (may be unreliable)
Examples
--------
Basic usage with panel data:
>>> import polars as pl
>>> import causers
>>> df = pl.DataFrame({
... 'unit': [1, 1, 1, 2, 2, 2, 3, 3, 3],
... 'time': [1, 2, 3, 1, 2, 3, 1, 2, 3],
... 'y': [1.0, 2.0, 5.0, 1.5, 2.5, 3.0, 1.2, 2.2, 2.8],
... 'treated': [0, 0, 1, 0, 0, 0, 0, 0, 0]
... })
>>> result = causers.synthetic_did(df, 'unit', 'time', 'y', 'treated', seed=42)
>>> print(f"ATT: {result.att:.4f} ± {result.standard_error:.4f}")
ATT: ... ± ...
Accessing weights and diagnostics:
>>> print(f"Control unit weights: {result.unit_weights}")
Control unit weights: [...]
>>> print(f"Pre-treatment fit RMSE: {result.pre_treatment_fit:.4f}")
Pre-treatment fit RMSE: ...
Notes
-----
**Panel Structure Detection**
The function automatically detects:
- **Control units**: Units where treatment=0 in all periods
- **Treated units**: Units where treatment=1 in at least one period
- **Pre-periods**: Periods where all observations have treatment=0
- **Post-periods**: Periods where at least one treated unit has treatment=1
**Algorithm**
The SDID estimator is:
.. math::
\\hat{\\tau}_{sdid} = (\\bar{Y}_{tr,post} - \\bar{Y}_{synth,post})
- \\sum_t \\lambda_t (\\bar{Y}_{tr,t} - \\bar{Y}_{synth,t})
where :math:`\\bar{Y}_{synth,t} = \\sum_i \\omega_i Y_{i,t}` uses optimized
unit weights :math:`\\omega` on control units.
**Standard Errors**
Standard errors are computed via placebo bootstrap:
1. Randomly select a control unit as "placebo treated"
2. Re-run SDID with this unit treated
3. Repeat for bootstrap_iterations
4. SE = standard deviation of placebo ATTs
References
----------
Arkhangelsky, D., Athey, S., Hirshberg, D. A., Imbens, G. W., & Wager, S. (2021).
Synthetic difference-in-differences. *American Economic Review*, 111(12), 4088-4118.
See Also
--------
SyntheticDIDResult : Result class with ATT and diagnostics.
linear_regression : For standard regression analysis.
"""
# ========================================================================
# pandas Conversion
# ========================================================================
# Detect and convert pandas DataFrames if needed
all_cols = [unit_col, time_col, outcome_col, treatment_col]
df = _convert_dataframe_if_pandas(df, all_cols, int_columns=[unit_col, time_col])
# ========================================================================
# Input Validation
# ========================================================================
# Check DataFrame is not empty
if len(df) == 0:
raise ValueError("Cannot perform SDID on empty DataFrame")
# Check all required columns exist
for col_name, col_label in [
(unit_col, "unit_col"),
(time_col, "time_col"),
(outcome_col, "outcome_col"),
(treatment_col, "treatment_col"),
]:
if col_name not in df.columns:
raise ValueError(f"Column '{col_name}' not found in DataFrame")
# Check unit_col is not float
unit_dtype = df[unit_col].dtype
if unit_dtype in (_pl.Float32, _pl.Float64):
raise ValueError(f"unit_col must be integer or string, not float")
# Check time_col is not float
time_dtype = df[time_col].dtype
if time_dtype in (_pl.Float32, _pl.Float64):
raise ValueError(f"time_col must be integer or string, not float")
# Check outcome_col is numeric
outcome_dtype = df[outcome_col].dtype
numeric_types = (
_pl.Float32, _pl.Float64,
_pl.Int8, _pl.Int16, _pl.Int32, _pl.Int64,
_pl.UInt8, _pl.UInt16, _pl.UInt32, _pl.UInt64,
)
if outcome_dtype not in numeric_types:
raise ValueError(f"outcome_col must be numeric")
# Check for nulls in outcome_col
null_count = df[outcome_col].null_count()
if null_count > 0:
raise ValueError(f"outcome_col '{outcome_col}' contains null values")
# Check treatment_col contains only 0 and 1
treatment_values = df[treatment_col].unique().to_list()
valid_treatment_values = {0, 1, 0.0, 1.0}
for val in treatment_values:
if val not in valid_treatment_values:
raise ValueError("treatment_col must contain only 0 and 1 values")
# Check bootstrap_iterations >= 0 (0 = no bootstrap, ATT only)
if bootstrap_iterations < 0:
raise ValueError("bootstrap_iterations must be at least 0")
# ========================================================================
# Panel Structure Detection
# ========================================================================
# Get unique units and periods
unique_units = df[unit_col].unique().sort().to_list()
unique_periods = df[time_col].unique().sort().to_list()
n_units = len(unique_units)
n_periods = len(unique_periods)
# Create mappings from unit/period values to indices
unit_to_idx = {unit: idx for idx, unit in enumerate(unique_units)}
period_to_idx = {period: idx for idx, period in enumerate(unique_periods)}
# Validate balanced panel
expected_rows = n_units * n_periods
if len(df) != expected_rows:
raise ValueError(
f"Panel is not balanced: expected {expected_rows} rows "
f"({n_units} units × {n_periods} periods), found {len(df)}"
)
# Identify control vs treated units
# Control units: treatment=0 in ALL periods
# Treated units: treatment=1 in at least one period
unit_max_treatment = (
df.group_by(unit_col)
.agg(_pl.col(treatment_col).max().alias("max_treatment"))
)
control_units = []
treated_units = []
for row in unit_max_treatment.iter_rows():
unit_val, max_treat = row[0], row[1]
if max_treat == 0 or max_treat == 0.0:
control_units.append(unit_val)
else:
treated_units.append(unit_val)
# Sort for deterministic ordering
control_units = sorted(control_units, key=lambda x: unit_to_idx[x])
treated_units = sorted(treated_units, key=lambda x: unit_to_idx[x])
# Validate sufficient control units
n_control = len(control_units)
if n_control < 2:
raise ValueError(f"At least 2 control units required; found {n_control}")
# Validate at least 1 treated unit
n_treated = len(treated_units)
if n_treated == 0:
raise ValueError("No treated units found in data")
# Identify pre vs post periods
# Pre-periods: all observations have treatment=0
# Post-periods: at least one treated unit has treatment=1
period_max_treatment = (
df.group_by(time_col)
.agg(_pl.col(treatment_col).max().alias("max_treatment"))
)
pre_periods = []
post_periods = []
for row in period_max_treatment.iter_rows():
period_val, max_treat = row[0], row[1]
if max_treat == 0 or max_treat == 0.0:
pre_periods.append(period_val)
else:
post_periods.append(period_val)
# Sort for deterministic ordering
pre_periods = sorted(pre_periods, key=lambda x: period_to_idx[x])
post_periods = sorted(post_periods, key=lambda x: period_to_idx[x])
# Validate sufficient pre-periods
n_pre = len(pre_periods)
if n_pre < 2:
raise ValueError(f"At least 2 pre-treatment periods required; found {n_pre}")
# Validate at least 1 post-period
n_post = len(post_periods)
if n_post == 0:
raise ValueError("No post-treatment periods found")
# ========================================================================
# Extract Outcome Matrix (Row-Major Order)
# ========================================================================
# Sort DataFrame by unit, then time for consistent ordering
df_sorted = df.sort([unit_col, time_col])
# OPTIMIZED: Use vectorized extraction instead of per-unit filter loop
# Since df is sorted by (unit, time), outcomes are already in row-major order:
# [unit_0_period_0, unit_0_period_1, ..., unit_N-1_period_T-1]
# This is O(1) extraction vs O(n_units × n_periods) for the filter loop
outcomes = df_sorted[outcome_col].cast(_pl.Float64).to_list()
# Create index arrays
control_indices = [unit_to_idx[u] for u in control_units]
treated_indices = [unit_to_idx[u] for u in treated_units]
pre_period_indices = [period_to_idx[p] for p in pre_periods]
post_period_indices = [period_to_idx[p] for p in post_periods]
# ========================================================================
# Call Rust Implementation
# ========================================================================
result = _synthetic_did_impl(
outcomes=outcomes,
n_units=n_units,
n_periods=n_periods,
control_indices=control_indices,
treated_indices=treated_indices,
pre_period_indices=pre_period_indices,
post_period_indices=post_period_indices,
bootstrap_iterations=bootstrap_iterations,
seed=seed,
)
# ========================================================================
# Post-Processing Warnings
# ========================================================================
# Check for unit weight concentration
if result.unit_weights:
max_unit_weight = max(result.unit_weights)
if max_unit_weight > _WEIGHT_CONCENTRATION_THRESHOLD:
max_idx = result.unit_weights.index(max_unit_weight)
_warnings.warn(
f"Unit weight concentration: control unit at index {max_idx} has "
f"weight {max_unit_weight:.2%}. Results may be sensitive to this unit.",
UserWarning,
stacklevel=2
)
# Check for time weight concentration
if result.time_weights:
max_time_weight = max(result.time_weights)
if max_time_weight > _WEIGHT_CONCENTRATION_THRESHOLD:
max_idx = result.time_weights.index(max_time_weight)
_warnings.warn(
f"Time weight concentration: pre-period at index {max_idx} has "
f"weight {max_time_weight:.2%}. Results may be sensitive to this period.",
UserWarning,
stacklevel=2
)
# Check for low bootstrap iterations
if bootstrap_iterations < _MIN_RELIABLE_BOOTSTRAP_ITERATIONS:
_warnings.warn(
f"bootstrap_iterations={bootstrap_iterations} is less than {_MIN_RELIABLE_BOOTSTRAP_ITERATIONS}. "
f"Standard error estimates may be unreliable.",
UserWarning,
stacklevel=2
)
return result
[docs]
def synthetic_control(
df: _Union[_pl.DataFrame, "pd.DataFrame"],
unit_col: str,
time_col: str,
outcome_col: str,
treatment_col: str,
method: str = "traditional",
lambda_param: _Optional[float] = None,
compute_se: bool = True,
n_placebo: _Optional[int] = None,
max_iter: int = 1000,
tol: float = 1e-6,
seed: _Optional[int] = None,
) -> SyntheticControlResult:
"""
Compute Synthetic Control (SC) estimator.
Implements the Synthetic Control method from Abadie et al. (2010, 2015),
which constructs a weighted combination of control units to create a
synthetic control that matches the treated unit's pre-treatment outcomes.
Supports four method variants:
- **Traditional**: Classic SC with simplex-constrained weights (Abadie et al., 2010)
- **Penalized**: L2 regularization for more uniform weights
- **Robust**: De-meaned data for matching dynamics instead of levels
- **Augmented**: Bias correction via ridge outcome model (Ben-Michael et al., 2021)
Parameters
----------
df : pl.DataFrame or pd.DataFrame
Panel data in long format with one row per unit-time observation.
Must be a balanced panel (all units observed in all time periods).
Accepts both Polars and pandas DataFrames.
unit_col : str
Column name identifying unique units (e.g., "state", "firm_id").
Must be integer or string type.
time_col : str
Column name identifying time periods (e.g., "year", "quarter").
Must be integer or string type.
outcome_col : str
Column name for the outcome variable. Must be numeric.
treatment_col : str
Column name for treatment indicator. Must contain only 0 and 1 values.
Exactly one unit must be treated (have treatment=1 in post-period).
method : str, default="traditional"
Synthetic control method to use. Options:
- "traditional": Classic SC minimizing pre-treatment MSE
- "penalized": L2 regularized SC for more uniform weights
- "robust": De-meaned SC for matching dynamics
- "augmented": Bias-corrected SC with ridge adjustment
lambda_param : float, optional
Regularization parameter for penalized/augmented methods.
If None, auto-selected via LOOCV for penalized method.
Must be >= 0 when specified.
compute_se : bool, default=True
Whether to compute standard errors via in-space placebo.
n_placebo : int, optional
Number of placebo iterations for SE. If None, uses all control units.
max_iter : int, default=1000
Maximum iterations for Frank-Wolfe optimizer.
tol : float, default=1e-6
Convergence tolerance for optimizer.
seed : int, optional
Random seed for reproducibility. If None, uses system time.
Returns
-------
SyntheticControlResult
Result object with the following attributes:
- att : float
Average Treatment Effect on the Treated
- standard_error : float or None
In-space placebo standard error (None if compute_se=False)
- unit_weights : List[float]
Weights assigned to each control unit (sums to 1)
- pre_treatment_rmse : float
Root Mean Squared Error of pre-treatment fit
- pre_treatment_mse : float
Mean Squared Error of pre-treatment fit
- method : str
Method used ("traditional", "penalized", "robust", "augmented")
- lambda_used : float or None
Lambda parameter used (for penalized/augmented methods)
- n_units_control : int
Number of control units
- n_periods_pre : int
Number of pre-treatment periods
- n_periods_post : int
Number of post-treatment periods
- solver_converged : bool
Whether the Frank-Wolfe solver converged
- solver_iterations : int
Number of optimizer iterations
- n_placebo_used : int or None
Number of successful placebo iterations (if SE computed)
Raises
------
ValueError
- If DataFrame is empty
- If any specified column doesn't exist
- If unit_col or time_col is float type
- If outcome_col is not numeric
- If outcome_col contains null values
- If treatment_col contains values other than 0 and 1
- If not exactly one treated unit found
- If fewer than 1 control unit found
- If fewer than 1 pre-treatment period found
- If no post-treatment periods found
- If panel is not balanced
- If method is not recognized
- If lambda_param < 0
Warns
-----
UserWarning
- If any unit weight > 0.5 (weight concentration on single unit)
- If pre_treatment_rmse > 0.1 × outcome std (poor pre-treatment fit)
Examples
--------
Basic usage with panel data:
>>> import polars as pl
>>> import causers
>>> df = pl.DataFrame({
... 'unit': [1, 1, 1, 2, 2, 2, 3, 3, 3],
... 'time': [1, 2, 3, 1, 2, 3, 1, 2, 3],
... 'y': [1.0, 2.0, 8.0, 1.5, 2.5, 3.0, 1.2, 2.2, 2.8],
... 'treated': [0, 0, 1, 0, 0, 0, 0, 0, 0]
... })
>>> result = causers.synthetic_control(df, 'unit', 'time', 'y', 'treated', seed=42)
>>> print(f"ATT: {result.att:.4f}")
ATT: ...
Using penalized method with auto lambda:
>>> result = causers.synthetic_control(
... df, 'unit', 'time', 'y', 'treated',
... method="penalized", seed=42
... )
>>> print(f"Lambda used: {result.lambda_used}")
Lambda used: ...
Accessing weights and diagnostics:
>>> print(f"Control unit weights: {result.unit_weights}")
Control unit weights: [...]
>>> print(f"Pre-treatment RMSE: {result.pre_treatment_rmse:.4f}")
Pre-treatment RMSE: ...
Without standard errors (faster):
>>> result = causers.synthetic_control(
... df, 'unit', 'time', 'y', 'treated',
... compute_se=False
... )
>>> print(f"ATT: {result.att:.4f} (SE not computed)")
ATT: ... (SE not computed)
Notes
-----
**Key Difference from Synthetic DID**
Synthetic Control requires exactly ONE treated unit, while SDID supports
multiple treated units. If you have multiple treated units, use
`synthetic_did()` instead.
**Panel Structure Detection**
The function automatically detects:
- **Control units**: Units where treatment=0 in all periods
- **Treated unit**: The single unit where treatment=1 in post-period
- **Pre-periods**: Periods where treatment=0 for all units
- **Post-periods**: Periods where the treated unit has treatment=1
**Algorithm**
The SC estimator finds weights ω such that:
.. math::
\\hat{\\omega} = \\arg\\min_{\\omega \\geq 0, \\sum \\omega = 1}
\\sum_{t \\in \\text{pre}} (Y_{1t} - \\sum_j \\omega_j Y_{jt})^2
Then the ATT is:
.. math::
\\hat{\\tau}_{SC} = \\frac{1}{|\\text{post}|} \\sum_{t \\in \\text{post}}
(Y_{1t} - \\sum_j \\hat{\\omega}_j Y_{jt})
**Standard Errors**
Standard errors are computed via in-space placebo:
1. For each control unit, treat it as the "placebo treated" unit
2. Compute SC weights and ATT using remaining controls
3. SE = standard deviation of placebo ATTs
References
----------
Abadie, A., Diamond, A., & Hainmueller, J. (2010). Synthetic Control Methods
for Comparative Case Studies. *Journal of the American Statistical Association*.
Abadie, A., Diamond, A., & Hainmueller, J. (2015). Comparative Politics and
the Synthetic Control Method. *American Journal of Political Science*.
Ben-Michael, E., Feller, A., & Rothstein, J. (2021). The Augmented Synthetic
Control Method. *Journal of the American Statistical Association*.
See Also
--------
SyntheticControlResult : Result class with ATT and diagnostics.
synthetic_did : For multiple treated units with DID adjustment.
"""
# ========================================================================
# pandas Conversion
# ========================================================================
# Detect and convert pandas DataFrames if needed
all_cols = [unit_col, time_col, outcome_col, treatment_col]
df = _convert_dataframe_if_pandas(df, all_cols, int_columns=[unit_col, time_col])
# ========================================================================
# Input Validation
# ========================================================================
# Check DataFrame is not empty
if len(df) == 0:
raise ValueError("Cannot perform SC on empty DataFrame")
# Check all required columns exist
for col_name, col_label in [
(unit_col, "unit_col"),
(time_col, "time_col"),
(outcome_col, "outcome_col"),
(treatment_col, "treatment_col"),
]:
if col_name not in df.columns:
raise ValueError(f"Column '{col_name}' not found in DataFrame")
# Check unit_col is not float
unit_dtype = df[unit_col].dtype
if unit_dtype in (_pl.Float32, _pl.Float64):
raise ValueError("unit_col must be integer or string, not float")
# Check time_col is not float
time_dtype = df[time_col].dtype
if time_dtype in (_pl.Float32, _pl.Float64):
raise ValueError("time_col must be integer or string, not float")
# Check outcome_col is numeric
outcome_dtype = df[outcome_col].dtype
numeric_types = (
_pl.Float32, _pl.Float64,
_pl.Int8, _pl.Int16, _pl.Int32, _pl.Int64,
_pl.UInt8, _pl.UInt16, _pl.UInt32, _pl.UInt64,
)
if outcome_dtype not in numeric_types:
raise ValueError("outcome_col must be numeric")
# Check for nulls in outcome_col
null_count = df[outcome_col].null_count()
if null_count > 0:
raise ValueError(f"outcome_col '{outcome_col}' contains null values")
# Check treatment_col contains only 0 and 1
treatment_values = df[treatment_col].unique().to_list()
valid_treatment_values = {0, 1, 0.0, 1.0}
for val in treatment_values:
if val not in valid_treatment_values:
raise ValueError("treatment_col must contain only 0 and 1 values")
# Validate method
valid_methods = {"traditional", "penalized", "robust", "augmented"}
method_lower = method.lower()
if method_lower not in valid_methods:
raise ValueError(
f"method must be one of {valid_methods}, got '{method}'"
)
# Validate lambda_param
if lambda_param is not None and lambda_param < 0:
raise ValueError(f"lambda_param must be >= 0, got {lambda_param}")
# ========================================================================
# Panel Structure Detection
# ========================================================================
# Get unique units and periods
unique_units = df[unit_col].unique().sort().to_list()
unique_periods = df[time_col].unique().sort().to_list()
n_units = len(unique_units)
n_periods = len(unique_periods)
# Create mappings from unit/period values to indices
unit_to_idx = {unit: idx for idx, unit in enumerate(unique_units)}
period_to_idx = {period: idx for idx, period in enumerate(unique_periods)}
# Validate balanced panel
expected_rows = n_units * n_periods
if len(df) != expected_rows:
raise ValueError(
f"Panel is not balanced: expected {expected_rows} rows "
f"({n_units} units × {n_periods} periods), found {len(df)}"
)
# Identify control vs treated units
# Control units: treatment=0 in ALL periods
# Treated unit: treatment=1 in at least one period
unit_max_treatment = (
df.group_by(unit_col)
.agg(_pl.col(treatment_col).max().alias("max_treatment"))
)
control_units = []
treated_units = []
for row in unit_max_treatment.iter_rows():
unit_val, max_treat = row[0], row[1]
if max_treat == 0 or max_treat == 0.0:
control_units.append(unit_val)
else:
treated_units.append(unit_val)
# Sort for deterministic ordering
control_units = sorted(control_units, key=lambda x: unit_to_idx[x])
treated_units = sorted(treated_units, key=lambda x: unit_to_idx[x])
# Validate exactly 1 treated unit (SC requirement)
n_treated = len(treated_units)
if n_treated == 0:
raise ValueError("No treated units found in data")
if n_treated > 1:
raise ValueError(
f"Synthetic Control requires exactly 1 treated unit; found {n_treated}. "
f"For multiple treated units, use synthetic_did() instead."
)
# Validate sufficient control units
n_control = len(control_units)
if n_control < 1:
raise ValueError(f"At least 1 control unit required; found {n_control}")
# For augmented method, need at least 2 controls
if method_lower == "augmented" and n_control < 2:
raise ValueError(
f"Augmented SC requires at least 2 control units; found {n_control}"
)
# Identify pre vs post periods
# Pre-periods: all observations have treatment=0
# Post-periods: at least one treated unit has treatment=1
period_max_treatment = (
df.group_by(time_col)
.agg(_pl.col(treatment_col).max().alias("max_treatment"))
)
pre_periods = []
post_periods = []
for row in period_max_treatment.iter_rows():
period_val, max_treat = row[0], row[1]
if max_treat == 0 or max_treat == 0.0:
pre_periods.append(period_val)
else:
post_periods.append(period_val)
# Sort for deterministic ordering
pre_periods = sorted(pre_periods, key=lambda x: period_to_idx[x])
post_periods = sorted(post_periods, key=lambda x: period_to_idx[x])
# Validate sufficient pre-periods
n_pre = len(pre_periods)
if n_pre < 1:
raise ValueError(f"At least 1 pre-treatment period required; found {n_pre}")
# For augmented method, need at least 2 pre-periods
if method_lower == "augmented" and n_pre < 2:
raise ValueError(
f"Augmented SC requires at least 2 pre-treatment periods; found {n_pre}"
)
# Validate at least 1 post-period
n_post = len(post_periods)
if n_post == 0:
raise ValueError("No post-treatment periods found")
# ========================================================================
# Extract Outcome Matrix (Row-Major Order)
# ========================================================================
# Sort DataFrame by unit, then time for consistent ordering
df_sorted = df.sort([unit_col, time_col])
# OPTIMIZED: Use vectorized extraction instead of per-unit filter loop
# Since df is sorted by (unit, time), outcomes are already in row-major order:
# [unit_0_period_0, unit_0_period_1, ..., unit_N-1_period_T-1]
# This is O(1) extraction vs O(n_units × n_periods) for the filter loop
outcomes = df_sorted[outcome_col].cast(_pl.Float64).to_list()
# Create index arrays
control_indices = [unit_to_idx[u] for u in control_units]
treated_index = unit_to_idx[treated_units[0]] # Single treated unit
pre_period_indices = [period_to_idx[p] for p in pre_periods]
post_period_indices = [period_to_idx[p] for p in post_periods]
# Determine n_placebo (default: all control units)
actual_n_placebo = n_placebo if n_placebo is not None else n_control
# ========================================================================
# Call Rust Implementation
# ========================================================================
result = _synthetic_control_impl(
outcomes=outcomes,
n_units=n_units,
n_periods=n_periods,
control_indices=control_indices,
treated_index=treated_index,
pre_period_indices=pre_period_indices,
post_period_indices=post_period_indices,
method=method_lower,
lambda_param=lambda_param,
compute_se=compute_se,
n_placebo=actual_n_placebo,
max_iter=max_iter,
tol=tol,
seed=seed,
)
# ========================================================================
# Post-Processing Warnings
# ========================================================================
# Check for unit weight concentration
if result.unit_weights:
max_unit_weight = max(result.unit_weights)
if max_unit_weight > _WEIGHT_CONCENTRATION_THRESHOLD:
max_idx = result.unit_weights.index(max_unit_weight)
_warnings.warn(
f"Unit weight concentration: control unit at index {max_idx} has "
f"weight {max_unit_weight:.2%}. Results may be sensitive to this unit.",
UserWarning,
stacklevel=2
)
# Check for poor pre-treatment fit
outcome_std = df[outcome_col].std()
if outcome_std is not None and outcome_std > 0:
relative_rmse = result.pre_treatment_rmse / outcome_std
if relative_rmse > _POOR_FIT_RMSE_RATIO:
_warnings.warn(
f"Pre-treatment RMSE ({result.pre_treatment_rmse:.4f}) is "
f"{relative_rmse:.1%} of outcome std ({outcome_std:.4f}). "
f"Consider using a different method or checking data quality.",
UserWarning,
stacklevel=2
)
return result
[docs]
def dml(
df: _Union[_pl.DataFrame, "pd.DataFrame"],
y_col: str,
d_col: str,
x_cols: _Union[str, _List[str]],
n_folds: int = 5,
treatment_type: str = "binary",
estimate_cate: bool = False,
alpha: float = 0.05,
propensity_clip: tuple = (0.01, 0.99),
cluster: _Optional[str] = None,
seed: _Optional[int] = None,
) -> DMLResult:
"""
Estimate causal treatment effects using Double Machine Learning (DML).
Implements the DML estimator from Chernozhukov et al. (2018) with cross-fitting
for debiased inference. Uses linear regression for outcome model and
logistic (binary) or linear (continuous) regression for propensity model.
Parameters
----------
df : pl.DataFrame or pd.DataFrame
The DataFrame containing the data. Accepts both Polars and pandas.
y_col : str
Name of the outcome variable column
d_col : str
Name of the treatment variable column
x_cols : str or List[str]
Name(s) of the covariate columns to control for
n_folds : int, default=5
Number of cross-fitting folds. Must be >= 2.
treatment_type : str, default="binary"
Type of treatment variable:
- "binary": 0/1 treatment (uses logistic propensity model)
- "continuous": Continuous treatment (uses linear propensity model)
estimate_cate : bool, default=False
Whether to estimate Conditional Average Treatment Effect (CATE)
coefficients. If True, returns CATE as linear function of X.
alpha : float, default=0.05
Significance level for confidence intervals. 0.05 = 95% CI.
propensity_clip : tuple, default=(0.01, 0.99)
Bounds for propensity score clipping (binary treatment only).
Extreme propensity scores are clipped to these bounds.
cluster : str, optional
Column name for cluster identifiers for cluster-robust standard errors.
When specified, uses cluster-robust Neyman-orthogonal variance estimation
with G/(G-1) small-sample adjustment. Requires at least 2 clusters.
seed : int, optional
Random seed for reproducible fold assignment. When None,
uses deterministic row-order assignment.
Returns
-------
DMLResult
Result object with the following attributes:
- theta : float
Average Treatment Effect (ATE) point estimate
- standard_error : float
Neyman-orthogonal robust standard error
- confidence_interval : Tuple[float, float]
(1-alpha) confidence interval bounds (lower, upper)
- p_value : float
Two-sided p-value for H₀: θ = 0
- n_samples : int
Number of observations used
- n_folds : int
Number of cross-fitting folds used
- propensity_residual_var : float
Variance of treatment residuals Var(D̃)
- outcome_residual_var : float
Variance of outcome residuals Var(Ỹ)
- outcome_r_squared : float
Average R² of outcome nuisance model across folds
- propensity_r_squared : float
Average R² (or pseudo-R²) of propensity model across folds
- n_propensity_clipped : int
Count of propensity scores clipped to bounds
- cate_coefficients : Dict[str, float] or None
CATE coefficients keyed by covariate name (if estimate_cate=True)
- cate_standard_errors : Dict[str, float] or None
CATE coefficient standard errors (if estimate_cate=True)
Raises
------
ValueError
- If x_cols is empty
- If n_folds is not 2, 5, or 10
- If n_folds >= n_samples
- If treatment_type is not "binary" or "continuous"
- If binary treatment doesn't contain both 0 and 1
- If binary treatment contains non-0/1 values
- If alpha is not in (0, 1)
- If propensity_clip bounds are invalid
- If treatment has no variation
- If treatment is fully explained by covariates
- If nuisance model fails to converge
- If covariate matrix is singular in any fold
Examples
--------
Basic ATE estimation with binary treatment:
>>> import polars as pl
>>> import causers
>>> df = pl.DataFrame({
... "y": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
... "d": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
... "x1": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
... "x2": [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]
... })
>>> result = causers.dml(df, "y", "d", ["x1", "x2"], n_folds=2, seed=42)
>>> print(f"ATE: {result.theta:.4f} ± {result.standard_error:.4f}")
ATE: ... ± ...
With CATE estimation:
>>> result = causers.dml(
... df, "y", "d", ["x1", "x2"],
... estimate_cate=True, seed=42
... )
>>> print(f"CATE coefficients: {result.cate_coefficients}")
CATE coefficients: {...}
Continuous treatment:
>>> df["d_cont"] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
>>> result = causers.dml(
... df, "y", "d_cont", ["x1", "x2"],
... treatment_type="continuous", seed=42
... )
>>> print(f"ATE: {result.theta:.4f}")
ATE: ...
Accessing diagnostics:
>>> print(f"Outcome R²: {result.outcome_r_squared:.3f}")
Outcome R²: ...
>>> print(f"Propensity R²: {result.propensity_r_squared:.3f}")
Propensity R²: ...
>>> print(result.summary()) # Formatted summary
Double Machine Learning Results
...
Notes
-----
**Algorithm Overview**
The DML estimator uses cross-fitting (sample splitting) to avoid overfitting
bias in the nuisance models:
1. Partition data into K folds for cross-fitting
2. For each fold k:
- Train outcome model ℓ(X) on observations NOT in fold k
- Train propensity model m(X) on observations NOT in fold k
- Predict for observations IN fold k (out-of-fold predictions)
3. Compute residuals: Ỹ = Y - ℓ̂(X), D̃ = D - m̂(X)
4. Final-stage regression: θ̂ = (D̃'D̃)⁻¹ D̃'Ỹ
5. Neyman-orthogonal variance: V̂ = (1/N) × J⁻² × Σψ²ᵢ
**CATE Estimation**
When estimate_cate=True, estimates heterogeneous treatment effects as:
τ(X) = θ₀ + X'γ
where γ coefficients capture how treatment effect varies with covariates.
**Standard Errors**
Uses Neyman-orthogonal variance estimation which provides:
- Robustness to first-stage nuisance estimation error
- Valid inference even with moderate sample sizes
- Automatic bias correction from cross-fitting
References
----------
Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C.,
Newey, W., & Robins, J. (2018). Double/debiased machine learning for
treatment and structural parameters. *The Econometrics Journal*, 21(1), C1-C68.
See Also
--------
DMLResult : Result class with ATE, CATE, and diagnostics.
linear_regression : For standard regression analysis.
"""
# Normalize x_cols to always be a list
if isinstance(x_cols, str):
x_cols_list = [x_cols]
else:
x_cols_list = list(x_cols)
# Detect and convert pandas DataFrames if needed
all_cols = [y_col, d_col] + x_cols_list
if cluster:
all_cols.append(cluster)
df = _convert_dataframe_if_pandas(df, all_cols)
# Normalize treatment_type to lowercase for case-insensitive matching
treatment_type_normalized = treatment_type.lower()
# Extract propensity_clip bounds
propensity_clip_low, propensity_clip_high = propensity_clip
# Call the Rust implementation
result = _dml_impl(
df,
y_col,
d_col,
x_cols_list,
n_folds,
treatment_type_normalized,
estimate_cate,
alpha,
propensity_clip_low,
propensity_clip_high,
cluster,
seed,
)
return result
[docs]
def two_stage_least_squares(
df: _Union[_pl.DataFrame, "pd.DataFrame"],
y_col: str,
d_cols: _Union[str, _List[str]],
z_cols: _Union[str, _List[str]],
x_cols: _Optional[_Union[str, _List[str]]] = None,
include_intercept: bool = True,
robust: bool = False,
cluster: _Optional[str] = None,
) -> TwoStageLSResult:
"""
Estimate causal effects using Two-Stage Least Squares (2SLS) instrumental variables.
The 2SLS estimator addresses endogeneity problems where treatment variables
are correlated with the error term. It uses instrumental variables (Z) that
affect the outcome (Y) only through their effect on the endogenous treatment (D).
**Algorithm**:
*First Stage*: Regress each endogenous variable on instruments and exogenous controls:
.. math::
D = Z\\pi + X\\delta + \\nu
*Second Stage*: Regress outcome on predicted endogenous values and controls:
.. math::
Y = \\hat{D}\\beta + X\\gamma + \\epsilon
**CRITICAL**: Standard errors are computed using residuals from the **original D**,
not the predicted D̂. Using D̂ would understate variance.
Parameters
----------
df : pl.DataFrame or pd.DataFrame
The DataFrame containing the data. Accepts both Polars and pandas.
y_col : str
Name of the outcome variable column
d_cols : str or List[str]
Name(s) of the endogenous treatment variable column(s).
These are the variables suspected to be correlated with the error term.
z_cols : str or List[str]
Name(s) of the excluded instrument column(s).
Instruments must:
- Affect Y only through D (exclusion restriction)
- Be correlated with D (relevance)
- Be uncorrelated with the error (exogeneity)
x_cols : str or List[str], optional
Name(s) of exogenous control variable column(s).
These variables are included in both stages.
include_intercept : bool, default=True
Whether to include an intercept term in both stages.
robust : bool, default=False
If True, compute HC3 heteroskedasticity-robust standard errors.
If False, compute conventional (homoskedastic) standard errors.
cluster : str, optional
Column name for cluster identifiers for cluster-robust SE.
When specified, computes cluster-robust standard errors.
Returns
-------
TwoStageLSResult
Result object with the following attributes:
- coefficients : List[float]
Structural coefficients for endogenous + exogenous variables
(excluding intercept). Order: d_cols first, then x_cols.
- standard_errors : List[float]
Standard errors for all coefficients
- intercept : float or None
Intercept term (None if include_intercept=False)
- intercept_se : float or None
Standard error for intercept
- n_samples : int
Number of observations used
- n_endogenous : int
Number of endogenous regressors
- n_instruments : int
Number of excluded instruments
- first_stage_f : List[float]
F-statistics for each endogenous variable.
Rule of thumb: F < 10 suggests weak instruments.
F < 4 raises an error (too weak for reliable inference).
- first_stage_coefficients : List[List[float]]
First-stage coefficients for instruments only (per endogenous variable)
- cragg_donald : float or None
Cragg-Donald statistic for multiple endogenous variables.
None for single endogenous (use first_stage_f instead).
- stock_yogo_critical : float or None
Stock-Yogo 10% maximal bias critical value for comparison
with Cragg-Donald statistic
- r_squared : float
R² from structural equation (using original D)
- se_type : str
Type of standard errors: "conventional", "hc3", or "clustered"
- n_clusters : int or None
Number of clusters if clustered SE used
Raises
------
ValueError
- If d_cols is empty
- If z_cols is empty
- If number of instruments < number of endogenous variables (under-identified)
- If first-stage F-statistic < 4 for any endogenous variable
- If first-stage or second-stage design matrix is singular
- If not enough observations for the number of parameters
- If any column contains null values
- If any column has zero variance
- If cluster column not found when cluster specified
Warns
-----
UserWarning
- If first-stage F-statistic < 10 (weak instruments warning)
- If Cragg-Donald < Stock-Yogo critical value
- If number of instruments > n_samples/10
Examples
--------
Basic IV regression with one endogenous variable:
>>> import polars as pl
>>> import causers
>>> # Angrist-Krueger style: quarter of birth as instrument for education
>>> df = pl.DataFrame({
... "wage": [2.5, 3.0, 2.8, 3.5, 3.2, 4.0, 3.8, 4.5, 4.2, 5.0],
... "educ": [10, 11, 10, 12, 11, 14, 13, 15, 14, 16],
... "quarter_born": [1, 4, 2, 3, 1, 4, 2, 4, 3, 4],
... "age": [30, 32, 31, 33, 34, 35, 36, 38, 40, 42]
... })
>>> result = causers.two_stage_least_squares(
... df,
... y_col="wage",
... d_cols="educ",
... z_cols="quarter_born",
... x_cols="age"
... )
>>> print(f"Returns to education: {result.coefficients[0]:.4f}")
Returns to education: ...
>>> print(f"First-stage F: {result.first_stage_f[0]:.1f}")
First-stage F: ...
Multiple endogenous variables:
>>> # Two endogenous vars with two instruments
>>> result = causers.two_stage_least_squares(
... df,
... y_col="wage",
... d_cols=["educ", "age"],
... z_cols=["quarter_born", "year_born"]
... )
>>> print(f"Cragg-Donald: {result.cragg_donald}")
Cragg-Donald: ...
Robust standard errors:
>>> result = causers.two_stage_least_squares(
... df, "wage", "educ", "quarter_born",
... robust=True
... )
>>> print(f"SE type: {result.se_type}")
SE type: hc3
Clustered standard errors:
>>> df = df.with_columns(_pl.lit([1, 1, 2, 2, 3, 3, 4, 4, 5, 5]).alias("state"))
>>> result = causers.two_stage_least_squares(
... df, "wage", "educ", "quarter_born",
... cluster="state"
... )
>>> print(f"N clusters: {result.n_clusters}")
N clusters: 5
Notes
-----
**Identification Requirements**
The model is identified when:
- m ≥ k₁ (number of instruments ≥ number of endogenous variables)
- m = k₁: exactly identified (use all instruments)
- m > k₁: over-identified (more instruments than needed)
**Weak Instrument Detection**
Weak instruments (those poorly correlated with D) lead to:
- Biased 2SLS estimates (toward OLS bias)
- Unreliable inference (undersized confidence intervals)
This function uses two diagnostics:
1. **First-stage F-statistic**: For single endogenous variable.
Rule of thumb: F > 10 is generally considered acceptable.
F < 4 raises an error as inference is unreliable.
2. **Cragg-Donald statistic**: For multiple endogenous variables.
Compare to Stock-Yogo critical values for desired bias/size control.
**Standard Errors**
Three options are available:
- **Conventional**: Assumes homoskedasticity. Use when error variance
is believed constant across observations.
- **HC3 Robust**: Heteroskedasticity-consistent standard errors.
Recommended when error variance may vary with X.
- **Clustered**: For data with within-cluster correlation
(e.g., students within schools, observations over time).
References
----------
Angrist, J. D., & Pischke, J. S. (2009). Mostly Harmless Econometrics:
An Empiricist's Companion. Princeton University Press.
Stock, J. H., & Yogo, M. (2005). Testing for Weak Instruments in Linear
IV Regression. In D. W. K. Andrews & J. H. Stock (Eds.), Identification
and Inference for Econometric Models (pp. 80-108). Cambridge University Press.
Staiger, D., & Stock, J. H. (1997). Instrumental Variables Regression with
Weak Instruments. Econometrica, 65(3), 557-586.
See Also
--------
TwoStageLSResult : Result class with coefficients and diagnostics.
linear_regression : For standard OLS regression without instruments.
dml : For causal inference using machine learning methods.
"""
# Normalize d_cols to always be a list
if isinstance(d_cols, str):
d_cols_list = [d_cols]
else:
d_cols_list = list(d_cols)
# Normalize z_cols to always be a list
if isinstance(z_cols, str):
z_cols_list = [z_cols]
else:
z_cols_list = list(z_cols)
# Normalize x_cols to a list or None
x_cols_list: _Optional[_List[str]] = None
if x_cols is not None:
if isinstance(x_cols, str):
x_cols_list = [x_cols]
else:
x_cols_list = list(x_cols)
# Detect and convert pandas DataFrames if needed
all_cols = [y_col] + d_cols_list + z_cols_list
if x_cols_list:
all_cols.extend(x_cols_list)
if cluster:
all_cols.append(cluster)
df = _convert_dataframe_if_pandas(df, all_cols)
# Check for float cluster column and emit warning
_warn_if_float_cluster(df, cluster)
# Call the Rust implementation
result = _two_stage_least_squares_rust(
df,
y_col,
d_cols_list,
z_cols_list,
x_cols_list,
include_intercept,
robust,
cluster,
)
# ========================================================================
# Emit warnings via warnings.warn()
# ========================================================================
# Weak instruments warning (F < 10 but >= 4)
# Note: F < 4 raises an error in Rust, so we only warn for 4 <= F < 10
for i, f_stat in enumerate(result.first_stage_f):
if f_stat < 10.0:
endog_name = d_cols_list[i] if i < len(d_cols_list) else f"D{i}"
_warnings.warn(
f"Weak instruments: first-stage F-statistic ({f_stat:.2f}) is below 10 "
f"for endogenous variable '{endog_name}'",
UserWarning,
stacklevel=2,
)
# Cragg-Donald below Stock-Yogo critical value
if result.cragg_donald is not None and result.stock_yogo_critical is not None:
if result.cragg_donald < result.stock_yogo_critical:
_warnings.warn(
f"Cragg-Donald statistic ({result.cragg_donald:.2f}) is below "
f"Stock-Yogo 10% critical value ({result.stock_yogo_critical:.2f})",
UserWarning,
stacklevel=2,
)
# Many instruments relative to sample size
if result.n_instruments > result.n_samples // 10:
_warnings.warn(
f"Large number of instruments ({result.n_instruments}) relative to "
f"sample size ({result.n_samples}); consider using fewer",
UserWarning,
stacklevel=2,
)
return result
# ============================================================
# BALANCE CHECK
# ============================================================
[docs]
class BalanceCheckResult:
"""Python wrapper around the Rust BalanceResult with convenience methods.
This class wraps the native ``BalanceResult`` returned by Rust and adds
``summary()``, ``imbalanced()``, and ``to_dataframe()`` helper methods.
All attributes of the underlying Rust object are accessible directly
(e.g. ``result.smd``, ``result.n_treated``).
Attributes
----------
mean_treated : dict[str, float]
Mean of each covariate in the treatment group.
mean_control : dict[str, float]
Mean of each covariate in the control group.
var_treated : dict[str, float]
Variance of each covariate in the treatment group.
var_control : dict[str, float]
Variance of each covariate in the control group.
sd_treated : dict[str, float]
Standard deviation of each covariate in the treatment group.
sd_control : dict[str, float]
Standard deviation of each covariate in the control group.
smd : dict[str, float]
Standardized Mean Difference for each covariate.
variance_ratio : dict[str, float]
Variance ratio (treated / control) for each covariate.
n_treated : int
Number of observations in the treatment group.
n_control : int
Number of observations in the control group.
ess_treated : float or None
Effective sample size for the treatment group (weighted analysis only).
ess_control : float or None
Effective sample size for the control group (weighted analysis only).
covariates : list[str]
Covariate names (after categorical expansion).
is_weighted : bool
Whether weighted statistics were computed.
"""
[docs]
def __init__(self, _inner: _BalanceResult) -> None:
self._inner = _inner
def __getattr__(self, name: str):
return getattr(self._inner, name)
def __repr__(self) -> str:
return repr(self._inner)
[docs]
def summary(self) -> _pl.DataFrame:
"""Return a Polars DataFrame summarizing balance statistics.
Columns: covariate, mean_treated, mean_control, sd_treated,
sd_control, smd, variance_ratio.
Returns
-------
pl.DataFrame
One row per covariate with key balance statistics.
"""
rows = []
for cov in self._inner.covariates:
rows.append({
"covariate": cov,
"mean_treated": self._inner.mean_treated.get(cov),
"mean_control": self._inner.mean_control.get(cov),
"sd_treated": self._inner.sd_treated.get(cov),
"sd_control": self._inner.sd_control.get(cov),
"smd": self._inner.smd.get(cov),
"variance_ratio": self._inner.variance_ratio.get(cov),
})
return _pl.DataFrame(rows)
[docs]
def imbalanced(self, threshold: float = 0.1) -> _List[str]:
"""Return covariate names with |SMD| exceeding *threshold*.
Parameters
----------
threshold : float, default 0.1
Absolute SMD threshold for flagging imbalance.
Returns
-------
list[str]
Covariate names whose absolute SMD exceeds the threshold.
Covariates with NaN SMD (e.g. from zero variance in both
groups) are excluded.
"""
out = []
for cov in self._inner.covariates:
smd_val = self._inner.smd.get(cov, 0.0)
if _math.isnan(smd_val):
continue
if abs(smd_val) > threshold:
out.append(cov)
return out
[docs]
def to_dataframe(self) -> _pl.DataFrame:
"""Export all per-covariate statistics as a Polars DataFrame.
Columns: covariate, mean_treated, mean_control, var_treated,
var_control, sd_treated, sd_control, smd, variance_ratio.
Returns
-------
pl.DataFrame
One row per covariate with all computed statistics.
"""
rows = []
for cov in self._inner.covariates:
rows.append({
"covariate": cov,
"mean_treated": self._inner.mean_treated.get(cov),
"mean_control": self._inner.mean_control.get(cov),
"var_treated": self._inner.var_treated.get(cov),
"var_control": self._inner.var_control.get(cov),
"sd_treated": self._inner.sd_treated.get(cov),
"sd_control": self._inner.sd_control.get(cov),
"smd": self._inner.smd.get(cov),
"variance_ratio": self._inner.variance_ratio.get(cov),
})
return _pl.DataFrame(rows)
[docs]
def balance_check(
df: _Union[_pl.DataFrame, "pd.DataFrame"],
treatment_col: str,
covariate_cols: _Union[str, _List[str]],
weights: _Optional[str] = None,
treatment_value=None,
control_value=None,
max_categorical_levels: int = 1000,
) -> BalanceCheckResult:
"""Check covariate balance between treatment and control groups.
Computes standardized mean differences (SMD), variance ratios, and
group-level summary statistics for each covariate. Supports weighted
analysis (e.g. inverse-propensity weights) and automatic treatment /
control value detection for binary indicators.
Parameters
----------
df : pl.DataFrame or pd.DataFrame
The DataFrame containing the data. Accepts both Polars and pandas.
treatment_col : str
Column that identifies the treatment assignment.
covariate_cols : str or list[str]
Column name(s) of the covariates to check for balance.
weights : str, optional
Column name containing observation weights (e.g. IPW weights).
When provided, weighted means / variances are computed.
treatment_value : int, float, str, or None
Value in ``treatment_col`` that denotes the *treated* group.
If ``None``, auto-detected from the column: when exactly two
unique non-null values exist the larger value is used.
control_value : int, float, str, or None
Value in ``treatment_col`` that denotes the *control* group.
If ``None``, auto-detected (the value that is not
``treatment_value``).
max_categorical_levels : int, default 1000
Maximum number of unique levels allowed for categorical columns
before raising an error.
Returns
-------
BalanceCheckResult
Result object exposing all per-covariate statistics plus the
convenience methods ``summary()``, ``imbalanced()``, and
``to_dataframe()``.
Raises
------
ValueError
- If ``treatment_col`` or any covariate column is missing.
- If auto-detection finds != 2 unique non-null treatment values.
- If the treatment column contains only one unique value.
Warns
-----
UserWarning
- Large imbalance: |SMD| > 0.25 for any covariate.
- Extreme variance ratio: VR < 0.5 or VR > 2.0.
- Small treatment or control group (n < 10).
- Low effective sample size (ESS < 10) in weighted analysis.
Examples
--------
Basic balance check with a binary treatment:
>>> import polars as pl
>>> import causers
>>> df = pl.DataFrame({
... "treated": [1, 1, 1, 0, 0, 0],
... "age": [25, 30, 35, 40, 45, 50],
... "income": [50000, 60000, 55000, 48000, 52000, 47000],
... })
>>> result = causers.balance_check(df, "treated", ["age", "income"])
>>> print(result.smd)
{...}
Using the convenience methods:
>>> summary_df = result.summary()
>>> imbalanced_covs = result.imbalanced(threshold=0.1)
>>> full_df = result.to_dataframe()
Weighted balance check:
>>> df = df.with_columns(pl.lit([1.0, 1.0, 1.0, 0.5, 0.8, 0.7]).alias("w"))
>>> result = causers.balance_check(df, "treated", ["age", "income"], weights="w")
>>> print(result.is_weighted)
True
See Also
--------
BalanceCheckResult : Result class with balance statistics and helpers.
"""
# ------------------------------------------------------------------
# Normalize covariate_cols
# ------------------------------------------------------------------
if isinstance(covariate_cols, str):
cov_cols_list = [covariate_cols]
else:
cov_cols_list = list(covariate_cols)
# ------------------------------------------------------------------
# pandas conversion
# ------------------------------------------------------------------
required_cols = [treatment_col] + cov_cols_list
if weights is not None:
required_cols.append(weights)
df = _convert_dataframe_if_pandas(df, required_cols)
# ------------------------------------------------------------------
# Auto-detect treatment / control values
# ------------------------------------------------------------------
tv_int = tv_float = tv_str = None
cv_int = cv_float = cv_str = None
if treatment_value is None:
# Auto-detect from column
unique_vals = df[treatment_col].drop_nulls().unique().sort().to_list()
if len(unique_vals) != 2:
raise ValueError(
f"Cannot auto-detect treatment/control values: "
f"treatment column '{treatment_col}' has {len(unique_vals)} "
f"unique non-null value(s); expected exactly 2."
)
# Use the larger value as treatment
control_value = unique_vals[0]
treatment_value = unique_vals[1]
# Route treatment_value to the correct typed parameter
if isinstance(treatment_value, bool):
tv_int = int(treatment_value)
elif isinstance(treatment_value, int):
tv_int = treatment_value
elif isinstance(treatment_value, float):
tv_float = treatment_value
elif isinstance(treatment_value, str):
tv_str = treatment_value
else:
raise TypeError(
f"treatment_value must be int, float, or str; got {type(treatment_value).__name__}"
)
# Route control_value to the correct typed parameter
if control_value is not None:
if isinstance(control_value, bool):
cv_int = int(control_value)
elif isinstance(control_value, int):
cv_int = control_value
elif isinstance(control_value, float):
cv_float = control_value
elif isinstance(control_value, str):
cv_str = control_value
else:
raise TypeError(
f"control_value must be int, float, or str; got {type(control_value).__name__}"
)
# ------------------------------------------------------------------
# Call the Rust implementation
# ------------------------------------------------------------------
inner = _balance_check_impl(
df,
treatment_col,
cov_cols_list,
weights,
tv_int,
tv_float,
tv_str,
cv_int,
cv_float,
cv_str,
max_categorical_levels,
)
result = BalanceCheckResult(inner)
# ------------------------------------------------------------------
# Post-call warnings
# ------------------------------------------------------------------
# Per-covariate warnings (single pass)
for cov in result.covariates:
smd_val = result.smd.get(cov, 0.0)
vr = result.variance_ratio.get(cov)
vt = result.var_treated.get(cov)
vc = result.var_control.get(cov)
# Large imbalance (FR-BAL-64)
if not _math.isnan(smd_val) and abs(smd_val) > 0.25:
_warnings.warn(
f"Large imbalance detected for covariate '{cov}': SMD = {smd_val:.4f}",
UserWarning,
stacklevel=2,
)
# Extreme variance ratio (FR-BAL-65)
if vr is not None and not _math.isnan(vr) and (vr < 0.5 or vr > 2.0):
_warnings.warn(
f"Extreme variance ratio for covariate '{cov}': {vr:.4f}",
UserWarning,
stacklevel=2,
)
# Zero variance in exactly one group (FR-BAL-62)
if vt is not None and vc is not None:
vt_zero = vt == 0.0
vc_zero = vc == 0.0
if vt_zero and not vc_zero:
_warnings.warn(
f"Covariate '{cov}' has zero variance in treatment group",
UserWarning,
stacklevel=2,
)
elif vc_zero and not vt_zero:
_warnings.warn(
f"Covariate '{cov}' has zero variance in control group",
UserWarning,
stacklevel=2,
)
# Numerically unstable weighted variance (FR-BAL-69a)
if result.is_weighted:
if vt is not None and _math.isnan(vt):
_warnings.warn(
f"Weighted variance for '{cov}' in treatment group is numerically "
f"unstable (ESS \u2248 1); returning NaN",
UserWarning,
stacklevel=2,
)
if vc is not None and _math.isnan(vc):
_warnings.warn(
f"Weighted variance for '{cov}' in control group is numerically "
f"unstable (ESS \u2248 1); returning NaN",
UserWarning,
stacklevel=2,
)
# Small samples
if result.n_treated < 10:
_warnings.warn(
f"Small treatment group: n_treated = {result.n_treated}. "
f"Balance statistics may be unreliable.",
UserWarning,
stacklevel=2,
)
if result.n_control < 10:
_warnings.warn(
f"Small control group: n_control = {result.n_control}. "
f"Balance statistics may be unreliable.",
UserWarning,
stacklevel=2,
)
# Low effective sample size
if result.ess_treated is not None and result.ess_treated < 10:
_warnings.warn(
f"Low effective sample size in treatment group: ESS = {result.ess_treated:.2f}. "
f"Weighted balance statistics may be unreliable.",
UserWarning,
stacklevel=2,
)
if result.ess_control is not None and result.ess_control < 10:
_warnings.warn(
f"Low effective sample size in control group: ESS = {result.ess_control:.2f}. "
f"Weighted balance statistics may be unreliable.",
UserWarning,
stacklevel=2,
)
return result