ArrayIndexOutOfBoundsException при подгонке логистической регрессии Spark ML в R

Я пытаюсь подобрать модель логистической регрессии, используя sparklyr::ml_logistic_regression. Мой набор обучающих данных содержит 42 457 строк и 785 столбцов; ответ представляет собой целое число 0/1 в столбце label, а все остальные столбцы представляют собой целочисленные функции 0/1. Мои исходные данные находятся во фрейме данных R (df), и я могу успешно подогнать модель к базе R, используя glm(label ~ ., data = df, family = binomial).

К сожалению, я не могу подогнать эту модель под ml_logistic_regression. Код выглядит следующим образом; sc — это существующее соединение Spark.

library(sparklyr)
library(tidyverse)

copy_to(sc, df, "spark_train", overwrite = TRUE)
train_tbl <- tbl(sc, "spark_train")
fit <- ml_logistic_regression(train_tbl, label ~ .)

Вот трассировка стека:

d> fit <- ml_logistic_regression(train_tbl, label ~ .)
* No rows dropped by 'na.omit' call
Error: java.lang.ArrayIndexOutOfBoundsException: 1
    at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:343)
    at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:159)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:90)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:71)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
    at java.lang.reflect.Method.invoke(Unknown Source)
    at sparklyr.Invoke$.invoke(invoke.scala:94)
    at sparklyr.StreamHandler$.handleMethodCall(stream.scala:89)
    at sparklyr.StreamHandler$.read(stream.scala:55)
    at sparklyr.BackendHandler.channelRead0(handler.scala:49)
    at sparklyr.BackendHandler.channelRead0(handler.scala:14)
    at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
    at io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:103)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
    at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:244)
    at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:308)
    at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:294)
    at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:846)
    at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:131)
    at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:511)
    at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:468)
    at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:382)
    at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:354)
    at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:111)
    at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:137)
    at java.lang.Thread.run(Unknown Source)

А вот мой sessionInfo():

R version 3.3.2 (2016-10-31)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows >= 8 x64 (build 9200)

locale:
[1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
[5] LC_TIME=English_United Kingdom.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] dplyr_0.7.1      purrr_0.2.2.2    readr_1.0.0      tidyr_0.6.3     
 [5] tibble_1.3.3     ggplot2_2.2.1    tidyverse_1.1.1  sparklyr_0.5.6  
 [9] robomarker_0.1.0 devtools_1.12.0 

loaded via a namespace (and not attached):
 [1] h2o_3.10.5.2     reshape2_1.4.2   haven_1.0.0      lattice_0.20-34 
 [5] colorspace_1.3-2 htmltools_0.3.5  yaml_2.1.14      base64enc_0.1-3 
 [9] rlang_0.1.1      foreign_0.8-67   glue_1.1.1       withr_1.0.2     
[13] DBI_0.7          rappdirs_0.3.1   dbplyr_1.0.0     modelr_0.1.0    
[17] readxl_1.0.0     bindrcpp_0.2     bindr_0.1        plyr_1.8.4      
[21] stringr_1.2.0    munsell_0.4.3    commonmark_1.1   gtable_0.2.0    
[25] cellranger_1.1.0 rvest_0.3.2      psych_1.7.3.21   memoise_1.0.0   
[29] forcats_0.2.0    httpuv_1.3.3     parallel_3.3.2   broom_0.4.2     
[33] Rcpp_0.12.10     xtable_1.8-2     backports_1.0.5  scales_0.4.1    
[37] jsonlite_1.2     config_0.2       mime_0.5         mnormt_1.5-5    
[41] hms_0.3          digest_0.6.12    stringi_1.1.2    shiny_1.0.3     
[45] grid_3.3.2       rprojroot_1.2    bitops_1.0-6     tools_3.3.2     
[49] magrittr_1.5     RCurl_1.95-4.8   lazyeval_0.2.0   pkgconfig_2.0.1 
[53] xml2_1.1.1       lubridate_1.6.0  assertthat_0.1   roxygen2_6.0.1  
[57] httr_1.2.1       rstudioapi_0.6   R6_2.2.0         rsparkling_0.2.0
[61] nlme_3.1-128

Любая идея, почему это может происходить?


person jkeirstead    schedule 28.06.2017    source источник


Ответы (1)


Эта ошибка может быть вызвана наличием только одного типа меток в наборе обучающих данных. Убедитесь, что у вас есть несколько типов этикеток; в зависимости от вашей версии искры вы можете использовать только две метки (например, 0 и 1 для биномиальной регрессии).

person kingledion    schedule 14.06.2018
comment
У меня похожая проблема, и это работает для меня, спасибо - person mjimcua; 29.10.2018