Funkcja stratna dla klasy niezrównoważonego klasyfikatora binarnego w przepływie Tensorowym

Próbuję zastosować deep learning dla problemu klasyfikacji binarnej z wysoką nierównowagą klas między klasami docelowymi (500k, 31K). Chcę napisać niestandardową funkcję straty, która powinna być jak: Minimalizuj (100 - ((predicted_smallerclass)/(total_smallerclass))*100)

Doceniam wszelkie wskazówki, jak mogę zbudować tę logikę.

Author: Venkata Dikshit Pappu, 2016-02-02

6 answers

Można dodać wagi klas do funkcji loss, mnożąc logity. Regularna utrata entropii krzyżowej jest następująca:

loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
               = -x[class] + log(\sum_j exp(x[j]))

W przypadku ważonym:

loss(x, class) = weights[class] * -x[class] + log(\sum_j exp(weights[class] * x[j]))

Więc mnożąc logity, przeskalowujesz prognozy każdej klasy przez jej wagę klasową.

Na przykład:

ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([ratio, 1.0 - ratio])
logits = ... # shape [batch_size, 2]
weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2]
xent = tf.nn.softmax_cross_entropy_with_logits(
  weighted_logits, labels, name="xent_raw")

Istnieje teraz standardowa funkcja strat, która obsługuje wagi na partię:

tf.losses.sparse_softmax_cross_entropy(labels=label, logits=logits, weights=weights)

Gdzie wagi powinny być przekształcone z wagi klasowej na wagę na przykład (z kształtem [batch_size]). Zobacz dokumentację tutaj .

 30
Author: ilblackdragon,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2017-06-30 20:00:03

Zaproponowany przez Ciebie kod wydaje mi się zły. Strata powinna być pomnożona przez wagę, zgadzam się.

Ale jeśli pomnożysz logit przez wagę klasy, kończysz na:

weights[class] * -x[class] + log( \sum_j exp(x[j] * weights[class]) )

Drugi termin nie jest równy:

weights[class] * log(\sum_j exp(x[j]))

Aby to pokazać, możemy przepisać to drugie jako:

log( (\sum_j exp(x[j]) ^ weights[class] )

Oto kod, który proponuję:

ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([[ratio, 1.0 - ratio]])
logits = ... # shape [batch_size, 2]

weight_per_label = tf.transpose( tf.matmul(labels
                           , tf.transpose(class_weight)) ) #shape [1, batch_size]
# this is the weight for each datapoint, depending on its label

xent = tf.mul(weight_per_label
         , tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="xent_raw") #shape [1, batch_size]
loss = tf.reduce_mean(xent) #shape 1
 36
Author: JL Meunier,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2016-10-16 14:35:44

Użycie tf.nn.weighted_cross_entropy_with_logits() i ustawić pos_weight na 1 / (oczekiwany stosunek pozytywów).

 9
Author: Malay Haldar,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2017-11-25 19:43:10

Możesz sprawdzić prowadnice w tensorflow https://www.tensorflow.org/api_guides/python/contrib.losses

...

Określając stratę skalarną przeskalowujemy stratę na całą partię, czasami chcemy przeskalować stratę na próbkę partii. Na przykład, jeśli mamy pewne przykłady, które mają większe znaczenie dla nas, aby uzyskać poprawnie, możemy chcieć mieć większą stratę niż inne próbki, których błędy mają mniejsze znaczenie. W takim przypadku możemy podać wektor ciężaru długości batch_size, który powoduje, że strata dla każdej próbki w partii jest skalowana przez odpowiedni element wagowy. Na przykład, rozważmy przypadek problemu klasyfikacji, w którym chcemy zmaksymalizować naszą dokładność, ale szczególnie interesuje nas uzyskanie wysokiej dokładności dla określonej klasy:

inputs, labels = LoadData(batch_size=3)
logits = MyModelPredictions(inputs)

# Ensures that the loss for examples whose ground truth class is `3` is 5x
# higher than the loss for all other examples.
weight = tf.multiply(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1

onehot_labels = tf.one_hot(labels, num_classes=5)
tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)
 3
Author: Victor Mondejar-Guerra,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2017-06-05 14:56:50

Musiałem pracować z podobnym niesymetrycznym zestawem danych wielu klas i tak to przerabiałem, mam nadzieję, że pomoże to komuś szukającemu podobnego rozwiązania:

To wchodzi do twojego modułu treningowego:

from sklearn.utils.class_weight import compute_sample_weight
#use class weights for handling unbalanced dataset
if mode == 'INFER' #test/dev mode, not weighing loss in test mode
   sample_weights = np.ones(labels.shape)
else:
   sample_weights = compute_sample_weight(class_weight='balanced', y=labels)

To wchodzi w definicję klasy modelu:

#an extra placeholder for sample weights
#assuming you already have batch_size tensor
self.sample_weight = tf.placeholder(dtype=tf.float32, shape=[None],
                       name='sample_weights')
cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                       labels=self.label, logits=logits, 
                       name='cross_entropy_loss')
cross_entropy_loss = tf.reduce_sum(cross_entropy_loss*self.sample_weight) / batch_size
 2
Author: bitspersecond,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2017-12-01 17:00:46

Czy ops tf.nn.weighted_cross_entropy_with_logits () dla dwóch klas:

classes_weights = tf.constant([0.1, 1.0])
cross_entropy = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=labels, pos_weight=classes_weights)
 1
Author: Denis Shcheglov,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2017-02-10 15:34:17