Which optimizer to use#
This is a very very very short and oversimplifying guide on selecting an optimization algorithm based on a minimum of information.
To select an optimizer, you need to answer two questions:
Is your criterion function differentiable?
Do you have a nonlinear least squares structure (i.e. do you sum some kind of squared residuals at the end of your criterion function)?
Define some inputs#
Again, we use versions of the sphere function to illustrate how to select these algorithms in practice
import numpy as np
import optimagic as om
@om.mark.least_squares
def sphere(params):
return params
def sphere_gradient(params):
return params * 2
start_params = np.arange(5)
Differentiable criterion function#
Use scipy_lbfsgsb
as optimizer and provide the closed form derivative if you can. If you do not provide a derivative, optimagic will calculate it numerically. However, this is less precise and slower.
res = om.minimize(
fun=sphere,
params=start_params,
algorithm="scipy_lbfgsb",
jac=sphere_gradient,
)
res.n_fun_evals
3
Note that this solves a 5 dimensional problem with just 3 criterion evaluations. For higher dimensions, you will need more, but it scales very well to dozens and hundreds of parameters.
If you are worried about being stuck in a local optimum, use multistart optimization.
Not differentiable, only scalar output#
Use nag_pybobyqa
. Note that for this you need to install the PyBOBYQA
package if you do not already have it:
pip install Py-BOBYQA
Then you select the algorithm as follows:
res = om.minimize(
fun=sphere,
params=start_params,
algorithm="nag_pybobyqa",
)
res.n_fun_evals
Not differentiable, least squares structure#
Use nag_dfols
. To use nag_dfols
, you need to install it via:
pip install DFO-LS
This optimizer will only work if your criterion function returns a dictionary that contains the entry root_contributions
. This needs to be a numpy array or pytree that contains the residuals of the least squares problem.
res = om.minimize(
fun=sphere,
params=start_params,
algorithm="nag_dfols",
)
res.n_fun_evals
9