Урок от край до край за това как да изградите конвейер за прогнозиране на отлив, като използвате само Apache Spark.

Тази статия е урок за това как да изградите класификатор за прогнозиране на отлив с помощта на ML стека от Spark.

Ще използваме данни от фиктивна компания, наречена Sparkify, компания за стрийминг на музика. Наборът от данни съдържа всички видове събития, създадени от потребителите, които са взаимодействали с платформата.

Ще разгледаме по-подробно данните по време на урока. Но искам да благодаря на Udacity, че направи тези данни публични. Без тях не бих могъл да направя този урок.

Тук можете да изтеглитемини версията на набора от данни и тук целия набор от данни.

Какво е отлив на клиенти?

Оттеглянето на клиенти е, когато някой реши да спре да използва вашите продукти или услуги.

Защо е полезен такъв модел?

Изграждайки такъв модел, можем да разберем защо потребителите напускат компанията. Следователно това е един от начините за подобряване на клиентското изживяване и задържане въз основа на данните за потребителската активност.

Съдържание

  1. Определете Spark сесия
  2. Заредете набора от данни
  3. Почистете набора от данни
  4. Дефинирайте етикета за оттегляне на клиенти
  5. Проучвателен анализ на данни
  6. Инженеринг на характеристиките
  7. Моделиране
  8. Резултати
  9. Заключение

Забележка: Изпълнихме целия код в един бележник.

1️⃣ Определете Spark сесия

spark = SparkSession.\
    builder.\
    appName("Sparkify Churn Prediction").\
    getOrCreate()

Създадохме искрова сесия, наречена Sparkify Churn Prediction. С тази сесия можем да заредим и изпълним всички изчисления.

2️⃣ Заредете набора от данни

EVENT_DATA_LINK = "mini_sparkify_event_data.json"
df = spark.read.json(EVENT_DATA_LINK)
df.persist()

Изтеглихме файла mini_sparkify_event_data.jsonв същата директория с бележника и го заредихме в паметта с помощта на сесията на Spark.

df.printSchema()

|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

