Комплексное руководство по созданию конвейера прогнозирования оттока с использованием только Apache Spark.

Эта статья представляет собой руководство по созданию классификатора прогнозирования оттока с использованием стека машинного обучения из Spark.

Мы будем использовать данные от вымышленной компании под названием Sparkify, занимающейся потоковой передачей музыки. Набор данных содержит все виды событий, созданных пользователями, которые взаимодействовали с платформой.

Мы более подробно рассмотрим данные во время урока. Но я хочу поблагодарить Udacity за обнародование этих данных. Без них я не смог бы сделать этот урок.

Здесь вы можете скачать мини-версию набора данных и здесь весь набор данных.

Что такое отток клиентов?

Отток клиентов — это когда кто-то решает прекратить использование ваших продуктов или услуг.

Чем полезна такая модель?

Построив такую ​​модель, мы сможем понять, почему пользователи уходят из компании. Таким образом, это один из способов улучшить качество обслуживания клиентов и их удержание на основе данных об активности пользователей.

Оглавление

  1. Определение сеанса Spark
  2. Загрузите набор данных
  3. Очистить набор данных
  4. Определите метку оттока клиентов
  5. Исследовательский анализ данных
  6. Разработка функций
  7. Моделирование
  8. Полученные результаты
  9. Заключение

Примечание. Мы выполнили весь код в одном блокноте.

1️⃣ Определите сеанс Spark

spark = SparkSession.\
    builder.\
    appName("Sparkify Churn Prediction").\
    getOrCreate()

Мы создали искровой сеанс под названием Прогнозирование оттока Sparkify. С помощью этого сеанса мы можем загрузить и запустить все вычисления.

2️⃣ Загрузите набор данных

EVENT_DATA_LINK = "mini_sparkify_event_data.json"
df = spark.read.json(EVENT_DATA_LINK)
df.persist()

Мы загрузили файл mini_sparkify_event_data.json в тот же каталог, что и блокнот, и загрузили его в память с помощью сеанса Spark.

df.printSchema()

|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

