В тази бележка ще приложим линейна регресия в стрийминг настройка. Предполагаме, че не разполагаме с всички налични данни и ще генерираме произволни данни за обучение в движение.
След това ще използваме Flink, за да коригираме теглата на линейната регресия, така че MSE да бъде сведен до минимум.
Така че ще използваме алгоритъм за градиентно спускане, за да коригираме теглата. Тъй като не разполагаме с всички налични данни веднага, ще трябва да изчакаме, докато пристигнат достатъчно данни. По принцип ще изчакаме пакет от данни да бъде готов за обработка. Концепцията за партида естествено се припокрива с прозорците за глобално броене на Flink.
Добре, така че една линейна функция се характеризира с вектор от тегла и евентуално отклонение. Тъй като програмата Flink ще работи непрекъснато известно време, докато не сме доволни от теглото си, трябва да ги съхраняваме някъде. Тук влиза в действие концепцията за състоянието на Flink.
Добре, да започваме. Ето кода на основната функция, който ще разопаковаме след секунда:
private static final int problemDim = 1; public static void main(String[] args) throws Exception { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); SingleOutputStreamOperator<RegressionParams> dataStream = env .socketTextStream("localhost", 9999) .map(new Splitter(problemDim)) .countWindowAll(8) .process(new RegressionWindowFunction(problemDim)); dataStream.print(); env.execute("Training"); }
Добре, така че слушаме текстовия поток на сокета, след което обработваме получения низ с обект Splitter. След това формираме партида с размер 8 и накрая извикваме функция за процес върху това.
И така, какво е сплитерът? Неговата задача е да вземе необработения текстов низ и да го преобразува в нещо по-удобно за работа. Ще го преобразуваме в следния обект:
Така че в общи линии, пример за обучение. Така че с това обектът Splitter изглежда така:
import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.common.linalg.DenseVector; import java.util.Arrays; public class Splitter implements MapFunction<String, RegressionTrainingSample> { private final int dim; public Splitter(int dim) { this.dim = dim; } @Override public RegressionTrainingSample map(String s) { String[] sr = s.split(" "); double[] rawPointValues = new double[dim]; int ct = 0; for (String token : Arrays.copyOfRange(sr, 0, dim)) { rawPointValues[ct] = Double.parseDouble(token); } return new RegressionTrainingSample(new DenseVector(rawPointValues), Double.parseDouble(sr[sr.length - 1])); } }
След като се създаде партида от 8, се нуждаем от имплементация за функцията за обработка, която ще вземе партида от тренировъчни проби, изчислени градиенти на параметри и след това ще актуализира състоянието. На първо място, ето един клас, който представлява нашата линейна регресия:
import org.apache.flink.ml.common.linalg.DenseVector; public class RegressionParams { private DenseVector weights; private Double bias; public RegressionParams(DenseVector weights, Double bias) { this.weights = weights; this.bias = bias; } public DenseVector getWeights() { return weights; } public Double getBias() { return bias; } @Override public String toString() { StringBuilder sb = new StringBuilder("y="); int ct = 0; for (Double w : weights.getData()) { sb.append(w + "x_" + ct); ct++; } sb.append("+" + bias); return sb.toString(); } }
Ето как изглежда сигнатурата на нашите процесни функции
private static class RegressionWindowFunction extends ProcessAllWindowFunction<RegressionTrainingSample, RegressionParams, GlobalWindow>
Така че функцията приема тренировъчна проба като вход и изхвърля нашите регресионни параметри. За да съхраним параметрите в състояние, се нуждаем от следната декларация
private final static ValueStateDescriptor<RegressionParams> params = new ValueStateDescriptor<>("params", RegressionParams.class);
И накрая, изпълнението на метода на процеса ще изглежда така
@Override public void process(ProcessAllWindowFunction<RegressionTrainingSample, RegressionParams, GlobalWindow>.Context context, Iterable<RegressionTrainingSample> elements, Collector<RegressionParams> out) throws Exception { RegressionParams regressionParams = context.globalState().getState(params).value(); if (Objects.isNull(regressionParams)) { RegressionParams freshParams = new RegressionParams(DenseVector.zeros(dim), 0.0); context.globalState().getState(params).update(freshParams); regressionParams = context.globalState().getState(params).value(); } System.out.println("Starting batch processing"); System.out.println("Regression params are " + regressionParams); RegressionParams accumulatedGradient = new RegressionParams(DenseVector.zeros(dim), 0.0); for (RegressionTrainingSample newDataPoint : elements) { System.out.println("Evaluating " + newDataPoint); Double response = calculateResponse(newDataPoint, regressionParams); System.out.println("Error is " + 0.5 * Math.pow(response - newDataPoint.getTarget(), 2)); System.out.println("Response is " + response); RegressionParams thisPointGradient = calculateGradient(newDataPoint, response); accumulatedGradient = addRegressionParamsGradients(accumulatedGradient, thisPointGradient); } System.out.println("Gradient is " + accumulatedGradient); RegressionParams newParams = adjustRegressionParamsWithGradient(regressionParams, accumulatedGradient, 0.0001); context.globalState().getState(params).update(newParams); System.out.println("After update regression params are " + context.globalState().getState(params).value()); out.collect(newParams); }
Тук няма нищо прекалено сложно, просто итерация през целия прозорец и след това изчисляване на линейния отговор и градиента, като същевременно натрупва градиентните вектори в една променлива.
Последните две функции, които липсват в това, са изчислението на линейния отговор и изчислението на производната:
private Double calculateResponse(RegressionTrainingSample newDataPoint, RegressionParams regressionParams) { return newDataPoint.getFeatures().dot(regressionParams.getWeights()) + regressionParams.getBias(); } private RegressionParams calculateGradient(RegressionTrainingSample newDataPoint, Double response) { DenseVector features = newDataPoint.getFeatures(); int maxIndex = features.size(); double[] rawGradientPoints = new double[maxIndex]; double multiplicationBasis = response - newDataPoint.getTarget(); for (int i = 0; i < maxIndex; i++) { rawGradientPoints[i] = multiplicationBasis * features.get(i); } Double biasGradient = multiplicationBasis; return new RegressionParams(new DenseVector(rawGradientPoints), biasGradient); }
Можем да тестваме всичко това с проста програма за драйвер на python, която ще публикува някои шумни стойности на кривата y=x, като тази:
import socket import time # Create a socket object s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Get local machine name host = "localhost" # Reserve a port for your service. port = 9999 # Bind to the port s.bind((host, port)) # Now wait for client connection. s.listen(1) print(f'Listening on {host}:{port}') from random import randint, gauss while True: # Establish connection with client. c, addr = s.accept() print(f'Got connection from {addr}') while True: x = randint(0,20) message = f'{x} {x+gauss(mu=0,sigma=1)}\n' print(message) c.send(message.encode('ascii')) time.sleep(0.01)
И изходът в конзолата на изпълнителя на Flink изглежда така:
Response is 1.0749612223522655 Evaluating 16.0 15.547041844810462 Error is 0.14912812721755458 Response is 16.093170267834854 Evaluating 19.0 17.992943134585456 Error is 0.6092633209379449 Response is 19.09681207693137 Evaluating 14.0 13.276572926577254 Error is 0.3314359617404347 Response is 14.090742395103844 Evaluating 15.0 13.441248518211866 Error is 1.362418142374651 Response is 15.09195633146935 Evaluating 9.0 7.045387703950424 Error is 2.0793416746306472 Response is 9.084672713276314 Evaluating 20.0 19.67354890099053 Error is 0.09009040943596827 Response is 20.098026013296877 Evaluating 1.0 3.2358480591650274 Error is 2.334715960755332 Response is 1.0749612223522655 Gradient is y=90.47661472652189x_0+4.3415897340513645 After update regression params are y=0.9921662748928537x_0+0.07331312701335459