Создайте собственный Transformer In Java spark ml

Я хочу создать собственный Spark Transformer на Java.

Transformer — это текстовый препроцессор, который действует как токенизатор. Он принимает входной столбец и выходной столбец в качестве параметров.

Я осмотрелся и нашел 2 черты Scala HasInputCol и HasOutputCol.

Как мне создать класс, расширяющий Transformer и реализующий HasInputCol и OutputCol?

Моя цель иметь что-то вроде этого.

   // Dataset that have a String column named "text"
   DataSet<Row> dataset;

   CustomTransformer customTransformer = new CustomTransformer();
   customTransformer.setInputCol("text");
   customTransformer.setOutputCol("result");

   // result that have 2 String columns named "text" and "result"
   DataSet<Row> result = customTransformer.transform(dataset);

person LonsomeHell    schedule 08.06.2017    source источник


Ответы (4)


Как предложил SergGr, вы можете расширить UnaryTransformer. Однако это довольно сложно.

ПРИМЕЧАНИЕ. Все приведенные ниже комментарии относятся к Spark версии 2.2.0.

Чтобы решить проблему, описанную в SPARK-12606, где они получали "...Param null__inputCol does not belong to...", вы должны реализовать String uid() следующим образом:

@Override
public String uid() {
    return getUid();
}

private String getUid() {

    if (uid == null) {
        uid = Identifiable$.MODULE$.randomUID("mycustom");
    }
    return uid;
}

Видимо они инициализировали uid в конструкторе. Но дело в том, что inputColoutputCol) UnaryTransformer инициализируется до того, как uid будет инициализирован в наследующем классе. См. HasInputCol:

final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name")

Вот как строится Param:

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

Таким образом, когда оценивается parent.uid, вызывается пользовательская реализация uid(), и в этот момент uid все еще равно нулю. Реализуя uid() с ленивой оценкой, вы гарантируете, что uid() никогда не вернет значение null.

Хотя в вашем случае:

Param d7ac3108-799c-4aed-a093-c85d12833a4e__inputCol does not belong to fe3d99ba-e4eb-4e95-9412-f84188d936e3

кажется, это немного другое. Поскольку "d7ac3108-799c-4aed-a093-c85d12833a4e" != "fe3d99ba-e4eb-4e95-9412-f84188d936e3", похоже, что ваша реализация метода uid() возвращает новое значение при каждом вызове. Возможно в вашем случае было реализовано это так:

@Override
public String uid() {
    return Identifiable$.MODULE$.randomUID("mycustom");
}

Кстати, при расширении UnaryTransformer убедитесь, что функция преобразования Serializable.

person vixter    schedule 16.08.2017

Возможно, вы захотите унаследовать CustomTransformer от org.apache.spark.ml.UnaryTransformer. Вы можете попробовать что-то вроде этого:

import org.apache.spark.ml.UnaryTransformer;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import scala.Function1;
import scala.collection.JavaConversions$;
import scala.collection.immutable.Seq;

import java.util.Arrays;

public class MyCustomTransformer extends UnaryTransformer<String, scala.collection.immutable.Seq<String>, MyCustomTransformer>
{
    private final String uid = Identifiable$.MODULE$.randomUID("mycustom");

    @Override
    public String uid()
    {
        return uid;
    }


    @Override
    public Function1<String, scala.collection.immutable.Seq<String>> createTransformFunc()
    {
        // can't use labmda syntax :(
        return new scala.runtime.AbstractFunction1<String, Seq<String>>()
        {
            @Override
            public Seq<String> apply(String s)
            {
                // do the logic
                String[] split = s.toLowerCase().split("\\s");
                // convert to Scala type
                return JavaConversions$.MODULE$.iterableAsScalaIterable(Arrays.asList(split)).toList();
            }
        };
    }


    @Override
    public void validateInputType(DataType inputType)
    {
        super.validateInputType(inputType);
        if (inputType != DataTypes.StringType)
            throw new IllegalArgumentException("Input type must be string type but got " + inputType + ".");
    }

