Есть ли в TensorFlow встроенная функция потери расхождения KL?

У меня есть два тензора prob_a и prob_b с формой [None, 1000], и я хочу вычислить расхождение KL от prob_a до prob_b. Есть ли в TensorFlow для этого встроенная функция? Я пробовал использовать tf.contrib.distributions.kl(prob_a, prob_b), но он дает:

NotImplementedError: KL (dist_a || dist_b) не зарегистрирован для Tensor типа dist_a и Tensor типа dist_b

Если нет встроенной функции, что было бы хорошим решением?


person Transcendental    schedule 25.01.2017    source источник


Ответы (7)


Предполагая, что ваши входные тензоры prob_a и prob_b являются тензорами вероятностей, сумма которых равна 1 по последней оси, вы можете сделать это следующим образом:

def kl(x, y):
    X = tf.distributions.Categorical(probs=x)
    Y = tf.distributions.Categorical(probs=y)
    return tf.distributions.kl_divergence(X, Y)

result = kl(prob_a, prob_b)

Простой пример:

import numpy as np
import tensorflow as tf
a = np.array([[0.25, 0.1, 0.65], [0.8, 0.15, 0.05]])
b = np.array([[0.7, 0.2, 0.1], [0.15, 0.8, 0.05]])
sess = tf.Session()
print(kl(a, b).eval(session=sess))  # [0.88995184 1.08808468]

Вы получите тот же результат с

np.sum(a * np.log(a / b), axis=1) 

Однако эта реализация немного глючна (проверено в Tensorflow 1.8.0).

Если у вас нулевая вероятность в a, например если вы попробуете [0.8, 0.2, 0.0] вместо [0.8, 0.15, 0.05], вы получите nan, хотя по определению Кульбака-Лейблера 0 * log(0 / b) должен давать нулевой вклад.

Чтобы смягчить это, нужно добавить небольшую числовую константу. Также разумно использовать tf.distributions.kl_divergence(X, Y, allow_nan_stats=False), чтобы вызвать ошибку времени выполнения в таких ситуациях.

Кроме того, если в b есть несколько нулей, вы получите inf значений, которые не будут улавливаться опцией allow_nan_stats=False, поэтому их тоже нужно обработать.

person meferne    schedule 25.06.2018
comment
Кажется, что ваши массивы a и b суммируются до 1 на последней оси, а не на первой - person Luca Di Liello; 06.08.2019
comment
Да, правильнее было бы сказать по оси 1, а еще лучше - по последней оси. Я имел в виду ось 1, когда писал по первой оси, так как есть еще ось 0. Отредактирую ответ. Спасибо! - person meferne; 07.08.2019

Поскольку есть softmax_cross_entropy_with_logits, нет необходимости оптимизировать на KL.

KL(prob_a, prob_b)  
  = Sum(prob_a * log(prob_a/prob_b))  
  = Sum(prob_a * log(prob_a) - prob_a * log(prob_b))  
  = - Sum(prob_a * log(prob_b)) + Sum(prob_a * log(prob_a)) 
  = - Sum(prob_a * log(prob_b)) + const 
  = H(prob_a, prob_b) + const 

Если prob_a не является константой. Вы можете переписать его на подгруппу двух энтропий.

KL(prob_a, prob_b)  
  = Sum(prob_a * log(prob_a/prob_b))  
  = Sum(prob_a * log(prob_a) - prob_a * log(prob_b))  
  = - Sum(prob_a * log(prob_b)) + Sum(prob_a * log(prob_a)) 
  = H(prob_a, prob_b) - H(prob_a, prob_a)  
person Jiecheng Zhao    schedule 29.06.2017
comment
Бывают случаи, когда целевая вероятность prob_a изменяется во время оптимизации. Тогда он становится непостоянным. - person CyberPlayerOne; 21.02.2019

Я не уверен, почему это не реализовано, но, возможно, есть обходной путь. Дивергенция KL определяется как:

KL(prob_a, prob_b) = Sum(prob_a * log(prob_a/prob_b))

С другой стороны, кросс-энтропия H определяется как:

H(prob_a, prob_b) = -Sum(prob_a * log(prob_b))

Итак, если вы создадите переменную y = prob_a/prob_b, вы можете получить расхождение KL, вызвав отрицательное значение H(proba_a, y). В нотации Tensorflow что-то вроде:

KL = tf.reduce_mean(-tf.nn.softmax_cross_entropy_with_logits(prob_a, y))

person E.J. White    schedule 26.01.2017
comment
Расхождение KL должно быть 0, когда prob_a = prob_b. Но последняя строка не дает 0. - person Transcendental; 26.01.2017
comment
Да. Когда prob_a = prob_b, мы получаем y = 1. Тогда H(prob_a, y) равно нулю от log(y). Вы хотите сказать, что проверили его с помощью Tensorflow softmax_cross_entropy_with_logits(prob_a, y), и результат не был нулевым? - person E.J. White; 27.01.2017
comment
Точно. Реализация TensorFlow может немного отличаться от фактической формулы. - person Transcendental; 27.01.2017
comment
Стоит отметить, что softmax_cross_entropy_with_logits (prob_a, y) на самом деле не реализует H (prob_a, y), он реализует H (softmax (a), y). Таким образом, использование softmax_cross_entropy_with_logits будет работать только в том случае, если вы попытаетесь вычислить расхождение KL при активации функции softmax (prob_a) и получите доступ к немасштабированным логитам (a) - person shapecatcher; 20.08.2019

tf.contrib.distributions.kl принимает экземпляры tf.distribution, а не Tensor.

Пример:

  ds = tf.contrib.distributions
  p = ds.Normal(loc=0., scale=1.)
  q = ds.Normal(loc=1., scale=2.)
  kl = ds.kl_divergence(p, q)
  # ==> 0.44314718
person jvdillon    schedule 18.07.2017

Предполагая, что у вас есть доступ к логитам a и b:

prob_a = tf.nn.softmax(a)
cr_aa = tf.nn.softmax_cross_entropy_with_logits(prob_a, a)
cr_ab = tf.nn.softmax_cross_entropy_with_logits(prob_a, b)
kl_ab = tf.reduce_sum(cr_ab - cr_aa)
person Sara    schedule 21.03.2018
comment
Не пойдет! Из документации: ВНИМАНИЕ! Этот параметр предполагает использование немасштабированных логитов, поскольку он выполняет softmax для внутренних логитов для повышения эффективности. Не вызывайте эту операцию с выводом softmax, так как она даст неверные результаты (выделено мной) - person mikkola; 22.03.2018
comment
Предполагая, что у вас есть доступ к журналам a и b. Это не вызывает его на prob_a и prob_b. Он вызывает это на a и b. - person Sara; 16.11.2018

Я думаю, это может сработать:

tf.reduce_sum(p * tf.log(p/q))

где p - мое фактическое распределение вероятностей, а q - мое приблизительное распределение вероятностей.

person Akshaya Natarajan    schedule 18.01.2019

Я использовал функцию из этого кода (из this Medium post) для расчета KL-дивергенции любой заданный тензор из нормального гауссовского распределения, где sd - стандартное отклонение, а mn - тензор.

latent_loss = -0.5 * tf.reduce_sum(1.0 + 2.0 * sd - tf.square(mn) - tf.exp(2.0 * sd), 1)
person generic_stackoverflow_user    schedule 16.02.2019