Jak zainicjować wagi w PyTorch?

Jak zainicjować wagi i błędy (na przykład inicjalizacja He lub Xavier) w sieci w PyTorch?

Author: Fábio Perez, 2018-03-22

9 answers

Pojedyncza warstwa

Aby zainicjować wagę pojedynczej warstwy, użyj funkcji z torch.nn.init. Na przykład:

conv1 = torch.nn.Conv2d(...)
torch.nn.init.xavier_uniform(conv1.weight)

Alternatywnie można zmodyfikować parametry pisząc do conv1.weight.data (co jest torch.Tensor). Przykład:

conv1.weight.data.fill_(0.01)

To samo dotyczy uprzedzeń:

conv1.bias.data.fill_(0.01)

nn.Sequential lub custom nn.Module

Przekazać funkcję inicjalizacyjną do torch.nn.Module.apply. Inicjalizuje wagi w całym nn.Module rekurencyjnie.

Apply (fn): stosuje fn rekurencyjnie do każdego podmodułu (zwracanego przez .children()) oraz self. Typowe zastosowanie obejmuje inicjalizację parametrów modelu (patrz także torch-NN-init).

Przykład:

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
 179
Author: Fábio Perez,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-05-02 13:28:53

Porównujemy różne sposoby inicjalizacji wagi przy użyciu tej samej architektury sieci neuronowej(NN).

Wszystkie zera lub Jedynki

Jeśli zastosujesz się do zasady brzytwy Occama, możesz pomyśleć, że ustawienie wszystkich wag na 0 LUB 1 byłoby najlepszym rozwiązaniem. Tak nie jest.

Z każdą wagą taką samą, wszystkie neurony w każdej warstwie wytwarzają tę samą moc wyjściową. To sprawia, że trudno zdecydować, które wagi do dostosuj się.

    # initialize two NN's with 0 and 1 constant weights
    model_0 = Net(constant_weight=0)
    model_1 = Net(constant_weight=1)
  • Po 2 epokach:

Wykres odchudzania treningowego z inicjalizacją wagi do stałej

Validation Accuracy
9.625% -- All Zeros
10.050% -- All Ones
Training Loss
2.304  -- All Zeros
1552.281  -- All Ones

Jednolita Inicjalizacja

A równomierny rozkład ma równe prawdopodobieństwo wybrania dowolnej liczby ze zbioru liczb.

Zobaczmy, jak dobrze sieć neuronowa trenuje używając jednolitej inicjalizacji wagi, gdzie low=0.0 i high=1.0.

Poniżej zobaczymy inny sposób (poza kodem klasy Net) inicjalizacji wag sieci. Aby zdefiniować wagi poza definicją modelu możemy:

  1. Zdefiniuj funkcję, która przypisuje wagi przez typ warstwy sieciowej, następnie
  2. Zastosuj te wagi do zainicjalizowanego modelu za pomocą model.apply(fn), który stosuje funkcję do każdej warstwy modelu.
    # takes in a module and applies the specified weight initialization
    def weights_init_uniform(m):
        classname = m.__class__.__name__
        # for every Linear layer in a model..
        if classname.find('Linear') != -1:
            # apply a uniform distribution to the weights and a bias=0
            m.weight.data.uniform_(0.0, 1.0)
            m.bias.data.fill_(0)

    model_uniform = Net()
    model_uniform.apply(weights_init_uniform)
  • Po 2 epokach:

Tutaj wpisz opis obrazka

Validation Accuracy
36.667% -- Uniform Weights
Training Loss
3.208  -- Uniform Weights

Ogólna zasada ustawiania ciężarów

Ogólna zasada ustalania wag w układzie nerwowym sieć ma ustawić je tak, aby były bliskie zeru, nie będąc zbyt małe.

Dobrą praktyką jest rozpoczynanie wag w zakresie [- y, y] gdzie y=1/sqrt(n)
(n to liczba wejść do danego neuronu).

    # takes in a module and applies the specified weight initialization
    def weights_init_uniform_rule(m):
        classname = m.__class__.__name__
        # for every Linear layer in a model..
        if classname.find('Linear') != -1:
            # get the number of the inputs
            n = m.in_features
            y = 1.0/np.sqrt(n)
            m.weight.data.uniform_(-y, y)
            m.bias.data.fill_(0)

    # create a new model with these weights
    model_rule = Net()
    model_rule.apply(weights_init_uniform_rule)

