Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster curve analysis #1192

Closed
nkanazawa1989 opened this issue May 31, 2023 · 4 comments
Closed

Faster curve analysis #1192

nkanazawa1989 opened this issue May 31, 2023 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@nkanazawa1989
Copy link
Collaborator

nkanazawa1989 commented May 31, 2023

Suggested feature

Background

CurveAnalysis is the one of core fitter base classes in Qiskit Experiments. This is currently used for every calibration experiment. We also support ParallelExperiment wrapper that allows us to combine multiple (no-qubit overlapping) experiment instances in the same run, and the calibration experiments can be seamlessly expanded to the device scale on top of this framework. In principle, we can calibrate production QC systems with 1000Q+ with current framework, however, the performance bottleneck must be carefully identified and resolved at this scale.

Current status

Let's go deep inside analysis with a simple example of T1 experiment. Note that the most time consuming operation in the curve analysis is the figure generation. We don't plan to build own (possibly) faster plotter. Indeed, in the era of 1000Q+ device, no hardware engineer will visually investigate plots from every experiment instance. So let's disable the plotter now.

Following (very naive) benchmark code is written in above assumption.

%load_ext snakeviz

import numpy as np
from qiskit_experiments.library import T1
from qiskit.providers.fake_provider import FakePerth
from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel

noise_model = NoiseModel.from_backend(
    FakePerth(), thermal_relaxation=True, gate_error=False, readout_error=False
)
backend = AerSimulator.from_backend(FakePerth(), noise_model=noise_model)

exp = T1(physical_qubits=(0,), delays=np.linspace(0, 300e-6, 30), backend=backend)
exp.analysis.set_options(plot=False)  # disable the plotter
exp_data = exp.run(analysis=None).block_for_results()

%snakeviz exp.analysis._run_analysis(exp_data)

Because BaseAnalysis.run invokes another thread, I just tested BaseAnalysis._run_analysis which is the core function of the analysis so that I can benchmark it with jupyter notebook (we should use another profiler to analyze entire performance including the initialization cost for the experiment data container).

image

As you can see, more than 50% of time is consumed by the scipy least_square solver that minimizes the residual for model parameter search. By default, this solver numerically computes the Jacobian matrix for gradient information.

Proposal

There would be two approaches to speedup this fitting operation. One could implement this minimization solver in (potentially faster) compiled languages such as Rust or Julia. For example, fitting library is written in Rust. However, according to the scipy documentation

With dense Jacobians trust-region subproblems are solved by an exact method very similar to the one described in [JJMore] (and implemented in MINPACK)

Because MINPACK is the library of FORTRAN, I'm not sure how much performance gain we can obtain in return for giving up using scipy (indeed scipy offers rich solver selection).

Apart from this venue, one could focus on the bottleneck of the Jacobian matrix computation. Because we already know the exact fit model, this approach sounds more natural idea. Fortunately, we don't need to write any callback to compute the Jacobian matrix, instead we can just rely on JAX. It offers a scipy minimize wrapper jax.scipy.optimize.minimize (seems like still under development though).

JAX implements python interface and math library which is comparable with numpy. Once we implement the residual function with their math library, we can even JIT compile the minimization routine. As a side-effect, this approach may make JAX required in our software (which might not be applicable to all users), but recently Qiskit released arraylias package that provides alias of JAX/NumPy. With the arraylias package, we can introduce JAX as an optional dependency and the CurveAnalysis solver may fall back to the NumPy residual function when JAX is not available. I believe approach is much more promising.

Test

In this test code, I use a damped sinusoidal function as a fit model. Fit data may have some noise around Y-axis. Because I'm very new to JAX, the code below is just a patchwork of some example codes.

from scipy.optimize import least_squares
import numpy as np
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax import jit
from jax.scipy.optimize import minimize

true_params = [0.35, 2.43, 0.12]
init_guess = [0., 2., 0.]
noise = 0.1
seed = 123

def function(x, a, b, c):
    return np.exp(-a * x) * np.sin(2 * np.pi * b * x) + c

def data_factory(a, b, c):
    rng = np.random.default_rng(seed=seed)
    
    x = np.linspace(0, 1, 30)
    y = function(x, a, b, c) + rng.normal(scale=noise, size=30)
    
    return x, y

This is the conventional numpy residual function and scipy least_square solver.

x, y = data_factory(*true_params)

def fobj(params):
    return y - function(x, *params)

res = least_squares(fobj, x0=init_guess)

%timeit least_squares(fobj, x0=init_guess)

3.08 ms ± 167 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Then, try JAX residual function and JAX scipy minimize solver with JIT compile.

x, y = data_factory(*true_params)

