подгонка многомерного curve_fit в python

Я пытаюсь подогнать простую функцию к двум массивам независимых данных в python. Я понимаю, что мне нужно сгруппировать данные для моих независимых переменных в один массив, но что-то все еще кажется неправильным с тем, как я передаю переменные, когда пытаюсь выполнить подгонку. (Есть пара предыдущих сообщений, связанных с этим, но они не очень помогли.)

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def fitFunc(x_3d, a, b, c, d):
    return a + b*x_3d[0,:] + c*x_3d[1,:] + d*x_3d[0,:]*x_3d[1,:]

x_3d = np.array([[1,2,3],[4,5,6]])

p0 = [5.11, 3.9, 5.3, 2]

fitParams, fitCovariances = curve_fit(fitFunc, x_3d[:2,:], x_3d[2,:], p0)
print ' fit coefficients:\n', fitParams

Ошибка, которую я получаю, читает,

raise TypeError('Improper input: N=%s must not exceed M=%s' % (n, m)) 
TypeError: Improper input: N=4 must not exceed M=3

Какова длина M? Является ли N длиной p0? Что я здесь делаю неправильно?


person user3133865    schedule 25.12.2013    source источник


Ответы (2)


N и M определены в справке для функции. N — количество точек данных, а M — количество параметров. Таким образом, ваша ошибка в основном означает, что вам нужно как минимум столько точек данных, сколько у вас есть параметров, что имеет смысл.

Этот код работает для меня:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def fitFunc(x, a, b, c, d):
    return a + b*x[0] + c*x[1] + d*x[0]*x[1]

x_3d = np.array([[1,2,3,4,6],[4,5,6,7,8]])

p0 = [5.11, 3.9, 5.3, 2]

fitParams, fitCovariances = curve_fit(fitFunc, x_3d, x_3d[1,:], p0)
print ' fit coefficients:\n', fitParams

Я включил больше данных. Я также изменил fitFunc, чтобы он был записан в форме, которая сканируется только как функция одного x - установщик будет обращаться с вызовом этого для всех точек данных. Код, который вы опубликовали, также ссылался на x_3d[2,:], что вызывало ошибку.

person chthonicdaemon    schedule 25.12.2013
comment
Большое спасибо! (Я думаю, что M — это количество точек данных, а N — это количество параметров.) - person user3133865; 27.12.2013
comment
В справке четко указано ydata : последовательность N-длины и p0 : None, скалярная или M-длинная последовательность, поэтому N — это количество точек данных, а M — это количество параметров. Однако похоже, что сообщение об ошибке имеет их задом наперед :-). Если вы считаете, что этот ответ был полезен, рассмотрите возможность его принятия. - person chthonicdaemon; 27.12.2013
comment
@VolodimirKopey Я действительно не понимаю, что этот ответ добавляет к этому - они кажутся мне очень похожими. - person chthonicdaemon; 24.11.2014
comment
@chthonicdaemon кажется, что M и N поменялись местами с тех пор, как вы в последний раз оставили комментарий. Теперь ydata имеет длину M, а p0 может быть последовательностью длины N. - person NeutronStar; 22.12.2014
comment
@Joshua Ага, я отправил отчет об ошибке и исправление. Я полагаю, я должен был обновить этот ответ. - person chthonicdaemon; 23.12.2014

Метод по умолчанию curve_fit требует, чтобы у вас было меньше параметров для подобранной функции fitFunc, чем точек данных. У меня была такая же проблема с подбором функции, которая принимала всего 15 параметров, и у меня было только 13 точек данных. Решение состоит в том, чтобы использовать другой метод (например, dogbox или trf).

person Chris Young    schedule 28.04.2020