    @Override
    public DataType outputDataType()
    {
        return DataTypes.createArrayType(DataTypes.StringType, true); // or false? depends on your data
    }
}
person SergGr    schedule 08.06.2017
comment
Это не работает. Я думаю, это из-за ошибки. я получаю java.lang.IllegalArgumentException: requirement failed: Param d7ac3108-799c-4aed-a093-c85d12833a4e__inputCol does not belong to fe3d99ba-e4eb-4e95-9412-f84188d936e3. - person LonsomeHell; 09.06.2017
comment
@LonsomeHell, просто чтобы перепроверить, вы уверены, что настроили его с допустимым столбцом ввода? - person SergGr; 09.06.2017
comment
Да, я использовал setInput с допустимым именем столбца. - person LonsomeHell; 12.06.2017
comment
Я думаю, что это связано с этой ошибкой issues.apache.org/jira/browse/SPARK- 12606 - person LonsomeHell; 12.06.2017
comment
Мне нравится трюк с использованием AbstractFunction1, поэтому вам не нужно реализовывать ВСЕ методы. - person Mike Pone; 09.04.2019

Я немного опоздал на вечеринку, но у меня есть несколько примеров пользовательских преобразований Java Spark здесь: https://github.com/dafrenchyman/spark/tree/master/src/main/java/com/mrsharky/spark/ml/feature

Вот пример только с входным столбцом, но вы можете легко добавить выходной столбец, следуя тем же шаблонам. Однако это не реализует читателей и писателей. Вам нужно будет проверить ссылку выше, чтобы увидеть, как это сделать.

public class DropColumns extends Transformer implements Serializable, 
DefaultParamsWritable {

    private StringArrayParam _inputCols;
    private final String _uid;

    public DropColumns(String uid) {
        _uid = uid;
    }

    public DropColumns() {
        _uid = DropColumns.class.getName() + "_" + 
UUID.randomUUID().toString();
    }

    // Getters
    public String[] getInputCols() { return get(_inputCols).get(); }

   // Setters
   public DropColumns setInputCols(String[] columns) {
       _inputCols = inputCols();
       set(_inputCols, columns);
       return this;
   }

public DropColumns setInputCols(List<String> columns) {
    String[] columnsString = columns.toArray(new String[columns.size()]);
    return setInputCols(columnsString);
}

public DropColumns setInputCols(String column) {
    String[] columns = new String[]{column};
    return setInputCols(columns);
}

// Overrides
@Override
public Dataset<Row> transform(Dataset<?> data) {
    List<String> dropCol = new ArrayList<String>();
    Dataset<Row> newData = null;
    try {
        for (String currColumn : this.get(_inputCols).get() ) {
            dropCol.add(currColumn);
        }
        Seq<String> seqCol = JavaConverters.asScalaIteratorConverter(dropCol.iterator()).asScala().toSeq();      
        newData = data.drop(seqCol);
    } catch (Exception ex) {
        ex.printStackTrace();
    }
    return newData;
}

@Override
public Transformer copy(ParamMap extra) {
    DropColumns copied = new DropColumns();
    copied.setInputCols(this.getInputCols());
    return copied;
}

@Override
public StructType transformSchema(StructType oldSchema) {
    StructField[] fields = oldSchema.fields();  
    List<StructField> newFields = new ArrayList<StructField>();
    List<String> columnsToRemove = Arrays.asList( get(_inputCols).get() );
    for (StructField currField : fields) {
        String fieldName = currField.name();
        if (!columnsToRemove.contains(fieldName)) {
            newFields.add(currField);
        }
    }
    StructType schema = DataTypes.createStructType(newFields);
    return schema;
}

@Override
public String uid() {
    return _uid;
}

@Override
public MLWriter write() {
    return new DropColumnsWriter(this);
}

@Override
public void save(String path) throws IOException {
    write().saveImpl(path);
}

public static MLReader<DropColumns> read() {
    return new DropColumnsReader();
}

public StringArrayParam inputCols() {
    return new StringArrayParam(this, "inputCols", "Columns to be dropped");
}

public DropColumns load(String path) {
    return ( (DropColumnsReader) read()).load(path);
}
}
person J Pierret    schedule 15.06.2018