x_jax = jnp.array(x, dtype=jnp.float32)
y_jax = jnp.array(y, dtype=jnp.float32)
init_guess_jax = jnp.array(init_guess, dtype=jnp.float32)

def residual(params):
    return jnp.sum(jnp.square(y_jax - jnp.exp(-params[0] * x) * jnp.sin(2 * np.pi * params[1] * x) + params[2]))

@jit
def min_op(x0):
    result = minimize(
        residual,
        x0,
        method="BFGS",
        options=dict(gtol=1e-6),
    )
    return result

# Warm up
jax_res = min_op(init_guess_jax)

%timeit min_op(init_guess_jax)

47.5 µs ± 1.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Finally compare the results.

plt.scatter(x, y)

print(res.x, jax_res.x)
plt.plot(x, function(x, *res.x))
plt.plot(x, function(x, *jax_res.x))

image

As you can see the JAX result is bit off from the actual data points. I guess the difference is due to the return value of the JAX residual function. It errored when I returned vector value, while scipy solver could properly handle it. I believe there is some trick to fix this. On the other hand, it's worth calling attention to the performance improvement; JAX solver only consumed 1.5% of the scipy solver time to return fit parameters!

Note

Because the JIT compile is invoked when the wrapped function is called for the first time, ideally the residual function must be singleton. Otherwise JIT compile is run in every fitter instance and doesn't give significant performance improvement in the parallel environment. In the above example, it was most efficient to JIT compile the minimize function. When I only JIT compile the residual function, I didn't obtain any performance gain -- need more investigation.

I hope the future fit model will look like

class FitModel:
    
    def parameters(cls) -> List[str]:  # name of fit parameters

    def initial_guesses(cls, x, y) -> Iterator:  #  return iterator for initial guesses
    
    def residual(cls, x, *args) -> DeviceArray:   # maybe JIT compiled, returns some cost function value

    def func(cls, x, *args) -> DeviceArray:  # compute fit y value, e.g. for visualization

    def repr(cls) -> str:  # human readable representation of the model for fitter metadata

The CurveAnalysis can just consume this model object which provides an efficient cost function along with the initial guess generator. This allows us to completely decouple the fit model and fit protocol, and reduce the maintenance overhead of curve analysis subclasses.

@nkanazawa1989 nkanazawa1989 added the enhancement New feature or request label May 31, 2023
@to24toro
Copy link

to24toro commented Jun 7, 2023

I looked into this in some detail.

By modifying the following function, we can get the same result as scipy.minimize.

def residual(params):
    return jnp.sum(jnp.square(y_jax - jnp.exp(-params[0] * x) * jnp.sin(2 * np.pi * params[1] * x) - params[2]

output

And, jaxopt a software library of the optimization for jax is provided.

We can use as below

import jaxopt

# scipy.optimize.minimize wrapper
%timeit jaxopt.ScipyMinimize(fun=residual).run(init_guess_jax)

34.2 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# use .params instead of .x
plt.plot(x, function(x, *jaxopt_res.params))

@nkanazawa1989
Copy link
Collaborator Author

Cool. Thanks @to24toro for followup. Could you please compare the performance of jaxopt and scipy optimization (JAX cost function and Numpy cost function) on your platform? I think jax functions need warmup to ignore the compile overhead for the first run.

@to24toro
Copy link

to24toro commented Jun 8, 2023

You are correct.
I have already trying it by using jit-compile. But the error occurs when jitting jaxopt.ScipyMinimize. We cannot jit as following

@jit
def jax_scipy(x):
    return jaxopt.ScipyMinimize(fun=residual).run(x)

jax_res = jax_scipy(init_guess_jax)

Meanwhile, we can jit jaxopt.BFGS as an example. Of course, I am not sure we can simply compare with ScipyMinimize.
Jaxopt.ScipyMinimize might not be supported by jax-jit ( There is a discussion about jaxopt.ScipyBoundedMinimize: google/jaxopt#214 (comment)). If I took a look at the code in more detail, I could probably find something.

Just in case, I show measured time for some approach.

res = least_squares(fobj, x0=init_guess)
%timeit least_squares(fobj, x0=init_guess)

1.06 ms ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%timeit -n 1000 jaxopt.ScipyMinimize(fun=residual).run(init_guess_jax)

36.3 ms ± 131 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@jit
def jax_bfgs(x):
    return jaxopt.BFGS(fun=residual).run(x)
# Warm up
jax_res = jax_bfgs(init_guess_jax)

%timeit -n  jax_bfgs(init_guess_jax)

60.7 µs ± 193 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

@nkanazawa1989
Copy link
Collaborator Author

Closed and merged into #1268

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants