Как я могу распараллелить цикл for для использования в PyTorch?

Я понимаю, что циклы for медленные с Python в целом. У меня есть код, который работает с некоторыми тензорами:


            for batch_index, mask_batch in enumerate(mask):
                mask_len = torch.sum(mask_batch).int()

                if mask_len == 0:
                    side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
                else:

                    m_nonzero = mask_batch.nonzero().flatten()
                    first_nonzero = m_nonzero[0]
                    last_nonzero = m_nonzero[-1]

                    if side == 'left':
                        end_index = first_nonzero - 1
                        start_index = 0
                    elif side == 'right':
                        start_index = last_nonzero + 1
                        end_index = inputs[batch_index].size(1)

                    side_input = inputs[batch_index][start_index:end_index]

                    if end_index - start_index < max_inp_len:
                        pad_zeros = torch.zeros(
                            (max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
                        if side == 'left':
                            side_input = torch.cat((pad_zeros, side_input), 0)
                        elif side == 'right':
                            side_input = torch.cat((side_input, pad_zeros), 0)

                side_inputs.append(side_input)

        return torch.stack(side_inputs)

Я чувствую, что этот цикл ДЕЙСТВИТЕЛЬНО замедляет работу. Есть ли способ сделать это без цикла?


person Shamoon    schedule 07.04.2020    source источник
comment
Разве цикл for не должен быть медленным... Как вы можете так говорить? В любом случае, если вы выполняете много медленных операций в своих циклах, рассмотрите возможность использования threading или multiprocessing библиотек.   -  person decadenza    schedule 11.04.2020
comment
Если ваш цикл for состоит из перебора тензора и выполнения некоторых операций над элементами тензора, вам следует попытаться векторизовать операции.   -  person gorjan    schedule 13.04.2020
comment
Возможно ли, что это медленно, потому что вы накапливаете градиенты? Что, если поставить with torch.no_grad() перед всем кодом?   -  person Corey Levinson    schedule 16.04.2020


Ответы (2)


Python не имеет настоящего параллелизма внутри любого заданного процесса. Вам нужно будет создать ProcessPool и сделать внутреннюю часть вашего цикла функцией, принимающей batch_index, mask_batch, а затем сопоставить эту функцию с объектом mask в текущем цикле for. Дело в том, что я не знаю, будет ли PyTorch хорошо с этим работать.

Вот так

def f(batch_index, mask_batch):
    mask_len = torch.sum(mask_batch).int()

    if mask_len == 0:
        side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
    else:
        m_nonzero = mask_batch.nonzero().flatten()
        first_nonzero = m_nonzero[0]
        last_nonzero = m_nonzero[-1]

        if side == 'left':
            end_index = first_nonzero - 1
            start_index = 0
        elif side == 'right':
            start_index = last_nonzero + 1
            end_index = inputs[batch_index].size(1)

            side_input = inputs[batch_index][start_index:end_index]

            if end_index - start_index < max_inp_len:
                pad_zeros = torch.zeros((max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
                if side == 'left':
                    side_input = torch.cat((pad_zeros, side_input), 0)
                elif side == 'right':
                    side_input = torch.cat((side_input, pad_zeros), 0)
    return side_input

Другие вещи, на которые вы можете обратить внимание, — это дальнейшая векторизация кода. Большинство вещей в PyTorch и Numpy можно векторизовать, используя встроенные функции и добавляя еще одно измерение к вашим тензорам, которое представляет измерение «петли». Это позволит PyTorch обрабатывать параллелизм за вас.

PyTorch может иметь концепцию устройств, на которые вы можете накладывать разные итерации цикла, опять же, это потребует от вас создания функции для этого цикла и, возможно, использования устройства, которое оно использует, в качестве входных данных.

Наконец, вы можете воспользоваться своевременной компиляцией, такой как Numba или torch.jit, чтобы выполнить автоматическую векторизацию для вас.

Ничего из этого не сработает (скорее всего), если mask имеет неизвестную длину. Если он имеет известную длину, я думаю, что векторизация, какой бы сложной она ни была, вероятно, ваш лучший выбор.

person Ryan    schedule 14.04.2020

Вы должны создать функцию, содержащую логику итерации цикла, и запустить ее как поток для каждого столбца (см. документы здесь). Вы также можете использовать библиотеку asyncio для параллелизма, но вы, вероятно, получите меньше улучшений.

Хороший пример порождения потока для каждого элемента списка можно прочитать здесь.

person Antoine Perry    schedule 12.04.2020