Co robi batch, repeat i shuffle z zestawem danych TensorFlow?

Obecnie uczę się TensorFlow, ale natknąłem się na zamieszanie w poniższym fragmencie kodu:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

Wiem, że najpierw zbiór danych będzie zawierał wszystkie dane, ale co shuffle(),repeat(), i batch() zrobić do zbioru danych? Proszę o pomoc z przykładem i wyjaśnieniem.

Author: Purushothaman Srikanth, 2018-11-28

3 answers

Update: tutaj jest mały notatnik współpracy dla demonstracji tej odpowiedzi.

Wyobraź sobie, że masz zbiór danych: [1, 2, 3, 4, 5, 6], Następnie:

Jak ds.shuffle () działa

dataset.shuffle(buffer_size=3) przydziela bufor o rozmiarze 3 do wybierania losowych wpisów. Bufor ten zostanie podłączony do zbioru danych źródłowych. Możemy to sobie wyobrazić tak:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

Załóżmy, że wpis {[7] } został pobrany z bufora losowego. Wolne miejsce wypełnia kolejny element z bufor źródłowy, czyli 4:

2 <= [1,3,4] <= [5,6]

Kontynuujemy czytanie, aż nic nie zostanie:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []

Jak ds.repeat () działa

Gdy tylko wszystkie wpisy zostaną odczytane ze zbioru danych i spróbujesz odczytać następny element, zbiór danych spowoduje wyświetlenie błędu. To jest miejsce, gdzie ds.repeat() wchodzi w grę. Spowoduje to ponowną inicjalizację zestawu danych, czyniąc go ponownie tak:

[1,2,3] <= [4,5,6]

Co da ds.Nr serii ()

The ds.batch() will take first batch_size entries and make dużo z nich. Tak więc, rozmiar wsadu 3 dla naszego przykładowego zbioru danych spowoduje utworzenie dwóch rekordów wsadowych:

[2,1,5]
[3,6,4]

Ponieważ mamy ds.repeat() przed wsadem, generowanie danych będzie kontynuowane. Ale kolejność elementów będzie inna, ze względu na ds.random(). Należy wziąć pod uwagę, że 6 nigdy nie będzie obecny w pierwszej partii, ze względu na rozmiar losowego bufora.

 75
Author: Vlad-HC,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2021-01-02 19:01:35

Następujące metody w tf.Dataset:

  1. repeat( count=0 ) metoda powtarza zbiór danych count wiele razy.
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) metoda tasuje próbki w zbiorze danych. buffer_size jest liczbą próbek, które są randomizowane i zwracane jako tf.Dataset.
  3. batch(batch_size,drop_remainder=False) tworzy partie zbioru danych o rozmiarze partii podanym jako batch_size, która jest również długością partii.
 5
Author: ,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-11-28 10:53:19

Przykład, który pokazuje zapętlenie epok. Po uruchomieniu tego skryptu zwróć uwagę na różnicę w

  • dataset_gen1 - Operacja shuffle generuje więcej losowych wyjść (może to być bardziej przydatne podczas uruchamiania eksperymentów uczenia maszynowego)
  • dataset_gen2 - brak operacji shuffle wytwarza elementy w sekwencji

Inne dodatki w tym skrypcie

  • tf.data.experimental.sample_from_datasets - służy do łączenia dwóch zestawów danych. Zauważ, że operacja shuffle w tym przypadku tworzy bufor że próbki w równym stopniu z obu zestawów danych.
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"

import tensorflow as tf
if len(tf.config.list_physical_devices('GPU')):
    tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)

class Augmentations:

    def __init__(self):
        pass

    @tf.function
    def filter_even(self, x):
        if x % 2 == 0:
            return False
        else:
            return True

class Dataset:

    def __init__(self, aug, range_min=0, range_max=100):
        self.range_min = range_min
        self.range_max = range_max
        self.aug = aug

    def generator(self):
        dataset = tf.data.Dataset.from_generator(self._generator
                        , output_types=(tf.float32), args=())

        dataset = dataset.filter(self.aug.filter_even)

        return dataset
    
    def _generator(self):
        for item in range(self.range_min, self.range_max):
            yield(item)

# Can be used when you have multiple datasets that you wish to combine
class ZipDataset:

    def __init__(self, datasets):
        self.datasets = datasets
        self.datasets_generators = []
    
    def generator(self):
        for dataset in self.datasets:
            self.datasets_generators.append(dataset.generator())
        return tf.data.experimental.sample_from_datasets(self.datasets_generators)

if __name__ == "__main__":
    aug = Augmentations()
    dataset1 = Dataset(aug, 0, 100)
    dataset2 = Dataset(aug, 100, 200)
    dataset = ZipDataset([dataset1, dataset2])

    epochs = 2
    shuffle_buffer = 10
    batch_size = 4
    prefetch_buffer = 5

    dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
    # dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence 

    for epoch in range(epochs):
        print ('\n ------------------ Epoch: {} ------------------'.format(epoch))
        for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
            print (X)
        
        # Do some stuff at end of loop
 0
Author: pmod,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2020-11-20 13:47:26