В тази бележка ще приложим линейна регресия в стрийминг настройка. Предполагаме, че не разполагаме с всички налични данни и ще генерираме произволни данни за обучение в движение.

След това ще използваме Flink, за да коригираме теглата на линейната регресия, така че MSE да бъде сведен до минимум.

Така че ще използваме алгоритъм за градиентно спускане, за да коригираме теглата. Тъй като не разполагаме с всички налични данни веднага, ще трябва да изчакаме, докато пристигнат достатъчно данни. По принцип ще изчакаме пакет от данни да бъде готов за обработка. Концепцията за партида естествено се припокрива с прозорците за глобално броене на Flink.

Добре, така че една линейна функция се характеризира с вектор от тегла и евентуално отклонение. Тъй като програмата Flink ще работи непрекъснато известно време, докато не сме доволни от теглото си, трябва да ги съхраняваме някъде. Тук влиза в действие концепцията за състоянието на Flink.

Добре, да започваме. Ето кода на основната функция, който ще разопаковаме след секунда:

    private static final int problemDim = 1;
    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        SingleOutputStreamOperator<RegressionParams> dataStream = env
                .socketTextStream("localhost", 9999)
                .map(new Splitter(problemDim))
                .countWindowAll(8)
                .process(new RegressionWindowFunction(problemDim));
        dataStream.print();
        env.execute("Training");
    }

Добре, така че слушаме текстовия поток на сокета, след което обработваме получения низ с обект Splitter. След това формираме партида с размер 8 и накрая извикваме функция за процес върху това.

И така, какво е сплитерът? Неговата задача е да вземе необработения текстов низ и да го преобразува в нещо по-удобно за работа. Ще го преобразуваме в следния обект:

Така че в общи линии, пример за обучение. Така че с това обектът Splitter изглежда така:

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.ml.common.linalg.DenseVector;

import java.util.Arrays;

public class Splitter implements MapFunction<String, RegressionTrainingSample> {
    private final int dim;

    public Splitter(int dim) {
        this.dim = dim;
    }

    @Override
    public RegressionTrainingSample map(String s) {
        String[] sr = s.split(" ");
        double[] rawPointValues = new double[dim];
        int ct = 0;
        for (String token : Arrays.copyOfRange(sr, 0, dim)) {
            rawPointValues[ct] = Double.parseDouble(token);
        }

        return new RegressionTrainingSample(new DenseVector(rawPointValues), Double.parseDouble(sr[sr.length - 1]));
    }
}

След като се създаде партида от 8, се нуждаем от имплементация за функцията за обработка, която ще вземе партида от тренировъчни проби, изчислени градиенти на параметри и след това ще актуализира състоянието. На първо място, ето един клас, който представлява нашата линейна регресия:

import org.apache.flink.ml.common.linalg.DenseVector;

public class RegressionParams {
    private DenseVector weights;
    private Double bias;

    public RegressionParams(DenseVector weights, Double bias) {
        this.weights = weights;
        this.bias = bias;
    }

    public DenseVector getWeights() {
        return weights;
    }

    public Double getBias() {
        return bias;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder("y=");
        int ct = 0;
        for (Double w : weights.getData()) {
            sb.append(w + "x_" + ct);
            ct++;
        }
        sb.append("+" + bias);
        return sb.toString();
    }
}

Ето как изглежда сигнатурата на нашите процесни функции

private static class RegressionWindowFunction extends ProcessAllWindowFunction<RegressionTrainingSample, RegressionParams, GlobalWindow>

Така че функцията приема тренировъчна проба като вход и изхвърля нашите регресионни параметри. За да съхраним параметрите в състояние, се нуждаем от следната декларация

private final static ValueStateDescriptor<RegressionParams> params = new ValueStateDescriptor<>("params", RegressionParams.class);

И накрая, изпълнението на метода на процеса ще изглежда така

