Ошибка при прогнозировании с помощью python onnxruntime

Я создал очень простое дерево решений, используя библиотеку sklearn. Это дерево обучается на основе 4 функций:

feat1 INT
feat2 INT
feat3 FLOAT
feat4 FLOAT

А метка / целевая функция - это логическое значение (0 или 1).

Я преобразовал дерево в формат ONNX и теперь хочу использовать библиотеку onnxruntime python, чтобы сделать прогноз. Я нашел в Интернете пример кода для этого. Проблема в том, что я не понимаю, что именно происходит во всех частях этого кода, функциях и параметрах. Это приводит к тому, что я получаю сообщение об ошибке. Я искал документацию, но не могу ее найти.

В приведенном ниже коде я конвертирую модель дерева в формат ONNX. Это успешно, но части кода я не понимаю. Что мне нужно ввести в переменную initial_type на основе 4 столбцов функций и метки / целевой функции, которую я использовал ранее? Теперь я ввел FloatTensorType([None, 4], потому что у меня есть 4 столбца с характеристиками, и что это за None, я понятия не имею.

##Convert to ONNX format

initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(treeModel, initial_types=initial_type)
with open("path", "wb") as f:
    f.write(onx.SerializeToString())

В приведенном ниже коде я хочу сделать прогноз, используя библиотеку onnxruntime, но получаю эту ошибку:

RuntimeError: Either type_proto was null or it was not of sequence type

Это потому, что я не понимаю последнюю строку кода ниже. Я ввел это {input_name: [4, 8, 77.8, 143.45], потому что это четыре значения для столбцов функций. Что я здесь делаю не так?

sess = rt.InferenceSession("pathToONNXModel")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: [4, 8, 77.8, 143.45]})[0]

person user7432713    schedule 26.11.2019    source источник


Ответы (1)


Вы пробовали {input_name: numpy.array([4, 8, 77.8, 143.45], dtype=numpy.float32)}? onnxruntime требует в качестве входных данных несколько массивов.

person xadupre    schedule 02.12.2019