"""
Preprocessing utilities for Dynamic Causal Strength (DCS).
This module provides functions for data cleaning, artifact removal, and matrix
regularization. These utilities ensure data quality and numerical stability for
the DCS pipeline.
"""
import logging
from typing import Optional, Tuple
import numpy as np
from sklearn.covariance import ledoit_wolf
from ..core.exceptions import ComputationError, ValidationError
logger = logging.getLogger(__name__)
def remove_artifact_trials(
event_data: np.ndarray, locations: np.ndarray, lower_threshold: float
) -> Tuple[np.ndarray, np.ndarray]:
"""
Remove trials from event data where the signal drops below a specified threshold.
This function identifies and removes trials where any value in the first two
variables of the event data falls below the given lower threshold. It also
removes the corresponding locations from the `locations` array.
Parameters
----------
event_data : np.ndarray
3D array of shape (variables, time_points, trials) containing the event data.
locations : np.ndarray
1D array of shape (trials,) containing location indices for each trial.
lower_threshold : float
The threshold value below which trials are considered artifacts and removed.
Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing:
- The updated event_data with artifact trials removed.
- The updated locations array with corresponding entries removed.
Raises
------
ValidationError
If input parameters are invalid.
ComputationError
If artifact removal fails.
Examples
--------
>>> import numpy as np
>>> event_data = np.random.randn(2, 100, 10) # (vars, time, trials)
>>> locations = np.array([100, 200, 300, 400, 500, 600, 700, 800, 900, 1000])
>>> cleaned_data, cleaned_locs = remove_artifact_trials(
... event_data, locations, -15000
... )
>>> print(f"Original trials: {event_data.shape[2]}")
>>> print(f"Cleaned trials: {cleaned_data.shape[2]}")
"""
# Input validation
if not isinstance(event_data, np.ndarray) or event_data.ndim != 3:
raise ValidationError(
"event_data must be a 3D numpy array",
"event_data_ndim",
event_data.ndim if hasattr(event_data, "ndim") else None,
)
if not isinstance(locations, np.ndarray) or locations.ndim != 1:
raise ValidationError(
"locations must be a 1D numpy array",
"locations_ndim",
locations.ndim if hasattr(locations, "ndim") else None,
)
if not isinstance(lower_threshold, (int, float)):
raise ValidationError(
"lower_threshold must be a number",
"lower_threshold_type",
type(lower_threshold),
)
# Check dimensions
if event_data.shape[2] != len(locations):
raise ValidationError(
"event_data trials dimension must match locations length",
"dimension_mismatch",
(event_data.shape[2], len(locations)),
)
if event_data.shape[0] < 2:
raise ValidationError(
"event_data must have at least 2 variables", "n_vars", event_data.shape[0]
)
try:
artifact_mask = np.any(event_data[:2, :, :] < lower_threshold, axis=(0, 1))
trials_to_remove = np.where(artifact_mask)[0]
if len(trials_to_remove) == 0:
logger.info("No artifact trials found")
return event_data, locations
updated_event_data = np.delete(event_data, trials_to_remove, axis=2)
updated_locations = np.delete(locations, trials_to_remove)
logger.info(
f"Removed {len(trials_to_remove)} artifact trials "
f"(threshold: {lower_threshold})"
)
logger.info(f"Remaining trials: {updated_event_data.shape[2]}")
return updated_event_data, updated_locations
except Exception as e:
logger.error(f"Artifact removal failed: {e}")
raise ComputationError(f"Artifact removal failed: {e}", "artifact_removal", ())
def regularize_if_singular(
matrix: np.ndarray,
samples: Optional[int] = None,
epsilon: float = 1e-4,
threshold: float = 1e-4,
) -> np.ndarray:
"""
Check if a matrix is singular and regularize it by adding epsilon to the
diagonal if needed.
This function checks if the determinant of the matrix is below a specified
threshold.
If it is, the matrix is considered singular, and a small value (epsilon)
is added to its diagonal
to make it invertible. Otherwise, the original matrix is returned.
Parameters
----------
matrix : np.ndarray
Square matrix to check and potentially regularize.
samples : Optional[int], optional
Number of samples for Ledoit-Wolf regularization (if provided).
epsilon : float, optional
Small value to add to the diagonal if the matrix is singular. Default is 1e-6.
threshold : float, optional
Determinant threshold below which the matrix is considered singular.
Default is 1e-6.
Returns
-------
np.ndarray
The original matrix if non-singular, or the regularized matrix if singular.
Raises
------
ValidationError
If input parameters are invalid.
ComputationError
If regularization fails.
Examples
--------
>>> import numpy as np
>>> matrix = np.array([[1, 0], [0, 0]]) # Singular matrix
>>> regularized = regularize_if_singular(matrix, epsilon=1e-6)
>>> print(f"Original determinant: {np.linalg.det(matrix)}")
>>> print(f"Regularized determinant: {np.linalg.det(regularized)}")
"""
if not isinstance(matrix, np.ndarray) or matrix.ndim != 2:
raise ValidationError(
"matrix must be a 2D numpy array",
"matrix_ndim",
matrix.ndim if hasattr(matrix, "ndim") else None,
)
if matrix.shape[0] != matrix.shape[1]:
raise ValidationError("matrix must be square", "matrix_shape", matrix.shape)
if not isinstance(epsilon, (int, float)) or epsilon <= 0:
raise ValidationError("epsilon must be a positive number", "epsilon", epsilon)
if not isinstance(threshold, (int, float)) or threshold <= 0:
raise ValidationError(
"threshold must be a positive number", "threshold", threshold
)
if samples is not None and (not isinstance(samples, int) or samples <= 0):
raise ValidationError("samples must be a positive integer", "samples", samples)
try:
det = np.linalg.det(matrix)
if abs(det) < threshold:
logger.warning(
f"Singular matrix detected (det={det:.2e}), applying regularization"
)
if samples is not None and samples > matrix.shape[0]:
try:
regularized_matrix = ledoit_wolf(matrix, assume_centered=False)[0]
logger.info("Applied Ledoit-Wolf regularization")
return regularized_matrix
except Exception as e:
logger.warning(
f"Ledoit-Wolf regularization failed: {e}, using "
f"diagonal regularization"
)
regularized_matrix = matrix + epsilon * np.eye(matrix.shape[0])
logger.info(f"Applied diagonal regularization with epsilon={epsilon}")
return regularized_matrix
else:
return matrix
except Exception as e:
logger.error(f"Matrix regularization failed: {e}")
raise ComputationError(
f"Matrix regularization failed: {e}", "matrix_regularization", ()
)
def validate_data_quality(
data: np.ndarray,
check_nan: bool = True,
check_inf: bool = True,
check_constant: bool = True,
) -> Tuple[bool, str]:
"""
Validate data quality by checking for common issues.
This function performs various quality checks on the input data to ensure
it's suitable for analysis.
Parameters
----------
data : np.ndarray
Input data to validate.
check_nan : bool, optional
Whether to check for NaN values. Default is True.
check_inf : bool, optional
Whether to check for infinite values. Default is True.
check_constant : bool, optional
Whether to check for constant variables. Default is True.
Returns
-------
Tuple[bool, str]
A tuple containing:
- is_valid : bool
True if data passes all quality checks.
- message : str
Description of any issues found.
Examples
--------
>>> import numpy as np
>>> data = np.random.randn(2, 100, 10)
>>> is_valid, message = validate_data_quality(data)
>>> print(f"Data valid: {is_valid}")
>>> print(f"Message: {message}")
"""
# Input validation
if not isinstance(data, np.ndarray):
return False, "Input must be a numpy array"
if not isinstance(check_nan, bool):
raise ValidationError(
"check_nan must be a boolean", "check_nan_type", type(check_nan)
)
if not isinstance(check_inf, bool):
raise ValidationError(
"check_inf must be a boolean", "check_inf_type", type(check_inf)
)
if not isinstance(check_constant, bool):
raise ValidationError(
"check_constant must be a boolean",
"check_constant_type",
type(check_constant),
)
issues = []
try:
if check_nan and np.any(np.isnan(data)):
nan_count = np.sum(np.isnan(data))
issues.append(f"Found {nan_count} NaN values")
if check_inf and np.any(np.isinf(data)):
inf_count = np.sum(np.isinf(data))
issues.append(f"Found {inf_count} infinite values")
if check_constant and data.ndim >= 2:
for i in range(data.shape[0]):
if np.all(data[i] == data[i].flat[0]):
issues.append(f"Variable {i} is constant")
if data.size == 0:
issues.append("Data is empty")
if data.size > 0:
max_val = np.max(np.abs(data))
if max_val > 1e10:
issues.append(f"Extreme values detected (max abs: {max_val:.2e})")
if issues:
return False, "; ".join(issues)
else:
return True, "Data quality checks passed"
except Exception as e:
return False, f"Data quality validation failed: {e}"
[docs]
def normalize_data(
data: np.ndarray, method: str = "zscore", axis: Optional[int] = None
) -> np.ndarray:
"""
Normalize data using various methods.
This function provides different normalization methods for data preprocessing.
Parameters
----------
data : np.ndarray
Input data to normalize.
method : str, optional
Normalization method: 'zscore', 'minmax', 'robust', or 'none'. Default
is 'zscore'.
axis : Optional[int], optional
Axis along which to normalize. If None, normalize over all axes.
Default is None.
Returns
-------
np.ndarray
Normalized data.
Raises
------
ValidationError
If input parameters are invalid.
ComputationError
If normalization fails.
Examples
--------
>>> import numpy as np
>>> data = np.random.randn(2, 100, 10)
>>> normalized = normalize_data(data, method='zscore', axis=1)
>>> print(f"Original mean: {np.mean(data):.3f}")
>>> print(f"Normalized mean: {np.mean(normalized):.3f}")
"""
# Input validation
if not isinstance(data, np.ndarray):
raise ValidationError("data must be a numpy array", "data_type", type(data))
if method not in ["zscore", "minmax", "robust", "none"]:
raise ValidationError(
"method must be one of: 'zscore', 'minmax', 'robust', 'none'",
"method",
method,
)
if axis is not None and not isinstance(axis, int):
raise ValidationError(
"axis must be an integer or None", "axis_type", type(axis)
)
try:
if method == "none":
return data.copy()
if axis is None:
data_flat = data.flatten()
else:
data_flat = data
if method == "zscore":
mean_val = np.mean(data_flat, axis=axis, keepdims=True)
std_val = np.std(data_flat, axis=axis, keepdims=True)
std_val = np.where(std_val == 0, 1, std_val)
normalized = (data - mean_val) / std_val
elif method == "minmax":
min_val = np.min(data_flat, axis=axis, keepdims=True)
max_val = np.max(data_flat, axis=axis, keepdims=True)
range_val = max_val - min_val
range_val = np.where(range_val == 0, 1, range_val)
normalized = (data - min_val) / range_val
elif method == "robust":
median_val = np.median(data_flat, axis=axis, keepdims=True)
mad_val = np.median(
np.abs(data_flat - median_val), axis=axis, keepdims=True
)
mad_val = np.where(mad_val == 0, 1, mad_val)
normalized = (data - median_val) / mad_val
logger.info(f"Successfully normalized data using {method} method")
return normalized
except Exception as e:
logger.error(f"Data normalization failed: {e}")
raise ComputationError(
f"Data normalization failed: {e}", "data_normalization", None
)