Распараллеливание цикла «первый выживший побеждает»

У меня есть проблема, которая при упрощении:

  1. имеет цикл, который отбирает новые точки
  2. оценивает их с помощью сложной/медленной функции
  3. принимает их, если значение превышает постоянно увеличивающийся порог.

Вот пример кода для иллюстрации:

from numpy.random import uniform
from time import sleep

def userfunction(x):
    # do something complicated
    # but computation always takes takes roughly the same time
    sleep(1) # comment this out if too slow
    xnew = uniform() # in reality, a non-trivial function of x
    y = -0.5 * xnew**2
    return xnew, y

x0, cur = userfunction([])
x = [x0] # a sequence of points

while cur < -2e-16:
    # this should be parallelised

    # search for a new point higher than a threshold
    x1, next = userfunction(x)
    if next <= cur:
        # throw away (this branch is taken 99% of the time)
        pass
    else:
        cur = next
        print cur
        x.append(x1) # note that userfunction depends on x

print x

Я хочу распараллелить это (например, в кластере), но проблема в том, что мне нужно остановить других рабочих, когда будет найдена успешная точка, или, по крайней мере, сообщить им о новом x (если им удастся превысить новый порог с более старым x, результат все еще приемлем). Пока ни одна точка не увенчалась успехом, мне нужно, чтобы рабочие повторили.

Я ищу инструменты/фреймворки, которые могут решать проблемы такого типа на любом научном языке программирования (C, C++, Python, Julia и т. д., пожалуйста, не на Fortran).

Можно ли это решить с помощью MPI полуэлегантно? Я не понимаю, как я могу информировать/прерывать/обновлять воркеров с помощью MPI.

Обновление: добавлены комментарии к коду, чтобы сказать, что большинство попыток неуспешны и не влияют на переменную, от которой зависит пользовательская функция.


person j13r    schedule 01.09.2017    source источник
comment
В пользовательской функции вам придется время от времени проверять, не было ли найдено лучшее решение другими потоками.   -  person Serge Rogatch    schedule 01.09.2017
comment
@SergeRogatch, не потребуется ли связь N ^ 2? В качестве альтернативы я мог бы заставить рабочих запрашивать у основной программы текущий x. В моей задаче успешное получение новой точки обычно происходит только в 1/1000 раз, поэтому было бы много бесполезных звонков, если бы просили работники.   -  person j13r    schedule 01.09.2017
comment
Нет, определенно не N*N связи. Рабочий информирует основной поток о наилучшем найденном значении. Основной поток сообщает об этом событии и значении всем остальным рабочим процессам. Другие рабочие время от времени проверяют это событие, и в зависимости от того, есть ли у них лучшее значение, они либо сообщают об этом основному потоку, либо завершают работу.   -  person Serge Rogatch    schedule 01.09.2017
comment
Тесно связанные stackoverflow.com/questions/43973504/   -  person Zulan    schedule 01.09.2017
comment
Не могли бы вы запустить второй поток в каждом процессе MPI, который выполняется параллельно с вашим основным кодом. Затем он будет сидеть в цикле, ожидая (блокируя) сообщение MPI, помеченное как «NEWSURVIVOR», и когда он его получит, он изменит атомарную переменную, совместно используемую с основным потоком. Основной поток будет проверять эту переменную каждый раз в своем цикле. Когда найдется новый выживший, вы просто транслируете его с тегом «NEWSURVIVOR». Просто мысль....   -  person Mark Setchell    schedule 01.09.2017


Ответы (2)


если userfunction() не займет слишком много времени, то вот вариант, который подходит для "полуэлегантно MPI"

для простоты предположим, что ранг 0 — это только оркестратор и ничего не вычисляет.

на ранге 0

cur = 0
x = []
while cur < -2e-16:
    MPI_Recv(buf=cur+x1, src=MPI_ANY_SOURCE)
    x.append(x1)
    MPI_Ibcast(buf=cur+x, root=0, request=req)
    MPI_Wait(request=req)

по рангу != 0

x0, cur = userfunction([])
x = [x0] # a sequence of points

