В этой заметке мы собираемся реализовать линейную регрессию в настройке потоковой передачи. Мы предполагаем, что у нас нет всех доступных данных, и мы будем генерировать случайные данные для обучения на лету.

Затем мы будем использовать Flink для корректировки весов линейной регрессии, чтобы минимизировать MSE.

Поэтому мы будем использовать алгоритм градиентного спуска для настройки весов. Поскольку у нас нет всех доступных данных сразу, нам придется подождать, пока не поступит достаточно данных. По сути, мы будем ждать, пока пакет данных будет готов к обработке. Концепция пакетной обработки естественным образом пересекается с окнами глобального подсчета Flink.

Итак, линейная функция характеризуется вектором весов и, возможно, смещением. Поскольку программа Flink будет работать некоторое время непрерывно, пока мы не будем довольны нашими весами, нам нужно их где-то хранить. Здесь в игру вступает концепция состояния Флинка.

Хорошо, начнем. Вот код основной функции, которую мы распакуем через секунду:

    private static final int problemDim = 1;
    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        SingleOutputStreamOperator<RegressionParams> dataStream = env
                .socketTextStream("localhost", 9999)
                .map(new Splitter(problemDim))
                .countWindowAll(8)
                .process(new RegressionWindowFunction(problemDim));
        dataStream.print();
        env.execute("Training");
    }

Итак, мы прослушиваем текстовый поток сокета, затем обрабатываем полученную строку с помощью объекта Splitter. После этого мы формируем пакет размером 8 и, наконец, вызываем функцию процесса поверх него.

Так что же такое сплиттер? Его задача — взять необработанную текстовую строку и преобразовать ее во что-то более удобное для работы. Мы будем конвертировать его в следующий объект:

В общем, обучающая выборка. Таким образом, объект Splitter выглядит следующим образом:

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.ml.common.linalg.DenseVector;

import java.util.Arrays;

public class Splitter implements MapFunction<String, RegressionTrainingSample> {
    private final int dim;

    public Splitter(int dim) {
        this.dim = dim;
    }

    @Override
    public RegressionTrainingSample map(String s) {
        String[] sr = s.split(" ");
        double[] rawPointValues = new double[dim];
        int ct = 0;
        for (String token : Arrays.copyOfRange(sr, 0, dim)) {
            rawPointValues[ct] = Double.parseDouble(token);
        }

        return new RegressionTrainingSample(new DenseVector(rawPointValues), Double.parseDouble(sr[sr.length - 1]));
    }
}

После того, как пакет из 8 создан, нам нужна реализация функции обработки, которая возьмет пакет обучающих выборок, вычислит градиенты параметров, а затем обновит состояние. Прежде всего, вот класс, который представляет нашу линейную регрессию:

import org.apache.flink.ml.common.linalg.DenseVector;

public class RegressionParams {
    private DenseVector weights;
    private Double bias;

    public RegressionParams(DenseVector weights, Double bias) {
        this.weights = weights;
        this.bias = bias;
    }

    public DenseVector getWeights() {
        return weights;
    }

    public Double getBias() {
        return bias;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder("y=");
        int ct = 0;
        for (Double w : weights.getData()) {
            sb.append(w + "x_" + ct);
            ct++;
        }
        sb.append("+" + bias);
        return sb.toString();
    }
}

Вот как выглядит сигнатура наших функций процесса

private static class RegressionWindowFunction extends ProcessAllWindowFunction<RegressionTrainingSample, RegressionParams, GlobalWindow>

Таким образом, функция принимает обучающую выборку в качестве входных данных и выдает наши параметры регрессии. Чтобы сохранить параметры в состоянии, нам нужно следующее объявление

private final static ValueStateDescriptor<RegressionParams> params = new ValueStateDescriptor<>("params", RegressionParams.class);

Наконец, реализация метода процесса будет выглядеть так

