KeyError при перечислении через загрузчик данных

Я пытаюсь перебрать загрузчик данных pytorch, инициализированный следующим образом:

trainDL = torch.utils.data.DataLoader (X_train, batch_size = BATCH_SIZE, shuffle = True, ** kwargs)

где X_train - это фрейм данных pandas, подобный этому:  мои панды DF

Итак, я не могу выполнить следующий оператор, так как я получаю KeyError в 'enumerate':

for batch_idx, (data, _) in enumerate(trainDL):
    {stuff}

Кто-нибудь знает, что происходит?

РЕДАКТИРОВАТЬ:

Я получаю следующую ошибку:

KeyError                                  Traceback (most recent call last)
~/.local/share/virtualenvs/Pipenv-l_wD1rT4/lib/python3.6/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2896             try:
-> 2897                 return self._engine.get_loc(key)
   2898             except KeyError:

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 40592

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
<ipython-input-63-95142e0748bb> in <module>
----> 1 for batch_idx, (data, _) in enumerate(trainDL):
      2     print(".")

~/.local/share/virtualenvs/Pipenv-l_wD1rT4/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    344     def __next__(self):
    345         index = self._next_index()  # may raise StopIteration
--> 346         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    347         if self._pin_memory:
    348             data = _utils.pin_memory.pin_memory(data)

~/.local/share/virtualenvs/Pipenv-l_wD1rT4/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/.local/share/virtualenvs/Pipenv-l_wD1rT4/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/.local/share/virtualenvs/Pipenv-l_wD1rT4/lib/python3.6/site-packages/pandas/core/frame.py in __getitem__(self, key)
   2993             if self.columns.nlevels > 1:
   2994                 return self._getitem_multilevel(key)
-> 2995             indexer = self.columns.get_loc(key)
   2996             if is_integer(indexer):
   2997                 indexer = [indexer]

~/.local/share/virtualenvs/Pipenv-l_wD1rT4/lib/python3.6/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2897                 return self._engine.get_loc(key)
   2898             except KeyError:
-> 2899                 return self._engine.get_loc(self._maybe_cast_indexer(key))
   2900         indexer = self.get_indexer([key], method=method, tolerance=tolerance)
   2901         if indexer.ndim > 1 or indexer.size > 1:

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 40592

person sooaran    schedule 29.10.2019    source источник
comment
Покажите фактическую ошибку.   -  person Daniel Roseman    schedule 29.10.2019


Ответы (1)


Вы должны создать torch.utils.data.Dataset упаковку для вашего набора данных.

Например:

from torch.utils.data import Dataset

class PandasDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        return self.dataframe.iloc[index]

Передайте этот объект в DataLoader, созданный вашим фреймом данных pandas, и все будет в порядке.

Пример использования с DataLoader:

import pandas as pd

df = pd.read_csv("data.csv")
dataset = PandasDataset(df)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16)
for sample in dataloader:
    ...
person Szymon Maszke    schedule 29.10.2019
comment
Здравствуйте, довольно интересный подход. Не могли бы вы привести пример того, как использовать это для загрузчика данных? Лично я использую настройку обучения pytorch по умолчанию. - person George Petropoulos; 24.06.2020
comment
@GeorgePetropoulos добавил пример - person Szymon Maszke; 24.06.2020
comment
С приведенным выше кодом я получаю следующую ошибку: TypeError: default_collate: пакет должен содержать тензоры, массивы numpy, числа, dicts или списки; найдено Pandas.series - person Raja C; 13.06.2021
comment
@RajaC преобразует вывод в torch.Tensor (по крайней мере, фрейм данных) - person Szymon Maszke; 13.06.2021