Source code for optimagic.optimization.history

import warnings
from dataclasses import dataclass
from typing import Any

import numpy as np

from optimagic.typing import EvalTask, PyTree


@dataclass(frozen=True)
class HistoryEntry:
    params: PyTree
    fun: float | None
    time: float
    task: EvalTask


[docs]class History: # TODO: add counters for the relevant evaluations def __init__(self) -> None: self._params: list[PyTree] = [] self._fun: list[float | None] = [] self._time: list[float] = [] self._batches: list[int] = [] self._task: list[EvalTask] = [] def add_entry(self, entry: HistoryEntry, batch_id: int | None = None) -> None: if batch_id is None: batch_id = self._get_next_batch_id() self._params.append(entry.params) self._fun.append(entry.fun) self._time.append(entry.time) self._batches.append(batch_id) self._task.append(entry.task) def add_batch( self, batch: list[HistoryEntry], batch_size: int | None = None ) -> None: # The naming is complicated here: # batch refers to the entries to be added to the history in one go # batch_size is a property of a parallelizing algorithm that influences how # the batch_ids are assigned. It is not the same as the length of the batch. if batch_size is None: batch_size = len(batch) start = self._get_next_batch_id() n_batches = int(np.ceil(len(batch) / batch_size)) ids = np.repeat(np.arange(start, start + n_batches), batch_size)[: len(batch)] for entry, id in zip(batch, ids, strict=False): self.add_entry(entry, id) @property def params(self) -> list[PyTree]: return self._params @property def fun(self) -> list[float | None]: return self._fun @property def time(self) -> list[float]: arr = np.array(self._time) return (arr - arr[0]).tolist() @property def batches(self) -> list[int]: return self._batches @property def task(self) -> list[EvalTask]: return self._task def _get_next_batch_id(self) -> int: if not self._batches: batch = 0 else: batch = self._batches[-1] + 1 return batch # ================================================================================== # Add deprecated dict access # ================================================================================== @property def criterion(self) -> list[float | None]: msg = "The attribute `criterion` of History is deprecated. Use `fun` instead." warnings.warn(msg, FutureWarning) return self.fun @property def runtime(self) -> list[float]: msg = "The attribute `runtime` of History is deprecated. Use `time` instead." warnings.warn(msg, FutureWarning) return self.time def __getitem__(self, key: str) -> Any: msg = "dict-like access to History is deprecated. Use attribute access instead." warnings.warn(msg, FutureWarning) return getattr(self, key)