Sklep SGDClassifier partial fit

Próbuję użyć SGD do sklasyfikowania dużego zbioru danych. Ponieważ dane są zbyt duże, aby zmieścić się w pamięci, chciałbym użyć metody partial_fit do wytrenowania klasyfikatora. Wybrałem próbkę zbioru danych (100 000 wierszy), który pasuje do pamięci, aby przetestować fit vs. partial_fit :

from sklearn.linear_model import SGDClassifier

def batches(l, n):
    for i in xrange(0, len(l), n):
        yield l[i:i+n]

clf1 = SGDClassifier(shuffle=True, loss='log')
clf1.fit(X, Y)

clf2 = SGDClassifier(shuffle=True, loss='log')
n_iter = 60
for n in range(n_iter):
    for batch in batches(range(len(X)), 10000):
        clf2.partial_fit(X[batch[0]:batch[-1]+1], Y[batch[0]:batch[-1]+1], classes=numpy.unique(Y))

Następnie testuję oba klasyfikatory z identycznym zestawem testów. W pierwszym przypadku uzyskuję dokładność 100%. Jak rozumiem SGD domyślnie przechodzi 5 razy nad danymi treningowymi (n_iter = 5).

W drugim przypadku, muszę przejść 60 razy nad danymi, aby osiągnąć taką samą dokładność.

Dlaczego ta różnica (5 vs. 60)? Czy robię coś nie tak?

Author: David M., 2014-07-07

1 answers

W końcu znalazłem odpowiedź. Musisz przetasować dane treningowe pomiędzy każdą iteracją , ponieważ ustawienie shuffle=True podczas tworzenia instancji model nie przetasuje danych podczas używania partial_fit (dotyczy tylko fit). Uwaga: pomocne byłoby znaleźć te informacje na sklepn.linear_model.Strona SGDClassifier .

Zmieniony kod otrzymuje brzmienie:

from sklearn.linear_model import SGDClassifier
import random
clf2 = SGDClassifier(loss='log') # shuffle=True is useless here
shuffledRange = range(len(X))
n_iter = 5
for n in range(n_iter):
    random.shuffle(shuffledRange)
    shuffledX = [X[i] for i in shuffledRange]
    shuffledY = [Y[i] for i in shuffledRange]
    for batch in batches(range(len(shuffledX)), 10000):
        clf2.partial_fit(shuffledX[batch[0]:batch[-1]+1], shuffledY[batch[0]:batch[-1]+1], classes=numpy.unique(Y))
 51
Author: David M.,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2014-07-15 09:51:27