Comments (3)
Not entirely surprising. Floating point arithmetic can differ between CPUs. You have a degenerate model with a number of symmetries, so the local optimizer's path will end up having a number of branching points in the loss function that differing floating point calculations could send it down, creating large differences in the final result despite the small differences in any given floating point calculation.
You could try to help this by reparameterizing your model and applying constraints so that you don't have as many symmetries. For example, consider representing the means as mu1
(same as before), dmu2 = mu2 - mu1
and dmu3 = mu3 - mu2
. Then optimize the parameter set mu1, dmu2, dmu3, ... etc.
with dmu2
and dmu3
bounded from below by 0. That will break some of the symmetries, though the loss surface may still have a number of those branching points remaining.
I don't think there's much we can do about it on our end. CPU differences in floating point arithmetic are a thing, and we don't have much control over it.
from scipy.
This strikes me as a problem which curve_fit()
is going to have a hard time solving. Fitting gaussians is a problem which has a lot of local optima. Any local optimizer is going to struggle. Since you already have bounds on reasonable values, why not use a global optimizer, such as dual_annealing()
?
Example:
'''
Example trimodal fitting code to be reported to SciPy
Code and sample data based on works from Jun
'''
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit, minimize, Bounds, dual_annealing
def map_param(param_name_list, param_array):
return dict(zip(param_name_list, param_array))
# Define trimodal function
def gauss(array, mu, sigma, amplitude):
return amplitude * np.exp(-(array - mu)**2 / (2 * sigma**2))
def trimodal(array, mu1, sigma1, amplitude1,
mu2, sigma2, amplitude2,
mu3, sigma3, amplitude3):
return gauss(array, mu1, sigma1, amplitude1) + \
gauss(array, mu2, sigma2, amplitude2) + \
gauss(array, mu3, sigma3, amplitude3)
# Define the sample data, which is a histogram
bin_str = """
-35. -34.9 -34.8 -34.7 -34.6 -34.5 -34.4 -34.3 -34.2 -34.1 -34. -33.9
-33.8 -33.7 -33.6 -33.5 -33.4 -33.3 -33.2 -33.1 -33. -32.9 -32.8 -32.7
-32.6 -32.5 -32.4 -32.3 -32.2 -32.1 -32. -31.9 -31.8 -31.7 -31.6 -31.5
-31.4 -31.3 -31.2 -31.1 -31. -30.9 -30.8 -30.7 -30.6 -30.5 -30.4 -30.3
-30.2 -30.1 -30. -29.9 -29.8 -29.7 -29.6 -29.5 -29.4 -29.3 -29.2 -29.1
-29. -28.9 -28.8 -28.7 -28.6 -28.5 -28.4 -28.3 -28.2 -28.1 -28. -27.9
-27.8 -27.7 -27.6 -27.5 -27.4 -27.3 -27.2 -27.1 -27. -26.9 -26.8 -26.7
-26.6 -26.5 -26.4 -26.3 -26.2 -26.1 -26. -25.9 -25.8 -25.7 -25.6 -25.5
-25.4 -25.3 -25.2 -25.1 -25. -24.9 -24.8 -24.7 -24.6 -24.5 -24.4 -24.3
-24.2 -24.1 -24. -23.9 -23.8 -23.7 -23.6 -23.5 -23.4 -23.3 -23.2 -23.1
-23. -22.9 -22.8 -22.7 -22.6 -22.5 -22.4 -22.3 -22.2 -22.1 -22. -21.9
-21.8 -21.7 -21.6 -21.5 -21.4 -21.3 -21.2 -21.1 -21. -20.9 -20.8 -20.7
-20.6 -20.5 -20.4 -20.3 -20.2 -20.1 -20. -19.9 -19.8 -19.7 -19.6 -19.5
-19.4 -19.3 -19.2 -19.1 -19. -18.9 -18.8 -18.7 -18.6 -18.5 -18.4 -18.3
-18.2 -18.1 -18. -17.9 -17.8 -17.7 -17.6 -17.5 -17.4 -17.3 -17.2 -17.1
-17. -16.9 -16.8 -16.7 -16.6 -16.5 -16.4 -16.3 -16.2 -16.1 -16. -15.9
-15.8 -15.7 -15.6 -15.5 -15.4 -15.3 -15.2 -15.1 -15. -14.9 -14.8 -14.7
-14.6 -14.5 -14.4 -14.3 -14.2 -14.1 -14. -13.9 -13.8 -13.7 -13.6 -13.5
-13.4 -13.3 -13.2 -13.1 -13. -12.9 -12.8 -12.7 -12.6 -12.5 -12.4 -12.3
-12.2 -12.1 -12. -11.9 -11.8 -11.7 -11.6 -11.5 -11.4 -11.3 -11.2 -11.1
-11. -10.9 -10.8 -10.7 -10.6 -10.5 -10.4 -10.3 -10.2 -10.1 -10. -9.9
-9.8 -9.7 -9.6 -9.5 -9.4 -9.3 -9.2 -9.1 -9. -8.9 -8.8 -8.7
-8.6 -8.5 -8.4 -8.3 -8.2 -8.1 -8. -7.9 -7.8 -7.7 -7.6 -7.5
-7.4 -7.3 -7.2 -7.1 -7. -6.9 -6.8 -6.7 -6.6 -6.5 -6.4 -6.3
-6.2 -6.1 -6. -5.9 -5.8 -5.7 -5.6 -5.5 -5.4 -5.3 -5.2 -5.1
-5. -4.9 -4.8 -4.7 -4.6 -4.5 -4.4 -4.3 -4.2 -4.1 -4. -3.9
-3.8 -3.7 -3.6 -3.5 -3.4 -3.3 -3.2 -3.1 -3. -2.9 -2.8 -2.7
-2.6 -2.5 -2.4 -2.3 -2.2 -2.1 -2. -1.9 -1.8 -1.7 -1.6 -1.5
-1.4 -1.3 -1.2 -1.1 -1. -0.9 -0.8 -0.7 -0.6 -0.5 -0.4 -0.3
-0.2 -0.1 0. 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9
1. 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2. 2.1
2.2 2.3 2.4 2.5 2.6 2.7 2.8 2.9 3. 3.1 3.2 3.3
3.4 3.5 3.6 3.7 3.8 3.9 4. 4.1 4.2 4.3 4.4 4.5
4.6 4.7 4.8 4.9 5. 5.1 5.2 5.3 5.4 5.5 5.6 5.7
5.8 5.9 6. 6.1 6.2 6.3 6.4 6.5 6.6 6.7 6.8 6.9
7. 7.1 7.2 7.3 7.4 7.5 7.6 7.7 7.8 7.9 8. 8.1
8.2 8.3 8.4 8.5 8.6 8.7 8.8 8.9 9. 9.1 9.2 9.3
9.4 9.5 9.6 9.7 9.8 9.9
"""
count_str = """
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0.004 0.004 0. 0.
0. 0.004 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.004
0. 0. 0. 0. 0. 0. 0.008 0.004 0.004 0. 0.004 0.008
0.008 0.004 0. 0. 0.004 0.008 0.012 0.004 0. 0.016 0.016 0.016
0.024 0.024 0.02 0.036 0.028 0.024 0.068 0.084 0.08 0.096 0.124 0.156
0.176 0.136 0.116 0.112 0.1 0.088 0.056 0.052 0.048 0.032 0.024 0.056
0.036 0.044 0.008 0.04 0.028 0.02 0.028 0.016 0.012 0.016 0.016 0.016
0.024 0.012 0.016 0.016 0.02 0.016 0.012 0.02 0.02 0.028 0.012 0.012
0.02 0.004 0.004 0.016 0.012 0.012 0.012 0.012 0.032 0.024 0.004 0.012
0.004 0.02 0.012 0.024 0.012 0.024 0.016 0.016 0.012 0.016 0.016 0.024
0.02 0.028 0.016 0.024 0.04 0.028 0.032 0.024 0.044 0.02 0.02 0.036
0.048 0.024 0.02 0.028 0.032 0.032 0.02 0.016 0.044 0.02 0.036 0.024
0.02 0.004 0.016 0.036 0.036 0.032 0.012 0.016 0.012 0.008 0.024 0.036
0.06 0.052 0.116 0.116 0.072 0.088 0.072 0.116 0.092 0.092 0.072 0.06
0.064 0.048 0.052 0.036 0.044 0.04 0.072 0.064 0.052 0.072 0.044 0.076
0.092 0.092 0.088 0.108 0.1 0.108 0.104 0.072 0.124 0.112 0.128 0.16
0.144 0.196 0.16 0.22 0.18 0.22 0.204 0.22 0.156 0.188 0.184 0.128
0.12 0.152 0.132 0.124 0.076 0.068 0.036 0.032 0.028 0.06 0.06 0.032
0.028 0.012 0.008 0.016 0. 0. 0.008 0.012 0.004 0.008 0.012 0.
0. 0.004 0.004 0.004 0.004 0.008 0.008 0. 0.012 0.004 0. 0.004
0.004 0.004 0. 0.008 0. 0.008 0.004 0. 0. 0. 0.004 0.
0. 0. 0. 0.012 0. 0. 0. 0. 0. 0. 0. 0.
0. 0.004 0. 0.004 0. 0.004 0.004 0. 0. 0. 0.004 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.004 0.
0. 0.004 0. 0. 0. 0. 0.008 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.004
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.004 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.004 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
"""
count = np.array([float(number) for number in count_str.split()],
dtype=np.float64)
bins = np.array([float(number) for number in bin_str.split()],
dtype=np.float64)
param_names = ['mu1', 'sigma1', 'amplitude1',
'mu2', 'sigma2', 'amplitude2',
'mu3', 'sigma3', 'amplitude3']
#initial value of the parameters
initvel = [-18.296757, 0.5, 0.1759999999999975,
-11.750452, 0.5, 0.2200000000000047,
-15.023604393005371, 0.5, 0.1]
# boundary to constrain the parameters
# NJO: Note: changed bounds to forbid zero sigma
bounds = Bounds([-35, 0.01, 0.01,
-35, 0.01, 0.01,
-35, 0.01, 0.01],
[5, 5, 0.95,
5, 5, 0.95,
5, 5, 0.95])
def loss(params, x, y_true):
y_pred = trimodal(x, **map_param(param_names, params))
return np.sqrt(np.mean((y_pred - y_true)**2))
result_da = dual_annealing(loss, x0=initvel, bounds=bounds, args=(bins, count))
result = None
print('SciPy dual_annealing result:')
for i_param, param_name in enumerate(param_names):
print(f'{param_name}: {result_da.x[i_param]:0.8f}', end='\t')
if i_param % 3 == 2:
print()
params, cov, infodict, mesg, ier = curve_fit(trimodal,
bins,
count,
initvel,
full_output=True,
method='trf',
bounds=bounds)
print('SciPy curve_fit result:')
for i_param, param_name in enumerate(param_names):
print(f'{param_name}: {params[i_param]:0.8f}', end='\t')
if i_param % 3 == 2:
print()
# Prepare to plot the fitting result and other reference data
y_fit_cf = trimodal(bins, **map_param(param_names, params))
y_fit_da = trimodal(bins, **map_param(param_names, result_da.x))
y_fit_init = trimodal(bins, **map_param(param_names, initvel))
plt.plot(bins, count, label='Histogram')
plt.plot(bins, y_fit_init, 'k--', alpha=0.5, label='Initial parameters')
plt.plot(bins, y_fit_cf, label='minimize')
plt.plot(bins, y_fit_da, label='dual_annealing')
plt.legend()
plt.grid()
plt.show()
Result:
This looks better to me than either of the results that curve_fit()
was giving - it can recognize the peak at -22, which neither curve_fit()
result found.
from scipy.
@rkern @nickodell Thank you for the comments and opinions. Maybe I have to admit that would not be possible to completely avoid the small discrepancy coming from CPU architecture.
I've had a chance to test dual_annealing
based on the sample code @nickodell has posted here. The same code has run on AMD/Linux, Intel/Linux, and M1/MacOS, and all of the three results are consistent.
I suspected numerical error when numerically computing the gradient, so I've come up with providing analytical Jacobian function into curve_fit
. Below is the fitting result.
from scipy.
Related Issues (20)
- BUG: incorrect origin tuple handling in ndimage `minimum_filter` and `maximum_filter` with footprint covering a subset of axes
- BUG: stats.weibull_min: poor fit results on shape parameters HOT 2
- BUG: spatial: `Rotation.align_vectors()` incorrect for anti-parallel vectors HOT 1
- MAINT, TST: failure in test_axis_nan_policy_decorated_positional_args with array API HOT 2
- DOC: Missing blankspace in error message raised by cont2discrete()
- BLD: cross build scipy 1.13.0 with numpy 2.0.0rc1 using pip install failed HOT 6
- Tracker: Support Python nogil builds
- BUG: sparse matrix creation in 1.13 with indices not summing entries any more HOT 2
- DOC: A typo in authors name in signal.ellipap reference
- ENH: signal: add array API support
- Cannot import scipy.linalg.triu HOT 8
- BUG: signal.correlate: many numerical close-to-zero errors on 3D array HOT 1
- RFC: Switch to `cython_<lapack,blas>` based wrappers and deprecate custom `scipy.linalg.<lapack,blas>` HOT 2
- DOC: A typo in ValueError raised by signal.iirdesign HOT 2
- ENH: spatial: Matrix multiplication of Rotation HOT 3
- ENH: Reintroduce Apple Accelerate support HOT 2
- BUG: stats.zipf: incorrect pmf values HOT 4
- BUG: stats.noncentral_t: incorrect pdf values HOT 6
- BUG: special: algorithmic Error in `ratevl` in `cephes/polevl.h` HOT 13
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from scipy.