Най-критичните колони, използвани в тази статия, са следните:

  • Изпълнител: Изпълнителят на песента, която се изпълнява в момента.
  • ниво: Категорична променлива, която може да бъде платенаили безплатна.
  • страница: местоположението на потребителя в рамките на приложението (напр. страницата за влизане,
  • регистрация: Времето UTC на регистрацията.
  • sessionId: ID на текущата сесия на потребителя.
  • песен: Текущата песен, която се възпроизвежда.
  • състояние: HTTP състоянието на събитието. (напр. 200, 307, 404).
  • ts: Времето UTC на събитието.
  • userId: ID на потребителя.

3️⃣ Почистете набора от данни

В действителния Notebook, който може да бъде достъпен тук, направихме по-широка проверка на това какво да почистваме, но за да запазим нещата компактни, ще покажем само най-важното.

Проверете за NaN/празни стойности

df.select([
    F.count(F.when(F.isnull(c), c)).alias(c) for c in df.columns
]).show()

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId| song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+
| 58392|   0|     8346|  8346|            0|    8346| 58392|    0|    8346|     0|   0|        8346|        0|58392|     0|  0|     8346|     0|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+

Повечето от празните стойности в набора от данни се сигнализират с Няма.

df.select([
     F.count(F.when(F.col(c) == "", c)).alias(c) for c in df.columns
]).show()

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId|song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|     0|   0|        0|     0|            0|       0|     0|    0|       0|     0|   0|           0|        0|   0|     0|  0|        0|  8346|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+

Има изключение в колоната userId, където празните стойности се сигнализират с празен низ.

Нерегистрирани потребители

Потребителите с празен низ са нерегистрирани потребители; следователно платформата не може да им присвои никакъв идентификатор.

Проверете разпространението на страниците на нерегистрирани потребители

df.filter(F.col("userId") == "").select("page").groupby("page").count().show()

+-------------------+-----+
|               page|count|
+-------------------+-----+
|               Home| 4375|
|              About|  429|
|              Login| 3241|
|               Help|  272|
|              Error|    6|
|           Register|   18|
|Submit Registration|    5|
+-------------------+-----+

Използването на нерегистрирани потребителски данни не може да ни помогне да предвидим отлив на клиенти. Следователно ще изпуснем тези редове от Spark DataFrame.

Изтриване на нерегистрирани потребители

cleaned_df = df.filter(F.col("userId") != "")

Празни редове за изпълнители

df.filter(
    F.isnull(F.col("artist"))
).select(
    ["artist", "song", "userId", "page"]
).groupby("page").count().show()

+--------------------+-----+
|                page|count|
+--------------------+-----+
|              Cancel|   52|
|    Submit Downgrade|   63|
|         Thumbs Down| 2546|
|                Home|14457|
|           Downgrade| 2055|
|         Roll Advert| 3933|
|              Logout| 3226|
|       Save Settings|  310|
|Cancellation Conf...|   52|
|               About|  924|
|            Settings| 1514|
|               Login| 3241|
|     Add to Playlist| 6526|
|          Add Friend| 4277|
|           Thumbs Up|12551|
|                Help| 1726|
|             Upgrade|  499|
|               Error|  258|
|      Submit Upgrade|  159|
|            Register|   18|
+--------------------+-----+

Когато изпълнителят е нулев, потребителите прекарват повече време в други страници, отколкото в слушане на музика. В този случай такава информация е ценна за разбиране на поведението на регистрираните потребители.

Попълване на празни стойности

cleaned_df = cleaned_df.fillna({
    "length": 0,
    "artist": "unknown",
    "song": "unknown"
})

Като последна стъпка искаме да попълним редовете със стойности Няма. В този момент празни редове се появяват само за полета, свързани с песни, когато потребителят не слуша нищо. Информацията от тези редове е ценна и не препоръчваме да ги изпускате.

Затова задаваме дължината на песента на 0и изпълнителя и песента нанеизвестен.

4️⃣ Дефинирайте етикета за оттегляне на клиенти

cleaned_df.select(
    F.count(
        F.when(F.col("page") == "Cancellation Confirmation", "page")
    ).alias("Cancellation Confirmation")
).show()

+-------------------------+
|Cancellation Confirmation|
+-------------------------+
|                       52|
+-------------------------+

Ще считаме, че даден потребител е „отхвърлен“когато абонаментът бъде анулиран.

Забележка:Ако искахме да преминем към следващото ниво, можехме също така да използваме събитието Понижаване, за да сигнализираме за събитие на отлив на клиенти.

labeled_df = cleaned_df.withColumn(
    "churnEvent", 
    F.when(
       F.col("page") == "Cancellation Confirmation",
       1
    ).otherwise(0)
)

Като междинна стъпка ще създадем колоната churnEvent, която е флаг, който проверява за събитие „Потвърждение за анулиране“. Ние проектирахме колонатаchurnEventна нивосъбитие.

user.labeled_df = labeled_df.withColumn(
    "churn", 
    F.sum("churnEvent").over(Window.partitionBy("userId"))
)
labeled_df = labeled_df.withColumn(
    "churn", 
    F.when(F.col("churn") >= 1, 1).otherwise(0)
)

Обединихме флага churnEvent към нивото на потребителя и го отрязахме до 0 или 1. Следователно потребителят има или churn = 1, или no-churn = 0знамена във всички събития.

Не забравяйте, че в крайна сметка искаме да изградим класификатор, който предвижда отлив на потребители, а не отлив на събития. Ето защо даден потребител е или отлив = 1или без отлив = 0.

5️⃣ Проучвателен анализ на данни

Забележка:За да бъде статията кратка, ще покажем само графиките, които ни помогнаха да създадем ценни функции. Можете да видите всички графики и как са начертани тук.

Разпределение на оттеглянето на потребителите

Разпределението е силно изкривено към потребителите без напускане. Ще използваме резултата F1, за да намерим най-добрия модел в стъпката на кръстосано валидиране, за да вземем предвид този фактор.

Възможни страници

eda_df.select("page").distinct().show()

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
|            Settings|
|     Add to Playlist|
|          Add Friend|
|            NextSong|
|           Thumbs Up|
|                Help|
|             Upgrade|
|               Error|
|      Submit Upgrade|
+--------------------+
eda_df.filter(
    F.col("artist") != "unknown"
).select("page").distinct().show()
+--------+
|    page|
+--------+
|NextSong|
+--------+

Единствената страница, на която се възпроизвежда музика, се нарича Следваща песен.

Разпределение на средния брой посетени страници от всеки потребител

Потребителите, които са останали на платформата, са посетили средно повече страници.

Разпределение на средния брой песни, слушани от всеки потребител

Средно потребителите, които са останали на платформата, слушат повече песни.

Разпределение на средния брой слушани изпълнители от всеки потребител

Средно потребителите, които остават на платформата, слушат по-голямо разнообразие от изпълнители.

Делта време от разпределението на регистрацията

Делта време = броят секунди, откакто потребителят се е регистрирал в платформата.

Изглежда, че разликата във времето от регистрацията за отлив потребители е изкривена надясно. В същото време разпределението наno-churn е нормално разпределено. По този начин, средната стойност на времевата делта е добър предиктор.

6️⃣ Инженеринг на функции

Класификацията на отлив ще бъде извършена на ниво потребител. Следователно трябва да обединим данните за всеки потребител в един ред.

Защо да агрегираме данните на потребителско ниво?

Окончателните прогнози трябва да се използват за прогнозиране дали даден потребител е склонен да напусне компанията, а не събитието, когато потребителят ще се отпише. По този начин едно събитие е безполезно за нас, но агрегирането му е изключително ценно.

Дефинирайте някои полезни функции:

def count_with_condition(condition):
    """Utility function to count only specific rows based on the 'condition'."""
    return F.count(F.when(condition, True))

def count_distinct_with_condition(condition, values):
    """Utility function to count only distinct & specific rows based on the 'condition'."""
    return F.count_distinct(F.when(condition, values))

Създайте разработената с функции Spark DataFrame

Като обобщение ще използваме функциите, представени в стъпката EDA. В Бележника проучихме повече потенциални функции, но тези четири показаха най-голяма предсказваща сила между отлив и без отлив потребители:

  • Общият брой на посетените страници.
  • Общият брой пуснати песни.
  • Общият брой на общо артисти.
  • Час от регистрацията (в секунди).
aggregated_df = labeled_df.groupby("userId").agg(
    F.count("page").alias("pages"),
    count_with_condition(
        F.col("page") == "NextSong"
    ).alias("plays"),
    count_distinct_with_condition(
        F.col("artist") != "unknown", F.col("artist")
    ).alias("artists"),
    F.max(F.col("ts") - F.col("registration")).alias("delta"),
    F.max("churn").alias("churn")
)
aggregated_df.show(n=5)

+------+-----+-----+-------+-----------+-----+
|userId|pages|plays|artists|      delta|churn|
+------+-----+-----+-------+-----------+-----+
|100010|  381|  275|    252| 4807612000|    0|
|100014|  310|  257|    233| 7351206000|    1|
|100021|  319|  230|    207| 5593438000|    1|
|   101| 2149| 1797|   1241| 4662657000|    1|
|    11|  848|  647|    534|10754921000|    0|
+------+-----+-----+-------+-----------+-----+

Картирайте DataFrame към Spark Vector

assembler = VectorAssembler(
    inputCols=[
        "pages", "plays", "artists", "delta"
], outputCol="unscaled_features")
engineered_df = assembler.transform(aggregated_df)
engineered_df = engineered_df.select(
    F.col("unscaled_features"),
    F.col("churn").alias("label")
)

Моделите Spark очакват Spark Vector на входа. Освен това, по подразбиране, те очакват функциите за вход да бъдат в колона, наречена функции, а целта да бъде в колона, наречена етикет.

7️⃣ Моделиране

Ще обучим и тестваме три модела:

  1. Логистична регресия
  2. Наивен Бейс
  3. Градиентно усилващо дърво

Ще използваме кръстосано валидиране с три пъти, за да намерим най-добрите хиперпараметри.

Ще използваме 80% от данните във влаковия сплит и 20% в тестовия сплит.

Ние нормализирахме характеристиките с помощта на StandardScaler.

Тъй като етикетите са силно небалансирани, ще използваме резултат F1за оценка на моделите. Показателят F1 резултат работи под капака на прецизност и запомняне, което отчита проблема с небалансираното разпределение.

Дефинирайте някои полезни функции:

Пропуснете ги, ако се интересувате само от резултатите.

Синтаксисът на Spark ML е много подобен на този, използван в Sklearn.

def run(pipeline, paramGrid, train_df, test_df):
    """
    Main function used to train & test a given model.
    The training step uses cross-validation to find the best hyper-parameters for the model.
    :param pipeline: Model pipeline.
    :param paramGrid: Parameter grid used for cross-validation.
    :param train_df: Training dataframe.
    :param test_df: Testing dataframe.
    :return: the best model from cross-validation.
    """
    fitted_model = fit_model(paramGrid, pipeline, train_df)
    evaluate_model(fitted_model, test_df)
    return fitted_model

def fit_model(paramGrid, pipeline, train_df):
    """
    Function that trains the model using cross-validation.
    Also, it prints the best validation results and hyper-parameters.
    :param paramGrid: Parameter grid used for cross-validation.
    :param pipeline: Model pipeline.
    :param train_df: Training dataframe.
    :return: the best model from cross-validation.
    """
    crossval = CrossValidator(
        estimator=pipeline,
        estimatorParamMaps=paramGrid,
        evaluator = MulticlassClassificationEvaluator(
              metricName="f1", 
              beta=1.0
        ),
        parallelism=3,
        numFolds=3
    )
    fitted_model = crossval.fit(train_df)
    print_best_validation_score(fitted_model)
    print_best_parameters(fitted_model)
    return fitted_model

def create_pipeline(model):
    """
    Create a pipeline based on a model.
    :param model: The end model that will be used for training.
    :return: the built pipeline.
    """
    scaler = StandardScaler(
        inputCol="unscaled_features", 
        outputCol="features"
    )
    pipeline = Pipeline(stages=[scaler, model])
    return pipeline

def print_best_validation_score(cross_validation_model):
    """Prints the best validation score based on the results from the cross-validation model."""
    print()
    print("-" * 60)
    print(f"F1 score, on the validation split, for the best model: {np.max(cross_validation_model.avgMetrics) * 100:.2f}%")
    print("-" * 60)

def print_best_parameters(cross_validation_model):
    """Prints the best hyper-parameters based on the results from the cross-validation model."""
    parameters = cross_validation_model \
       .getEstimatorParamMaps() [np.argmax(cross_validation_model.avgMetrics)]
    print()
    print("-" * 60)
    print("Best model hyper-parameters:")
    for param, value in parameters.items():
        print(f"{param}: {value}")
    print("-" * 60)

def evaluate_model(model, test_df):
    """Evaluate the model on the test set using F1 score and print the results."""
    predictions = model.transform(test_df)
    evaluator =  MulticlassClassificationEvaluator(
            metricName="f1", 
            beta=1.0
          )
    metric = evaluator.evaluate(predictions)
    print()
    print("-" * 60)
    print(f"F1 score, on the test set is: {metric*100:.2f}%")
    print("-" * 60)
    return metric

Разделяне на данните

train_df, test_df = engineered_df.randomSplit([0.8, 0.2], seed=42)

Логистична регресия

lr = LogisticRegression()
pipeline = create_pipeline(lr)
paramGrid = ParamGridBuilder() \
    .addGrid(lr.maxIter, [10, 25, 50])  \
    .addGrid(lr.regParam, [0.05, 0.1, 0.2]) \
    .addGrid(lr.elasticNetParam, [0.05, 0.1, 0.2]) \
    .build()run(
   pipeline,
   paramGrid,
   train_df.alias("train_df_lr"),
   test_df.alias("test_df_lr")
);

Наивен Бейс

nb = NaiveBayes()
pipeline = create_pipeline(nb)
paramGrid = ParamGridBuilder() \
    .addGrid(nb.smoothing, [0.5, 1, 2])  \
    .build()run(
   pipeline,
   paramGrid,
   train_df.alias("train_df_nb"),
   test_df.alias("test_df_nb")
);

Градиентно усилване

gbt = GBTClassifier()
pipeline = create_pipeline(gbt)
paramGrid = ParamGridBuilder() \
    .addGrid(gbt.maxIter, [10, 20, 30]) \
    .addGrid(gbt.stepSize, [0.05, 0.1]) \
    .build()run(
   pipeline,
   paramGrid, 
   train_df.alias("train_df_gbt"),
   test_df.alias("test_df_gbt")
);

И трите фрагмента от кода за обучение и тестване имат една и съща структура.

Дефинирахме тръбопроводакато стъпка на мащабиране и действителния модел. Също така дефинирахме мрежата за кръстосано валидиранена хиперпараметри. В крайна сметка предаваме всички тези обекти на функцията run. Което обучава и оценява модела на дадено разделяне на данни и конфигурация.

8️⃣ Резултати

|        Model        | Validation |   Test   |
|:-------------------:|:----------:|:--------:|
| Logistic Regression |   0.6958   |  0.5952  |
|     Naive Bayes     |   0.6672   |  0.5952  |
|  Gradient Boosting  |  *0.7333*  | *0.8473* |

Моделът GBT има по-добър F1 резултат от Logistic Regression и Naive Bay и това е така, защото GBTе по-сложен модел, който може да разбере по-добре нелинейните връзки.

Обикновено методите за градиентно усилване се представят по-добре от други методи като LRили NBбез да се вземат предвид фактори като корелациятамежду две променливи или нелинейни зависимости. Базовите модели на дърветата не са чувствителни към тези проблеми, защото създават листа независимо и поддържат многоизмерни връзки.

9️⃣ Заключение

Страхотен! Успяхме да обучим приличен класификатор, използвайки само Spark.

Заредихме набора от данни, почистихме го, анализирахме го и накрая създадохме набор от полезни Характеристики за предвиждане на напускане на клиенти.

С инженерните данни ние обучихме и тествахме три модела за прогнозиране на отлив. Използвайки кръстосано валидиране, намираме най-добрите хиперпараметри за следните модели:

  • Логистична регресия
  • Наивен Бейс
  • Градиентно усилване

Сравнихме резултатите и видяхме, че моделът GBTима най-високия резултат F1при разделянето на валидирането и теста.

За да подобрим допълнително модела, можем да направим следното:

  • добавете още функции
  • разрешаване на дисбаланса на етикета
  • използвайте събитието Понижаване, за да генерирате повече етикети за отлив
  • повече хиперпараметрична настройка
  • използвайте XGBoost или LightGBM

Какви други предложения имате за допълнително подобряване на модела?

Забележка: Можете да получите достъп до хранилището на GitHub тук.

🎉 Благодаря ви, че прочетохте моята статия!

📢 Ако тази статия ви е харесала и искате да споделите моето учебно пътуване в AI, ML и MLOps, можете също да ме последвате в LinkedIn.