Журнал разработчика
Я написал в своей первой статье о процедурах линейной регрессии, которую добавил для Neo4j. Сегодня я хочу объяснить некоторые внутренние компоненты и почему я решил построить их именно так.
Пользовательские процедуры должны запоминать информацию между вызовами, чтобы строить и поддерживать модель машинного обучения. Это выходит за рамки типичной функциональности процедур в Neo4j. В следующей статье я исследую ключевые детали реализации набора функций и процедур, которые создают, обучают, тестируют, используют и хранят линейные модели данных графа. Я надеюсь, что это поможет вам понять мою работу или реализовать подобные собственные процедуры, определяемые пользователем.
Чтобы увидеть эти определяемые пользователем функции и процедуры в действии, прочитайте мой предыдущий пост, в котором я использую линейную регрессию для создания предсказателя цен для краткосрочной аренды в Остине, штат Техас.
Цель
Я хочу выполнить линейную регрессию для данных, которые я сохранил в Neo4j. Существует множество библиотек с инструментами для создания линейных моделей (Pandas и Numpy для Python, Commons Math для Java и т. Д.), Но чтобы использовать их напрямую, я должен экспортировать данные из моего графика в другое программное обеспечение. Чтобы смоделировать мои данные из Neo4j, мне нужно расширить функциональность языка запросов к графам Cypher.
Предпочтительным средством расширения Cypher являются определяемые пользователем функции и процедуры. Они написаны на Java, созданы, например, с помощью Maven, развернуты в базе данных Neo4j как файлы JAR и вызываются из Cypher. Мне нужно написать процедуру, которая выполняет линейную регрессию для данных графа.
Что делает эту проблему интересной?
Во-первых, давайте взглянем на типичную процедуру в Neo4j: apoc.meta.graph
. Запуск CALL apoc.meta.graph()
вернет визуальное представление схемы графика (вашей модели данных). Например, вот результат вызова этой процедуры на графике краткосрочной аренды из моего предыдущего поста:
Эта процедура обращается к информации, хранящейся в графе, для определения базовой структуры. Другие процедуры изменяют граф, например apoc.refactor.mergeNodes
, который объединяет несколько узлов в один. Однако процедуры обычно не запоминают информацию между вызовами, они просто создают поток выходных данных или изменяют данные.
Я могу создать процедуру, которая, как apoc.meta.graph
, обращается к информации на графике, чтобы создать линейную модель без сохранения какой-либо внешней информации. Я могу передать все данные сразу, выполнить вычисления методом наименьших квадратов в процедуре и вернуть параметры линейной модели. Но если я сделаю еще один вызов процедуры, она уже забудет только что созданную модель.
Но что, если я решу добавить в модель больше данных? Что, если я хочу использовать большой объем обучающих данных, который требует слишком много памяти для ввода в качестве аргумента одной процедуры?
Идея №1: сериализовать!
Моя первая попытка решения - это своего рода обходной путь, потому что вместо запоминания информации процедура сохраняет модель на графике, чтобы к ней можно было получить доступ и обновить позже. Идея состоит в том, чтобы сериализовать объект Java модели и сохранить байтовый массив в графе между вызовами процедур. Вот визуальное представление процесса сериализации и десериализации:
Примечание. Я использую
SimpleRegression
из библиотеки Apache Commons Math.SimpleRegression
выполняет обновление вычислений по входящим точкам данных, так что отдельные точки данных не сохраняются. Вместо этого он сохраняет определенную информацию, такую как среднее значение y, общее количество точек данных и т. Д., И обновляет эти значения с каждой новой точкой данных. Таким образом, с каждой дополнительной точкой данных модель выполняет вычисления и улучшается без увеличения использования памяти. Результат: когда мы сериализуем объектSimpleRegression
, соответствующий массив байтов оказывается не очень большим (по крайней мере, он не масштабируется с размером набора данных!).
Сначала я написал следующие вспомогательные функции, чтобы на протяжении всего проекта я мог преобразовать объект SimpleRegression
в byte[]
и наоборот. Это потребовало импорта из java.io.*
.
//Serializes the object into a byte array for storage static byte[] convertToBytes(Object object) throws IOException { try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutput out = new ObjectOutputStream(bos)) { out.writeObject(object); return bos.toByteArray(); } } //de serializes the byte array and returns the stored object static Object convertFromBytes(byte[] bytes) throws IOException, ClassNotFoundException { try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInput in = new ObjectInputStream(bis)) { return in.readObject(); } }
Затем, когда я в следующий раз захотел отредактировать модель, я извлек байтовый массив из графа и десериализовал его.
try { ResourceIterator<Entity> n = db.execute("MATCH (n:LinReg {ID:$ID}) RETURN n", parameters).columnAs("n"); modelNode = n.next(); byte[] model = (byte[])modelNode.getProperty("serializedModel"); R = (SimpleRegression) convertFromBytes(model); } catch (Exception e) { throw new RuntimeException("no existing model for specified independent and dependent variables and model ID"); }
И после редактирования модели сохраните новое представление byte [] обратно в том же узле.
try { byte[] byteModel = convertToBytes(R); modelNode.setProperty("serializedModel", byteModel); } catch (IOException e) { throw new RuntimeException("something went wrong, model can't be linearized so new model not stored"); }
Если вам интересно, посмотрите полный код. Обратите внимание, что эти предварительные реализации сильно отличаются (и более запутаны!) От финальной версии моих процедур линейной регрессии, представленных ниже.
Проблемы
Что, если я хочу разделить create
модель, add
данные и remove
процедуры данных? Базы данных графов постоянно получают обновления, поэтому мне нужно создать модель, которая была бы такой же гибкой, как граф. Сериализация и десериализация требуют значительных временных затрат. Вы можете передать несколько точек данных одновременно, чтобы ограничить количество вызовов процедур (и количество раз, когда модель сохраняется и извлекается), но мне нужен лучший способ сохранить промежуточную модель между вызовами процедур, чтобы она могла обновляться столько раз, сколько мне нужно.
Идея №2: Статическая карта
Статические переменные в классах Java, реализующих процедуру, существуют, пока база данных продолжает работать. Следовательно, мы можем хранить объекты модели по имени на статической карте. Модели хранятся в процедуре, поэтому каждый шаг линейной регрессии - создание, добавление данных, удаление данных и т. Д. - изолирован в отдельную процедуру, но изменяет ту же SimpleRegression
модель. Что-то вроде add
процедуры может быть вызвано один раз для каждой точки данных без серьезных потерь производительности. Это создает желаемый упрощенный дизайн. С каждым изолированным шагом процедуры понятны, и пользователь имеет больший контроль над каждым шагом построения линейной модели.
Модели хранятся в статической ConcurrentHashMap в одном из классов Java, используемых для реализации процедур: LRModel.java
. Каждый раз, когда вызывается процедура, которой требуется доступ к модели, она извлекается из models
по имени с помощью метода from
. Использование этой конкретной реализации карты позволяет осуществлять одновременный доступ из нескольких потоков.
private static ConcurrentHashMap<String, LRModel> models = new ConcurrentHashMap<>(); static LRModel from(String name) { LRModel model = models.get(name); if (model != null) return model; throw new IllegalArgumentException("No valid LR-Model " + name); }
Теперь нам нужно только сериализовать и сохранить модель перед выключением базы данных и загрузить ее обратно в статическую память процедуры при перезапуске базы данных. Ознакомьтесь с полной реализацией на Github.
Ограничения
- Если база данных неожиданно завершает работу, статические переменные будут очищены, а модель потеряна. Было бы неплохо иметь вариант резервного копирования, в котором модель сериализуется и сохраняется через регулярные промежутки времени. После сбоя базы данных перезапустите базу данных и перестройте модель.
- Сериализация - не лучший способ сохранить модель, потому что, если что-либо изменится в следующей версии Commons Math, обновленная версия может не распознать сериализованный объект
SimpleRegression
из предыдущего. - Статистика для тестовых данных не сохраняется между выключением / перезапуском базы данных. В идеале я бы сам реализовал обновление простой регрессии вместо использования Commons Math, а затем сохранил бы всю необходимую информацию о данных обучения и тестирования вместо сохранения сериализованного
SimpleRegression
.
Улучшите мою работу!
Если у вас есть идеи более стабильные, чем хранение в статических картах и сериализация, дайте мне знать или реализуйте их сами! Я призываю вас улучшить мою работу. Я хотел бы обсудить возможности, просто напишите мне в LinkedIn или @ML_auren. Ваше здоровье!