Source code for optimagic.constraints

"""User-facing constraint classes and their resolved internal counterparts.

Each constraint class describes a constraint on a subset of the parameters that is
selected via a selector function. During constraints processing, the selectors are
resolved to positions in the flat parameter vector (``Constraint._resolve``), which
produces the ``Resolved*`` dataclass defined next to each constraint class. A
resolved constraint refers to parameters by their integer positions and carries
provenance information that links it back to the user provided constraints it was
derived from. The provenance is used to phrase error messages in terms of what the
user actually wrote, even after constraints have been rewritten or merged.

"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import KW_ONLY, dataclass
from typing import TYPE_CHECKING, Any, Callable, TypeAlias

import numpy as np
import pandas as pd
from numpy.typing import ArrayLike, NDArray

from optimagic.exceptions import InvalidConstraintError
from optimagic.optimization.algo_options import CONSTRAINTS_ABSOLUTE_TOLERANCE
from optimagic.typing import PyTree

if TYPE_CHECKING:
    from optimagic.parameters.constraints.resolution import ResolutionContext

FloatArray: TypeAlias = NDArray[np.float64]
IntArray: TypeAlias = NDArray[np.int64]


class Constraint(ABC):
    """Base class for all constraints used for subtyping."""

    @abstractmethod
    def _to_dict(self) -> dict[str, Any]:
        pass

    @abstractmethod
    def _resolve(self, context: ResolutionContext) -> ResolvedConstraint | None:
        """Resolve the constraint's selectors to flat parameter positions.

        Returns None if the selection is empty, in which case the constraint is
        dropped.

        """


@dataclass(frozen=True)
class ConstraintSource:
    """User constraint from which an internal constraint was derived.

    Attributes:
        constraint: The user provided constraint object. Dictionary constraints are
            converted to constraint objects before resolution, so this is always a
            Constraint instance.
        position: The position of the constraint in the user provided list of
            constraints.

    """

    constraint: Constraint
    position: int

    def describe(self) -> str:
        return f"constraint {self.position}: {self.constraint!r}"


class ResolvedConstraint(ABC):  # noqa: B024
    """Base class for all resolved constraints used for subtyping."""


def _as_position_array(positions: Any) -> IntArray:
    """Convert positions to an int64 array."""
    return np.asarray(positions, dtype=np.int64)


def _as_float_array(values: Any) -> FloatArray:
    """Convert values to a float64 array."""
    return np.asarray(values, dtype=np.float64)


def identity_selector(x: PyTree) -> PyTree:
    return x


[docs] @dataclass(frozen=True) class FixedConstraint(Constraint): """Constraint that fixes the selected parameters at their starting values. Attributes: selector: A function that takes as input the parameters and returns the subset of parameters to be constrained. By default, all parameters are constrained. Raises: InvalidConstraintError: If the selector is not callable. """ selector: Callable[[PyTree], PyTree] = identity_selector def _to_dict(self) -> dict[str, Any]: return {"type": "fixed", "selector": self.selector} def __post_init__(self) -> None: if not callable(self.selector): raise InvalidConstraintError("'selector' must be callable.") def _resolve(self, context: ResolutionContext) -> ResolvedFixedConstraint | None: index = context.select(self.selector) if len(index) == 0: return None return ResolvedFixedConstraint(index=index, sources=(context.source,))
@dataclass(frozen=True, eq=False) class ResolvedFixedConstraint(ResolvedConstraint): """Fix the selected parameters. Attributes: index: Positions of the fixed parameters in the flat parameter vector. sources: The user constraints this constraint was derived from. value: Explicit values at which the parameters are fixed. None means the parameters are fixed at their start values. Explicit values only exist for deprecated dictionary constraints and must coincide with the start values. """ index: IntArray sources: tuple[ConstraintSource, ...] value: Any = None def __post_init__(self) -> None: object.__setattr__(self, "index", _as_position_array(self.index))
[docs] @dataclass(frozen=True) class IncreasingConstraint(Constraint): """Constraint that ensures the selected parameters are increasing. Attributes: selector: A function that takes as input the parameters and returns the subset of parameters to be constrained. By default, all parameters are constrained. Raises: InvalidConstraintError: If the selector is not callable. """ selector: Callable[[PyTree], PyTree] = identity_selector def _to_dict(self) -> dict[str, Any]: return {"type": "increasing", "selector": self.selector} def __post_init__(self) -> None: if not callable(self.selector): raise InvalidConstraintError("'selector' must be callable.") def _resolve( self, context: ResolutionContext ) -> ResolvedIncreasingConstraint | None: index = context.select(self.selector) if len(index) == 0: return None return ResolvedIncreasingConstraint(index=index, sources=(context.source,))
@dataclass(frozen=True, eq=False) class ResolvedIncreasingConstraint(ResolvedConstraint): """Enforce that the selected parameters are weakly increasing. Attributes: index: Positions of the parameters in the flat parameter vector, in the order in which they have to be increasing. sources: The user constraints this constraint was derived from. """ index: IntArray sources: tuple[ConstraintSource, ...] def __post_init__(self) -> None: object.__setattr__(self, "index", _as_position_array(self.index))
[docs] @dataclass(frozen=True) class DecreasingConstraint(Constraint): """Constraint that ensures that the selected parameters are decreasing. Attributes: selector: A function that takes as input the parameters and returns the subset of parameters to be constrained. By default, all parameters are constrained. Raises: InvalidConstraintError: If the selector is not callable. """ selector: Callable[[PyTree], PyTree] = identity_selector def _to_dict(self) -> dict[str, Any]: return {"type": "decreasing", "selector": self.selector} def __post_init__(self) -> None: if not callable(self.selector): raise InvalidConstraintError("'selector' must be callable.") def _resolve( self, context: ResolutionContext ) -> ResolvedDecreasingConstraint | None: index = context.select(self.selector) if len(index) == 0: return None return ResolvedDecreasingConstraint(index=index, sources=(context.source,))
@dataclass(frozen=True, eq=False) class ResolvedDecreasingConstraint(ResolvedConstraint): """Enforce that the selected parameters are weakly decreasing. Attributes: index: Positions of the parameters in the flat parameter vector, in the order in which they have to be decreasing. sources: The user constraints this constraint was derived from. """ index: IntArray sources: tuple[ConstraintSource, ...] def __post_init__(self) -> None: object.__setattr__(self, "index", _as_position_array(self.index))
[docs] @dataclass(frozen=True) class EqualityConstraint(Constraint): """Constraint that ensures that the selected parameters are equal. Attributes: selector: A function that takes as input the parameters and returns the subset of parameters to be constrained. By default, all parameters are constrained. Raises: InvalidConstraintError: If the selector is not callable. """ selector: Callable[[PyTree], PyTree] = identity_selector def _to_dict(self) -> dict[str, Any]: return {"type": "equality", "selector": self.selector} def __post_init__(self) -> None: if not callable(self.selector): raise InvalidConstraintError("'selector' must be callable.") def _resolve(self, context: ResolutionContext) -> ResolvedEqualityConstraint | None: index = context.select(self.selector) if len(index) == 0: return None return ResolvedEqualityConstraint(index=index, sources=(context.source,))
@dataclass(frozen=True, eq=False) class ResolvedEqualityConstraint(ResolvedConstraint): """Enforce that the selected parameters are equal. Attributes: index: Positions of the equal parameters in the flat parameter vector. sources: The user constraints this constraint was derived from. """ index: IntArray sources: tuple[ConstraintSource, ...] def __post_init__(self) -> None: object.__setattr__(self, "index", _as_position_array(self.index))
[docs] @dataclass(frozen=True) class ProbabilityConstraint(Constraint): """Constraint that ensures that the selected parameters are probabilities. This constraint ensures that each of the selected parameters is positive and that the sum of the selected parameters is 1. Attributes: selector: A function that takes as input the parameters and returns the subset of parameters to be constrained. By default, all parameters are constrained. Raises: InvalidConstraintError: If the selector is not callable. """ selector: Callable[[PyTree], PyTree] = identity_selector def _to_dict(self) -> dict[str, Any]: return {"type": "probability", "selector": self.selector} def __post_init__(self) -> None: if not callable(self.selector): raise InvalidConstraintError("'selector' must be callable.") def _resolve( self, context: ResolutionContext ) -> ResolvedProbabilityConstraint | None: index = context.select(self.selector) if len(index) == 0: return None return ResolvedProbabilityConstraint(index=index, sources=(context.source,))
@dataclass(frozen=True, eq=False) class ResolvedProbabilityConstraint(ResolvedConstraint): """Enforce that the selected parameters are positive and sum to one. Attributes: index: Positions of the parameters in the flat parameter vector. sources: The user constraints this constraint was derived from. """ index: IntArray sources: tuple[ConstraintSource, ...] def __post_init__(self) -> None: object.__setattr__(self, "index", _as_position_array(self.index))
[docs] @dataclass(frozen=True) class PairwiseEqualityConstraint(Constraint): """Constraint that ensures that groups of selected parameters are equal. This constraint ensures that each pair between the selected parameters is equal. Attributes: selectors: A list of functions that take as input the parameters and return the subsets of parameters to be constrained. Raises: InvalidConstraintError: If the selector is not callable. """ selectors: list[Callable[[PyTree], PyTree]] def _to_dict(self) -> dict[str, Any]: return {"type": "pairwise_equality", "selectors": self.selectors} def __post_init__(self) -> None: if len(self.selectors) < 2: raise InvalidConstraintError("At least two selectors must be provided.") if not all(callable(s) for s in self.selectors): raise InvalidConstraintError("All selectors must be callable.") def _resolve( self, context: ResolutionContext ) -> ResolvedPairwiseEqualityConstraint | None: indices = tuple(context.select(selector) for selector in self.selectors) lengths = [len(index) for index in indices] if len(set(lengths)) != 1: msg = ( "All selections of a pairwise equality constraint need to have the " f"same length. You have lengths {lengths} in " f"{context.source.describe()}." ) raise InvalidConstraintError(msg) if len(indices[0]) == 0: return None return ResolvedPairwiseEqualityConstraint( indices=indices, sources=(context.source,) )
@dataclass(frozen=True, eq=False) class ResolvedPairwiseEqualityConstraint(ResolvedConstraint): """Enforce equality between corresponding parameters of multiple selections. Attributes: indices: One position array per selection. All arrays have the same length and corresponding entries are constrained to be equal. sources: The user constraints this constraint was derived from. """ indices: tuple[IntArray, ...] sources: tuple[ConstraintSource, ...] def __post_init__(self) -> None: frozen = tuple(_as_position_array(index) for index in self.indices) object.__setattr__(self, "indices", frozen)
[docs] @dataclass(frozen=True) class FlatCovConstraint(Constraint): """Constraint that ensures the selected parameters are a valid covariance matrix. Attributes: selector: A function that takes as input the parameters and returns the subset of parameters to be constrained. By default, all parameters are constrained. regularization: Helps in guiding the optimization towards finding a positive definite covariance matrix instead of only a positive semi-definite matrix. Larger values correspond to a higher likelihood of positive definiteness. Defaults to 0. Raises: InvalidConstraintError: If the selector is not callable or regularization is not a non-negative float or int. """ selector: Callable[[PyTree], PyTree] = identity_selector _: KW_ONLY regularization: float = 0.0 def _to_dict(self) -> dict[str, Any]: return { "type": "covariance", "selector": self.selector, "regularization": self.regularization, } def __post_init__(self) -> None: if not callable(self.selector): raise InvalidConstraintError("'selector' must be callable.") if not isinstance(self.regularization, float | int) or self.regularization < 0: raise InvalidConstraintError( "'regularization' must be a non-negative float or int." ) def _resolve(self, context: ResolutionContext) -> ResolvedFlatCovConstraint | None: index = context.select(self.selector) if len(index) == 0: return None return ResolvedFlatCovConstraint( index=index, regularization=self.regularization, sources=(context.source,), )
@dataclass(frozen=True, eq=False) class ResolvedFlatCovConstraint(ResolvedConstraint): """Enforce that the selected parameters form a valid covariance matrix. Attributes: index: Positions of the parameters in the flat parameter vector. The parameters are the lower triangle of the covariance matrix in C order. regularization: Lower bound on the diagonal of the Cholesky factor of the covariance matrix that helps to keep the matrix positive definite. sources: The user constraints this constraint was derived from. """ index: IntArray regularization: float sources: tuple[ConstraintSource, ...] def __post_init__(self) -> None: object.__setattr__(self, "index", _as_position_array(self.index))
[docs] @dataclass(frozen=True) class FlatSDCorrConstraint(Constraint): """Constraint that ensures the selected parameters are a valid correlation matrix. This constraint ensures that each of the selected parameters is positive and that the sum of the selected parameters is 1. Attributes: selector: A function that takes as input the parameters and returns the subset of parameters to be constrained. By default, all parameters are constrained. regularization: Helps in guiding the optimization towards finding a positive definite covariance matrix instead of only a positive semi-definite matrix. Larger values correspond to a higher likelihood of positive definiteness. Defaults to 0. Raises: InvalidConstraintError: If the selector is not callable or regularization is not a non-negative float or int. """ selector: Callable[[PyTree], PyTree] = identity_selector _: KW_ONLY regularization: float = 0.0 def _to_dict(self) -> dict[str, Any]: return { "type": "sdcorr", "selector": self.selector, "regularization": self.regularization, } def __post_init__(self) -> None: if not callable(self.selector): raise InvalidConstraintError("'selector' must be callable.") if not isinstance(self.regularization, float | int) or self.regularization < 0: raise InvalidConstraintError( "'regularization' must be a non-negative float or int." ) def _resolve( self, context: ResolutionContext ) -> ResolvedFlatSDCorrConstraint | None: index = context.select(self.selector) if len(index) == 0: return None return ResolvedFlatSDCorrConstraint( index=index, regularization=self.regularization, sources=(context.source,), )
@dataclass(frozen=True, eq=False) class ResolvedFlatSDCorrConstraint(ResolvedConstraint): """Enforce that the selected parameters are valid standard deviations and correlations. Attributes: index: Positions of the parameters in the flat parameter vector. The parameters are the standard deviations followed by the lower triangle of the correlation matrix in C order. regularization: Lower bound on the diagonal of the Cholesky factor of the implied covariance matrix that helps to keep the matrix positive definite. sources: The user constraints this constraint was derived from. """ index: IntArray regularization: float sources: tuple[ConstraintSource, ...] def __post_init__(self) -> None: object.__setattr__(self, "index", _as_position_array(self.index))
[docs] @dataclass(frozen=True) class LinearConstraint(Constraint): """Constraint that bounds a linear combination of the selected parameters. This constraint ensures that a linear combination of the selected parameters with the 'weights' is either equal to 'value', or is bounded by 'lower_bound' and 'upper_bound'. Attributes: selector: A function that takes as input the parameters and returns the subset of parameters to be constrained. By default, all parameters are constrained. weights: The weights for the linear combination. If a scalar is provided, it is used for all parameters. Otherwise, it must have the same structure as the selected parameters. lower_bound: The lower bound for the linear combination. Defaults to None. upper_bound: The upper bound for the linear combination. Defaults to None. value: The value to compare the linear combination to. Defaults to None. Raises: InvalidConstraintError: If the selector is not callable, or if the weights, lower_bound, upper_bound, or value are not valid. """ selector: Callable[[PyTree], ArrayLike | "pd.Series[float]" | float | int] = ( identity_selector ) _: KW_ONLY weights: ArrayLike | "pd.Series[float]" | float | int | None = None lower_bound: float | int | None = None upper_bound: float | int | None = None value: float | int | None = None def _to_dict(self) -> dict[str, Any]: return { "type": "linear", "selector": self.selector, "weights": self.weights, **_select_non_none( lower_bound=self.lower_bound, upper_bound=self.upper_bound, value=self.value, ), } def __post_init__(self) -> None: if not callable(self.selector): raise InvalidConstraintError("'selector' must be callable.") if _all_none(self.lower_bound, self.upper_bound, self.value): raise InvalidConstraintError( "At least one of 'lower_bound', 'upper_bound', or 'value' must be " "non-None." ) if self.value is not None and not _all_none(self.lower_bound, self.upper_bound): raise InvalidConstraintError( "'value' cannot be used with 'lower_bound' or 'upper_bound'." ) if not isinstance(self.weights, np.ndarray | list | pd.Series | float | int): raise InvalidConstraintError( "'weights' must be an array-like, a pandas Series, a float, or an int." ) if self.lower_bound is not None and not isinstance( self.lower_bound, float | int ): raise InvalidConstraintError("'lower_bound' must be a float or an int.") if self.upper_bound is not None and not isinstance( self.upper_bound, float | int ): raise InvalidConstraintError("'upper_bound' must be a float or an int.") if self.value is not None and not isinstance(self.value, float | int): raise InvalidConstraintError("'value' must be a float or an int.") def _resolve(self, context: ResolutionContext) -> ResolvedLinearConstraint | None: index = context.select(self.selector) if len(index) == 0: return None return ResolvedLinearConstraint( index=index, weights=self._aligned_weights(index, context.source), lower_bound=-np.inf if self.lower_bound is None else self.lower_bound, upper_bound=np.inf if self.upper_bound is None else self.upper_bound, value=np.nan if self.value is None else self.value, sources=(context.source,), ) def _aligned_weights(self, index: IntArray, source: ConstraintSource) -> FloatArray: """Broadcast and length-check the weights against the selected positions.""" if isinstance(self.weights, (np.ndarray, list, tuple, pd.Series)): if len(self.weights) != len(index): msg = ( f"weights of length {len(self.weights)} could not be aligned " f"with the {len(index)} selected parameters in " f"{source.describe()}." ) raise InvalidConstraintError(msg) out = np.asarray(self.weights, dtype=np.float64) elif isinstance(self.weights, (float, int)): out = np.full(len(index), float(self.weights)) else: msg = ( f"Invalid type for linear weights: {type(self.weights)}. The " f"problematic constraint is {source.describe()}." ) raise InvalidConstraintError(msg) return out
@dataclass(frozen=True, eq=False) class ResolvedLinearConstraint(ResolvedConstraint): """Restrict a weighted sum of the selected parameters. Attributes: index: Positions of the parameters in the flat parameter vector. weights: Weights of the parameters in the weighted sum, aligned with index. lower_bound: Lower bound on the weighted sum; -inf if there is none. upper_bound: Upper bound on the weighted sum; inf if there is none. value: Value at which the weighted sum is fixed; nan if it is not fixed. sources: The user constraints this constraint was derived from. """ index: IntArray weights: FloatArray sources: tuple[ConstraintSource, ...] lower_bound: float = -np.inf upper_bound: float = np.inf value: float = np.nan def __post_init__(self) -> None: object.__setattr__(self, "index", _as_position_array(self.index)) object.__setattr__(self, "weights", _as_float_array(self.weights))
[docs] @dataclass(frozen=True) class NonlinearConstraint(Constraint): """Constraint that bounds a nonlinear function of the selected parameters. This constraint ensures that a nonlinear function of the selected parameters is either equal to 'value', or is bounded by 'lower_bound' and 'upper_bound'. Attributes: selector: A function that takes as input the parameters and returns the subset of parameters to be constrained. By default, all parameters are constrained. func: The constraint function which is applied to the selected parameters. derivative: The derivative of the constraint function with respect to the selected parameters. Defaults to None. lower_bound: The lower bound for the nonlinear function. Can be a scalar or of the same structure as output of the constraint function. Defaults to None. upper_bound: The upper bound for the nonlinear function. Can be a scalar or of the same structure as output of the constraint function. Defaults to None. value: The value to compare the nonlinear function to. Can be a scalar or of the same structure as output of the constraint function. Defaults to None. tol: The tolerance for the constraint function. Defaults to `optimagic.optimization.algo_options.CONSTRAINTS_ABSOLUTE_TOLERANCE`. Raises: InvalidConstraintError: If the selector is not callable, or if the func, derivative, lower_bound, upper_bound, or value are not valid. """ selector: Callable[[PyTree], PyTree] = identity_selector _: KW_ONLY func: Callable[[PyTree], ArrayLike | "pd.Series[float]" | float] | None = None derivative: Callable[[PyTree], PyTree] | None = None lower_bound: ArrayLike | "pd.Series[float]" | float | None = None upper_bound: ArrayLike | "pd.Series[float]" | float | None = None value: ArrayLike | "pd.Series[float]" | float | None = None tol: float = CONSTRAINTS_ABSOLUTE_TOLERANCE def _to_dict(self) -> dict[str, Any]: return { "type": "nonlinear", "selector": self.selector, **_select_non_none( func=self.func, derivative=self.derivative, # In the dict representation, we write _bounds instead of _bound. lower_bounds=self.lower_bound, upper_bounds=self.upper_bound, value=self.value, tol=self.tol, ), } def __post_init__(self) -> None: if not callable(self.selector): raise InvalidConstraintError("'selector' must be callable.") if _all_none(self.lower_bound, self.upper_bound, self.value): raise InvalidConstraintError( "At least one of 'lower_bound', 'upper_bound', or 'value' must be " "non-None." ) if self.value is not None and not _all_none(self.lower_bound, self.upper_bound): raise InvalidConstraintError( "'value' cannot be used with 'lower_bound' or 'upper_bound'." ) if self.tol is not None and ( not isinstance(self.tol, float | int) or self.tol < 0 ): raise InvalidConstraintError("'tol' must be non-negative.") if self.func is None or not callable(self.func): raise InvalidConstraintError("'func' must be callable.") if self.derivative is not None and not callable(self.derivative): raise InvalidConstraintError("'derivative' must be callable.") def _resolve(self, context: ResolutionContext) -> ResolvedConstraint | None: raise NotImplementedError( "Nonlinear constraints are directly passed to optimizers that support " "them and must not be resolved." )
def _all_none(*args: Any) -> bool: return all(v is None for v in args) def _select_non_none(**kwargs: Any) -> dict[str, Any]: return {k: v for k, v in kwargs.items() if v is not None}