# Copyright (c) 2025 Sigrun May,
# Ostfalia Hochschule für angewandte Wissenschaften
#
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
r"""Generation of correlated feature clusters simulating pathway-like modules.
Overview
--------
This module generates *correlated Gaussian feature clusters* that can be
interpreted as simplified "pathway-like" modules (e.g., sets of co-expressed
genes or co-regulated proteins).
Each cluster is defined by:
* A correlation structure (equicorrelated or Toeplitz/AR(1)).
* A correlation strength parameter ``correlation``.
* Optionally class-specific correlation strengths to mimic activation in
specific biological conditions (e.g., tumors vs controls).
* An anchor feature with class-specific mean shifts representing diagnostic
strength (e.g., biomarker concentration changes).
The resulting clusters are concatenated horizontally.
Statistical model
-----------------
At the core, each cluster implements a multivariate Gaussian model:
* For a given cluster with n_features (p) and a correlation matrix
:math:`\Sigma`, we generate samples according to
.. math::
x \sim \mathcal{N}_p(\mu_c, \Sigma_c),
where :math:`\mu_c` and :math:`\Sigma_c` depend on class :math:`c`.
* Two correlation structures are supported:
- **Equicorrelated**:
All off-diagonal entries are equal to the correlation parameter:
.. math::
\Sigma_{ij} =
\begin{cases}
1 & i = j, \\
\rho & i \neq j.
\end{cases}
where :math:`\rho` is the correlation parameter.
- **Toeplitz / AR(1)**:
Correlation decays with distance:
.. math::
\Sigma_{ij} = \rho^{\lvert i - j \rvert}.
where :math:`\rho` is the correlation parameter.
Anchor effects (mean shifts)
-----------------------------
When ``anchor_role="informative"`` and ``anchor_effect_size`` is specified,
the anchor feature receives a class-specific mean shift:
.. math::
\mu_{anchor, c} = \text{anchor_effect_size} \cdot \mathbb{1}_{c = anchor\_class}.
Proxy features inherit this shift through correlation but with attenuated
magnitude proportional to their correlation with the anchor.
**Configuration semantics** (enforced by CorrClusterConfig validation):
- ``anchor_role="noise"`` → no mean shift (effect_size ignored if present)
- ``anchor_role="informative"`` → MUST have anchor_effect_size > 0
Limitations and biological realism
----------------------------------
See module docstring for detailed discussion of simplifications.
Key points:
1. Gaussian marginals (real data is often skewed, zero-inflated)
2. Linear dependence only (no thresholds, saturation)
3. Independent clusters (no pathway crosstalk)
4. Blockwise effects (partial activation not modeled)
5. No sample-level heterogeneity (no subtypes)
Intended use
------------
Realistic enough for teaching and benchmarking, but not a fully realistic
generative model for complex omics data.
"""
from __future__ import annotations
from typing import Any
import numpy as np
from biomedical_data_generator.config import CorrClusterConfig, DatasetConfig
__all__ = [
"build_correlation_matrix",
"sample_correlated_data",
"apply_anchor_effects",
"sample_all_correlated_clusters",
]
CORRELATION_ZERO_THRESHOLD = 1e-12
# ============================================================================
# Correlation matrix construction
# ============================================================================
[docs]
def build_correlation_matrix(
n_features: int,
correlation: float,
structure: str = "equicorrelated",
) -> np.ndarray:
"""Build a correlation matrix with specified structure.
Args:
n_features: Number of features (matrix dimension).
correlation: Correlation parameter.
structure: Either 'equicorrelated' or 'toeplitz'.
Returns:
Correlation matrix of shape (n_features, n_features).
Raises:
ValueError: If structure is unknown or correlation is out of bounds.
"""
if structure not in {"equicorrelated", "toeplitz"}:
raise ValueError(f"Unknown correlation structure: {structure}")
if n_features < 2:
raise ValueError(f"Correlation matrix requires at least two features, got {n_features}.")
if structure == "equicorrelated":
lower_bound = -1.0 / (n_features - 1) if n_features > 1 else -1.0
if not (lower_bound < correlation < 1.0):
raise ValueError(
f"For equicorrelated with n_features={n_features}, correlation must be in "
f"({lower_bound:.4f}, 1.0), got {correlation}."
)
r = np.full((n_features, n_features), correlation, dtype=np.float64)
np.fill_diagonal(r, 1.0)
return r
elif structure == "toeplitz":
if not (-1.0 < correlation < 1.0):
raise ValueError(f"For toeplitz, correlation must be in (-1.0, 1.0), got {correlation}.")
exponents = np.abs(np.arange(n_features)[:, None] - np.arange(n_features)[None, :])
r = correlation**exponents
return r.astype(np.float64)
else:
raise ValueError(f"Unknown correlation structure: {structure}")
def _cholesky_with_jitter(
corr_matrix: np.ndarray,
initial_jitter: float = 1e-10,
growth: float = 10.0,
max_tries: int = 8,
) -> np.ndarray:
r"""Compute a Cholesky factor with diagonal jitter fallback.
This helper is designed to be robust for nearly singular covariance or
correlation matrices that may arise from extreme parameter choices or
numerical round-off.
The strategy is:
1. Try plain Cholesky factorization.
2. If it fails with ``LinAlgError``, successively add a small diagonal
jitter ``eps * I`` and retry, increasing ``eps`` by ``growth`` after
each failed attempt.
3. If all attempts fail, raise a ``LinAlgError``.
Args:
corr_matrix: Symmetric positive (semi-)definite matrix. It is assumed
to be theoretically positive definite for the chosen parameters.
max_tries: Maximum number of jitter attempts after the initial
Cholesky attempt.
initial_jitter: Starting jitter value ``eps``.
growth: Factor by which ``eps`` is multiplied after each failed
attempt.
Returns:
Lower-triangular Cholesky factor ``L`` such that
``L @ L.T`` approximates ``corr_matrix`` (plus small jitter).
Raises:
np.linalg.LinAlgError: If Cholesky factorization fails even after all
jitter attempts.
Notes:
The added jitter is intentionally small and increased only as much as
required to obtain a numerically stable factor. This trades negligible
perturbations of the correlation structure for robust behavior in
ill-conditioned settings, which is typically acceptable for didactic
simulations.
"""
try:
# Fast path: no jitter needed.
return np.linalg.cholesky(corr_matrix)
except np.linalg.LinAlgError:
pass
jitter = float(initial_jitter)
identity = np.eye(corr_matrix.shape[0], dtype=corr_matrix.dtype)
for _ in range(max_tries):
try:
return np.linalg.cholesky(corr_matrix + jitter * identity)
except np.linalg.LinAlgError:
jitter *= growth
# Should never reach here due to raise in loop
raise np.linalg.LinAlgError(
"Cholesky factorization failed even after adding diagonal jitter. "
"Check correlation parameters for near-singular configurations."
)
# ============================================================================
# Anchor effect application
# ============================================================================
[docs]
def apply_anchor_effects(
x: np.ndarray,
y: np.ndarray,
cluster_configs: list[CorrClusterConfig],
) -> np.ndarray:
"""Apply class-specific mean shifts to anchor features.
This function modifies the data matrix in-place by adding mean shifts to
anchor features based on their configured effect sizes and target classes.
The anchor feature (typically the first feature in each cluster) receives
the full effect size, while correlated proxy features receive attenuated
shifts proportional to their empirical correlation with the anchor.
**Effect application logic**:
- anchor_role="noise" → no shift (effect_size ignored)
- anchor_role="informative" + anchor_effect_size > 0 → apply shift
- Due to CorrClusterConfig validation, informative anchors always have
anchor_effect_size != None
Args:
x: Feature matrix of shape (n_samples, n_features). Modified in-place.
y: Class labels of shape (n_samples,).
cluster_configs: List of cluster configurations with anchor metadata.
Returns:
The modified feature matrix (same object as input x).
"""
x = np.asarray(x)
y = np.asarray(y)
feature_offset = 0
for cluster_cfg in cluster_configs:
n_cluster_features = cluster_cfg.n_cluster_features
cluster_slice = slice(feature_offset, feature_offset + n_cluster_features)
# Resolve numeric effect size (returns 0.0 for noise anchors)
effect_size = cluster_cfg.resolve_anchor_effect_size()
target_class = cluster_cfg.anchor_class
# Skip if no effect (noise anchor or zero effect)
if effect_size == 0.0:
feature_offset += n_cluster_features
continue
# Identify samples in target class (None → all classes)
if target_class is None:
target_mask = np.ones(len(y), dtype=bool)
else:
target_mask = y == target_class
# Apply full shift to anchor feature (first in cluster)
anchor_idx = feature_offset
x[target_mask, anchor_idx] += effect_size
# Apply attenuated shifts to proxy features based on correlation
if n_cluster_features > 1:
cluster_data = x[:, cluster_slice]
cluster_corr = np.corrcoef(cluster_data, rowvar=False)
# Correlation of each proxy with anchor
anchor_correlations = cluster_corr[0, 1:]
for i, corr_with_anchor in enumerate(anchor_correlations, start=1):
proxy_idx = feature_offset + i
# Proxy shift = effect_size * correlation_with_anchor
x[target_mask, proxy_idx] += effect_size * corr_with_anchor
feature_offset += n_cluster_features
return x
# ============================================================================
# High-level cluster generation
# ============================================================================
def _sample_class_specific_cluster(
y: np.ndarray,
n_features: int,
cluster_cfg: CorrClusterConfig,
structure: str,
rng: np.random.Generator,
) -> np.ndarray:
"""Generate a cluster with per-class correlation strengths.
Args:
y: Class labels as 1D array of length n_samples.
n_features: Number of features in this cluster.
cluster_cfg: Cluster configuration with class-specific correlations.
structure: Correlation structure ('equicorrelated' or 'toeplitz').
rng: Random number generator.
Returns:
Feature block of shape (n_samples, n_features) with class-specific
correlation patterns.
"""
n_samples = len(y)
block = np.empty((n_samples, n_features), dtype=float)
for cls in np.unique(y):
cls_int = int(cls)
cls_mask = y == cls_int
n_cls = int(cls_mask.sum())
if n_cls == 0:
continue
corr_cls = float(cluster_cfg.get_correlation_for_class(cls_int))
if abs(corr_cls) < CORRELATION_ZERO_THRESHOLD:
cls_block = rng.standard_normal(size=(n_cls, n_features))
else:
cls_block = sample_correlated_data(
n_cls,
n_features,
corr_cls,
structure=structure,
rng=rng,
)
block[cls_mask, :] = cls_block
return block