В этой заметке мы собираемся реализовать линейную регрессию в настройке потоковой передачи. Мы предполагаем, что у нас нет всех доступных данных, и мы будем генерировать случайные данные для обучения на лету.
Затем мы будем использовать Flink для корректировки весов линейной регрессии, чтобы минимизировать MSE.
Поэтому мы будем использовать алгоритм градиентного спуска для настройки весов. Поскольку у нас нет всех доступных данных сразу, нам придется подождать, пока не поступит достаточно данных. По сути, мы будем ждать, пока пакет данных будет готов к обработке. Концепция пакетной обработки естественным образом пересекается с окнами глобального подсчета Flink.
Итак, линейная функция характеризуется вектором весов и, возможно, смещением. Поскольку программа Flink будет работать некоторое время непрерывно, пока мы не будем довольны нашими весами, нам нужно их где-то хранить. Здесь в игру вступает концепция состояния Флинка.
Хорошо, начнем. Вот код основной функции, которую мы распакуем через секунду:
private static final int problemDim = 1; public static void main(String[] args) throws Exception { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); SingleOutputStreamOperator<RegressionParams> dataStream = env .socketTextStream("localhost", 9999) .map(new Splitter(problemDim)) .countWindowAll(8) .process(new RegressionWindowFunction(problemDim)); dataStream.print(); env.execute("Training"); }
Итак, мы прослушиваем текстовый поток сокета, затем обрабатываем полученную строку с помощью объекта Splitter. После этого мы формируем пакет размером 8 и, наконец, вызываем функцию процесса поверх него.
Так что же такое сплиттер? Его задача — взять необработанную текстовую строку и преобразовать ее во что-то более удобное для работы. Мы будем конвертировать его в следующий объект:
В общем, обучающая выборка. Таким образом, объект Splitter выглядит следующим образом:
import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.common.linalg.DenseVector; import java.util.Arrays; public class Splitter implements MapFunction<String, RegressionTrainingSample> { private final int dim; public Splitter(int dim) { this.dim = dim; } @Override public RegressionTrainingSample map(String s) { String[] sr = s.split(" "); double[] rawPointValues = new double[dim]; int ct = 0; for (String token : Arrays.copyOfRange(sr, 0, dim)) { rawPointValues[ct] = Double.parseDouble(token); } return new RegressionTrainingSample(new DenseVector(rawPointValues), Double.parseDouble(sr[sr.length - 1])); } }
После того, как пакет из 8 создан, нам нужна реализация функции обработки, которая возьмет пакет обучающих выборок, вычислит градиенты параметров, а затем обновит состояние. Прежде всего, вот класс, который представляет нашу линейную регрессию:
import org.apache.flink.ml.common.linalg.DenseVector; public class RegressionParams { private DenseVector weights; private Double bias; public RegressionParams(DenseVector weights, Double bias) { this.weights = weights; this.bias = bias; } public DenseVector getWeights() { return weights; } public Double getBias() { return bias; } @Override public String toString() { StringBuilder sb = new StringBuilder("y="); int ct = 0; for (Double w : weights.getData()) { sb.append(w + "x_" + ct); ct++; } sb.append("+" + bias); return sb.toString(); } }
Вот как выглядит сигнатура наших функций процесса
private static class RegressionWindowFunction extends ProcessAllWindowFunction<RegressionTrainingSample, RegressionParams, GlobalWindow>
Таким образом, функция принимает обучающую выборку в качестве входных данных и выдает наши параметры регрессии. Чтобы сохранить параметры в состоянии, нам нужно следующее объявление
private final static ValueStateDescriptor<RegressionParams> params = new ValueStateDescriptor<>("params", RegressionParams.class);
Наконец, реализация метода процесса будет выглядеть так
@Override public void process(ProcessAllWindowFunction<RegressionTrainingSample, RegressionParams, GlobalWindow>.Context context, Iterable<RegressionTrainingSample> elements, Collector<RegressionParams> out) throws Exception { RegressionParams regressionParams = context.globalState().getState(params).value(); if (Objects.isNull(regressionParams)) { RegressionParams freshParams = new RegressionParams(DenseVector.zeros(dim), 0.0); context.globalState().getState(params).update(freshParams); regressionParams = context.globalState().getState(params).value(); } System.out.println("Starting batch processing"); System.out.println("Regression params are " + regressionParams); RegressionParams accumulatedGradient = new RegressionParams(DenseVector.zeros(dim), 0.0); for (RegressionTrainingSample newDataPoint : elements) { System.out.println("Evaluating " + newDataPoint); Double response = calculateResponse(newDataPoint, regressionParams); System.out.println("Error is " + 0.5 * Math.pow(response - newDataPoint.getTarget(), 2)); System.out.println("Response is " + response); RegressionParams thisPointGradient = calculateGradient(newDataPoint, response); accumulatedGradient = addRegressionParamsGradients(accumulatedGradient, thisPointGradient); } System.out.println("Gradient is " + accumulatedGradient); RegressionParams newParams = adjustRegressionParamsWithGradient(regressionParams, accumulatedGradient, 0.0001); context.globalState().getState(params).update(newParams); System.out.println("After update regression params are " + context.globalState().getState(params).value()); out.collect(newParams); }
Здесь нет ничего слишком сложного, просто итерация по всему окну, а затем вычисление линейного отклика и градиента при накоплении векторов градиента в одну переменную.
Последние две функции, отсутствующие в нем, - это расчет линейного отклика и расчет производной:
private Double calculateResponse(RegressionTrainingSample newDataPoint, RegressionParams regressionParams) { return newDataPoint.getFeatures().dot(regressionParams.getWeights()) + regressionParams.getBias(); } private RegressionParams calculateGradient(RegressionTrainingSample newDataPoint, Double response) { DenseVector features = newDataPoint.getFeatures(); int maxIndex = features.size(); double[] rawGradientPoints = new double[maxIndex]; double multiplicationBasis = response - newDataPoint.getTarget(); for (int i = 0; i < maxIndex; i++) { rawGradientPoints[i] = multiplicationBasis * features.get(i); } Double biasGradient = multiplicationBasis; return new RegressionParams(new DenseVector(rawGradientPoints), biasGradient); }
Мы можем протестировать все это с помощью простой программы-драйвера Python, которая опубликует некоторые зашумленные значения на кривой y=x, например:
import socket import time # Create a socket object s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Get local machine name host = "localhost" # Reserve a port for your service. port = 9999 # Bind to the port s.bind((host, port)) # Now wait for client connection. s.listen(1) print(f'Listening on {host}:{port}') from random import randint, gauss while True: # Establish connection with client. c, addr = s.accept() print(f'Got connection from {addr}') while True: x = randint(0,20) message = f'{x} {x+gauss(mu=0,sigma=1)}\n' print(message) c.send(message.encode('ascii')) time.sleep(0.01)
А вывод в консоли исполнителя Flink выглядит так:
Response is 1.0749612223522655 Evaluating 16.0 15.547041844810462 Error is 0.14912812721755458 Response is 16.093170267834854 Evaluating 19.0 17.992943134585456 Error is 0.6092633209379449 Response is 19.09681207693137 Evaluating 14.0 13.276572926577254 Error is 0.3314359617404347 Response is 14.090742395103844 Evaluating 15.0 13.441248518211866 Error is 1.362418142374651 Response is 15.09195633146935 Evaluating 9.0 7.045387703950424 Error is 2.0793416746306472 Response is 9.084672713276314 Evaluating 20.0 19.67354890099053 Error is 0.09009040943596827 Response is 20.098026013296877 Evaluating 1.0 3.2358480591650274 Error is 2.334715960755332 Response is 1.0749612223522655 Gradient is y=90.47661472652189x_0+4.3415897340513645 After update regression params are y=0.9921662748928537x_0+0.07331312701335459