TensorFlow Обслуживание регрессионного метода REST API

У меня проблема с "регрессивным" API на обслуживающем сервере TensorFlow. Пожалуйста, просмотрите ссылку ниже, чтобы читать было удобнее. https://gist.github.com/krikit/32c918cc03b52315ade562267a91fa6b

Я сделал простую модель keras, которая имела два входа (x1, x2) и показывала одно выходное значение (y). С этой моделью я получал результаты ошибок от обслуживающего сервера TensorFlow, когда я использовал «регрессивный» REST API.

# the model
inputs = {
    'x1': tf.keras.layers.Input(shape=(1, ), name='x1', dtype='float32'),
    'x2': tf.keras.layers.Input(shape=(1, ), name='x2', dtype='float32'),
}
concat = tf.keras.layers.Concatenate(name='concat')([inputs['x1'], inputs['x2']])
dense = tf.keras.layers.Dense(10, use_bias=True, activation='relu', name='dense')(concat)
outputs = tf.keras.layers.Dense(1, use_bias=True, activation='sigmoid', name='y')(dense)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='SGD', loss='binary_crossentropy')
model.summary()


# training
num_exam = 10000
model.fit({'x1': np.random.randn(num_exam), 'x2': np.random.rand(num_exam)}, np.random.randn(num_exam))


# save
input_infos = {name: tf.saved_model.build_tensor_info(tensor) for name, tensor in model.input.items()}
output_infos = {'y': tf.saved_model.build_tensor_info(model.outputs[0])}
signature = tf.saved_model.build_signature_def(
    inputs=input_infos,
    outputs=output_infos,
    method_name=tf.saved_model.signature_constants.REGRESS_METHOD_NAME
)
print(signature)

model_dir = './random_regression/1'
shutil.rmtree(model_dir, ignore_errors=True)
model_builder = tf.saved_model.builder.SavedModelBuilder(model_dir)
model_builder.add_meta_graph_and_variables(
    tf.keras.backend.get_session(),
    tags=[tf.saved_model.tag_constants.SERVING, ],
    signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
)
model_builder.save()

После того, как я сохранил модель, казалось, что все в порядке на выходе инструмента "saved_model_cli".

$ saved_model_cli show --dir ./random_regression/1 --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['x1'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: x1:0
    inputs['x2'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: x2:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['y'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: y/Sigmoid:0
  Method name is: tensorflow/serving/regress

После того, как я начал обслуживать сервер с помощью модели, я протестировал REST API методом «регресса». Но я получил ошибку, как показано ниже,

$ curl -X POST -H "Content-Type: application/json" http://localhost:8501/v1/models/random_regression/versions/1:regress -d '
{
  "examples": [
    {
      "x1": [0.1],
      "x2": [0.2]
    },
    {
      "x1": [0.1],
      "x2": [0.3]
    }
  ]
}'

Response:
{ "error": "Expected one input Tensor." }

Хотя я сделал сигнатуру регресса, API прогнозирования также был доступен.

$ curl -X POST -H "Content-Type: application/json" http://localhost:8501/v1/models/random_regression/versions/1:predict -d '
{
  "instances": [
    {
      "x1": [0.1],
      "x2": [0.2]
    },
    {
      "x1": [0.1],
      "x2": [0.3]
    }
  ]
}'

Response:
{
    "predictions": [[0.143165469], [0.124352224]
    ]
}

Причина, по которой я использую метод «регресса», заключается в том, что мне нужно поле «контекст», как показано ниже.

$ curl -X POST -H "Content-Type: application/json" http://localhost:8501/v1/models/random_regression/versions/1:regress -d '
{
  "context": {
    "x1": [0.1]
  },
  "examples": [
    {
      "x2": [0.2]
    },
    {
      "x2": [0.3]
    }
  ]
}'

Response:
{ "error": "Expected one input Tensor." }

Я очень извиняюсь за ДЛИННЫЙ ~~~ вопрос, но есть ли кто-нибудь, кто может мне помочь, пожалуйста?


person Jamie    schedule 03.06.2020    source источник


Ответы (1)


Просто чтобы убедиться, что я правильно понимаю проблему. Когда вы использовали прогнозируемый API, это нормально, он не дает никаких ошибок, но вам нужно использовать методы регрессии, так как вам нужно поле контекста?

Вы используете TF 2.0?

person the curious mind    schedule 05.06.2020
comment
Да, мне нужно использовать regress API для поля контекста. Приходится предсказывать сразу сотни примеров. У них почти такие же тяжелые функции, за исключением одной или двух разных функций. Я думаю, что это стоит тяжелой сети и времени сериализации / десериализации. В настоящее время я использую TF 1.14.0, но любая версия подойдет, если API регрессии работает. - person Jamie; 05.06.2020