while cur < -2e-16:
    MPI_Ibcast(buf=newcur+newx, root=0, request=req
    # search for a new point higher than a threshold
    x1, next = userfunction(x)
    if next <= cur:
        # throw away (this branch is taken 99% of the time)
        MPI_Test(request=ret, flag=found)
        if found:
            MPI_Wait(request)   
    else:
        cur = next
        MPI_Send(buffer=cur+x1, dest=0)
        MPI_Wait(request)

для правильной обработки требуется дополнительная логика - ранг 0 также выполняет вычисления - несколько рангов находят решение одновременно, последующие сообщения должны потребляться рангом 0

строго говоря, задача не «прерывается», когда найдено решение другой задачи. вместо этого каждая задача периодически проверяет, не было ли решение найдено другой задачей. поэтому есть задержка между временем, когда решение, если оно где-то найдено, и все задачи перестают искать решения, но если userfunction() не занимает "слишком долго", это выглядит очень приемлемым для меня.

person Gilles Gouaillardet    schedule 01.09.2017
comment
Это похоже на хорошее начало. Что дальше()? Я думаю, что в этом решении отсутствует внешний цикл. Как я могу обновить рабочих, чтобы сообщить им о новом массиве x и новом пороговом значении cur? - person j13r; 01.09.2017
comment
Эта зависимость от данных является трудностью. Вы можете предположить, что первая ветвь if используется примерно в 99% случаев, поэтому зависимость довольно слабая. Однако, если будет взят другой, все работники должны быть проинформированы. - person j13r; 01.09.2017
comment
пользовательская функция действительно недетерминирована (в начале у нее есть вызов случайного числа uniform()). - person j13r; 01.09.2017
comment
о, я вижу, я отредактировал свой ответ и соответственно удалил свои комментарии - person Gilles Gouaillardet; 01.09.2017
comment
Хм, у меня проблемы с пониманием сообщения. Не могли бы вы уточнить часть кода rank=0 (для меня нормально, если rank=0 не выполняет никаких вычислений)? В частности, я не понимаю, что делает MPI_test в цикле и когда установлено значение found. Кроме того, я думаю, должен ли быть цикл вокруг обоих сегментов кода? - person j13r; 01.09.2017
comment
@ j13r j13r я обновил псевдокод, теперь он более понятен? - person Gilles Gouaillardet; 04.09.2017

Я решил это примерно с помощью следующего кода.

На данный момент это передает только curmax, но можно отправить другой массив со вторым тегом широковещательной рассылки +.

import numpy
import time

from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

import logging
logging.basicConfig(filename='mpitest%d.log' % rank,level=logging.DEBUG)
logFormatter = logging.Formatter("[%(name)s %(levelname)s]: %(message)s")
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
consoleHandler.setLevel(logging.INFO)
logging.getLogger().addHandler(consoleHandler)

log = logging.getLogger(__name__)

if rank == 0:
    curmax = numpy.random.random()
    seq = [curmax]
    log.info('%d broadcasting starting value %f...' % (rank, curmax))
    comm.Ibcast(numpy.array([curmax]))

    was_updated = False
    while True:
        # check if news available
        status = MPI.Status()
        a_avail = comm.iprobe(source=MPI.ANY_SOURCE, tag=12, status=status)
        if a_avail:
            sugg = comm.recv(source=status.Get_source(), tag=12)
            log.info('%d received new limit from %d: %s' % (rank, status.Get_source(), sugg))
            if sugg < curmax:
                curmax = sugg
                seq.append(curmax)
                log.info('%d updating to %s' % (rank, curmax))
                was_updated = True
            else:
                # ignore
                pass
        # check if next message is already waiting:
        if comm.iprobe(source=MPI.ANY_SOURCE, tag=12):
            # consume it first before broadcasting outdated info
            continue

        if was_updated:
            log.info('%d broadcasting new limit %f...' % (rank, curmax))
            comm.Ibcast(numpy.array([curmax]))
            was_updated = False
        else:
            # no message waiting for us and no broadcast done, so pause
            time.sleep(0.1)
        print

    print data, rank
else:
    log.info('%d waiting for root to send us starting value...' % (rank))
    nextmax = numpy.empty(1, dtype=float)
    comm.Ibcast(nextmax).Wait()

    amax = float(nextmax)
    numpy.random.seed(rank)
    update_req = comm.Ibcast(nextmax)
    while True:
        a = numpy.random.uniform()
        if a < amax:
            log.info('%d found new: %s, sending to root' % (rank, a))
            amax = a
            comm.isend(a, dest=0, tag=12)
        s = update_req.Get_status()
        #log.info('%d bcast status: %s' % (rank, s))
        if s:
            update_req.Wait()
            log.info('%d receiving new limit from root, %s' % (rank, nextmax))
            amax = float(nextmax)
            update_req = comm.Ibcast(nextmax)
person j13r    schedule 19.09.2017