Source code for trancit.utils.core
"""
Core utilities for Dynamic Causal Strength (DCS).
This module provides fundamental utilities for event extraction, statistics computation,
and DeSnap analysis. These functions form the backbone of the DCS pipeline.
"""
import logging
from typing import Dict, Union
import numpy as np
from ..config import DeSnapParams
from ..core.exceptions import ComputationError, ValidationError
from .preprocess import regularize_if_singular
from .residuals import estimate_residuals
logger = logging.getLogger(__name__)
def extract_event_windows(
signal: np.ndarray, centers: np.ndarray, start_offset: int, window_length: int
) -> np.ndarray:
"""
Extract windows of data from a signal around specified center points.
This function extracts fixed-length windows from a 1D signal around each
specified center point. The windows are extracted with a specified offset
from the center and have a fixed length.
Parameters
----------
signal : np.ndarray
1D array representing the signal data.
centers : np.ndarray
1D array of center points (indices) around which to extract windows.
start_offset : int
Offset from the center to start the window (can be negative).
window_length : int
Length of each window to extract.
Returns
-------
np.ndarray
2D array of shape (window_length, len(centers)) containing the
extracted windows. Invalid windows are filled with NaN.
Raises
------
ValidationError
If input parameters are invalid.
IndexError
If the calculated indices for any window are out of bounds for the signal array.
Examples
--------
>>> import numpy as np
>>> signal = np.random.randn(1000)
>>> centers = np.array([100, 200, 300])
>>> windows = extract_event_windows(
... signal, centers, start_offset=50, window_length=100
... )
>>> print(f"Windows shape: {windows.shape}") # (100, 3)
"""
if not isinstance(signal, np.ndarray) or signal.ndim != 1:
raise ValidationError(
"signal must be a 1D numpy array",
"signal_ndim",
signal.ndim if hasattr(signal, "ndim") else None,
)
if not isinstance(centers, np.ndarray) or centers.ndim != 1:
raise ValidationError(
"centers must be a 1D numpy array",
"centers_ndim",
centers.ndim if hasattr(centers, "ndim") else None,
)
if not isinstance(start_offset, int):
raise ValidationError(
"start_offset must be an integer", "start_offset_type", type(start_offset)
)
if not isinstance(window_length, int) or window_length <= 0:
raise ValidationError(
"window_length must be a positive integer", "window_length", window_length
)
if len(centers) == 0:
logger.warning("Empty centers array provided")
return np.empty((window_length, 0))
event_windows = np.full((window_length, len(centers)), np.nan)
for i, center in enumerate(centers):
try:
start_idx = int(np.round(center - start_offset))
end_idx = start_idx + window_length
idx = np.arange(start_idx, end_idx)
if np.any(idx < 0) or np.any(idx >= len(signal)):
logger.warning(
f"Window {i} (center={center}) out of bounds: "
f"indices {idx[0]}-{idx[-1]} for signal length {len(signal)}"
)
continue
event_windows[:, i] = signal[idx]
except Exception as e:
logger.error(f"Failed to extract window {i} (center={center}): {e}")
continue
valid_windows = np.sum(~np.isnan(event_windows[0, :]))
logger.info(f"Successfully extracted {valid_windows}/{len(centers)} windows")
return event_windows
[docs]
def compute_event_statistics(
event_data: np.ndarray, model_order: int, epsilon: float = 1e-5
) -> Dict[str, Union[np.ndarray, Dict]]:
"""
Compute conditional statistics for VAR time series events.
This function computes mean, covariance, and OLS coefficients for VAR time series
events. It handles both homogeneous and inhomogeneous VAR models.
Parameters
----------
event_data : np.ndarray
VAR time series events of shape (nvar * (model_order + 1), time_points, trials).
model_order : int
The model order for the VAR process.
epsilon : float, optional
Small value for regularization if the matrix is singular. Default is 1e-4.
Returns
-------
Dict[str, Union[np.ndarray, Dict]]
Dictionary containing the conditional statistics:
- 'mean': Mean of the events (shape: (nvar * (model_order + 1),
time_points))
- 'Sigma': Covariance matrices (shape: (time_points,
nvar * (model_order + 1), nvar * (model_order + 1)))
- 'OLS': Dictionary with:
- 'At': OLS coefficients (shape: (time_points, nvar,
nvar * model_order))
- 'bt': Residual biases
- 'Sigma_Et': Residual covariance
- 'sigma_Et': Residual standard deviation
Raises
------
ValidationError
If input parameters are invalid.
ComputationError
If computation fails due to numerical issues.
Examples
--------
>>> import numpy as np
>>> event_data = np.random.randn(6, 50, 10) # (nvar * (morder + 1), time, trials)
>>> stats = compute_conditional_event_statistics(event_data, model_order=2)
>>> print(f"Mean shape: {stats['mean'].shape}")
>>> print(f"Sigma shape: {stats['Sigma'].shape}")
"""
if not isinstance(event_data, np.ndarray) or event_data.ndim != 3:
raise ValidationError(
"event_data must be a 3D numpy array",
"event_data_ndim",
event_data.ndim if hasattr(event_data, "ndim") else None,
)
if not isinstance(model_order, int) or model_order <= 0:
raise ValidationError(
"model_order must be a positive integer", "model_order", model_order
)
if not isinstance(epsilon, (int, float)) or epsilon <= 0:
raise ValidationError("epsilon must be a positive number", "epsilon", epsilon)
nvar = event_data.shape[0] // (model_order + 1)
if event_data.shape[0] != nvar * (model_order + 1):
raise ValidationError(
f"event_data shape {event_data.shape[0]} is not compatible "
f"with model_order {model_order}",
"event_data_shape",
event_data.shape,
)
if event_data.shape[2] < 2:
raise ValidationError(
"At least 2 trials required for statistics computation",
"n_trials",
event_data.shape[2],
)
try:
stats = {
"mean": np.mean(event_data, axis=2),
"Sigma": np.zeros(
(
event_data.shape[1],
nvar * (model_order + 1),
nvar * (model_order + 1),
)
),
"OLS": {"At": np.zeros((event_data.shape[1], nvar, nvar * model_order))},
}
for t in range(event_data.shape[1]):
try:
temp = event_data[:, t, :] - stats["mean"][:, t, np.newaxis]
stats["Sigma"][t, :, :] = np.dot(temp, temp.T) / event_data.shape[2]
Sigma_sub_matrix = stats["Sigma"][t, :nvar, nvar:]
Sigma_past = stats["Sigma"][t, nvar:, nvar:]
if np.linalg.det(Sigma_past) > epsilon:
Sigma_past_inv = np.linalg.inv(Sigma_past)
stats["OLS"]["At"][t, :, :] = np.dot(
Sigma_sub_matrix, Sigma_past_inv
)
else:
logger.warning(
f"Singular Sigma_past at time {t}, using regularization"
)
Sigma_past_regularized = regularize_if_singular(
Sigma_past, epsilon=epsilon
)
stats["OLS"]["At"][t, :, :] = np.linalg.solve(
Sigma_past_regularized, Sigma_sub_matrix.T
).T
except Exception as e:
logger.error(f"Failed to compute statistics at time {t}: {e}")
raise ComputationError(
f"Statistics computation failed at time {t}",
"statistics_computation",
(t,),
)
stats["OLS"]["bt"], stats["OLS"]["Sigma_Et"], stats["OLS"]["sigma_Et"] = (
estimate_residuals(stats)
)
logger.info(
f"Successfully computed statistics for {event_data.shape[1]} time points"
)
return stats
except Exception as e:
logger.error(f"Statistics computation failed: {e}")
raise ComputationError(
f"Statistics computation failed: {e}", "statistics_computation", ()
)
def extract_event_snapshots(
time_series: np.ndarray,
locations: np.ndarray,
model_order: int,
lag_step: int,
start_offset: int,
extract_length: int,
) -> np.ndarray:
"""
Extract event snapshots from time series data.
This function extracts fixed-length snapshots from multivariate time series
data around specified event locations. Each snapshot includes the current
time point and lagged data according to the model order.
Parameters
----------
time_series : np.ndarray
Time series data of shape (n_vars, n_time_points).
locations : np.ndarray
1D array of event location indices.
model_order : int
Model order (number of lags to include).
lag_step : int
Step size between lags.
start_offset : int
Offset from event location to start extraction.
extract_length : int
Length of each snapshot to extract.
Returns
-------
np.ndarray
3D array of shape (n_vars * (model_order + 1), extract_length, n_events)
containing the extracted snapshots.
Raises
------
ValidationError
If input parameters are invalid.
IndexError
If any event location is out of bounds.
Examples
--------
>>> import numpy as np
>>> time_series = np.random.randn(2, 1000) # (n_vars, time)
>>> locations = np.array([100, 200, 300])
>>> snapshots = extract_event_snapshots(
... time_series, locations, model_order=4, lag_step=1,
... start_offset=50, extract_length=100
... )
>>> print(f"Snapshots shape: {snapshots.shape}") # (10, 100, 3)
"""
if not isinstance(time_series, np.ndarray) or time_series.ndim != 2:
raise ValidationError(
"time_series must be a 2D numpy array",
"time_series_ndim",
time_series.ndim if hasattr(time_series, "ndim") else None,
)
if not isinstance(locations, np.ndarray) or locations.ndim != 1:
raise ValidationError(
"locations must be a 1D numpy array",
"locations_ndim",
locations.ndim if hasattr(locations, "ndim") else None,
)
if not isinstance(model_order, int) or model_order <= 0:
raise ValidationError(
"model_order must be a positive integer", "model_order", model_order
)
if not isinstance(lag_step, int) or lag_step <= 0:
raise ValidationError(
"lag_step must be a positive integer", "lag_step", lag_step
)
if not isinstance(start_offset, int):
raise ValidationError(
"start_offset must be an integer", "start_offset", start_offset
)
if not isinstance(extract_length, int) or extract_length <= 0:
raise ValidationError(
"extract_length must be a positive integer",
"extract_length",
extract_length,
)
n_vars, n_time_points = time_series.shape
if len(locations) == 0:
logger.warning("Empty locations array provided")
return np.empty((n_vars * (model_order + 1), extract_length, 0))
min_location = np.min(locations)
max_location = np.max(locations)
if min_location < 0 or max_location >= n_time_points:
raise IndexError(
f"Event locations {min_location}-{max_location} out of bounds "
f"for time series length {n_time_points}"
)
try:
n_events = len(locations)
snapshots = np.zeros((n_vars * (model_order + 1), extract_length, n_events))
idx1 = np.arange(n_vars * (model_order + 1))
idx2 = np.tile(np.arange(n_vars), model_order + 1)
delay = np.tile(np.arange(0, model_order + 1) * lag_step, (n_vars, 1)).flatten()
for n in range(len(idx1)):
snapshots[idx1[n], :, :] = extract_event_windows(
time_series[idx2[n], :],
locations - delay[n],
start_offset,
extract_length,
)
valid_snapshots = np.sum(~np.isnan(snapshots[0, 0, :]))
logger.info(f"Successfully extracted {valid_snapshots}/{n_events} snapshots")
return snapshots
except Exception as e:
logger.error(f"Snapshot extraction failed: {e}")
raise ComputationError(
f"Snapshot extraction failed: {e}", "snapshot_extraction", None
)
[docs]
def perform_desnap_analysis(inputs: DeSnapParams) -> Dict[str, Union[np.ndarray, Dict]]:
"""
Perform DeSnap analysis for event-based causality.
This function implements the DeSnap (De-Snapshot) analysis method for
event-based causality analysis. It performs a "de-snapshotting" analysis
to derive unconditional statistics from conditional statistics by accounting
for a conditioning variable 'D'.
Parameters
----------
inputs : DeSnapParams
Configuration object containing all parameters for DeSnap analysis.
Returns
-------
Dict[str, Union[np.ndarray, Dict]]
Dictionary containing DeSnap analysis results:
- 'loc_size': Size of locations per bin of D
- 'p_t', 'q_t': Coefficients from the first linear regression
- 'd_bin_bar': Mean D values for each bin
- 'mean_Yt_cond': Mean Yt events per bin of D
- 'mu_D': Estimated unconditional mean of D
- 'event_stats_uncond': Unconditional statistics
- 'cov_pt': Covariance related to p_t
- 'c': Covariance adjustment factor
Raises
------
ValidationError
If input parameters are invalid.
ComputationError
If computation fails due to numerical issues.
Examples
--------
>>> from trancit.config import DeSnapParams
>>> params = DeSnapParams(...) # Configure parameters
>>> results = perform_desnap_analysis(params)
>>> print(f"Results keys: {list(results.keys())}")
"""
if not isinstance(inputs, DeSnapParams):
raise ValidationError(
"inputs must be a DeSnapParams object", "inputs_type", type(inputs)
)
try:
if inputs.d0_max is None and inputs.maxStdRatio is None:
raise ValidationError(
"Either d0_max or maxStdRatio must be provided", "d0_max_missing", None
)
if inputs.N_d <= 0:
raise ValidationError(
"N_d (number of bins) must be positive", "N_d", inputs.N_d
)
if inputs.detection_signal.ndim != 1:
raise ValidationError(
"detection_signal must be 1D",
"detection_signal_ndim",
inputs.detection_signal.ndim,
)
if inputs.original_signal.ndim != 2:
raise ValidationError(
"original_signal must be 2D",
"original_signal_ndim",
inputs.original_signal.ndim,
)
if inputs.detection_signal.shape[0] != inputs.original_signal.shape[1]:
raise ValidationError(
"detection_signal and original_signal must have same time dimension",
"signal_lengths",
(inputs.detection_signal.shape[0], inputs.original_signal.shape[1]),
)
if inputs.d0_max is None:
if inputs.maxStdRatio is not None:
inputs.d0_max = np.mean(
inputs.detection_signal
) + inputs.maxStdRatio * np.std(inputs.detection_signal)
else:
inputs.d0_max = np.max(inputs.detection_signal)
logger.info(
f"DeSnap analysis: d0={inputs.d0}, d0_max={inputs.d0_max}, N_d={inputs.N_d}"
)
bin_step = abs(inputs.d0_max - inputs.d0) / inputs.N_d
d_bin_edges = np.arange(inputs.d0, inputs.d0_max + bin_step + 1e-12, bin_step)
num_bins = len(d_bin_edges)
d_bin_mean_detection = np.full(num_bins, np.nan)
num_input_channels = inputs.original_signal.shape[0]
num_snapshot_vars = num_input_channels * (inputs.morder + 1)
mean_events_cond_binned = np.full(
(num_bins, num_snapshot_vars, inputs.l_extract), np.nan
)
DeSnap_results = {
"loc_size": np.full(num_bins, np.nan),
"event_stats": inputs.event_stats,
}
logger.info("Processing bins of conditioning variable D ...")
current_bin_uplim = np.max(inputs.detection_signal)
for n, current_bin_lolim in enumerate(d_bin_edges):
mask = (inputs.detection_signal >= current_bin_lolim) & (
inputs.detection_signal < current_bin_uplim
)
d_bin_mean_detection[n] = np.mean(inputs.detection_signal[mask])
temp_loc = np.where(mask)[0]
valid_locs = temp_loc.copy()
# valid_locs = valid_locs[
# inputs.original_signal.shape[0] - valid_locs >=
# inputs.l_extract - inputs.l_start
# ]
# valid_locs = valid_locs[
# valid_locs - inputs.l_start - (inputs.morder * inputs.tau) >= 0
# ]
DeSnap_results["loc_size"][n] = len(valid_locs)
if len(valid_locs) > 0:
try:
events_binned = extract_event_snapshots(
inputs.original_signal,
valid_locs,
inputs.morder,
inputs.tau,
inputs.l_start,
inputs.l_extract,
)
if events_binned.shape[2] > 0:
mean_events_cond_binned[n, :, :] = np.mean(
events_binned, axis=2
)
else:
logger.warning(
f"No snapshots extracted for bin {n+1} despite "
f"{len(valid_locs)} valid_locs"
)
except Exception as e:
logger.error(f"Failed to extract snapshots for bin {n}: {e}")
else:
logger.warning(
f"No valid locations for snapshot extraction in bin {n+1}"
)
# First Linear Regression: Fit p_t and q_t
logger.info("Performing first linear regression for p_t and q_t...")
try:
from .helpers import compute_multi_variable_linear_regression
p_t, q_t = compute_multi_variable_linear_regression(
d_bin_mean_detection, mean_events_cond_binned
)
DeSnap_results["p_t"] = p_t
DeSnap_results["q_t"] = q_t
DeSnap_results["d_bin_bar"] = d_bin_mean_detection
DeSnap_results["mean_Yt_cond"] = mean_events_cond_binned
except Exception as e:
logger.error(f"First linear regression failed: {e}")
raise ComputationError(
f"First linear regression failed: {e}", "linear_regression_1", None
)
# Second Linear Regression: Estimate mu_D (unconditional mean of D)
p_t_flat = p_t.reshape(-1, 1)
q_t_flat = -q_t.reshape(-1)
nan_mask_regression2 = ~np.isnan(p_t_flat.ravel()) & ~np.isnan(q_t_flat)
if not np.any(nan_mask_regression2):
raise ComputationError(
"All p_t or q_t values are NaN, cannot compute mu_D",
"mu_D_computation",
None,
)
p_t_flat_valid = p_t_flat[nan_mask_regression2]
q_t_flat_valid = q_t_flat[nan_mask_regression2]
try:
DeSnap_results["mu_D"] = np.linalg.lstsq(
p_t_flat_valid, q_t_flat_valid, rcond=None
)[0][0]
except Exception as e:
logger.error(f"Second linear regression failed: {e}")
raise ComputationError(
f"Second linear regression failed: {e}", "linear_regression_2", None
)
DeSnap_results["event_stats_uncond"] = {}
DeSnap_results["event_stats_uncond"]["mean"] = (
q_t + p_t * DeSnap_results["mu_D"]
)
# Third Linear Regression: Compute Covariance Adjustment Factor 'c'
logger.info(
"Performing third linear regression for covariance adjustment factor 'c'..."
)
DeSnap_results["cov_pt"] = np.full(
(inputs.l_extract, num_snapshot_vars, num_snapshot_vars), np.nan
)
for t in range(inputs.l_extract):
DeSnap_results["cov_pt"][t, :, :] = np.outer(p_t[:, t], p_t[:, t])
try:
if inputs.diff_flag:
x_reg_c = np.diff(DeSnap_results["cov_pt"], axis=0)
y_reg_c = np.diff(inputs.event_stats["Sigma"], axis=0)
DeSnap_results["c"] = np.linalg.lstsq(
x_reg_c.reshape(-1, 1), y_reg_c.reshape(-1), rcond=None
)[0][0]
else:
x_reg_c_levels = DeSnap_results["cov_pt"][:, 0, 0]
y_reg_c_levels = inputs.event_stats["Sigma"][:, 0, 0]
X_design_c = np.vstack([np.ones_like(x_reg_c_levels), x_reg_c_levels]).T
temp_coeffs_c = np.linalg.lstsq(X_design_c, y_reg_c_levels, rcond=None)[
0
]
DeSnap_results["c"] = temp_coeffs_c[1]
except Exception as e:
logger.error(f"Third linear regression failed: {e}")
raise ComputationError(
f"Third linear regression failed: {e}", "linear_regression_3", None
)
DeSnap_results["event_stats_uncond"]["Sigma"] = (
inputs.event_stats["Sigma"] - DeSnap_results["c"] * DeSnap_results["cov_pt"]
)
logger.info("Calculating unconditional AR coefficients...")
try:
nvar_actual = inputs.event_stats["OLS"]["At"].shape[1]
except (KeyError, AttributeError, IndexError):
raise ValidationError(
"Could not determine 'nvar_actual' from "
"inputs.event_stats['OLS']['At']",
"ols_at_structure",
None,
)
DeSnap_results["event_stats_uncond"]["OLS"] = {}
DeSnap_results["event_stats_uncond"]["OLS"]["At"] = np.full(
(inputs.l_extract, nvar_actual, nvar_actual * inputs.morder), np.nan
)
for t in range(inputs.l_extract):
try:
Sigma_yx_uncond = DeSnap_results["event_stats_uncond"]["Sigma"][
t, :nvar_actual, nvar_actual:
]
Sigma_xx_uncond = DeSnap_results["event_stats_uncond"]["Sigma"][
t, nvar_actual:, nvar_actual:
]
Sigma_xx_uncond_reg = regularize_if_singular(Sigma_xx_uncond)
if not np.allclose(Sigma_xx_uncond, Sigma_xx_uncond_reg):
logger.warning(
f"DeSnap: Applied regularization to Sigma_xx_uncond "
f"at time step {t}"
)
DeSnap_results["event_stats_uncond"]["OLS"]["At"][t, :, :] = (
Sigma_yx_uncond @ np.linalg.inv(Sigma_xx_uncond_reg)
)
except np.linalg.LinAlgError:
logger.warning(
f"DeSnap: Singular matrix at time step {t}, using pseudo-inverse"
)
DeSnap_results["event_stats_uncond"]["OLS"]["At"][t, :, :] = (
Sigma_yx_uncond @ np.linalg.pinv(Sigma_xx_uncond_reg)
)
except Exception as e:
logger.error(f"Failed to compute AR coefficients at time step {t}: {e}")
continue
logger.info(f"Successfully completed DeSnap analysis with {num_bins} bins")
return DeSnap_results
except Exception as e:
logger.error(f"DeSnap analysis failed: {e}")
raise ComputationError(f"DeSnap analysis failed: {e}", "desnap_analysis", ())