@Override
public void process(ProcessAllWindowFunction<RegressionTrainingSample, RegressionParams, GlobalWindow>.Context context, Iterable<RegressionTrainingSample> elements, Collector<RegressionParams> out) throws Exception {
    RegressionParams regressionParams = context.globalState().getState(params).value();
    if (Objects.isNull(regressionParams)) {
        RegressionParams freshParams = new RegressionParams(DenseVector.zeros(dim), 0.0);
        context.globalState().getState(params).update(freshParams);
        regressionParams = context.globalState().getState(params).value();
    }
    System.out.println("Starting batch processing");
    System.out.println("Regression params are " + regressionParams);
    RegressionParams accumulatedGradient = new RegressionParams(DenseVector.zeros(dim), 0.0);
    for (RegressionTrainingSample newDataPoint : elements) {
        System.out.println("Evaluating " + newDataPoint);
        Double response = calculateResponse(newDataPoint, regressionParams);
        System.out.println("Error is " + 0.5 * Math.pow(response - newDataPoint.getTarget(), 2));
        System.out.println("Response is " + response);
        RegressionParams thisPointGradient = calculateGradient(newDataPoint, response);
        accumulatedGradient = addRegressionParamsGradients(accumulatedGradient, thisPointGradient);
    }
    System.out.println("Gradient is " + accumulatedGradient);
    RegressionParams newParams = adjustRegressionParamsWithGradient(regressionParams, accumulatedGradient, 0.0001);
    context.globalState().getState(params).update(newParams);
    System.out.println("After update regression params are " + context.globalState().getState(params).value());
    out.collect(newParams);
}

Тук няма нищо прекалено сложно, просто итерация през целия прозорец и след това изчисляване на линейния отговор и градиента, като същевременно натрупва градиентните вектори в една променлива.

Последните две функции, които липсват в това, са изчислението на линейния отговор и изчислението на производната:

private Double calculateResponse(RegressionTrainingSample newDataPoint, RegressionParams regressionParams) {
    return newDataPoint.getFeatures().dot(regressionParams.getWeights()) + regressionParams.getBias();
}

private RegressionParams calculateGradient(RegressionTrainingSample newDataPoint, Double response) {
    DenseVector features = newDataPoint.getFeatures();
    int maxIndex = features.size();
    double[] rawGradientPoints = new double[maxIndex];
    double multiplicationBasis = response - newDataPoint.getTarget();
    for (int i = 0; i < maxIndex; i++) {
        rawGradientPoints[i] = multiplicationBasis * features.get(i);
    }
    Double biasGradient = multiplicationBasis;
    return new RegressionParams(new DenseVector(rawGradientPoints), biasGradient);
}

Можем да тестваме всичко това с проста програма за драйвер на python, която ще публикува някои шумни стойности на кривата y=x, като тази:

import socket
import time

# Create a socket object
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

# Get local machine name
host = "localhost"

# Reserve a port for your service.
port = 9999 

# Bind to the port
s.bind((host, port))

# Now wait for client connection.
s.listen(1)

print(f'Listening on {host}:{port}')
from random import randint, gauss
while True:
    # Establish connection with client.
    c, addr = s.accept()
    print(f'Got connection from {addr}')
    while True:
        x = randint(0,20)
        message = f'{x} {x+gauss(mu=0,sigma=1)}\n'
        print(message)
        c.send(message.encode('ascii'))
        time.sleep(0.01)

И изходът в конзолата на изпълнителя на Flink изглежда така:

Response is 1.0749612223522655
Evaluating 16.0 15.547041844810462
Error is 0.14912812721755458
Response is 16.093170267834854
Evaluating 19.0 17.992943134585456
Error is 0.6092633209379449
Response is 19.09681207693137
Evaluating 14.0 13.276572926577254
Error is 0.3314359617404347
Response is 14.090742395103844
Evaluating 15.0 13.441248518211866
Error is 1.362418142374651
Response is 15.09195633146935
Evaluating 9.0 7.045387703950424
Error is 2.0793416746306472
Response is 9.084672713276314
Evaluating 20.0 19.67354890099053
Error is 0.09009040943596827
Response is 20.098026013296877
Evaluating 1.0 3.2358480591650274
Error is 2.334715960755332
Response is 1.0749612223522655
Gradient is y=90.47661472652189x_0+4.3415897340513645
After update regression params are y=0.9921662748928537x_0+0.07331312701335459