# Copyright (c) 2025 Sigrun May,
# Ostfalia Hochschule für angewandte Wissenschaften
#
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
"""Generation of free informative features and class separation.
This module builds numeric class labels from DatasetConfig.class_configs,
samples base values for free informative features according to per-class
distributions, and applies class-wise mean shifts controlled by
DatasetConfig.class_sep.
Correlated clusters (including anchors) are handled in `correlated.py`.
Noise features are handled in `noise.py`. The shifting logic is implemented
in `shift_classes` and can be reused by other modules (for example, for
anchor effects in correlated clusters).
"""
from __future__ import annotations
import numpy as np
from biomedical_data_generator.config import DatasetConfig
from biomedical_data_generator.utils.sampling import sample_distribution
__all__ = [
"generate_informative_features",
]
def _class_offsets_from_sep(sep_vec: list[float]) -> np.ndarray:
"""Construct centered class-wise offsets from a (K-1,) separation vector.
The returned offsets have length K where K = len(sep_vec) + 1. Offsets
are cumulative sums of the separation entries and are mean-centered.
Args:
sep_vec: 1-D array of length K-1 representing pairwise separations.
Returns:
np.ndarray: 1-D array of length K with class offsets whose mean is zero.
"""
sep = np.asarray(sep_vec, dtype=float).ravel()
offsets = np.concatenate(([0.0], np.cumsum(sep)))
offsets -= offsets.mean()
return offsets
# ---------------------------------------------------------------------------
# Label construction
# ---------------------------------------------------------------------------
def _build_class_labels(cfg: DatasetConfig) -> np.ndarray:
"""Build numeric class labels 0..K-1 from DatasetConfig.class_configs.
Args:
cfg: DatasetConfig containing class_configs with per-class n_samples.
Returns:
np.ndarray: 1-D integer array of length cfg.n_samples with labels in
{0, ..., K-1}.
Raises:
RuntimeError: If the concatenated label length does not match cfg.n_samples.
"""
labels: list[np.ndarray] = []
for idx, cls_cfg in enumerate(cfg.class_configs):
labels.append(np.full(cls_cfg.n_samples, idx, dtype=int))
y = np.concatenate(labels, axis=0)
if y.shape[0] != cfg.n_samples:
raise RuntimeError(f"Inconsistent label construction: got {y.shape[0]} labels, expected {cfg.n_samples}.")
return y
# ---------------------------------------------------------------------------
# Public API: generate_informative_features
# ---------------------------------------------------------------------------