"""
Relative Dynamic Causal Strength (rDCS) implementation.
This module provides the implementation of time-varying causality
including Transfer Entropy (TE), Dynamic Causal Strength (DCS),
and Relative Dynamic Causal Strength (rDCS).
"""
import logging
from typing import Dict
import numpy as np
from trancit.core.base import BaseAnalyzer, BaseResult
from trancit.core.exceptions import ComputationError, ValidationError
from trancit.utils.preprocess import regularize_if_singular
logger = logging.getLogger(__name__)
[docs]
class RelativeDCSResult(BaseResult):
"""Result container for Relative Dynamic Causal Strength analysis."""
[docs]
def __init__(
self,
transfer_entropy: np.ndarray,
dynamic_causal_strength: np.ndarray,
relative_dynamic_causal_strength: np.ndarray,
coefficients: np.ndarray,
):
"""
Initialize Relative DCS result.
Parameters
----------
transfer_entropy : np.ndarray
Transfer Entropy values
dynamic_causal_strength : np.ndarray
Dynamic Causal Strength values
relative_dynamic_causal_strength : np.ndarray
Relative Dynamic Causal Strength values
coefficients : np.ndarray
VAR coefficients
"""
super().__init__(
transfer_entropy=transfer_entropy,
dynamic_causal_strength=dynamic_causal_strength,
relative_dynamic_causal_strength=relative_dynamic_causal_strength,
coefficients=coefficients,
)
[docs]
class RelativeDCSCalculator(BaseAnalyzer):
"""
Relative Dynamic Causal Strength (rDCS) calculator.
This class implements the Relative Dynamic Causal Strength algorithm
for quantifying causal relationships relative to a baseline period.
"""
[docs]
def __init__(
self,
model_order: int,
reference_time: int,
estimation_mode: str = "OLS",
use_diagonal_covariance: bool = False,
use_old_version: bool = False,
**kwargs,
):
"""
Initialize Relative DCS calculator.
Parameters
----------
model_order : int
Model order for VAR analysis
reference_time : int
Reference time point for baseline calculation
estimation_mode : str
Estimation mode: 'OLS' or 'RLS'
use_diagonal_covariance : bool
Whether to use diagonal covariance approximation
use_old_version : bool
Whether to use old version of rDCS calculation
**kwargs
Additional configuration parameters
"""
super().__init__(
model_order=model_order,
reference_time=reference_time,
estimation_mode=estimation_mode,
use_diagonal_covariance=use_diagonal_covariance,
use_old_version=use_old_version,
**kwargs,
)
[docs]
def analyze( # type: ignore[override]
self, event_data: np.ndarray, stats: Dict, **kwargs
) -> RelativeDCSResult:
"""
Perform Relative Dynamic Causal Strength analysis.
Parameters
----------
event_data : np.ndarray
Event data array of shape (nvar * (model_order + 1), nobs, ntrials)
stats : Dict
Model statistics with keys:
- 'OLS' or 'RLS': Sub-dict with 'At' (coefficients) and 'Sigma_Et'
(residual covariance)
- 'Sigma': Covariance matrices
- 'mean': Mean values
**kwargs
Additional parameters
Returns
-------
RelativeDCSResult
Relative DCS analysis results
"""
self._validate_input_data(event_data, stats)
self._log_analysis_start(event_data.shape)
try:
causal_params = {
"ref_time": self.config["reference_time"],
"estim_mode": self.config["estimation_mode"],
"morder": self.config["model_order"],
"diag_flag": self.config["use_diagonal_covariance"],
"old_version": self.config["use_old_version"],
}
causality_measures = time_varying_causality(
event_data, stats, causal_params
)
self._log_analysis_complete()
return RelativeDCSResult(
transfer_entropy=causality_measures["TE"],
dynamic_causal_strength=causality_measures["DCS"],
relative_dynamic_causal_strength=causality_measures["rDCS"],
coefficients=stats[self.config["estimation_mode"]]["At"],
)
except Exception as e:
logger.error(f"Relative DCS analysis failed: {e}")
raise ComputationError(
f"Relative DCS analysis failed: {e}",
"rdcs_computation",
event_data.shape,
)
def _validate_config(self) -> None:
"""Validate configuration parameters."""
if self.config["model_order"] <= 0:
raise ValidationError(
"model_order must be positive",
"model_order",
self.config["model_order"],
)
if self.config["reference_time"] < 0:
raise ValidationError(
"reference_time must be non-negative",
"reference_time",
self.config["reference_time"],
)
if self.config["estimation_mode"] not in ["OLS", "RLS"]:
raise ValidationError(
"estimation_mode must be 'OLS' or 'RLS'",
"estimation_mode",
self.config["estimation_mode"],
)
def _validate_input_data( # type: ignore[override]
self, event_data: np.ndarray, stats: Dict
) -> None:
"""Validate input data format and dimensions."""
super()._validate_input_data(event_data)
if event_data.ndim != 3:
raise ValidationError(
"event_data must be 3D", "event_data_ndim", event_data.ndim
)
if not isinstance(stats, dict):
raise ValidationError(
"stats must be a dictionary", "stats_type", type(stats)
)
required_keys = ["OLS", "Sigma", "mean"]
for key in required_keys:
if key not in stats:
raise ValidationError(
f"stats must contain '{key}'", "stats_missing_key", key
)
[docs]
def time_varying_causality(
event_data: np.ndarray, stats: Dict, causal_params: Dict
) -> Dict[str, np.ndarray]:
"""
Compute time-varying causality measures for bivariate signals.
Calculates Transfer Entropy (TE), Dynamic Causal Strength (DCS), and Relative
Dynamic Causal Strength (rDCS) based on a VAR model. This function maintains
exact mathematical alignment with the previous implementation while providing
enhanced error handling and validation.
Parameters
----------
event_data : np.ndarray
Event data array of shape (nvar * (model_order + 1), nobs, ntrials).
Must be bivariate (nvar = 2).
stats : Dict
Model statistics with keys:
- 'OLS' or 'RLS': Sub-dict with 'At' (coefficients) and 'Sigma_Et'
(residual covariance).
- 'Sigma': Covariance matrices of shape
(nobs, nvar * (model_order + 1), nvar * (model_order + 1)).
- 'mean': Mean values of shape (nvar * (model_order + 1), nobs).
causal_params : Dict
Parameters with keys:
- 'ref_time': Reference time index for rDCS calculation.
- 'estim_mode': 'OLS' or 'RLS' estimation mode.
- 'morder': Model order (number of lags).
- 'diag_flag': Boolean for diagonal covariance approximation.
- 'old_version': Boolean for rDCS calculation method.
Returns
-------
Dict[str, np.ndarray]
Causality measures:
- 'TE': Transfer Entropy, shape (nobs, 2) where [:, 0] is Y->X,
[:, 1] is X->Y.
- 'DCS': Dynamic Causal Strength, shape (nobs, 2) where [:, 0] is
Y->X, [:, 1] is X->Y.
- 'rDCS': Relative Dynamic Causal Strength, shape (nobs, 2)
where [:, 0] is Y->X, [:, 1] is X->Y.
Raises
------
ValidationError
If input data or parameters are invalid.
ComputationError
If computation fails due to numerical issues.
Notes
-----
- The function assumes bivariate data (nvar = 2).
- Transfer Entropy (TE) measures directed information flow.
- Dynamic Causal Strength (DCS) measures direct causal influence.
- Relative Dynamic Causal Strength (rDCS) measures causal influence
relative to a baseline.
- All measures are computed using the Structural Causal Model (SCM) framework.
Examples
--------
>>> event_data = np.random.randn(6, 50, 10) # (nvar * (morder + 1), nobs, ntrials)
>>> stats = {
... 'OLS': {
... 'At': np.random.randn(50, 2, 4),
... 'Sigma_Et': np.array([np.eye(2) for _ in range(50)])
... },
... 'Sigma': np.random.randn(50, 6, 6),
... 'mean': np.random.randn(6, 50)
... }
>>> causal_params = {
... 'ref_time': 10, 'estim_mode': 'OLS', 'morder': 2,
... 'diag_flag': False, 'old_version': False
... }
>>> result = time_varying_causality(event_data, stats, causal_params)
>>> print(f"TE shape: {result['TE'].shape}")
>>> print(f"DCS shape: {result['DCS'].shape}")
>>> print(f"rDCS shape: {result['rDCS'].shape}")
"""
_validate_time_varying_causality_inputs(event_data, stats, causal_params)
_, nobs, ntrials = event_data.shape
nvar = stats["OLS"]["At"].shape[1]
ref_time = _normalize_ref_time(causal_params["ref_time"], nobs)
estim_mode = causal_params["estim_mode"]
morder = causal_params["morder"]
diag_flag = causal_params["diag_flag"]
old_version = causal_params["old_version"]
logger.info(
f"Computing time-varying causality: mode={estim_mode}, "
f"morder={morder}, ref_time={ref_time}"
)
causality_measures = {
"TE": np.zeros((nobs, 2)),
"DCS": np.zeros((nobs, 2)),
"rDCS": np.zeros((nobs, 2)),
}
for t in range(nobs):
try:
_compute_causality_at_timepoint(
t,
event_data,
stats,
causality_measures,
nvar,
ntrials,
ref_time,
estim_mode,
morder,
diag_flag,
old_version,
)
except Exception as e:
logger.error(f"Computation failed at time point {t}: {e}")
causality_measures["TE"][t] = np.zeros(2)
causality_measures["DCS"][t] = np.zeros(2)
causality_measures["rDCS"][t] = np.zeros(2)
logger.info("Time-varying causality computation completed")
return causality_measures
def _validate_time_varying_causality_inputs(
event_data: np.ndarray, stats: Dict, causal_params: Dict
) -> None:
"""
Validate inputs for time_varying_causality function.
Parameters
----------
event_data : np.ndarray
Event data array
stats : Dict
Model statistics
causal_params : Dict
Causality parameters
Raises
------
ValidationError
If inputs are invalid
"""
if not isinstance(event_data, np.ndarray):
raise ValidationError(
"event_data must be a NumPy array", "event_data_type", type(event_data)
)
if event_data.ndim != 3:
raise ValidationError(
"event_data must be 3D", "event_data_ndim", event_data.ndim
)
if not isinstance(stats, dict):
raise ValidationError("stats must be a dictionary", "stats_type", type(stats))
required_keys = ["OLS", "Sigma", "mean"]
for key in required_keys:
if key not in stats:
raise ValidationError(
f"stats must contain '{key}'", "stats_missing_key", key
)
if not isinstance(causal_params, dict):
raise ValidationError(
"causal_params must be a dictionary",
"causal_params_type",
type(causal_params),
)
required_params = ["ref_time", "estim_mode", "morder", "diag_flag", "old_version"]
for param in required_params:
if param not in causal_params:
raise ValidationError(
f"causal_params must contain '{param}'", "causal_params_missing", param
)
if causal_params["estim_mode"] not in ["OLS", "RLS"]:
raise ValidationError(
"estim_mode must be 'OLS' or 'RLS'",
"estim_mode",
causal_params["estim_mode"],
)
_validate_ref_time(causal_params["ref_time"])
def _validate_ref_time(ref_time: object) -> None:
"""
Validate ref_time parameter.
Parameters
----------
ref_time : object
Reference time parameter to validate
Raises
------
ValidationError
If ref_time is invalid
"""
if isinstance(ref_time, (int, np.integer)):
if ref_time < 0:
raise ValidationError("ref_time must be non-negative", "ref_time", ref_time)
else:
try:
values = list(ref_time)
except Exception:
raise ValidationError(
"ref_time must be an int or an iterable of ints",
"ref_time",
type(ref_time),
)
if len(values) == 0:
raise ValidationError(
"ref_time iterable must be non-empty", "ref_time", values
)
if np.any(np.array(values) < 0):
raise ValidationError(
"ref_time iterable values must be non-negative", "ref_time", values
)
def _normalize_ref_time(ref_time: object, nobs: int) -> int:
"""
Normalize ref_time to an integer suitable for slicing [:ref_time].
- If ref_time is an int, return it clamped to [0, nobs].
- If ref_time is an iterable (e.g., range/list/ndarray), use max(ref_time),
then clamp to [0, nobs].
"""
if isinstance(ref_time, (int, np.integer)):
idx = int(ref_time)
else:
try:
values = list(ref_time)
except Exception:
raise ValidationError(
"ref_time must be an int or an iterable of ints",
"ref_time",
type(ref_time),
)
if len(values) == 0:
raise ValidationError(
"ref_time iterable must be non-empty", "ref_time", values
)
if np.any(np.array(values) < 0):
raise ValidationError(
"ref_time iterable values must be non-negative", "ref_time", values
)
idx = int(np.max(values))
if idx < 0:
raise ValidationError("ref_time must be non-negative", "ref_time", idx)
if idx > nobs:
idx = nobs
return idx
def _compute_causality_at_timepoint(
t: int,
event_data: np.ndarray,
stats: Dict,
causality_measures: Dict[str, np.ndarray],
nvar: int,
ntrials: int,
ref_time: int,
estim_mode: str,
morder: int,
diag_flag: bool,
old_version: bool,
) -> None:
"""
Compute causality measures for a specific time point.
Parameters
----------
t : int
Time point index
event_data : np.ndarray
Event data array
stats : Dict
Model statistics
causality_measures : Dict[str, np.ndarray]
Dictionary to store results
nvar : int
Number of variables
ntrials : int
Number of trials
ref_time : int
Reference time index
estim_mode : str
Estimation mode
morder : int
Model order
diag_flag : bool
Whether to use diagonal covariance
old_version : bool
Whether to use old version of rDCS
"""
lagged_vars = event_data[2:, t, :]
coeff = stats[estim_mode]["At"][t, :, :]
residual_cov = stats[estim_mode]["Sigma_Et"][t, :, :]
a_square = coeff.reshape(nvar, nvar, morder, order="F")
b = a_square[0, 1, :] # X -> Y coupling
c = a_square[1, 0, :] # Y -> X coupling
sigy = residual_cov[0, 0] or np.finfo(float).eps
sigx = residual_cov[1, 1] or np.finfo(float).eps
x_past_start = 3
y_past_start = 2
x_past_indices = slice(x_past_start, x_past_start + 2 * morder, 2)
y_past_indices = slice(y_past_start, y_past_start + 2 * morder, 2)
cov_xp = stats["Sigma"][t, x_past_indices, x_past_indices] # X past covariance
cov_yp = stats["Sigma"][t, y_past_indices, y_past_indices] # Y past covariance
c_xyp = stats["Sigma"][
t, x_past_indices, y_past_indices
] # X past - Y past cross-covariance
c_yxp = stats["Sigma"][
t, y_past_indices, x_past_indices
] # Y past - X past cross-covariance
mean_xp = stats["mean"][x_past_indices, t] # X past mean
mean_yp = stats["mean"][y_past_indices, t] # Y past mean
cov_xp_reg = regularize_if_singular(cov_xp)
cov_yp_reg = regularize_if_singular(cov_yp)
_compute_transfer_entropy(
t,
b,
c,
sigy,
sigx,
cov_xp,
cov_yp,
c_xyp,
c_yxp,
cov_xp_reg,
cov_yp_reg,
causality_measures,
)
mean_x_ref = np.mean(
stats["mean"][x_past_indices, :ref_time], axis=1
) # X past mean
mean_y_ref = np.mean(
stats["mean"][y_past_indices, :ref_time], axis=1
) # Y past mean
cov_xp_ref = (
cov_xp
+ mean_xp @ mean_xp.T
- mean_xp @ mean_x_ref.T
- mean_x_ref @ mean_xp.T
+ mean_x_ref @ mean_x_ref.T
)
cov_yp_ref = (
cov_yp
+ mean_yp @ mean_yp.T
- mean_yp @ mean_y_ref.T
- mean_y_ref @ mean_yp.T
+ mean_y_ref @ mean_y_ref.T
)
ref_cov_xp = np.mean(
stats["Sigma"][:ref_time, x_past_indices, x_past_indices], axis=0
) # X past
ref_cov_yp = np.mean(
stats["Sigma"][:ref_time, y_past_indices, y_past_indices], axis=0
) # Y past
_compute_causal_strength_measures(
t,
b,
c,
sigy,
sigx,
cov_xp,
cov_yp,
cov_xp_ref,
cov_yp_ref,
ref_cov_xp,
ref_cov_yp,
lagged_vars,
stats,
ntrials,
ref_time,
diag_flag,
old_version,
causality_measures,
)
def _compute_transfer_entropy(
t: int,
b: np.ndarray,
c: np.ndarray,
sigy: float,
sigx: float,
cov_xp: np.ndarray,
cov_yp: np.ndarray,
c_xyp: np.ndarray,
c_yxp: np.ndarray,
cov_xp_reg: np.ndarray,
cov_yp_reg: np.ndarray,
causality_measures: Dict[str, np.ndarray],
) -> None:
"""
Compute Transfer Entropy for a specific time point.
Parameters
----------
t : int
Time point index
b : np.ndarray
X -> Y coupling coefficients
c : np.ndarray
Y -> X coupling coefficients
sigy : float
Y residual variance
sigx : float
X residual variance
cov_xp : np.ndarray
X past covariance
cov_yp : np.ndarray
Y past covariance
c_xyp : np.ndarray
X past - Y past cross-covariance
c_yxp : np.ndarray
Y past - X past cross-covariance
cov_xp_reg : np.ndarray
Regularized X past covariance
cov_yp_reg : np.ndarray
Regularized Y past covariance
causality_measures : Dict[str, np.ndarray]
Dictionary to store results
"""
# TE(X -> Y)
causality_measures["TE"][t, 1] = 0.5 * np.log(
(
sigy
+ b.T @ cov_xp @ b
- b.T @ c_xyp @ np.linalg.inv(cov_yp_reg) @ c_xyp.T @ b
)
/ sigy
)
# TE(Y -> X)
causality_measures["TE"][t, 0] = 0.5 * np.log(
(
sigx
+ c.T @ cov_yp @ c
- c.T @ c_yxp @ np.linalg.inv(cov_xp_reg) @ c_yxp.T @ c
)
/ sigx
)
def _compute_causal_strength_measures(
t: int,
b: np.ndarray,
c: np.ndarray,
sigy: float,
sigx: float,
cov_xp: np.ndarray,
cov_yp: np.ndarray,
cov_xp_ref: np.ndarray,
cov_yp_ref: np.ndarray,
ref_cov_xp: np.ndarray,
ref_cov_yp: np.ndarray,
lagged_vars: np.ndarray,
stats: Dict,
ntrials: int,
ref_time: int,
diag_flag: bool,
old_version: bool,
causality_measures: Dict[str, np.ndarray],
) -> None:
"""
Compute Dynamic Causal Strength (DCS) and Relative Dynamic Causal Strength (rDCS).
Parameters
----------
t : int
Time point index
b : np.ndarray
X -> Y coupling coefficients
c : np.ndarray
Y -> X coupling coefficients
sigy : float
Y residual variance
sigx : float
X residual variance
cov_xp : np.ndarray
X past covariance
cov_yp : np.ndarray
Y past covariance
cov_xp_ref : np.ndarray
Reference-adjusted X past covariance
cov_yp_ref : np.ndarray
Reference-adjusted Y past covariance
ref_cov_xp : np.ndarray
Reference X past covariance
ref_cov_yp : np.ndarray
Reference Y past covariance
lagged_vars : np.ndarray
Lagged variables
stats : Dict
Model statistics
ntrials : int
Number of trials
ref_time : int
Reference time index
diag_flag : bool
Whether to use diagonal covariance
old_version : bool
Whether to use old version of rDCS
causality_measures : Dict[str, np.ndarray]
Dictionary to store results
"""
if not diag_flag:
causality_measures["DCS"][t, 1] = 0.5 * np.log((sigy + b.T @ cov_xp @ b) / sigy)
causality_measures["DCS"][t, 0] = 0.5 * np.log((sigx + c.T @ cov_yp @ c) / sigx)
if old_version:
_compute_old_version_rdcs(
t,
b,
c,
sigy,
sigx,
ref_cov_xp,
ref_cov_yp,
lagged_vars,
stats,
ntrials,
ref_time,
causality_measures,
)
else:
_compute_new_version_rdcs(
t,
b,
c,
sigy,
sigx,
cov_xp_ref,
cov_yp_ref,
ref_cov_xp,
ref_cov_yp,
causality_measures,
)
else:
causality_measures["DCS"][t, 1] = 0.5 * np.log(
(sigy + b.T @ np.diag(np.diag(cov_xp)) @ b) / sigy
)
causality_measures["DCS"][t, 0] = 0.5 * np.log(
(sigx + c.T @ np.diag(np.diag(cov_yp)) @ c) / sigx
)
if old_version:
_compute_old_version_rdcs_diagonal(
t,
b,
c,
sigy,
sigx,
ref_cov_xp,
ref_cov_yp,
lagged_vars,
stats,
ntrials,
ref_time,
causality_measures,
)
else:
_compute_new_version_rdcs_diagonal(
t,
b,
c,
sigy,
sigx,
cov_xp_ref,
cov_yp_ref,
ref_cov_xp,
ref_cov_yp,
causality_measures,
)
def _compute_old_version_rdcs(
t: int,
b: np.ndarray,
c: np.ndarray,
sigy: float,
sigx: float,
ref_cov_xp: np.ndarray,
ref_cov_yp: np.ndarray,
lagged_vars: np.ndarray,
stats: Dict,
ntrials: int,
ref_time: int,
causality_measures: Dict[str, np.ndarray],
) -> None:
"""Compute old version of rDCS with full covariance."""
cov_xp_lag = (
np.dot(
lagged_vars - np.mean(stats["mean"][3:, :ref_time], axis=1),
(lagged_vars - np.mean(stats["mean"][3:, :ref_time], axis=1)).T,
)
/ ntrials
)
causality_measures["rDCS"][t, 1] = (
0.5 * np.log((sigy + b.T @ ref_cov_xp @ b) / sigy)
- 0.5
+ 0.5
* (sigy + b.T @ cov_xp_lag[2::2, 2::2] @ b)
/ (sigy + b.T @ ref_cov_xp @ b)
)
causality_measures["rDCS"][t, 0] = (
(0.5 * np.log((sigx + c.T @ ref_cov_yp @ c) / sigx))
- 0.5
+ (
0.5
* (sigx + c.T @ cov_xp_lag[1::2, 1::2] @ c)
/ (sigx + c.T @ ref_cov_yp @ c)
)
)
def _compute_new_version_rdcs(
t: int,
b: np.ndarray,
c: np.ndarray,
sigy: float,
sigx: float,
cov_xp_ref: np.ndarray,
cov_yp_ref: np.ndarray,
ref_cov_xp: np.ndarray,
ref_cov_yp: np.ndarray,
causality_measures: Dict[str, np.ndarray],
) -> None:
"""Compute new version of rDCS with full covariance."""
causality_measures["rDCS"][t, 1] = (
(0.5 * np.log((sigy + b.T @ ref_cov_xp @ b) / sigy))
- 0.5
+ (0.5 * (sigy + b.T @ cov_xp_ref @ b) / (sigy + b.T @ ref_cov_xp @ b))
)
causality_measures["rDCS"][t, 0] = (
(0.5 * np.log((sigx + c.T @ ref_cov_yp @ c) / sigx))
- 0.5
+ (0.5 * (sigx + c.T @ cov_yp_ref @ c) / (sigx + c.T @ ref_cov_yp @ c))
)
def _compute_old_version_rdcs_diagonal(
t: int,
b: np.ndarray,
c: np.ndarray,
sigy: float,
sigx: float,
ref_cov_xp: np.ndarray,
ref_cov_yp: np.ndarray,
lagged_vars: np.ndarray,
stats: Dict,
ntrials: int,
ref_time: int,
causality_measures: Dict[str, np.ndarray],
) -> None:
"""Compute old version of rDCS with diagonal covariance."""
cov_xp_lag = (
np.dot(
lagged_vars - np.mean(stats["mean"][3:, :ref_time], axis=1)[:, np.newaxis],
(
lagged_vars
- np.mean(stats["mean"][3:, :ref_time], axis=1)[:, np.newaxis]
).T,
)
/ ntrials
)
causality_measures["rDCS"][t, 1] = (
(0.5 * np.log((sigy + b.T @ np.diag(np.diag(ref_cov_xp)) @ b) / sigy))
- 0.5
+ (
0.5
* (sigy + b.T @ np.diag(np.diag(cov_xp_lag[2::2, 2::2])) @ b)
/ (sigy + b.T @ np.diag(np.diag(ref_cov_xp)) @ b)
)
)
causality_measures["rDCS"][t, 0] = (
(0.5 * np.log((sigx + c.T @ np.diag(np.diag(ref_cov_yp)) @ c) / sigx))
- 0.5
+ (
0.5
* (sigx + c.T @ np.diag(np.diag(cov_xp_lag[1::2, 1::2])) @ c)
/ (sigx + c.T @ np.diag(np.diag(ref_cov_yp)) @ c)
)
)
def _compute_new_version_rdcs_diagonal(
t: int,
b: np.ndarray,
c: np.ndarray,
sigy: float,
sigx: float,
cov_xp_ref: np.ndarray,
cov_yp_ref: np.ndarray,
ref_cov_xp: np.ndarray,
ref_cov_yp: np.ndarray,
causality_measures: Dict[str, np.ndarray],
) -> None:
"""Compute new version of rDCS with diagonal covariance."""
causality_measures["rDCS"][t, 1] = (
(0.5 * np.log((sigy + b.T @ np.diag(np.diag(ref_cov_xp)) @ b) / sigy))
- 0.5
+ (
0.5
* (sigy + b.T @ np.diag(np.diag(cov_xp_ref)) @ b)
/ (sigy + b.T @ np.diag(np.diag(ref_cov_xp)) @ b)
)
)
causality_measures["rDCS"][t, 0] = (
(0.5 * np.log((sigx + c.T @ np.diag(np.diag(ref_cov_yp)) @ c) / sigx))
- 0.5
+ (
0.5
* (sigx + c.T @ np.diag(np.diag(cov_yp_ref)) @ c)
/ (sigx + c.T @ np.diag(np.diag(ref_cov_yp)) @ c)
)
)