Наиболее важные столбцы, используемые в этой статье, следующие:

  • исполнитель: исполнитель воспроизводимой в данный момент песни.
  • уровень. Категориальная переменная, которая может быть платной или бесплатной.
  • страница: местоположение пользователя в приложении (например, страница входа,
  • registration: метка времени регистрации в формате UTC.
  • sessionId: идентификатор текущего сеанса пользователя.
  • песня: текущая воспроизводимая песня.
  • status: HTTP-статус события. (например, 200, 307, 404).
  • ts: временная метка события в формате UTC.
  • userId: идентификатор пользователя.

3️⃣ Очистить набор данных

В настоящем Блокноте, доступ к которому можно получить здесь, мы сделали более широкий обзор того, что нужно чистить, но для компактности мы покажем только самое важное.

Проверить наличие NaN/пустых значений

df.select([
    F.count(F.when(F.isnull(c), c)).alias(c) for c in df.columns
]).show()

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId| song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+
| 58392|   0|     8346|  8346|            0|    8346| 58392|    0|    8346|     0|   0|        8346|        0|58392|     0|  0|     8346|     0|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+

Большинство пустых значений в наборе данных помечаются как Нет.

df.select([
     F.count(F.when(F.col(c) == "", c)).alias(c) for c in df.columns
]).show()

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId|song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|     0|   0|        0|     0|            0|       0|     0|    0|       0|     0|   0|           0|        0|   0|     0|  0|        0|  8346|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+

Существует исключение в столбце userId, где пустые значения обозначаются пустой строкой.

Незарегистрированные пользователи

Пользователи с пустой строкой являются незарегистрированными пользователями; поэтому платформа не могла присвоить им какой-либо идентификатор.

Проверить распределение страниц незарегистрированных пользователей

df.filter(F.col("userId") == "").select("page").groupby("page").count().show()

+-------------------+-----+
|               page|count|
+-------------------+-----+
|               Home| 4375|
|              About|  429|
|              Login| 3241|
|               Help|  272|
|              Error|    6|
|           Register|   18|
|Submit Registration|    5|
+-------------------+-----+

Использование незарегистрированных пользовательских данных не может помочь нам в прогнозировании оттока клиентов. Поэтому мы удалим эти строки из фрейма данных Spark.

Удалить незарегистрированных пользователей

cleaned_df = df.filter(F.col("userId") != "")

Пустые строки исполнителей

df.filter(
    F.isnull(F.col("artist"))
).select(
    ["artist", "song", "userId", "page"]
).groupby("page").count().show()

+--------------------+-----+
|                page|count|
+--------------------+-----+
|              Cancel|   52|
|    Submit Downgrade|   63|
|         Thumbs Down| 2546|
|                Home|14457|
|           Downgrade| 2055|
|         Roll Advert| 3933|
|              Logout| 3226|
|       Save Settings|  310|
|Cancellation Conf...|   52|
|               About|  924|
|            Settings| 1514|
|               Login| 3241|
|     Add to Playlist| 6526|
|          Add Friend| 4277|
|           Thumbs Up|12551|
|                Help| 1726|
|             Upgrade|  499|
|               Error|  258|
|      Submit Upgrade|  159|
|            Register|   18|
+--------------------+-----+

Когда исполнитель пуст, пользователи проводят больше времени на других страницах, чем слушая музыку. В данном случае такая информация ценна для понимания поведения зарегистрированных пользователей.

Заполнить пустые значения

cleaned_df = cleaned_df.fillna({
    "length": 0,
    "artist": "unknown",
    "song": "unknown"
})

В качестве последнего шага мы хотим заполнить строки значениями None. На данный момент пустые строки появляются только для полей, связанных с песней, когда пользователь ничего не слушает. Информация из этих строк ценна, и мы не рекомендуем их удалять.

Поэтому мы устанавливаем для длины песни значение 0, а для исполнителя и песни значение неизвестно.

4️⃣ Определите метку оттока клиентов

cleaned_df.select(
    F.count(
        F.when(F.col("page") == "Cancellation Confirmation", "page")
    ).alias("Cancellation Confirmation")
).show()

+-------------------------+
|Cancellation Confirmation|
+-------------------------+
|                       52|
+-------------------------+

Мы будем считать, что пользователь "удален", когда подписка отменена.

Примечание. Если бы мы хотели перейти на следующий уровень, мы могли бы также использовать событие Downgrade, чтобы сигнализировать об уходе клиентов.

labeled_df = cleaned_df.withColumn(
    "churnEvent", 
    F.when(
       F.col("page") == "Cancellation Confirmation",
       1
    ).otherwise(0)
)

В качестве промежуточного шага мы создадим столбец churnEvent, который представляет собой флаг, проверяющий наличие события «Подтверждение отмены». Мы разработали столбец churnEvent на уровне события.

user.labeled_df = labeled_df.withColumn(
    "churn", 
    F.sum("churnEvent").over(Window.partitionBy("userId"))
)
labeled_df = labeled_df.withColumn(
    "churn", 
    F.when(F.col("churn") >= 1, 1).otherwise(0)
)

Мы агрегировали флаг churnEvent на уровне пользователя и обрезали его до 0 или 1. Таким образом, пользователь имеет либо churn = 1, либо no-churn = 0флажков для всех событий.

Помните, в конце концов, мы хотим создать классификатор, который прогнозирует отток пользователей, а не отток событий. Вот почему пользователь имеет либо отток = 1, либо отсутствие оттока = 0.

5️⃣ Исследовательский анализ данных

Примечание. Чтобы статья была короткой, мы покажем только те графики, которые помогли нам разработать ценные функции. Вы можете увидеть все графики и как они построены здесь.

Распределение оттока пользователей

Распределение сильно перекошено в сторону пользователей без оттока. Мы будем использовать оценку F1, чтобы найти лучшую модель на этапе перекрестной проверки с учетом этого фактора.

Возможные страницы

eda_df.select("page").distinct().show()

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
|            Settings|
|     Add to Playlist|
|          Add Friend|
|            NextSong|
|           Thumbs Up|
|                Help|
|             Upgrade|
|               Error|
|      Submit Upgrade|
+--------------------+
eda_df.filter(
    F.col("artist") != "unknown"
).select("page").distinct().show()
+--------+
|    page|
+--------+
|NextSong|
+--------+

Единственная страница, на которой воспроизводится музыка, называется NextSong.

Распределение среднего количества посещенных страниц каждым пользователем

Пользователи, оставшиеся на платформе, в среднем посещали больше страниц.

Распределение среднего количества песен, прослушанных каждым пользователем

В среднем пользователи, оставшиеся на платформе, слушают больше песен.

Распределение среднего количества исполнителей, прослушанных каждым пользователем

В среднем пользователи, которые остаются на платформе, слушают более широкий спектр исполнителей.

Дельта времени с момента регистрации

Дельта времени = количество секунд, прошедших с момента регистрации пользователя на платформе.

Похоже, что дельта времени с момента регистрации для оттока пользователей смещена вправо. В то же время распределение без оттока имеет нормальное распределение. Таким образом, среднее значение временной дельты является хорошим предиктором.

6️⃣ Разработка функций

Классификация оттока будет выполняться на уровне пользователя. Поэтому нам нужно агрегировать данные для каждого пользователя в одной строке.

Зачем агрегировать данные на уровне пользователя?

Окончательные прогнозы следует использовать для прогнозирования того, склонен ли конкретный пользователь покинуть компанию, а не для того, чтобы пользователь отписался. Таким образом, единичное событие для нас бесполезно, но его совокупность чрезвычайно ценна.

Определите некоторые вспомогательные функции:

def count_with_condition(condition):
    """Utility function to count only specific rows based on the 'condition'."""
    return F.count(F.when(condition, True))

def count_distinct_with_condition(condition, values):
    """Utility function to count only distinct & specific rows based on the 'condition'."""
    return F.count_distinct(F.when(condition, values))

Создание фрейма данных Spark с функциональной инженерией

В качестве резюме мы будем использовать функции, представленные на этапе EDA. В Блокноте мы исследовали больше потенциальных функций, но эти четыре показали наибольшую прогностическую силу между пользователями оттока и не оттока:

  • Общее количество посещенных страниц.
  • Общее количество воспроизведенных песен.
  • Общее количество всего исполнителей.
  • Временная метка с момента регистрации (в секундах).
aggregated_df = labeled_df.groupby("userId").agg(
    F.count("page").alias("pages"),
    count_with_condition(
        F.col("page") == "NextSong"
    ).alias("plays"),
    count_distinct_with_condition(
        F.col("artist") != "unknown", F.col("artist")
    ).alias("artists"),
    F.max(F.col("ts") - F.col("registration")).alias("delta"),
    F.max("churn").alias("churn")
)
aggregated_df.show(n=5)

+------+-----+-----+-------+-----------+-----+
|userId|pages|plays|artists|      delta|churn|
+------+-----+-----+-------+-----------+-----+
|100010|  381|  275|    252| 4807612000|    0|
|100014|  310|  257|    233| 7351206000|    1|
|100021|  319|  230|    207| 5593438000|    1|
|   101| 2149| 1797|   1241| 4662657000|    1|
|    11|  848|  647|    534|10754921000|    0|
+------+-----+-----+-------+-----------+-----+

Сопоставьте DataFrame с вектором искры

assembler = VectorAssembler(
    inputCols=[
        "pages", "plays", "artists", "delta"
], outputCol="unscaled_features")
engineered_df = assembler.transform(aggregated_df)
engineered_df = engineered_df.select(
    F.col("unscaled_features"),
    F.col("churn").alias("label")
)

Spark Models ожидает на входе Spark Vector. Кроме того, по умолчанию они ожидают, что входные объекты будут находиться в столбце с именем функции, а целевые — в столбце с именем >метка.

7️⃣ Моделирование

Мы будем обучать и тестировать три модели:

  1. Логистическая регрессия
  2. Наивный Байес
  3. Дерево повышения градиента

Мы будем использовать перекрестную проверку с тремя сгибами, чтобы найти лучшие гиперпараметры.

Мы будем использовать 80 % данных в разбивке поезда и 20 % в тестовой разбивке.

Мы нормализовали функции с помощью StandardScaler.

Поскольку метки сильно несбалансированы, мы будем использовать оценку F1 для оценки моделей. Метрика Оценка F1 работает под капотом точность и отзыв, которые учитывают проблему несбалансированного распределения.

Определите некоторые вспомогательные функции:

Пропустите их, если вас интересуют только результаты.

Синтаксис Spark ML очень похож на тот, что используется в Sklearn.

def run(pipeline, paramGrid, train_df, test_df):
    """
    Main function used to train & test a given model.
    The training step uses cross-validation to find the best hyper-parameters for the model.
    :param pipeline: Model pipeline.
    :param paramGrid: Parameter grid used for cross-validation.
    :param train_df: Training dataframe.
    :param test_df: Testing dataframe.
    :return: the best model from cross-validation.
    """
    fitted_model = fit_model(paramGrid, pipeline, train_df)
    evaluate_model(fitted_model, test_df)
    return fitted_model

def fit_model(paramGrid, pipeline, train_df):
    """
    Function that trains the model using cross-validation.
    Also, it prints the best validation results and hyper-parameters.
    :param paramGrid: Parameter grid used for cross-validation.
    :param pipeline: Model pipeline.
    :param train_df: Training dataframe.
    :return: the best model from cross-validation.
    """
    crossval = CrossValidator(
        estimator=pipeline,
        estimatorParamMaps=paramGrid,
        evaluator = MulticlassClassificationEvaluator(
              metricName="f1", 
              beta=1.0
        ),
        parallelism=3,
        numFolds=3
    )
    fitted_model = crossval.fit(train_df)
    print_best_validation_score(fitted_model)
    print_best_parameters(fitted_model)
    return fitted_model

def create_pipeline(model):
    """
    Create a pipeline based on a model.
    :param model: The end model that will be used for training.
    :return: the built pipeline.
    """
    scaler = StandardScaler(
        inputCol="unscaled_features", 
        outputCol="features"
    )
    pipeline = Pipeline(stages=[scaler, model])
    return pipeline

def print_best_validation_score(cross_validation_model):
    """Prints the best validation score based on the results from the cross-validation model."""
    print()
    print("-" * 60)
    print(f"F1 score, on the validation split, for the best model: {np.max(cross_validation_model.avgMetrics) * 100:.2f}%")
    print("-" * 60)

def print_best_parameters(cross_validation_model):
    """Prints the best hyper-parameters based on the results from the cross-validation model."""
    parameters = cross_validation_model \
       .getEstimatorParamMaps() [np.argmax(cross_validation_model.avgMetrics)]
    print()
    print("-" * 60)
    print("Best model hyper-parameters:")
    for param, value in parameters.items():
        print(f"{param}: {value}")
    print("-" * 60)

def evaluate_model(model, test_df):
    """Evaluate the model on the test set using F1 score and print the results."""
    predictions = model.transform(test_df)
    evaluator =  MulticlassClassificationEvaluator(
            metricName="f1", 
            beta=1.0
          )
    metric = evaluator.evaluate(predictions)
    print()
    print("-" * 60)
    print(f"F1 score, on the test set is: {metric*100:.2f}%")
    print("-" * 60)
    return metric

Разделить данные

train_df, test_df = engineered_df.randomSplit([0.8, 0.2], seed=42)

Логистическая регрессия

lr = LogisticRegression()
pipeline = create_pipeline(lr)
paramGrid = ParamGridBuilder() \
    .addGrid(lr.maxIter, [10, 25, 50])  \
    .addGrid(lr.regParam, [0.05, 0.1, 0.2]) \
    .addGrid(lr.elasticNetParam, [0.05, 0.1, 0.2]) \
    .build()run(
   pipeline,
   paramGrid,
   train_df.alias("train_df_lr"),
   test_df.alias("test_df_lr")
);

Наивный Байес

nb = NaiveBayes()
pipeline = create_pipeline(nb)
paramGrid = ParamGridBuilder() \
    .addGrid(nb.smoothing, [0.5, 1, 2])  \
    .build()run(
   pipeline,
   paramGrid,
   train_df.alias("train_df_nb"),
   test_df.alias("test_df_nb")
);

Повышение градиента

gbt = GBTClassifier()
pipeline = create_pipeline(gbt)
paramGrid = ParamGridBuilder() \
    .addGrid(gbt.maxIter, [10, 20, 30]) \
    .addGrid(gbt.stepSize, [0.05, 0.1]) \
    .build()run(
   pipeline,
   paramGrid, 
   train_df.alias("train_df_gbt"),
   test_df.alias("test_df_gbt")
);

Все три фрагмента кода для обучения и тестирования имеют одинаковую структуру.

Мы определили конвейер как шаг масштабирования и фактической модели. Кроме того, мы определили сетку перекрестной проверки гиперпараметров. В конце мы передаем все эти объекты в функцию run. Который обучает и оценивает модель на заданном разделении данных и конфигурации.

8️⃣ Результаты

|        Model        | Validation |   Test   |
|:-------------------:|:----------:|:--------:|
| Logistic Regression |   0.6958   |  0.5952  |
|     Naive Bayes     |   0.6672   |  0.5952  |
|  Gradient Boosting  |  *0.7333*  | *0.8473* |

Модель GBT имеет лучший показатель F1, чем логистическая регрессия и наивный залив, потому что GBT — это более сложная модель, которая может лучше понять нелинейные отношения.

Обычно методы повышения градиента работают лучше, чем другие методы, такие как LR или NB, без учета таких факторов, как корреляция между двумя переменными или нелинейные отношения. Базовые модели деревьев не чувствительны к этим проблемам, потому что они создают листья независимо и поддерживают многомерные отношения.

9️⃣ Заключение

Большой! Нам удалось обучить достойный классификатор, используя только Spark.

Мы загрузили набор данных, очистили его, проанализировали и в итоге создали набор полезных функции для прогнозирования оттока клиентов.

Используя инженерные данные, мы обучили и протестировали три модели прогнозирования оттока. Используя перекрестную проверку, мы находим наилучшие гиперпараметры для следующих моделей:

  • Логистическая регрессия
  • Наивный Байес
  • Повышение градиента

Мы сравнили результаты и увидели, что модель GBT имеет самый высокий балл F1 при проверке и тестировании.

Для дальнейшего улучшения модели мы можем сделать следующее:

  • добавить больше функций
  • устранить дисбаланс этикеток
  • используйте событие Downgrade, чтобы создать больше ярлыков изменения
  • дополнительная настройка гиперпараметров
  • используйте XGBoost или LightGBM

Какие еще у вас есть предложения по дальнейшему улучшению модели?

Примечание. Доступ к репозиторию GitHub можно получить здесь.

🎉 Спасибо, что прочитали мою статью!

📢 Если вам понравилась эта статья и вы хотите поделиться своим опытом обучения искусственному интеллекту, машинному обучению и MLOps, вы также можете подписаться на меня в LinkedIn.