Как определить переменную динамической формы при построении вычислительного графа с помощью Tensorflow 1.15

Информация о системе

  1. Написал ли я собственный код (в отличие от использования стандартного примера сценария, предоставленного в TensorFlow): Нет

  2. Платформа ОС и дистрибутив (например, Linux Ubuntu 16.04): Linux Ubuntu 18.04

  3. Мобильное устройство (например, iPhone 8, Pixel 2, Samsung Galaxy), если проблема возникает на мобильном устройстве:

  4. TensorFlow установлен из (исходного или бинарного): репозиторий Conda

  5. Версия TensorFlow (используйте команду ниже): 1.15

  6. Версия Python: 3.7.7

  7. Версия Bazel (при компиляции из исходников):

  8. Версия GCC/компилятора (при компиляции из исходников):

  9. Версия CUDA/cuDNN: 10.1

  10. Модель графического процессора и память: Tesla V100-SMX3-32GB

  11. Опишите текущее поведение

    tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign требует совпадения форм обоих тензоров. lhs shape= [] rhs shape= [1,1] [[{{переменная узла/назначение}}]]

Опишите ожидаемое поведение

Нет ошибок

Отдельный код для воспроизведения проблемы

import tensorflow as tf
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'

with tf.Session() as sess:
    v = tf.Variable(np.zeros(shape=[1,1]),shape=tf.TensorShape(None))
    sess.run(tf.global_variables_initializer())

Наблюдение: ошибка не появлялась, когда я использую нетерпеливый_execution_mode()

Код:

tf.enable_eager_execution()
v = tf.Variable(np.zeros([1,1]),shape=tf.TensorShape(None))
tf.print(v)
v.assign(np.ones([2,2]))
tf.print(v)    

Выход:

[[0]]
[[1 1]
 [1 1]]

Ссылка на MWE: https://colab.research.google.com/gist/amahendrakar/3fe8345db4092d520246205be4b97948/41620.ipynb


person lengoanhcat    schedule 24.07.2020    source источник


Ответы (1)


Просто нужно включить переменную ресурса, так как поведение динамической формы доступно только для этого «обновленного» класса Variable.

import tensorflow as tf
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
tf.compat.v1.enable_resource_variables()

with tf.Session() as sess:
    v = tf.Variable(np.zeros(shape=[1,1]),shape=tf.TensorShape(None))
    sess.run(tf.global_variables_initializer())
person lengoanhcat    schedule 24.07.2020