Poniżej porównujemy wydajność NN, wag zainicjalizowanych z równomiernym rozkładem [-0.5,0.5) i tej, której waga jest zainicjalizowana za pomocą ogólna zasada

  • Po 2 epok:

wykres pokazujący skuteczność jednolitej inicjalizacji wagi a ogólna zasada inicjalizacji

Validation Accuracy
75.817% -- Centered Weights [-0.5, 0.5)
85.208% -- General Rule [-y, y)
Training Loss
0.705  -- Centered Weights [-0.5, 0.5)
0.469  -- General Rule [-y, y)

Rozkład normalny do inicjalizacji wag

Rozkład normalny powinien mieć średnią 0 i odchylenie standardowe y=1/sqrt(n), gdzie n jest liczbą wejść do NN

    ## takes in a module and applies the specified weight initialization
    def weights_init_normal(m):
        '''Takes in a module and initializes all linear layers with weight
           values taken from a normal distribution.'''

        classname = m.__class__.__name__
        # for every Linear layer in a model
        if classname.find('Linear') != -1:
            y = m.in_features
        # m.weight.data shoud be taken from a normal distribution
            m.weight.data.normal_(0.0,1/np.sqrt(y))
        # m.bias.data should be 0
            m.bias.data.fill_(0)

Poniżej pokazujemy działanie dwóch NN, z których jeden zainicjalizowany jest za pomocą rozkładu jednorodnego , a drugi za pomocą rozkładu normalnego

  • Po 2 epok:

wydajność inicjalizacji masy przy użyciu rozkładu jednorodnego w porównaniu z rozkładem normalnym

Validation Accuracy
85.775% -- Uniform Rule [-y, y)
84.717% -- Normal Distribution
Training Loss
0.329  -- Uniform Rule [-y, y)
0.443  -- Normal Distribution
 47
Author: ashunigion,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2019-04-06 06:37:46

Aby zainicjować warstwy, zazwyczaj nie trzeba nic robić.

PyTorch zrobi to za Ciebie. Jeśli się nad tym zastanowisz, to ma sens. Dlaczego warto inicjalizować warstwy, gdy PyTorch może to zrobić zgodnie z najnowszymi trendami.

Sprawdź na przykład warstwę liniową .

W metodzie __init__ wywoła Kaiming He funkcję init.

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(3))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

Podobne jest dla innych typów warstw. Dla conv2d na przykład sprawdź tutaj.

Uwaga : Zysk prawidłowej inicjalizacji to większa szybkość treningu. Jeśli twój problem zasługuje na specjalną inicjalizację, możesz to zrobić po słowach.

 24
Author: prosti,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2020-03-04 14:54:21
    import torch.nn as nn        

    # a simple network
    rand_net = nn.Sequential(nn.Linear(in_features, h_size),
                             nn.BatchNorm1d(h_size),
                             nn.ReLU(),
                             nn.Linear(h_size, h_size),
                             nn.BatchNorm1d(h_size),
                             nn.ReLU(),
                             nn.Linear(h_size, 1),
                             nn.ReLU())

    # initialization function, first checks the module type,
    # then applies the desired changes to the weights
    def init_normal(m):
        if type(m) == nn.Linear:
            nn.init.uniform_(m.weight)

    # use the modules apply function to recursively apply the initialization
    rand_net.apply(init_normal)
 10
Author: Duane,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-12-29 03:06:33

Przepraszam za spóźnienie, mam nadzieję, że moja odpowiedź pomoże.

Do inicjalizacji wag za pomocą normal distribution Użyj:

torch.nn.init.normal_(tensor, mean=0, std=1)

Lub użyć constant distribution napisać:

torch.nn.init.constant_(tensor, value)

Lub użyć uniform distribution:

torch.nn.init.uniform_(tensor, a=0, b=1) # a: lower_bound, b: upper_bound

Możesz sprawdzić inne metody inicjalizacji tensorów tutaj

 5
Author: Luca Di Liello,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2019-03-10 23:11:10

