# Copyright (c) 2025 Sigrun May,
# Ostfalia Hochschule für angewandte Wissenschaften
#
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
"""Configuration models for the dataset generator."""
from __future__ import annotations
import warnings
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, TypeAlias, cast
import numpy as np
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
AnchorMode: TypeAlias = Literal["equalized", "strong"]
DistributionType = Literal[
"normal",
"lognormal",
"exp_normal", # np.exp(rng.normal()) - direct control over underlying parameters for lognormal distribution
"uniform",
"exponential",
"laplace",
]
def validate_distribution_params(
params: dict[str, float],
distribution: str,
) -> dict[str, float]:
"""Shared validator for distribution parameters.
Args:
params: Parameter dict to validate.
distribution: Distribution type (e.g., "normal", "uniform").
Returns:
Validated parameter dict.
Raises:
ValueError: If parameters are invalid for the given distribution.
"""
if not params:
return params
param_schema = {
"normal": {"required": set(), "optional": {"loc", "scale"}},
"uniform": {"required": {"low", "high"}, "optional": set()},
"laplace": {"required": set(), "optional": {"loc", "scale"}},
"exponential": {"required": set(), "optional": {"scale"}},
"lognormal": {"required": set(), "optional": {"mean", "sigma"}},
"exp_normal": {"required": set(), "optional": {"loc", "scale"}},
}
schema = param_schema.get(distribution)
if not schema:
return params
provided = set(params.keys())
invalid = provided - (schema["required"] | schema["optional"])
if invalid:
raise ValueError(
f"Invalid parameters {invalid} for '{distribution}'. " f"Allowed: {schema['required'] | schema['optional']}"
)
missing = schema["required"] - provided
if missing:
raise ValueError(f"Missing required parameters {missing} for '{distribution}'")
# Distribution-specific checks
if distribution == "uniform":
try:
low = float(params["low"])
high = float(params["high"])
except (ValueError, TypeError) as e:
raise ValueError(
f"uniform parameters must be numeric, got low={params['low']}, high={params['high']}"
) from e
if not (low < high):
raise ValueError(f"uniform: 'high' ({high}) must be > 'low' ({low})")
# Scale parameters must be positive (normal, laplace, exponential, exp_normal)
if "scale" in params:
try:
scale_val = float(params["scale"])
except (ValueError, TypeError) as e:
raise ValueError(f"'scale' must be numeric, got {params['scale']}") from e
if scale_val <= 0:
raise ValueError(f"'scale' must be > 0, got {scale_val}")
# Sigma must be positive (lognormal)
if "sigma" in params:
try:
sigma_val = float(params["sigma"])
except (ValueError, TypeError) as e:
raise ValueError(f"'sigma' must be numeric, got {params['sigma']}") from e
if sigma_val <= 0:
raise ValueError(f"'sigma' must be > 0, got {sigma_val}")
# loc and mean must be numeric if present
for param_name in ["loc", "mean"]:
if param_name in params:
try:
float(params[param_name])
except (ValueError, TypeError) as e:
raise ValueError(f"'{param_name}' must be numeric, got {params[param_name]}") from e
return params
[docs]
class ClassConfig(BaseModel):
"""Configuration for a single class in the dataset.
Each class is defined by its sample count, distribution, and optional label.
Class indices (0, 1, 2, ...) are assigned by position in the list.
Auto-generated labels follow pattern “class_{idx}”.
Args:
n_samples: Number of samples for this class (must be >= 1).
class_distribution: Distribution type for feature generation. Supported numpy random generator distrubutions:
- "normal", "lognormal", "uniform", "exponential", "laplace". Additionally, "exp_normal" for direct control
over lognormal parameters.
class_distribution_params: Parameters for the chosen distribution.
label: Optional descriptive name. Auto-generated as "class_0", "class_1", etc. if not provided.
Examples:
>>> # Auto-generated labels
>>> configs = [
... ClassConfig(n_samples=100), # label → "class_0"
... ClassConfig(n_samples=50) # label → "class_1"
... ]
>>> # Explicit semantic labels
>>> configs = [
... ClassConfig(n_samples=100, label="healthy"),
... ClassConfig(n_samples=50, label="diseased")
... ]
>>> # Different distributions per class
>>> configs = [
... ClassConfig(
... n_samples=50,
... label="control",
... class_distribution="normal"
... ),
... ClassConfig(
... n_samples=30,
... label="diseased",
... class_distribution="lognormal",
... class_distribution_params={"mean": 0, "sigma": 0.5}
... )
... ]
"""
model_config = ConfigDict(extra="forbid")
n_samples: int = Field(
default=30,
ge=1,
description="Number of samples in this class.",
)
class_distribution: DistributionType = Field(
default="normal",
description="Distribution type for base feature generation.",
)
class_distribution_params: dict[str, Any] = Field(
default_factory=lambda: {"loc": 0, "scale": 1},
description="Distribution parameters.",
)
label: str | None = Field(
default=None,
description="Label (auto-generated as 'class_{idx}' if None).",
)
@field_validator("class_distribution_params")
@classmethod
def _validate_class_params(cls, v: dict[str, float] | None, info) -> dict[str, float] | None:
"""Validate distribution parameters match the chosen distribution."""
if v is None:
return v
distribution = info.data.get("class_distribution", "normal")
return validate_distribution_params(v, distribution)
def __str__(self) -> str:
"""Concise string representation."""
parts = [f"n={self.n_samples}"]
if self.label:
parts.append(f"label='{self.label}'")
if self.class_distribution != "normal":
parts.append(f"dist={self.class_distribution}")
return f"ClassConfig({', '.join(parts)})"
[docs]
class BatchEffectsConfig(BaseModel):
"""Configuration for simulating batch effects.
Simulate batch effects by adding random intercepts or scaling factors
to a subset of features. This can be used to mimic:
- site-to-site differences (multi-center studies),
- instrument calibration shifts,
- cohort / recruitment waves (temporal batches).
**Conceptual separation of batch effect aspects**:
- ``confounding_with_class`` controls **sampling bias**:
which samples (classes) are recruited into which batch.
- ``effect_strength``, ``effect_type`` and ``effect_granularity`` control
**technical variation**: how strongly, and how coherently across features,
the measurements shift between batches.
Args:
n_batches:
Number of batches. Value 0 effectively disables batch effects.
effect_strength: Scale of batch effects. Must be non-negative.
- For ``effect_type="additive"``: standard deviation of the additive
batch effects, sampled as ``Normal(0, effect_strength)``.
- For ``effect_type="multiplicative"``: standard deviation of the
multiplicative deviations around 1.0, sampled as
``1 + Normal(0, effect_strength)``.
effect_type: Type of batch effect.
- ``"additive"``: Additive intercepts (shifts in feature means).
- ``"multiplicative"``: Multiplicative scaling (changes in variance/scale).
effect_granularity: Granularity of batch effects across features:
- ``"per_feature"``: draw distinct effects per batch and affected
feature (shape ``(n_batches, n_affected_features)``).
- ``"scalar"``: draw a single effect per batch and apply it
uniformly to all affected features (global per-batch shift/scale).
confounding_with_class: Degree of confounding between batch and class in ``[0.0, 1.0]``.
Controls how strongly batch assignment correlates with class labels,
simulating **recruitment bias** in multi-center studies.
Semantics (for two classes / two batches with equal base proportions):
- 0.0 → independent: each batch has ~50/50 class mix.
- 0.5 → moderate correlation.
- 0.8 → strong recruitment bias (most samples of a class go to
one batch).
- 1.0 → perfect confounding: each class maps to one preferred
batch (if ``n_batches >= n_classes``).
affected_features: Which features should be affected:
- ``"all"``: apply batch effects to all features.
- list of ints: explicit 0-based column indices of affected features.
proportions: Optional target proportions for batch sizes. Values are normalized
to sum to 1. If ``None``, batches are (approximately) equal in size.
"""
model_config = ConfigDict(extra="forbid")
# 0 or 1 => effectively no batch effect
n_batches: int = Field(default=0, ge=0)
# std of batch effects (0.0 allowed => no effect)
effect_strength: float = Field(default=0.5, ge=0.0)
effect_type: Literal["additive", "multiplicative"] = "additive"
# how structured across features: per-feature vs scalar per batch
effect_granularity: Literal["per_feature", "scalar"] = Field(default="per_feature")
# in [0, 1], controls recruitment bias / confounding
confounding_with_class: float = Field(default=0.0, ge=0.0, le=1.0)
# 0-based column indices; "all" => all features
affected_features: list[int] | Literal["all"] = "all"
# optional batch size proportions
proportions: list[float] | None = None
[docs]
@field_validator("proportions")
@classmethod
def validate_proportions(cls, v: list[float] | None, info):
"""Ensure proportions are non-negative, match n_batches, and sum to 1."""
if v is None:
return v
if len(v) == 0:
raise ValueError("proportions must not be empty if provided.")
# Non-negative entries
for p in v:
if p < 0:
raise ValueError(f"proportions must be non-negative, got {p}.")
# Check length vs n_batches (if > 0)
n_batches = info.data.get("n_batches")
if isinstance(n_batches, int) and 0 < n_batches != len(v):
raise ValueError(f"proportions length ({len(v)}) must match n_batches ({n_batches}).")
total = float(sum(v))
if total <= 0:
raise ValueError(f"Sum of proportions must be > 0, got {total}.")
# Normalize to sum to 1.0
return [p / total for p in v]
[docs]
class CorrClusterConfig(BaseModel):
"""Correlated feature cluster simulating coordinated biomarker patterns.
A cluster represents a group of biomarkers that move together, such as
markers in a metabolic pathway or proteins in a signaling cascade. One
marker acts as the "anchor" (driver), while the others are "proxies"
(followers).
**Two correlation modes are supported:**
1) Global correlation (most common):
correlation: float
structure: "equicorrelated" or "toeplitz"
Example:
correlation = 0.7
structure = "equicorrelated"
All samples share the same correlation pattern.
2) Class-specific correlation:
correlation: dict[int, float]
Example (pathway only active in class 1):
correlation = {0: 0.0, 1: 0.8}
Classes not listed in the dict default to 0.0 (independent cluster).
The correlation *structure* is global for the cluster and applies
to all classes.
Args:
n_cluster_features:
Number of biomarkers in the cluster (including anchor). Must be >= 1.
structure:
Correlation structure for this cluster:
- "equicorrelated": all pairwise correlations are equal.
- "toeplitz": correlation decays with feature distance.
correlation:
Either a single global correlation strength (float) or a mapping
{class_index -> correlation} for class-specific correlations.
Typical magnitudes:
- 0.0 = independent
- 0.3 ≈ weak correlation
- 0.5 ≈ moderate correlation
- 0.8+ ≈ strong correlation
anchor_role:
"informative" or "noise".
anchor_effect_size:
"small" (0.5), "medium" (1.0), "large" (1.5), custom > 0, or None.
anchor_class:
Class index the anchor predicts (if informative). None → all classes.
label:
Descriptive name for documentation.
"""
model_config = ConfigDict(extra="forbid")
# Core cluster structure and correlation settings -------------------------
n_cluster_features: int = Field(
...,
ge=1,
description="Number of biomarkers in cluster (including anchor).",
)
structure: Literal["equicorrelated", "toeplitz"] = Field(
default="equicorrelated",
description="Correlation structure shared by all classes.",
)
# Either a global correlation or a per-class mapping
correlation: float | dict[int, float] = Field(
default=0.8,
description=(
"Global correlation strength (float) or per-class mapping "
"{class_index: correlation}. Values must satisfy PD constraints."
),
)
anchor_role: Literal["informative", "noise"] = "informative"
anchor_effect_size: Literal["small", "medium", "large"] | float | None = None
anchor_class: int | None = None
label: str | None = None
@field_validator("n_cluster_features")
@classmethod
def _validate_size(cls, v: int) -> int:
"""Ensure cluster has at least one marker."""
if v < 1:
raise ValueError(f"n_cluster_features must be >= 1, got {v}.")
return v
@field_validator("correlation")
@classmethod
def _validate_correlation(cls, v, info):
"""Validate correlation values against structure constraints.
For equicorrelated clusters with p features:
-1/(p-1) < rho < 1 (for p > 1)
For toeplitz clusters:
|rho| < 1
Where rho is the correlation coefficient between two features.
"""
n_features = int(info.data.get("n_cluster_features", 0))
if n_features <= 0:
return v
structure = info.data.get("structure", "equicorrelated")
def check_one(correlation: float) -> None:
if structure == "equicorrelated":
lower = -1.0 / (n_features - 1) if n_features > 1 else float("-inf")
if not (lower < correlation < 1.0):
raise ValueError(
f"correlation={correlation} invalid for equicorrelated with "
f"n_cluster_features={n_features}; "
f"require {lower:.6f} < correlation < 1."
)
else: # toeplitz
if not (-1.0 < correlation < 1.0):
raise ValueError(f"correlation={correlation} invalid for toeplitz; require |correlation| < 1.")
if isinstance(v, dict):
for cls_idx, rho_val in v.items():
if cls_idx < 0:
raise ValueError(f"correlation keys must be >= 0, got {cls_idx}.")
check_one(float(rho_val))
else:
check_one(float(v))
return v
@field_validator("anchor_effect_size")
@classmethod
def _validate_effect_size(cls, v) -> Literal["small", "medium", "large"] | float | None:
"""Validate effect size is either a preset or positive float."""
if v is None:
return v
if isinstance(v, str):
if v not in ("small", "medium", "large"):
raise ValueError("anchor_effect_size must be 'small', 'medium', or 'large', " f"got '{v}'.")
return cast(Literal["small", "medium", "large"], v)
try:
val = float(v)
if val <= 0:
raise ValueError(f"anchor_effect_size must be > 0, got {val}.")
return val
except (TypeError, ValueError) as e:
raise ValueError(
"anchor_effect_size must be 'small'/'medium'/'large' or positive float, " f"got {v}."
) from e
@field_validator("anchor_class")
@classmethod
def _validate_anchor_class(cls, v: int | None) -> int | None:
"""Ensure disease class is non-negative if specified."""
if v is not None and v < 0:
raise ValueError(f"anchor_class must be >= 0 or None, got {v}.")
return v
# Convenience methods -----------------------------------------------------
[docs]
def is_class_specific(self) -> bool:
"""Return True if this cluster uses class-specific correlations."""
return isinstance(self.correlation, dict)
[docs]
def get_correlation_for_class(self, class_idx: int) -> float:
"""Resolve correlation for a specific class.
- Global mode: return the single correlation value.
- Class-specific mode: return mapping value or 0.0 if not specified.
"""
if not self.is_class_specific():
return float(self.correlation) # type: ignore[arg-type]
mapping = cast(dict[int, float], self.correlation)
return float(mapping.get(class_idx, 0.0))
[docs]
def resolve_anchor_effect_size(self) -> float:
"""Convert anchor_effect_size to a numeric effect size."""
if self.anchor_effect_size is None:
return 1.0 # default to medium
if isinstance(self.anchor_effect_size, str):
effect_map = {
"small": 0.5,
"medium": 1.0,
"large": 1.5,
}
return effect_map[self.anchor_effect_size]
return float(self.anchor_effect_size)
def __str__(self) -> str:
"""Concise representation for quick reference."""
parts = [f"n_cluster_features={self.n_cluster_features}"]
if isinstance(self.correlation, dict):
parts.append("correlation=class-specific")
else:
parts.append(f"correlation={self.correlation}")
if self.anchor_role != "informative":
parts.append(self.anchor_role)
if self.anchor_effect_size:
parts.append(f"effect={self.anchor_effect_size}")
if self.label:
parts.append(f"'{self.label}'")
return f"CorrCluster({', '.join(parts)})"
# ---------------------------------------------------------------------------
# Dataset configuration
# ---------------------------------------------------------------------------
[docs]
class DatasetConfig(BaseModel):
"""Configuration for synthetic dataset generation.
This model defines the *input-level* controls for building a synthetic dataset.
It combines:
- Base role counts: `n_informative` and `n_noise`
- Correlated clusters: `corr_clusters` (each 1 anchor + (k−1) proxies)
- Class definitions: `class_configs` (with per-class n_samples and labels)
- Optional batch effects
*Examples (counting)*:
1) One cluster k=4 with an informative anchor, plus n_informative=3, n_noise=2
proxies_from_clusters = (4−1) = 3
n_features_expected = 3 + 2 + 3 = 8
Breakdown:
- informative_anchors = 1 → free_informative = 3 − 1 = 2
- noise_anchors = 0 → free_noise = 2 − 0 = 2
2) Two clusters: k=5 (informative anchor), k=3 ("noise" anchor), base n_informative=4, n_noise=3
proxies_from_clusters = (5−1) + (3−1) = 6
n_features_expected = 4 + 3 + 6 = 13
Breakdown:
- informative_anchors = 1 → free_informative = 4 − 1 = 3
- noise_anchors = 1 → free_noise = 3 − 1 = 2
*Derived quantities*:
These attributes are **derived** and must not be passed by the user:
- ``n_samples`` (int): Total samples (derived from `class_configs`).
- ``n_features`` (int): Total number of features of the complete the dataset
(derived from n_informative, n_noise, and corr_clusters).
- ``n_classes`` (int): Number of classes (derived from `class_configs`).
- ``n_informative_free`` (int): Informative features not used as anchors.
- ``n_noise_free`` (int): Noise features not used as anchors.
Args:
n_informative (int): Number of base informative features (not in clusters).
n_noise (int): Number of base noise features (not in clusters).
class_configs (list[ClassConfig]): List of class definitions.
class_sep (float | Sequence[float]): Class separation values (length n_classes - 1); scalar is broadcast.
corr_clusters (list[CorrClusterConfig]): List of CorrClusterConfig defining correlated feature clusters.
corr_between (float): Correlation between different clusters/roles (0 = independent).
noise_distribution: (str): Distribution for noise features. Can be any supported `DistributionType`.
noise_distribution_params (dict): Parameters for noise distribution.
prefixed_feature_naming (bool):
If True, role-based prefixed feature names:
* Free informative: i1, i2, ...
* Free noise: n1, n2, ...
* Correlated: corr{cid}_anchor, corr{cid}_2, ..., corr{cid}_k
If False, use generic feature_{i} naming.
Default: True.
prefix_informative (str): Prefix for informative features (if prefixed_feature_naming=True). Default: "i".
prefix_noise (str): Prefix for noise features (if prefixed_feature_naming=True). Default: "n".
prefix_corr (str): Prefix for correlated cluster features (if prefixed_feature_naming=True). Default: "corr".
batch_effects (BatchEffectsConfig): Optional BatchEffectsConfig for simulating batch effects.
random_state (int | None): Global random seed for dataset generation.
Methods:
count_informative_anchors(): Return number of informative anchors across all clusters.
count_noise_anchors(): Return number of noise anchors across all clusters.
breakdown(): Return dict with detailed feature/class counts.
Validation:
Before model construction:
- Forbid manual `n_samples`, `n_classes`, `n_features`.
- Normalize `class_sep`: broadcast scalar to length `n_classes - 1` or validate sequence length.
After model construction:
- Ensure `n_informative >= #informative_anchors` and `n_noise >= #noise_anchors`.
- Check `corr_between` in [-1, 1].
- Ensure `anchor_class` indices < `n_classes`.
- Require at least one non-zero `class_sep` if `n_informative_free > 0`.
- Auto-generate missing class labels as ``class_{idx}``.
Raises:
ValueError: On invalid numeric ranges or inconsistent counts.
TypeError: For invalid types in `class_configs` or `class_sep`.
Examples:
>>> # Basic dataset with two classes
>>> cfg = DatasetConfig(
... n_informative=5,
... n_noise=3,
... class_configs=[
... ClassConfig(n_samples=50, label="healthy"),
... ClassConfig(n_samples=50, label="diseased"),
... ],
... corr_clusters=[
... CorrClusterConfig(
... n_cluster_features=4,
... correlation=0.8,
... anchor_role="informative",
... anchor_effect_size="medium",
... anchor_class=1,
... label="Metabolic Pathway A"
... ),
... CorrClusterConfig(
... n_cluster_features=3,
... correlation=0.5,
... anchor_role="noise",
... label="Random Noise Cluster"
... )
... ],
... corr_between=0.1,
... noise_distribution="normal",
... noise_distribution_params={"loc": 0, "scale": 1},
... prefixed_feature_naming=True,
... random_state=42
... )
"""
model_config = ConfigDict(extra="forbid", use_enum_values=True)
# Core dataset structure
n_informative: int = Field(default=2, ge=0)
n_noise: int = Field(default=0, ge=0)
# Multi-class controls
class_configs: list[ClassConfig] = Field(
[ClassConfig(n_samples=30, label="healthy"), ClassConfig(n_samples=30, label="diseased")], min_length=2
)
class_sep: list[float] = Field(
default_factory=lambda: [1.5],
description="Class separation values (normalized to length n_classes - 1).",
)
# Noise distribution (NumPy Generator API)
noise_distribution: DistributionType = "normal"
noise_distribution_params: dict[str, Any] = Field(default_factory=lambda: {"loc": 0, "scale": 1})
# Naming
prefixed_feature_naming: bool = True
prefix_informative: str = "i"
prefix_noise: str = "n"
prefix_corr: str = "corr"
# Correlated structure
corr_clusters: list[CorrClusterConfig] = Field(default_factory=list)
corr_between: float = 0.0 # correlation between different clusters/roles (0 = independent)
# Batch effects
batch_effects: BatchEffectsConfig | None = None
# Global seed
random_state: int | None = None
@staticmethod
def _iter_cluster_dicts(
raw_config: Mapping[str, Any],
) -> Iterable[Mapping[str, Any]]:
"""Yield cluster dicts from raw_config, regardless of item type.
This helper is kept for potential external use (e.g., pre-inspection of
raw YAML configs). It is not used in the main validation path.
"""
clusters: Any = raw_config.get("corr_clusters") # list[dict] / list[CorrClusterConfig] / None
if not clusters:
return []
out: list[Mapping[str, Any]] = []
for cc in clusters:
if isinstance(cc, CorrClusterConfig):
out.append(cc.model_dump())
elif isinstance(cc, Mapping):
out.append(cc)
else:
raise TypeError(
"corr_clusters entries must be Mapping or CorrClusterConfig, " f"got {type(cc).__name__}"
)
return out
@classmethod
def _validate_sep_value(cls, class_separation: Any) -> float:
"""Validate a single class separation value (numeric & finite)."""
try:
fv = float(class_separation)
except (TypeError, ValueError) as e:
raise TypeError(f"class_sep entries must be numeric, got {class_separation!r}") from e
if not np.isfinite(fv):
raise ValueError(f"class_sep entries must be finite numbers, got {class_separation!r}")
return fv
# ------------------------------------------------------ before validator
@model_validator(mode="before")
@classmethod
def _normalize_and_validate(cls, data: Any) -> Any:
"""Normalize incoming data BEFORE model construction.
- Forbids manual `n_samples`, `n_classes`, `n_features`.
- Normalizes `class_sep` to a list of length `n_classes - 1`.
"""
if isinstance(data, cls):
return data
if not isinstance(data, Mapping):
raise TypeError(f"DatasetConfig expects a mapping-like raw_config, got {type(data).__name__}")
d: dict[str, Any] = dict(data)
# Forbid manual override of derived attributes
for forbidden in ("n_samples", "n_classes", "n_features"):
if forbidden in d:
raise ValueError(
f"{forbidden} is derived from class_configs/corr_clusters and "
"must not be set manually on DatasetConfig."
)
classes = d.get("class_configs")
if not isinstance(classes, Sequence) or isinstance(classes, (str, bytes)):
raise TypeError("class_configs must be a non-string sequence of class definitions.")
n_classes = len(classes)
if n_classes < 2:
raise ValueError(f"At least two classes are required, got {n_classes}.")
# Normalize class_sep:
# - scalar → broadcast
# - sequence → validate entries and length
raw_sep = d.get("class_sep", [1.5])
if isinstance(raw_sep, (int, float)):
sep_list = [cls._validate_sep_value(raw_sep)] * (n_classes - 1)
elif isinstance(raw_sep, Sequence) and not isinstance(raw_sep, (str, bytes)):
sep_list = [cls._validate_sep_value(v) for v in raw_sep]
else:
raise TypeError(f"class_sep must be a number or sequence, got {type(raw_sep).__name__}")
if len(sep_list) != n_classes - 1:
raise ValueError(f"class_sep length must be n_classes - 1 ({n_classes - 1}), " f"got {len(sep_list)}.")
d["class_sep"] = sep_list
return d
# ------------------------------------------------------ field validation
@field_validator("noise_distribution_params")
@classmethod
def _validate_noise_params(cls, v: dict[str, float] | None, info) -> dict[str, float] | None:
"""Validate distribution parameters match the chosen noise distribution."""
if v is None:
return v
distribution = info.data.get("noise_distribution", "normal")
return validate_distribution_params(v, distribution)
# ------------------------------------------------------ after validators
@model_validator(mode="after")
def _enforce_minimum_informative(self):
"""Ensure n_informative >= number of informative anchors."""
required = self.count_informative_anchors()
if self.n_informative < required:
old = self.n_informative
object.__setattr__(self, "n_informative", required)
warnings.warn(
f"[DatasetConfig] n_informative was increased from {old} to {required} "
f"because your correlated clusters define {required} informative anchors.",
UserWarning,
)
return self
@model_validator(mode="after")
def _auto_generate_labels(self):
"""Auto-generate labels as 'class_{idx}' if not provided."""
for idx, cls_cfg in enumerate(self.class_configs):
if cls_cfg.label is None or cls_cfg.label == "":
# ClassConfig is a BaseModel, so we need object.__setattr__
object.__setattr__(cls_cfg, "label", f"class_{idx}")
return self
@model_validator(mode="after")
def _enforce_nonzero_class_sep(self):
if self.n_informative_free > 0:
if all(abs(s) < 1e-12 for s in self.class_sep):
raise ValueError(
"class_sep must contain at least one non-zero value when "
"n_informative_free > 0; otherwise informative features "
"carry no class signal."
)
return self
@model_validator(mode="after")
def __post_init__(self):
"""Sanity checks tying together anchors, counts, and classes."""
if self.n_noise < 0:
raise ValueError("n_noise must be >= 0.")
if self.n_informative < 0:
raise ValueError("n_informative must be >= 0.")
inf_anchors = self.count_informative_anchors()
noise_anchors = self.count_noise_anchors()
if self.n_informative < inf_anchors:
raise ValueError(f"n_informative ({self.n_informative}) < number of informative anchors ({inf_anchors}).")
if self.n_noise < noise_anchors:
raise ValueError(f"n_noise ({self.n_noise}) < number of noise anchors ({noise_anchors}).")
# Corr-between range sanity check
if not (-1.0 <= float(self.corr_between) <= 1.0):
raise ValueError(f"corr_between must lie in [-1, 1], got {self.corr_between}.")
# anchor_class indices must be < n_classes
max_idx = self.n_classes - 1
for cluster in self.corr_clusters or []:
if cluster.anchor_class is not None and cluster.anchor_class > max_idx:
raise ValueError(
f"CorrClusterConfig.anchor_class={cluster.anchor_class} "
f"but only {self.n_classes} classes are defined (max index {max_idx})."
)
return self
[docs]
@classmethod
def from_yaml(cls, path: str) -> DatasetConfig:
"""Load from YAML and validate via the same pipeline."""
import yaml # local import to keep core dependencies lean
with open(path, encoding="utf-8") as f:
raw_config: dict[str, Any] = yaml.safe_load(f) or {}
return cls.model_validate(raw_config)
[docs]
def count_noise_anchors(self) -> int:
"""Count clusters whose anchor is 'noise' (non-informative anchor).
Returns:
The number of clusters with anchor_role == "noise".
"""
return sum(1 for c in (self.corr_clusters or []) if c.anchor_role == "noise")
@staticmethod
def _proxies_from_clusters(
clusters: Iterable[CorrClusterConfig] | None,
) -> int:
"""Number of additional features contributed by all clusters.
For a cluster of size k, proxies = max(0, k - 1) regardless of anchor_role.
"""
if not clusters:
return 0
return sum(max(0, int(c.n_cluster_features) - 1) for c in clusters)
@property
def n_informative_free(self) -> int:
"""Informative features outside clusters (excludes informative anchors)."""
return max(self.n_informative - self.count_informative_anchors(), 0)
@property
def n_noise_free(self) -> int:
"""Independent noise features (excludes noise anchors)."""
return max(self.n_noise - self.count_noise_anchors(), 0)
# ------------------------------ derived global counts ---------------------
@property
def n_samples(self) -> int:
"""Total samples (derived from class_configs)."""
return sum(c.n_samples for c in self.class_configs)
@property
def n_classes(self) -> int:
"""Number of classes (derived from class_configs)."""
return len(self.class_configs)
@property
def n_features(self) -> int:
"""Total number of features (informative + noise + cluster proxies)."""
proxies = self._proxies_from_clusters(self.corr_clusters)
return int(self.n_informative + self.n_noise + proxies)
# ------------------------------ class-level helpers ----------------------
@property
def class_labels(self) -> list[str]:
"""List of class labels (auto-generated or user-provided)."""
return [
c.label if (c.label is not None and c.label != "") else f"class_{i}"
for i, c in enumerate(self.class_configs)
]
@property
def class_counts(self) -> dict[int, int]:
"""Class counts as dict {class_idx: n_samples}."""
return {idx: c.n_samples for idx, c in enumerate(self.class_configs)}
[docs]
def breakdown(self) -> dict[str, int]:
"""Structured feature counts incl. cluster proxies and anchor split.
Returns:
A dict with keys:
- n_informative_total
- n_informative_anchors
- n_informative_free
- n_noise_total
- n_noise_anchors
- n_noise_free
- proxies_from_clusters
- n_features
"""
proxies = self._proxies_from_clusters(self.corr_clusters)
n_inf_anchors = self.count_informative_anchors()
n_noise_anchors = self.count_noise_anchors()
return {
"n_informative_total": int(self.n_informative),
"n_informative_anchors": int(n_inf_anchors),
"n_informative_free": int(max(self.n_informative - n_inf_anchors, 0)),
"n_noise_total": int(self.n_noise),
"n_noise_anchors": int(n_noise_anchors),
"n_noise_free": int(max(self.n_noise - n_noise_anchors, 0)),
"proxies_from_clusters": int(proxies),
"n_features": int(self.n_features),
}