Source code for optimagic.visualization.history_plots

import inspect
import itertools
from pathlib import Path
from typing import Any

import numpy as np
import plotly.graph_objects as go
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten

from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE
from optimagic.logging.logger import LogReader, SQLiteLogOptions
from optimagic.optimization.algorithm import Algorithm
from optimagic.optimization.history import History
from optimagic.optimization.optimize_result import OptimizeResult
from optimagic.parameters.tree_registry import get_registry
from optimagic.typing import Direction


[docs]def criterion_plot( results, names=None, max_evaluations=None, template=PLOTLY_TEMPLATE, palette=PLOTLY_PALETTE, stack_multistart=False, monotone=False, show_exploration=False, ): """Plot the criterion history of an optimization. Args: results (Union[List, Dict][Union[OptimizeResult, pathlib.Path, str]): A (list or dict of) optimization results with collected history. If dict, then the key is used as the name in a legend. names (Union[List[str], str]): Names corresponding to res or entries in res. max_evaluations (int): Clip the criterion history after that many entries. template (str): The template for the figure. Default is "plotly_white". palette (Union[List[str], str]): The coloring palette for traces. Default is "qualitative.Plotly". stack_multistart (bool): Whether to combine multistart histories into a single history. Default is False. monotone (bool): If True, the criterion plot becomes monotone in the sense that only that at each iteration the current best criterion value is displayed. Default is False. show_exploration (bool): If True, exploration samples of a multistart optimization are visualized. Default is False. Returns: plotly.graph_objs._figure.Figure: The figure. """ # ================================================================================== # Process inputs # ================================================================================== results = _harmonize_inputs_to_dict(results, names) if not isinstance(palette, list): palette = [palette] palette = itertools.cycle(palette) fun_or_monotone_fun = "monotone_fun" if monotone else "fun" # ================================================================================== # Extract plotting data from results objects / data base # ================================================================================== data = [] for name, res in results.items(): if isinstance(res, OptimizeResult): _data = _extract_plotting_data_from_results_object( res, stack_multistart, show_exploration, plot_name="criterion_plot" ) elif isinstance(res, (str, Path)): _data = _extract_plotting_data_from_database( res, stack_multistart, show_exploration ) else: msg = "results must be (or contain) an OptimizeResult or a path to a log" f"file, but is type {type(res)}." raise TypeError(msg) _data["name"] = name data.append(_data) # ================================================================================== # Create figure # ================================================================================== fig = go.Figure() plot_multistart = ( len(data) == 1 and data[0]["is_multistart"] and not stack_multistart ) # ================================================================================== # Plot multistart paths if plot_multistart: scatter_kws = { "connectgaps": True, "showlegend": False, } for i, local_history in enumerate(data[0]["local_histories"]): history = getattr(local_history, fun_or_monotone_fun) if max_evaluations is not None and len(history) > max_evaluations: history = history[:max_evaluations] trace = go.Scatter( x=np.arange(len(history)), y=history, mode="lines", name=str(i), line_color="#bab0ac", **scatter_kws, ) fig.add_trace(trace) # ================================================================================== # Plot main optimization objects for _data in data: if stack_multistart and _data["stacked_local_histories"] is not None: _history = _data["stacked_local_histories"] else: _history = _data["history"] history = getattr(_history, fun_or_monotone_fun) if max_evaluations is not None and len(history) > max_evaluations: history = history[:max_evaluations] scatter_kws = { "connectgaps": True, "showlegend": not plot_multistart, } _color = next(palette) if not isinstance(_color, str): msg = "highlight_palette needs to be a string or list of strings, but its " f"entry is of type {type(_color)}." raise TypeError(msg) line_kws = { "color": _color, } trace = go.Scatter( x=np.arange(len(history)), y=history, mode="lines", name="best result" if plot_multistart else _data["name"], line=line_kws, **scatter_kws, ) fig.add_trace(trace) fig.update_layout( template=template, xaxis_title_text="No. of criterion evaluations", yaxis_title_text="Criterion value", legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95}, ) return fig
def _harmonize_inputs_to_dict(results, names): """Convert all valid inputs for results and names to dict[str, OptimizeResult].""" # convert scalar case to list case if not isinstance(names, list) and names is not None: names = [names] if isinstance(results, (OptimizeResult, str, Path)): results = [results] if names is not None and len(names) != len(results): raise ValueError("len(results) needs to be equal to len(names).") # handle dict case if isinstance(results, dict): if names is not None: results_dict = dict(zip(names, list(results.values()), strict=False)) else: results_dict = results # unlabeled iterable of results else: names = range(len(results)) if names is None else names results_dict = dict(zip(names, results, strict=False)) # convert keys to strings results_dict = {_convert_key_to_str(k): v for k, v in results_dict.items()} return results_dict def _convert_key_to_str(key: Any) -> str: if inspect.isclass(key) and issubclass(key, Algorithm): out = str(key.name) elif isinstance(key, Algorithm): out = str(key.name) else: out = str(key) return out
[docs]def params_plot( result, selector=None, max_evaluations=None, template=PLOTLY_TEMPLATE, show_exploration=False, ): """Plot the params history of an optimization. Args: result (Union[OptimizeResult, pathlib.Path, str]): An optimization results with collected history. If dict, then the key is used as the name in a legend. selector (callable): A callable that takes params and returns a subset of params. If provided, only the selected subset of params is plotted. max_evaluations (int): Clip the criterion history after that many entries. template (str): The template for the figure. Default is "plotly_white". show_exploration (bool): If True, exploration samples of a multistart optimization are visualized. Default is False. Returns: plotly.graph_objs._figure.Figure: The figure. """ # ================================================================================== # Process inputs # ================================================================================== if isinstance(result, OptimizeResult): data = _extract_plotting_data_from_results_object( result, stack_multistart=True, show_exploration=show_exploration, plot_name="params_plot", ) start_params = result.start_params elif isinstance(result, (str, Path)): data = _extract_plotting_data_from_database( result, stack_multistart=True, show_exploration=show_exploration, ) start_params = data["start_params"] else: raise TypeError("result must be an OptimizeResult or a path to a log file.") if data["stacked_local_histories"] is not None: history = data["stacked_local_histories"].params else: history = data["history"].params # ================================================================================== # Create figure # ================================================================================== fig = go.Figure() registry = get_registry(extended=True) hist_arr = np.array([tree_just_flatten(p, registry=registry) for p in history]).T names = leaf_names(start_params, registry=registry) if selector is not None: flat, treedef = tree_flatten(start_params, registry=registry) helper = tree_unflatten(treedef, list(range(len(flat))), registry=registry) selected = np.array(tree_just_flatten(selector(helper), registry=registry)) names = [names[i] for i in selected] hist_arr = hist_arr[selected] for name, data in zip(names, hist_arr, strict=False): if max_evaluations is not None and len(data) > max_evaluations: plot_data = data[:max_evaluations] else: plot_data = data trace = go.Scatter( x=np.arange(len(plot_data)), y=plot_data, mode="lines", name=name, ) fig.add_trace(trace) fig.update_layout( template=template, xaxis_title_text="No. of criterion evaluations", yaxis_title_text="Parameter value", legend={"yanchor": "top", "xanchor": "right", "y": 0.95, "x": 0.95}, ) return fig
def _extract_plotting_data_from_results_object( res, stack_multistart, show_exploration, plot_name ): """Extract data for plotting from results object. Args: res (OptmizeResult): An optimization results object. stack_multistart (bool): Whether to combine multistart histories into a single history. Default is False. show_exploration (bool): If True, exploration samples of a multistart optimization are visualized. Default is False. plot_name (str): Name of the plotting function that calls this function. Used for rasing errors. Returns: dict: - "history": The results history - "direction": maximize or minimize - "is_multistart": Whether the optimization used multistart - "local_histories": All other multistart histories except for 'history'. If not available is None. If show_exploration is True, the exploration phase is added as the first entry. - "stacked_local_histories": If stack_multistart is True the local histories are stacked into a single one. """ if res.history is None: msg = f"{plot_name} requires an optimize result with history. Enable history " "collection by setting collect_history=True when calling maximize or minimize." raise ValueError(msg) is_multistart = res.multistart_info is not None if is_multistart: local_histories = [opt.history for opt in res.multistart_info.local_optima] else: local_histories = None if stack_multistart and local_histories is not None: stacked = _get_stacked_local_histories(local_histories, res.direction) if show_exploration: fun = res.multistart_info.exploration_results.tolist()[::-1] + stacked.fun params = res.multistart_info.exploration_sample[::-1] + stacked.params stacked = History( direction=stacked.direction, fun=fun, params=params, # TODO: This needs to be fixed start_time=len(fun) * [None], stop_time=len(fun) * [None], batches=len(fun) * [None], task=len(fun) * [None], ) else: stacked = None data = { "history": res.history, "direction": Direction(res.direction), "is_multistart": is_multistart, "local_histories": local_histories, "stacked_local_histories": stacked, } return data def _extract_plotting_data_from_database(res, stack_multistart, show_exploration): """Extract data for plotting from database. Args: res (str or pathlib.Path): A path to an optimization database. stack_multistart (bool): Whether to combine multistart histories into a single history. Default is False. show_exploration (bool): If True, exploration samples of a multistart optimization are visualized. Default is False. Returns: dict: - "history": The results history - "direction": maximize or minimize - "is_multistart": Whether the optimization used multistart - "local_histories": All other multistart histories except for 'history'. If not available is None. If show_exploration is True, the exploration phase is added as the first entry. - "stacked_local_histories": If stack_multistart is True the local histories are stacked into a single one. """ reader = LogReader.from_options(SQLiteLogOptions(res)) _problem_table = reader.problem_df direction = _problem_table["direction"].tolist()[-1] _history, local_histories, exploration = reader.read_multistart_history(direction) if stack_multistart and local_histories is not None: stacked = _get_stacked_local_histories(local_histories, direction, _history) if show_exploration: stacked["params"] = exploration["params"][::-1] + stacked["params"] stacked["criterion"] = exploration["criterion"][::-1] + stacked["criterion"] else: stacked = None history = History( direction=direction, fun=_history["fun"], params=_history["params"], start_time=_history["time"], # TODO (@janosg): Retrieve `stop_time` from `hist` once it is available. # https://github.com/optimagic-dev/optimagic/pull/553 stop_time=len(_history["fun"]) * [None], task=len(_history["fun"]) * [None], batches=list(range(len(_history["fun"]))), ) data = { "history": history, "direction": direction, "is_multistart": local_histories is not None, "local_histories": local_histories, "stacked_local_histories": stacked, "start_params": reader.read_start_params(), } return data def _get_stacked_local_histories(local_histories, direction, history=None): """Stack local histories. Local histories is a list of dictionaries, each of the same structure. We transform this to a dictionary of lists. Finally, when the data is read from the database we append the best history at the end. """ stacked = {"criterion": [], "params": [], "runtime": []} for hist in local_histories: stacked["criterion"].extend(hist.fun) stacked["params"].extend(hist.params) stacked["runtime"].extend(hist.time) # append additional history is necessary if history is not None: stacked["criterion"].extend(history.fun) stacked["params"].extend(history.params) stacked["runtime"].extend(history.time) return History( direction=direction, fun=stacked["criterion"], params=stacked["params"], start_time=stacked["runtime"], # TODO (@janosg): Retrieve `stop_time` from `hist` once it is available for the # IterationHistory. # https://github.com/optimagic-dev/optimagic/pull/553 stop_time=len(stacked["criterion"]) * [None], task=len(stacked["criterion"]) * [None], batches=list(range(len(stacked["criterion"]))), )