Еще позже к вечеринке у меня есть еще одно обновление. Мне было трудно найти информацию о расширении Spark Transformers для Java, поэтому я публикую свои выводы здесь.

Я также работал над пользовательскими преобразователями в Java. На момент написания было немного проще включить функцию сохранения/загрузки. Можно создать записываемые параметры, реализуя DefaultParamsWritable. Однако реализация DefaultParamsReadable по-прежнему приводит к исключению для меня, но есть простой обходной путь.

Вот базовая реализация переименования столбца:

public class ColumnRenamer extends Transformer implements DefaultParamsWritable {
    /**
     * A custom Spark transformer that renames the inputCols to the outputCols.
     * 
     * We would also like to implement DefaultParamsReadable<ColumnRenamer>, but
     * there appears to be a bug in DefaultParamsReadable when used in Java, see:
     * https://issues.apache.org/jira/browse/SPARK-17048
     **/
    private final String uid_;
    private StringArrayParam inputCols_;
    private StringArrayParam outputCols_;
    private HashMap<String, String> renameMap;

    public ColumnRenamer() {
        this(Identifiable.randomUID("ColumnRenamer"));
    }

    public ColumnRenamer(String uid) {
        this.uid_ = uid;
        init();
    }

    @Override
    public String uid() {
        return uid_;
    }

    @Override
    public Transformer copy(ParamMap extra) {
        return defaultCopy(extra);
    }

    /**
     * The below method is a work around, see:
     * https://issues.apache.org/jira/browse/SPARK-17048
     **/
    public static MLReader<ColumnRenamer> read() {
        return new DefaultParamsReader<>();
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        Dataset<Row> transformedDataset = dataset.toDF();
        // Check schema.
        transformSchema(transformedDataset.schema(), true); // logging = true
        // Rename columns.
        for (Map.Entry<String, String> entry: renameMap.entrySet()) {
            String inputColName = entry.getKey();
            String outputColName = entry.getValue();
            transformedDataset = transformedDataset
                .withColumnRenamed(inputColName, outputColName);
        }
        return transformedDataset;
    }

    @Override
    public StructType transformSchema(StructType schema) {

        // Validate the parameters here...
        
        String[] inputCols = getInputCols();
        String[] outputCols = getOutputCols();
        // Create rename mapping.
        renameMap = new HashMap<> ();
        for (int i = 0; i < inputCols.length; i++) {
            renameMap.put(inputCols[i], outputCols[i]);
        }
        // Rename columns.
        ArrayList<StructField> fields = new ArrayList<> ();
        for (StructField field: schema.fields()) {
            String columnName = field.name();
            if (renameMap.containsKey(columnName)) {
                columnName = renameMap.get(columnName);
            }
            fields.add(new StructField(
                columnName, field.dataType(), field.nullable(), field.metadata()
            ));
        }
        // Return as StructType.
        return new StructType(fields.toArray(new StructField[0]));
    }

    private void init() {
        inputCols_ = new StringArrayParam(this, "inputCols", "input column names");
        outputCols_ = new StringArrayParam(this, "outputCols", "output column names");
    }

    public StringArrayParam inputCols() {
        return inputCols_;
    }
    public ColumnRenamer setInputCols(String[] value) {
        set(inputCols_, value);
        return this;
    }
    public String[] getInputCols() {
        return getOrDefault(inputCols_);
    }

    public StringArrayParam outputCols() {
        return outputCols_;
    }
    public ColumnRenamer setOutputCols(String[] value) {
        set(outputCols_, value);
        return this;
    }
    public String[] getOutputCols() {
        return getOrDefault(outputCols_);
    }
}
person Nicio    schedule 26.04.2021