Jeśli potrzebujesz dodatkowej elastyczności, Możesz również ustawić wagi ręcznie .

Powiedz, że masz wszystkie wejścia:

import torch
import torch.nn as nn

input = torch.ones((8, 8))
print(input)
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

I chcesz zrobić gęstą warstwę bez uprzedzeń (abyśmy mogli wizualizować):

d = nn.Linear(8, 8, bias=False)

Ustaw wszystkie wagi na 0,5 (lub cokolwiek innego):

d.weight.data = torch.full((8, 8), 0.5)
print(d.weight.data)

Waga:

Out[14]: 
tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000]])
Wszystkie Twoje ciężary wynoszą 0,5. Podaj dane:
d(input)
Out[13]: 
tensor([[4., 4., 4., 4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4., 4., 4., 4.]], grad_fn=<MmBackward>)

Pamiętaj, że każdy neuron otrzymuje 8 wejść, z których wszystkie mają waga 0,5 i wartość 1 (i bez biasu), więc sumuje się do 4 dla każdego.

 4
Author: Nicolas Gervais,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2019-12-22 03:43:07

Iteracja nad parametrami

Jeśli nie możesz użyć apply, na przykład jeśli model nie zaimplementuje Sequential bezpośrednio:

To samo dla wszystkich

# see UNet at https://github.com/milesial/Pytorch-UNet/tree/master/unet


def init_all(model, init_func, *params, **kwargs):
    for p in model.parameters():
        init_func(p, *params, **kwargs)

model = UNet(3, 10)
init_all(model, torch.nn.init.normal_, mean=0., std=1) 
# or
init_all(model, torch.nn.init.constant_, 1.) 

W zależności od kształtu

def init_all(model, init_funcs):
    for p in model.parameters():
        init_func = init_funcs.get(len(p.shape), init_funcs["default"])
        init_func(p)

model = UNet(3, 10)
init_funcs = {
    1: lambda x: torch.nn.init.normal_(x, mean=0., std=1.), # can be bias
    2: lambda x: torch.nn.init.xavier_normal_(x, gain=1.), # can be weight
    3: lambda x: torch.nn.init.xavier_uniform_(x, gain=1.), # can be conv1D filter
    4: lambda x: torch.nn.init.xavier_uniform_(x, gain=1.), # can be conv2D filter
    "default": lambda x: torch.nn.init.constant(x, 1.), # everything else
}

init_all(model, init_funcs)

Możesz spróbować za pomocą torch.nn.init.constant_(x, len(x.shape)), aby sprawdzić, czy są one odpowiednio zainicjowane:

init_funcs = {
    "default": lambda x: torch.nn.init.constant_(x, len(x.shape))
}
 2
Author: ted,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2019-07-19 16:08:46

Ponieważ do tej pory nie miałem wystarczającej reputacji, nie mogę dodać komentarza pod

ODPOWIEDŹ dodana przez prosti W Jun 26 '19 o 13: 16.

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(3))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

Ale chcę podkreślić, że w rzeczywistości znamy pewne założenia w artykule Kaiming He, zagłębianie się w prostowniki: przekraczanie wydajności na poziomie człowieka w klasyfikacji ImageNet , nie jest właściwe, choć wygląda na to, że celowo zaprojektowana metoda inicjalizacji sprawia, że hit w praktyce.

Np. w podsekcji Przypadku propagacji wstecznej zakładają, że $w_l$ i $\delta y_l$ są od siebie niezależne. Ale jak wszyscy wiemy, weźmy mapę wyniku $ \ delta y^L_i$ jako instancję, często jest to $y_i-softmax(y^L_i)=y_i-softmax(w^l_ix^L_i)$ jeśli używamy typowego celu funkcji strat entropii krzyżowej.

Więc myślę, że prawdziwy powód, dla którego Inicjalizacja działa dobrze, pozostaje do rozwikłania. Bo wszyscy byli świadkami jego moc na zwiększenie szkolenia deep learning.

 1
Author: Glory Chen,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2020-03-09 02:06:22

Jeśli widzisz ostrzeżenie o deprecjacji (@Fábio Perez)...

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
 0
Author: Joseph Konan,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2019-05-08 09:08:16