Я понимаю, что циклы for
медленные с Python
в целом. У меня есть код, который работает с некоторыми тензорами:
for batch_index, mask_batch in enumerate(mask):
mask_len = torch.sum(mask_batch).int()
if mask_len == 0:
side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
else:
m_nonzero = mask_batch.nonzero().flatten()
first_nonzero = m_nonzero[0]
last_nonzero = m_nonzero[-1]
if side == 'left':
end_index = first_nonzero - 1
start_index = 0
elif side == 'right':
start_index = last_nonzero + 1
end_index = inputs[batch_index].size(1)
side_input = inputs[batch_index][start_index:end_index]
if end_index - start_index < max_inp_len:
pad_zeros = torch.zeros(
(max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
if side == 'left':
side_input = torch.cat((pad_zeros, side_input), 0)
elif side == 'right':
side_input = torch.cat((side_input, pad_zeros), 0)
side_inputs.append(side_input)
return torch.stack(side_inputs)
Я чувствую, что этот цикл ДЕЙСТВИТЕЛЬНО замедляет работу. Есть ли способ сделать это без цикла?
threading
илиmultiprocessing
библиотек. - person decadenza   schedule 11.04.2020with torch.no_grad()
перед всем кодом? - person Corey Levinson   schedule 16.04.2020