diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java index 8db08b870d0..2d624d23abd 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java @@ -166,7 +166,7 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int ro out.quickSetValue(i, outputCol, getCode(in, i)); } }*/ - + protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ // Apply loop tiling to exploit CPU caches double[] codes = getCodeCol(in, rowStart, blk); @@ -343,7 +343,7 @@ public Callable getBuildTask(CacheBlock in) { throw new DMLRuntimeException("Trying to get the Build task of an Encoder which does not require building"); } - public Callable getPartialBuildTask(CacheBlock in, int startRow, + public Callable getPartialBuildTask(CacheBlock in, int startRow, int blockSize, HashMap ret) { throw new DMLRuntimeException( "Trying to get the PartialBuild task of an Encoder which does not support partial building"); @@ -409,7 +409,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) { } public enum EncoderType { - Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite + Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, Udf } /* diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java index 00e588ee17e..03b20f0b824 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,6 +19,9 @@ package org.apache.sysds.runtime.transform.encode; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.List; import org.apache.sysds.api.DMLScript; @@ -45,8 +48,8 @@ public class ColumnEncoderUDF extends ColumnEncoder { //TODO pass execution context through encoder factory for arbitrary functions not just builtin //TODO integration into IPA to ensure existence of unoptimized functions - - private final String _fName; + + private String _fName; public int _domainSize = 1; protected ColumnEncoderUDF(int ptCols, String name) { @@ -72,7 +75,7 @@ public void build(CacheBlock in) { public List> getBuildTasks(CacheBlock in) { return null; } - + @Override public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; @@ -82,7 +85,7 @@ public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowSta MatrixBlock col = out.slice(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, new MatrixBlock()); ec.setVariable("I", new ListObject(new Data[] {ParamservUtils.newMatrixObject(col, true)})); ec.setVariable("O", ParamservUtils.newMatrixObject(col, true)); - + //call UDF function via eval machinery var fun = new EvalNaryCPInstruction(null, "eval", "", new CPOperand("O", ValueType.FP64, DataType.MATRIX), @@ -124,14 +127,14 @@ else if(columnEncoder instanceof ColumnEncoderFeatureHash){ } } } - + @Override protected ColumnApplyTask getSparseTask(CacheBlock in, MatrixBlock out, int outputCol, int startRow, int blk) { throw new DMLRuntimeException("UDF encoders do not support sparse tasks."); } - + @Override public void mergeAt(ColumnEncoder other) { if(other instanceof ColumnEncoderUDF) @@ -164,5 +167,21 @@ protected double getCode(CacheBlock in, int row) { @Override protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) { throw new DMLRuntimeException("UDF encoders only support full column access."); - } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + LOG.debug("Writing ColumnEncoderUTF to create"); + super.writeExternal(out); + out.writeInt(_domainSize); + out.writeUTF(_fName); + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + LOG.debug("reading ColumnEncoderUTF"); + super.readExternal(in); + _domainSize = in.readInt(); + _fName = in.readUTF(); + LOG.debug("set _fName: "+_fName); + }} } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java index f7f7a7f990c..adcf409ca21 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java @@ -92,7 +92,7 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V List mvIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfMethod.IMPUTE.toString(), minCol, maxCol))); List udfIDs = TfMetaUtils.parseUDFColIDs(jSpec, colnames, minCol, maxCol); - + // create individual encoders if(!rcIDs.isEmpty()) for(Integer id : rcIDs) @@ -103,7 +103,7 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V if(!ptIDs.isEmpty()) for(Integer id : ptIDs) addEncoderToMap(new ColumnEncoderPassThrough(id), colEncoders); - + if(!binIDs.isEmpty()) for(Object o : (JSONArray) jSpec.get(TfMethod.BIN.toString())) { JSONObject colspec = (JSONObject) o; @@ -130,7 +130,7 @@ else if ("EQUI-HEIGHT".equals(method)) for(Integer id : udfIDs) addEncoderToMap(new ColumnEncoderUDF(id, name), colEncoders); } - + // create composite decoder of all created encoders for(Entry> listEntry : colEncoders.entrySet()) { if(DMLScript.STATISTICS) @@ -200,6 +200,8 @@ else if(columnEncoder instanceof ColumnEncoderPassThrough) return EncoderType.PassThrough.ordinal(); else if(columnEncoder instanceof ColumnEncoderRecode) return EncoderType.Recode.ordinal(); + else if(columnEncoder instanceof ColumnEncoderUDF) + return EncoderType.Udf.ordinal(); throw new DMLRuntimeException("Unsupported encoder type: " + columnEncoder.getClass().getCanonicalName()); } @@ -216,6 +218,8 @@ public static ColumnEncoder createInstance(int type) { return new ColumnEncoderPassThrough(); case Recode: return new ColumnEncoderRecode(); + case Udf: + return new ColumnEncoderUDF(); default: throw new DMLRuntimeException("Unsupported encoder type: " + etype); } diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java index 1586a51b7d3..81513784716 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,28 +29,30 @@ import org.apache.sysds.test.TestUtils; import org.apache.sysds.utils.Statistics; -public class TransformEncodeUDFTest extends AutomatedTestBase +public class TransformEncodeUDFTest extends AutomatedTestBase { private final static String TEST_NAME1 = "TransformEncodeUDF1"; //min-max private final static String TEST_NAME2 = "TransformEncodeUDF2"; //scale w/ defaults + private final static String TEST_NAME3 = "TransformEncodeUDF3"; //simple custom UDF private final static String TEST_DIR = "functions/transform/"; private final static String TEST_CLASS_DIR = TEST_DIR + TransformEncodeUDFTest.class.getSimpleName() + "/"; - + //dataset and transform tasks without missing values private final static String DATASET = "homes3/homes.csv"; - + @Override public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) ); addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) ); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"R"}) ); } - + @Test public void testUDF1Singlenode() { runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME1); } - + @Test public void testUDF1Hybrid() { runTransformTest(ExecMode.HYBRID, TEST_NAME1); @@ -63,29 +65,39 @@ public void testUDF2Singlenode() { @Test public void testUDF2Hybrid() { - runTransformTest(ExecMode.HYBRID, TEST_NAME2); + runTransformTest(ExecMode.HYBRID, TEST_NAME2); } - + + @Test + public void testUDF3Singlenode() { + runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME3); + } + + @Test + public void testUDF3Hybrid() { + runTransformTest(ExecMode.HYBRID, TEST_NAME3); + } + private void runTransformTest(ExecMode rt, String testname) { //set runtime platform ExecMode rtold = setExecMode(rt); - + try { getAndLoadTestConfiguration(testname); - + String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; programArgs = new String[]{"-explain", "-nvargs", "DATA=" + DATASET_DIR + DATASET, "R="+output("R")}; //compare transformencode+scale vs transformencode w/ UDF - runTest(true, false, null, -1); - + runTest(true, false, null, -1); + double ret = HDFSTool.readDoubleFromHDFSFile(output("R")); Assert.assertEquals(Double.valueOf(148*9), Double.valueOf(ret)); - + if( rt == ExecMode.HYBRID ) { Long num = Long.valueOf(Statistics.getNoOfExecutedSPInst()); Assert.assertEquals("Wrong number of executed Spark instructions: " + num, Long.valueOf(0), num); diff --git a/src/test/scripts/functions/transform/TransformEncodeUDF3.dml b/src/test/scripts/functions/transform/TransformEncodeUDF3.dml new file mode 100644 index 00000000000..6f8ae6f47b7 --- /dev/null +++ b/src/test/scripts/functions/transform/TransformEncodeUDF3.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F1 = read($DATA, data_type="frame", format="csv"); + + + +