Jak Pytorch Dataloader obsługuje dane o zmiennych rozmiarach?
Mam zestaw danych, który wygląda jak poniżej. To jest pierwszy element jest ID użytkownika, a następnie zestaw elementów, które są klikane przez użytkownika.
0 24104 27359 6684
0 24104 27359
1 16742 31529 31485
1 16742 31529
2 6579 19316 13091 7181 6579 19316 13091
2 6579 19316 13091 7181 6579 19316
2 6579 19316 13091 7181 6579 19316 13091 6579
2 6579 19316 13091 7181 6579
4 19577 21608
4 19577 21608
4 19577 21608 18373
5 3541 9529
5 3541 9529
6 6832 19218 14144
6 6832 19218
7 9751 23424 25067 12606 26245 23083 12606
Definiuję niestandardowy zestaw danych do obsługi moich danych dziennika kliknięć.
import torch.utils.data as data
class ClickLogDataset(data.Dataset):
def __init__(self, data_path):
self.data_path = data_path
self.uids = []
self.streams = []
with open(self.data_path, 'r') as fdata:
for row in fdata:
row = row.strip('\n').split('\t')
self.uids.append(int(row[0]))
self.streams.append(list(map(int, row[1:])))
def __len__(self):
return len(self.uids)
def __getitem__(self, idx):
uid, stream = self.uids[idx], self.streams[idx]
return uid, stream
Następnie używam Dataloadera do pobierania mini partii z danych do treningu.
from torch.utils.data.dataloader import DataLoader
clicklog_dataset = ClickLogDataset(data_path)
clicklog_data_loader = DataLoader(dataset=clicklog_dataset, batch_size=16)
for uid_batch, stream_batch in stream_data_loader:
print(uid_batch)
print(stream_batch)
Powyższy kod zwraca się inaczej niż oczekiwałem, chcę, aby stream_batch
był tensorem 2D typu integer o długości 16
. Jednak to co dostaję to lista tensorów 1D o długości 16, A lista ma tylko jeden element, jak poniżej. Dlaczego ?
#stream_batch
[tensor([24104, 24104, 16742, 16742, 6579, 6579, 6579, 6579, 19577, 19577,
19577, 3541, 3541, 6832, 6832, 9751])]
3 answers
Więc jak sobie radzisz z tym, że twoje próbki są różnej długości? torch.utils.data.DataLoader
posiada parametr collate_fn
, który służy do przekształcenia listy próbek w partię. By default it does this to lists. Możesz napisać własne collate_fn
, które na przykład 0
-blokuje dane wejściowe, obraca je do określonej wcześniej długości lub stosuje dowolną inną wybraną operację.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2019-03-07 10:23:01
Tak to robię:
def collate_fn_padd(batch):
'''
Padds batch of variable length
note: it converts things ToTensor manually here since the ToTensor transform
assume it takes in images rather than arbitrary tensors.
'''
## get sequence lengths
lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
## padd
batch = [ torch.Tensor(t).to(device) for t in batch ]
batch = torch.nn.utils.rnn.pad_sequence(batch)
## compute mask
mask = (batch != 0).to(device)
return batch, lengths, mask
Następnie przekazuję to do klasy dataloader jako collate_fn
.
Wydaje się, że istnieje ogromna lista różnych postów na forum pytorch. Pozwól mi połączyć się z nimi wszystkimi. Wszyscy mają własne odpowiedzi i dyskusje. Nie wydaje mi się, że istnieje jeden "standardowy sposób, aby to zrobić", ale jeśli jest z autorytatywnego odniesienia, podziel się.
Byłoby miło, że idealna odpowiedź wspomina
- efektywność, np. jeśli do wykonaj przetwarzanie w GPU za pomocą palnika w funkcji collate vs numpy
Rzeczy z tego Sortuj.
Lista:
- https://discuss.pytorch.org/t/how-to-create-batches-of-a-list-of-varying-dimension-tensors/50773
- https://discuss.pytorch.org/t/how-to-create-a-dataloader-with-variable-size-input/8278
- https://discuss.pytorch.org/t/using-variable-sized-input-is-padding-required/18131
- https://discuss.pytorch.org/t/dataloader-for-various-length-of-data/6418
- https://discuss.pytorch.org/t/how-to-do-padding-based-on-lengths/24442
Bucketing: - https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2019-07-25 18:07:46
Jak zasugerował @Jatentaki, napisałem swoją niestandardową funkcję sortowania i działała dobrze.
def get_max_length(x):
return len(max(x, key=len))
def pad_sequence(seq):
def _pad(_it, _max_len):
return [0] * (_max_len - len(_it)) + _it
return [_pad(it, get_max_length(seq)) for it in seq]
def custom_collate(batch):
transposed = zip(*batch)
lst = []
for samples in transposed:
if isinstance(samples[0], int):
lst.append(torch.LongTensor(samples))
elif isinstance(samples[0], float):
lst.append(torch.DoubleTensor(samples))
elif isinstance(samples[0], collections.Sequence):
lst.append(torch.LongTensor(pad_sequence(samples)))
return lst
stream_dataset = StreamDataset(data_path)
stream_data_loader = torch.utils.data.dataloader.DataLoader(dataset=stream_dataset,
batch_size=batch_size,
collate_fn=custom_collate,
shuffle=False)
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2019-03-08 08:56:55