Co robi batch, repeat i shuffle z zestawem danych TensorFlow?
Obecnie uczę się TensorFlow, ale natknąłem się na zamieszanie w poniższym fragmencie kodu:
dataset = dataset.shuffle(buffer_size = 10 * batch_size)
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
Wiem, że najpierw zbiór danych będzie zawierał wszystkie dane, ale co shuffle()
,repeat()
, i batch()
zrobić do zbioru danych?
Proszę o pomoc z przykładem i wyjaśnieniem.
3 answers
Update: tutaj jest mały notatnik współpracy dla demonstracji tej odpowiedzi.
Wyobraź sobie, że masz zbiór danych: [1, 2, 3, 4, 5, 6]
, Następnie:
Jak ds.shuffle () działa
dataset.shuffle(buffer_size=3)
przydziela bufor o rozmiarze 3 do wybierania losowych wpisów. Bufor ten zostanie podłączony do zbioru danych źródłowych.
Możemy to sobie wyobrazić tak:
Random buffer
|
| Source dataset where all other elements live
| |
↓ ↓
[1,2,3] <= [4,5,6]
Załóżmy, że wpis {[7] } został pobrany z bufora losowego. Wolne miejsce wypełnia kolejny element z bufor źródłowy, czyli 4
:
2 <= [1,3,4] <= [5,6]
Kontynuujemy czytanie, aż nic nie zostanie:
1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
Jak ds.repeat () działa
Gdy tylko wszystkie wpisy zostaną odczytane ze zbioru danych i spróbujesz odczytać następny element, zbiór danych spowoduje wyświetlenie błędu.
To jest miejsce, gdzie ds.repeat()
wchodzi w grę. Spowoduje to ponowną inicjalizację zestawu danych, czyniąc go ponownie tak:
[1,2,3] <= [4,5,6]
Co da ds.Nr serii ()
The ds.batch()
will take first batch_size
entries and make dużo z nich. Tak więc, rozmiar wsadu 3 dla naszego przykładowego zbioru danych spowoduje utworzenie dwóch rekordów wsadowych:
[2,1,5]
[3,6,4]
Ponieważ mamy ds.repeat()
przed wsadem, generowanie danych będzie kontynuowane. Ale kolejność elementów będzie inna, ze względu na ds.random()
. Należy wziąć pod uwagę, że 6
nigdy nie będzie obecny w pierwszej partii, ze względu na rozmiar losowego bufora.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2021-01-02 19:01:35
Następujące metody w tf.Dataset:
-
repeat( count=0 )
metoda powtarza zbiór danychcount
wiele razy. -
shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)
metoda tasuje próbki w zbiorze danych.buffer_size
jest liczbą próbek, które są randomizowane i zwracane jakotf.Dataset
. -
batch(batch_size,drop_remainder=False)
tworzy partie zbioru danych o rozmiarze partii podanym jakobatch_size
, która jest również długością partii.
Przykład, który pokazuje zapętlenie epok. Po uruchomieniu tego skryptu zwróć uwagę na różnicę w
-
dataset_gen1
- Operacja shuffle generuje więcej losowych wyjść (może to być bardziej przydatne podczas uruchamiania eksperymentów uczenia maszynowego) -
dataset_gen2
- brak operacji shuffle wytwarza elementy w sekwencji
Inne dodatki w tym skrypcie
-
tf.data.experimental.sample_from_datasets
- służy do łączenia dwóch zestawów danych. Zauważ, że operacja shuffle w tym przypadku tworzy bufor że próbki w równym stopniu z obu zestawów danych.
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"
import tensorflow as tf
if len(tf.config.list_physical_devices('GPU')):
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
class Augmentations:
def __init__(self):
pass
@tf.function
def filter_even(self, x):
if x % 2 == 0:
return False
else:
return True
class Dataset:
def __init__(self, aug, range_min=0, range_max=100):
self.range_min = range_min
self.range_max = range_max
self.aug = aug
def generator(self):
dataset = tf.data.Dataset.from_generator(self._generator
, output_types=(tf.float32), args=())
dataset = dataset.filter(self.aug.filter_even)
return dataset
def _generator(self):
for item in range(self.range_min, self.range_max):
yield(item)
# Can be used when you have multiple datasets that you wish to combine
class ZipDataset:
def __init__(self, datasets):
self.datasets = datasets
self.datasets_generators = []
def generator(self):
for dataset in self.datasets:
self.datasets_generators.append(dataset.generator())
return tf.data.experimental.sample_from_datasets(self.datasets_generators)
if __name__ == "__main__":
aug = Augmentations()
dataset1 = Dataset(aug, 0, 100)
dataset2 = Dataset(aug, 100, 200)
dataset = ZipDataset([dataset1, dataset2])
epochs = 2
shuffle_buffer = 10
batch_size = 4
prefetch_buffer = 5
dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
# dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence
for epoch in range(epochs):
print ('\n ------------------ Epoch: {} ------------------'.format(epoch))
for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
print (X)
# Do some stuff at end of loop
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2020-11-20 13:47:26