@Override
public void process(ProcessAllWindowFunction<RegressionTrainingSample, RegressionParams, GlobalWindow>.Context context, Iterable<RegressionTrainingSample> elements, Collector<RegressionParams> out) throws Exception {
    RegressionParams regressionParams = context.globalState().getState(params).value();
    if (Objects.isNull(regressionParams)) {
        RegressionParams freshParams = new RegressionParams(DenseVector.zeros(dim), 0.0);
        context.globalState().getState(params).update(freshParams);
        regressionParams = context.globalState().getState(params).value();
    }
    System.out.println("Starting batch processing");
    System.out.println("Regression params are " + regressionParams);
    RegressionParams accumulatedGradient = new RegressionParams(DenseVector.zeros(dim), 0.0);
    for (RegressionTrainingSample newDataPoint : elements) {
        System.out.println("Evaluating " + newDataPoint);
        Double response = calculateResponse(newDataPoint, regressionParams);
        System.out.println("Error is " + 0.5 * Math.pow(response - newDataPoint.getTarget(), 2));
        System.out.println("Response is " + response);
        RegressionParams thisPointGradient = calculateGradient(newDataPoint, response);
        accumulatedGradient = addRegressionParamsGradients(accumulatedGradient, thisPointGradient);
    }
    System.out.println("Gradient is " + accumulatedGradient);
    RegressionParams newParams = adjustRegressionParamsWithGradient(regressionParams, accumulatedGradient, 0.0001);
    context.globalState().getState(params).update(newParams);
    System.out.println("After update regression params are " + context.globalState().getState(params).value());
    out.collect(newParams);
}

Здесь нет ничего слишком сложного, просто итерация по всему окну, а затем вычисление линейного отклика и градиента при накоплении векторов градиента в одну переменную.

Последние две функции, отсутствующие в нем, - это расчет линейного отклика и расчет производной:

private Double calculateResponse(RegressionTrainingSample newDataPoint, RegressionParams regressionParams) {
    return newDataPoint.getFeatures().dot(regressionParams.getWeights()) + regressionParams.getBias();
}

private RegressionParams calculateGradient(RegressionTrainingSample newDataPoint, Double response) {
    DenseVector features = newDataPoint.getFeatures();
    int maxIndex = features.size();
    double[] rawGradientPoints = new double[maxIndex];
    double multiplicationBasis = response - newDataPoint.getTarget();
    for (int i = 0; i < maxIndex; i++) {
        rawGradientPoints[i] = multiplicationBasis * features.get(i);
    }
    Double biasGradient = multiplicationBasis;
    return new RegressionParams(new DenseVector(rawGradientPoints), biasGradient);
}

Мы можем протестировать все это с помощью простой программы-драйвера Python, которая опубликует некоторые зашумленные значения на кривой y=x, например:

import socket
import time

# Create a socket object
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

# Get local machine name
host = "localhost"

# Reserve a port for your service.
port = 9999 

# Bind to the port
s.bind((host, port))

# Now wait for client connection.
s.listen(1)

print(f'Listening on {host}:{port}')
from random import randint, gauss
while True:
    # Establish connection with client.
    c, addr = s.accept()
    print(f'Got connection from {addr}')
    while True:
        x = randint(0,20)
        message = f'{x} {x+gauss(mu=0,sigma=1)}\n'
        print(message)
        c.send(message.encode('ascii'))
        time.sleep(0.01)

А вывод в консоли исполнителя Flink выглядит так:

Response is 1.0749612223522655
Evaluating 16.0 15.547041844810462
Error is 0.14912812721755458
Response is 16.093170267834854
Evaluating 19.0 17.992943134585456
Error is 0.6092633209379449
Response is 19.09681207693137
Evaluating 14.0 13.276572926577254
Error is 0.3314359617404347
Response is 14.090742395103844
Evaluating 15.0 13.441248518211866
Error is 1.362418142374651
Response is 15.09195633146935
Evaluating 9.0 7.045387703950424
Error is 2.0793416746306472
Response is 9.084672713276314
Evaluating 20.0 19.67354890099053
Error is 0.09009040943596827
Response is 20.098026013296877
Evaluating 1.0 3.2358480591650274
Error is 2.334715960755332
Response is 1.0749612223522655
Gradient is y=90.47661472652189x_0+4.3415897340513645
After update regression params are y=0.9921662748928537x_0+0.07331312701335459