From f26c5f4a4c7361732b3883dc142141153541eef5 Mon Sep 17 00:00:00 2001 From: Kevin Innerebner Date: Fri, 15 Jul 2022 14:45:34 +0200 Subject: [PATCH 1/6] Write matrices and frames at site for federated write --- .../federated/FederatedData.java | 6 +- .../fed/VariableFEDInstruction.java | 4 - .../runtime/io/ReaderWriterFederated.java | 49 ++++++- .../primitives/FederatedWriteTest.java | 126 ++++++++++++++++++ .../federated/FederatedWriteTest.dml | 25 ++++ .../federated/FederatedWriteTestRead.dml | 23 ++++ .../federated/FederatedWriteTestReference.dml | 24 ++++ 7 files changed, 249 insertions(+), 8 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java create mode 100644 src/test/scripts/functions/federated/FederatedWriteTest.dml create mode 100644 src/test/scripts/functions/federated/FederatedWriteTestRead.dml create mode 100644 src/test/scripts/functions/federated/FederatedWriteTestReference.dml diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java index 370163aaf2c..7f0b22d2dfb 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java @@ -71,7 +71,7 @@ public class FederatedData { private final Types.DataType _dataType; private final InetSocketAddress _address; - private final String _filepath; + private String _filepath; /** * The ID of default matrix/tensor on which operations get executed if no other ID is given. @@ -105,6 +105,10 @@ public long getVarID() { return _varID; } + public void setFilepath(String filepath) { + _filepath = filepath; + } + public String getFilepath() { return _filepath; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java index 197ba43cf74..1e7573ca495 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java @@ -78,10 +78,6 @@ public void processInstruction(ExecutionContext ec) { private void processWriteInstruction(ExecutionContext ec) { LOG.warn("Processing write command federated"); - // TODO Add write command to the federated site if the matrix has been modified - // this has to be done while appending some string to the federated output file. - // furthermore the outputted file on the federated sites path should be returned - // the controller. _in.processInstruction(ec); } diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java index 070c5346ab7..712fdfb33a9 100644 --- a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java +++ b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java @@ -40,10 +40,14 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.common.Types; import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.federated.FederatedData; -import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; -import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.*; +import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.meta.DataCharacteristics; /** @@ -62,6 +66,7 @@ */ public class ReaderWriterFederated { private static final Log LOG = LogFactory.getLog(ReaderWriterFederated.class.getName()); + private static final IDSequence siteUniqueCounter = new IDSequence(true); /** * Read a federated map from disk, It is not initialized before it is used in: @@ -103,6 +108,19 @@ public static void write(String file, FederationMap fedMap) { JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); Path path = new Path(file); FileSystem fs = IOUtilFunctions.getFileSystem(path, job); + + fedMap.forEachParallel((range, data) -> { + String siteFilename = Long.toString(siteUniqueCounter.getNextID()) + '_' + path.getName(); + try { + FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, + data.getVarID(), new WriteAtSiteUDF(data.getVarID(), siteFilename))).get(); + data.setFilepath((String) response.getData()[0]); + } catch (Exception e) { + throw new DMLRuntimeException(e); + } + return null; + }); + DataOutputStream out = fs.create(path, true); ObjectMapper mapper = new ObjectMapper(); FederatedDataAddress[] outObjects = parseMap(fedMap.getMap()); @@ -207,4 +225,29 @@ public String toString() { return sb.toString(); } } + + private static class WriteAtSiteUDF extends FederatedUDF { + private static final long serialVersionUID = -6645546954618784216L; + + private final String _filename; + + public WriteAtSiteUDF(long input, String filename) { + super(new long[] {input}); + _filename = filename; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixObject mo = (MatrixObject) data[0]; + String tmpDir = ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.LOCAL_TMP_DIR); + Path p = new Path(tmpDir + '/' + _filename); + mo.exportData(p.toString(), null); + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, p.toString()); + } + + @Override + public Pair getLineageItem(ExecutionContext ec) { + return null; + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java new file mode 100644 index 00000000000..1782e5534f6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedWriteTest extends AutomatedTestBase { + private final static String TEST_DIR = "functions/federated/"; + private final static String TEST_NAME = "FederatedWriteTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedWriteTest.class.getSimpleName() + "/"; + + private static final int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Parameterized.Parameters + public static Collection data() { + // cols have to be dividable by 4 for Frame tests + return Arrays.asList(new Object[][] { + // {1, 1024}, {8, 256}, {256, 8}, {1024, 4}, {16, 2048}, + {2048, 32}}); + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"})); + } + + @Test + public void federatedMatrixWriteCP() { + federatedMatrixWrite(Types.ExecMode.SINGLE_NODE); + } + + @Test + public void federatedMatrixWriteSP() { + federatedMatrixWrite(Types.ExecMode.SPARK); + } + + public void federatedMatrixWrite(Types.ExecMode execMode) { + getAndLoadTestConfiguration(TEST_NAME); + // write input matrix + double[][] A = getRandomMatrix(rows, cols, -1, 1, 1, 1234); + writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols)); + federatedWrite(execMode, null); + } + + // TODO: frame write testcase + public void federatedWrite(Types.ExecMode execMode, Types.ValueType[] schema) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + + String HOME = SCRIPT_DIR + TEST_DIR; + + int port = getRandomAvailablePort(); + Thread t = startLocalFedWorkerThread(port); + + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // we need the reference file to not be written to hdfs, so we get the correct format + rtplatform = Types.ExecMode.SINGLE_NODE; + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-nvargs", "in=" + input("A"), "out=" + expected("B")}; + runTest(true, false, null, -1); + + // reference file should not be written to hdfs + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), + "rows=" + rows, "cols=" + cols, "tmp=" + output("T")}; + runTest(true, false, null, -1); + + Assert.assertSame(getMetaData("T").getFileFormat(), Types.FileFormat.FEDERATED); + + fullDMLScriptName = HOME + TEST_NAME + "Read.dml"; + programArgs = new String[] {"-explain", "-nvargs", "tmp=" + output("T"), "out=" + output("B")}; + runTest(true, false, null, -1); + // compare via files + if(schema != null) + compareResults(schema); + else + compareResults(1e-12); + + TestUtils.shutdownThread(t); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } +} diff --git a/src/test/scripts/functions/federated/FederatedWriteTest.dml b/src/test/scripts/functions/federated/FederatedWriteTest.dml new file mode 100644 index 00000000000..e803a630ad5 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedWriteTest.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = federated(addresses=list($in, $in), + ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols))) +B = A * 42.0 +write(B, $tmp, format="federated") diff --git a/src/test/scripts/functions/federated/FederatedWriteTestRead.dml b/src/test/scripts/functions/federated/FederatedWriteTestRead.dml new file mode 100644 index 00000000000..c05fc7d99ef --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedWriteTestRead.dml @@ -0,0 +1,23 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($tmp) +write(A, $out) diff --git a/src/test/scripts/functions/federated/FederatedWriteTestReference.dml b/src/test/scripts/functions/federated/FederatedWriteTestReference.dml new file mode 100644 index 00000000000..f3839f9c9d7 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedWriteTestReference.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = rbind(read($in), read($in)) +B = A * 42.0 +write(B, $out) From 3435c69d2bd835caac06a0b1955a777ef2a9d095 Mon Sep 17 00:00:00 2001 From: Kevin Innerebner Date: Sun, 17 Jul 2022 15:02:04 +0200 Subject: [PATCH 2/6] Add and fix frame testcases --- .../controlprogram/caching/CacheableData.java | 8 +++- .../controlprogram/caching/FrameObject.java | 16 ++++---- .../instructions/fed/InitFEDInstruction.java | 4 ++ .../runtime/io/ReaderWriterFederated.java | 5 ++- .../primitives/FederatedWriteTest.java | 40 +++++++++++++++++-- .../federated/FederatedWriteFrameTest.dml | 24 +++++++++++ .../FederatedWriteFrameTestReference.dml | 23 +++++++++++ ...eTest.dml => FederatedWriteMatrixTest.dml} | 0 ... => FederatedWriteMatrixTestReference.dml} | 0 9 files changed, 106 insertions(+), 14 deletions(-) create mode 100644 src/test/scripts/functions/federated/FederatedWriteFrameTest.dml create mode 100644 src/test/scripts/functions/federated/FederatedWriteFrameTestReference.dml rename src/test/scripts/functions/federated/{FederatedWriteTest.dml => FederatedWriteMatrixTest.dml} (100%) rename src/test/scripts/functions/federated/{FederatedWriteTestReference.dml => FederatedWriteMatrixTestReference.dml} (100%) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index 8cb108d478a..6dd726db6f3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -392,8 +392,12 @@ public boolean isFederated() { if(_fedMapping == null && _metaData instanceof MetaDataFormat){ MetaDataFormat mdf = (MetaDataFormat) _metaData; if(mdf.getFileFormat() == FileFormat.FEDERATED){ - InitFEDInstruction.federateMatrix( - this, ReaderWriterFederated.read(_hdfsFileName, mdf.getDataCharacteristics())); + if (this instanceof FrameObject) + InitFEDInstruction.federateFrame((FrameObject) this, + ReaderWriterFederated.read(_hdfsFileName, mdf.getDataCharacteristics())); + else + InitFEDInstruction.federateMatrix( + this, ReaderWriterFederated.read(_hdfsFileName, mdf.getDataCharacteristics())); return true; } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java index 84ff65192d4..9ff5785e3f1 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java @@ -34,10 +34,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; -import org.apache.sysds.runtime.io.FileFormatProperties; -import org.apache.sysds.runtime.io.FrameReaderFactory; -import org.apache.sysds.runtime.io.FrameWriter; -import org.apache.sysds.runtime.io.FrameWriterFactory; +import org.apache.sysds.runtime.io.*; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageRecomputeUtils; import org.apache.sysds.runtime.matrix.data.FrameBlock; @@ -287,9 +284,14 @@ protected void writeBlobToHDFS(String fname, String ofmt, int rep, FileFormatPro { MetaDataFormat iimd = (MetaDataFormat) _metaData; FileFormat fmt = (ofmt != null ? FileFormat.safeValueOf(ofmt) : iimd.getFileFormat()); - - FrameWriter writer = FrameWriterFactory.createFrameWriter(fmt, fprop); - writer.writeFrameToHDFS(_data, fname, getNumRows(), getNumColumns()); + + if(this.isFederated() && FileFormat.safeValueOf(ofmt) == FileFormat.FEDERATED) { + ReaderWriterFederated.write(fname, this._fedMapping); + } + else { + FrameWriter writer = FrameWriterFactory.createFrameWriter(fmt, fprop); + writer.writeFrameToHDFS(_data, fname, getNumRows(), getNumColumns()); + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java index 3e648bbe3b6..68965db4c34 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java @@ -451,6 +451,10 @@ public static void federateMatrix(CacheableData output, List> workers) { + federateFrame(output, workers, null); + } + public static void federateFrame(FrameObject output, List> workers, CacheBlock[] blocks) { List> fedMapping = new ArrayList<>(); for(Pair e : workers) diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java index 712fdfb33a9..246cf8e7030 100644 --- a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java +++ b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java @@ -42,6 +42,7 @@ import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.*; @@ -238,10 +239,10 @@ public WriteAtSiteUDF(long input, String filename) { @Override public FederatedResponse execute(ExecutionContext ec, Data... data) { - MatrixObject mo = (MatrixObject) data[0]; + CacheableData cd = (CacheableData) data[0]; String tmpDir = ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.LOCAL_TMP_DIR); Path p = new Path(tmpDir + '/' + _filename); - mo.exportData(p.toString(), null); + cd.exportData(p.toString(), null); return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, p.toString()); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java index 1782e5534f6..f887acda45c 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java @@ -19,8 +19,12 @@ package org.apache.sysds.test.functions.federated.primitives; +import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; +import java.util.List; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; @@ -29,6 +33,7 @@ import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -38,6 +43,8 @@ public class FederatedWriteTest extends AutomatedTestBase { private final static String TEST_DIR = "functions/federated/"; private final static String TEST_NAME = "FederatedWriteTest"; + private final static String TEST_NAME_MATRIX = "FederatedWriteMatrixTest"; + private final static String TEST_NAME_FRAME = "FederatedWriteFrameTest"; private final static String TEST_CLASS_DIR = TEST_DIR + FederatedWriteTest.class.getSimpleName() + "/"; private static final int blocksize = 1024; @@ -78,10 +85,37 @@ public void federatedMatrixWrite(Types.ExecMode execMode) { federatedWrite(execMode, null); } - // TODO: frame write testcase + @Test + public void federatedFrameWriteCP() throws IOException { + federatedFrameWrite(Types.ExecMode.SINGLE_NODE); + } + + @Test + @Ignore + public void federatedFrameWriteSP() throws IOException { + federatedFrameWrite(Types.ExecMode.SPARK); + } + + public void federatedFrameWrite(Types.ExecMode execMode) throws IOException { + getAndLoadTestConfiguration(TEST_NAME); + // write input matrix + double[][] A = getRandomMatrix(rows, cols, -1, 1, 1, 1234); + + List schemaList = new ArrayList<>(Collections.nCopies(cols / 4, Types.ValueType.STRING)); + schemaList.addAll(Collections.nCopies(cols / 4, Types.ValueType.FP64)); + schemaList.addAll(Collections.nCopies(cols / 4, Types.ValueType.INT64)); + schemaList.addAll(Collections.nCopies(cols / 4, Types.ValueType.BOOLEAN)); + + Types.ValueType[] schema = new Types.ValueType[cols]; + schemaList.toArray(schema); + writeInputFrameWithMTD("A", A, false, schema, Types.FileFormat.BINARY); + federatedWrite(execMode, schema); + } + public void federatedWrite(Types.ExecMode execMode, Types.ValueType[] schema) { boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; Types.ExecMode platformOld = rtplatform; + String testName = schema == null ? TEST_NAME_MATRIX : TEST_NAME_FRAME; String HOME = SCRIPT_DIR + TEST_DIR; @@ -94,7 +128,7 @@ public void federatedWrite(Types.ExecMode execMode, Types.ValueType[] schema) { // we need the reference file to not be written to hdfs, so we get the correct format rtplatform = Types.ExecMode.SINGLE_NODE; // Run reference dml script with normal matrix - fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + fullDMLScriptName = HOME + testName + "Reference.dml"; programArgs = new String[] {"-nvargs", "in=" + input("A"), "out=" + expected("B")}; runTest(true, false, null, -1); @@ -103,7 +137,7 @@ public void federatedWrite(Types.ExecMode execMode, Types.ValueType[] schema) { if(rtplatform == Types.ExecMode.SPARK) { DMLScript.USE_LOCAL_SPARK_CONFIG = true; } - fullDMLScriptName = HOME + TEST_NAME + ".dml"; + fullDMLScriptName = HOME + testName + ".dml"; programArgs = new String[] {"-explain", "-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows, "cols=" + cols, "tmp=" + output("T")}; runTest(true, false, null, -1); diff --git a/src/test/scripts/functions/federated/FederatedWriteFrameTest.dml b/src/test/scripts/functions/federated/FederatedWriteFrameTest.dml new file mode 100644 index 00000000000..489fd6c2410 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedWriteFrameTest.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = federated(type="Frame", addresses=list($in, $in), + ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols))) +write(A, $tmp, format="federated") diff --git a/src/test/scripts/functions/federated/FederatedWriteFrameTestReference.dml b/src/test/scripts/functions/federated/FederatedWriteFrameTestReference.dml new file mode 100644 index 00000000000..b697703935c --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedWriteFrameTestReference.dml @@ -0,0 +1,23 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = rbind(read($in), read($in)) +write(A, $out) diff --git a/src/test/scripts/functions/federated/FederatedWriteTest.dml b/src/test/scripts/functions/federated/FederatedWriteMatrixTest.dml similarity index 100% rename from src/test/scripts/functions/federated/FederatedWriteTest.dml rename to src/test/scripts/functions/federated/FederatedWriteMatrixTest.dml diff --git a/src/test/scripts/functions/federated/FederatedWriteTestReference.dml b/src/test/scripts/functions/federated/FederatedWriteMatrixTestReference.dml similarity index 100% rename from src/test/scripts/functions/federated/FederatedWriteTestReference.dml rename to src/test/scripts/functions/federated/FederatedWriteMatrixTestReference.dml From e899113a3e7d6ca0d592a3f34bed45b03cb13075 Mon Sep 17 00:00:00 2001 From: Kevin Innerebner Date: Sun, 17 Jul 2022 20:18:30 +0200 Subject: [PATCH 3/6] Minor cleanup --- .../runtime/controlprogram/caching/FrameObject.java | 9 ++++++--- .../runtime/instructions/fed/FEDInstructionUtils.java | 7 +------ .../runtime/instructions/fed/VariableFEDInstruction.java | 8 -------- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java index 9ff5785e3f1..50ff71b795a 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java @@ -34,7 +34,11 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.instructions.spark.data.RDDObject; -import org.apache.sysds.runtime.io.*; +import org.apache.sysds.runtime.io.FileFormatProperties; +import org.apache.sysds.runtime.io.FrameReaderFactory; +import org.apache.sysds.runtime.io.FrameWriter; +import org.apache.sysds.runtime.io.FrameWriterFactory; +import org.apache.sysds.runtime.io.ReaderWriterFederated; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageRecomputeUtils; import org.apache.sysds.runtime.matrix.data.FrameBlock; @@ -285,9 +289,8 @@ protected void writeBlobToHDFS(String fname, String ofmt, int rep, FileFormatPro MetaDataFormat iimd = (MetaDataFormat) _metaData; FileFormat fmt = (ofmt != null ? FileFormat.safeValueOf(ofmt) : iimd.getFileFormat()); - if(this.isFederated() && FileFormat.safeValueOf(ofmt) == FileFormat.FEDERATED) { + if(this.isFederated() && FileFormat.safeValueOf(ofmt) == FileFormat.FEDERATED) ReaderWriterFederated.write(fname, this._fedMapping); - } else { FrameWriter writer = FrameWriterFactory.createFrameWriter(fmt, fprop); writer.writeFrameToHDFS(_data, fname, getNumRows(), getNumColumns()); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java index 5e3b7a9f8b0..4f9ab9972c9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java @@ -232,12 +232,7 @@ else if((tinst.input1.isMatrix() && ec.getCacheableData(tinst.input1).isFederate } else if(inst instanceof VariableCPInstruction ){ VariableCPInstruction ins = (VariableCPInstruction) inst; - if(ins.getVariableOpcode() == VariableOperationCode.Write - && ins.getInput1().isMatrix() - && ins.getInput3().getName().contains("federated")){ - fedinst = VariableFEDInstruction.parseInstruction(ins); - } - else if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable + if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable && ins.getInput1().isMatrix() && ec.getCacheableData(ins.getInput1()).isFederatedExcept(FType.BROADCAST)){ fedinst = VariableFEDInstruction.parseInstruction(ins); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java index 1e7573ca495..f62fef63fb1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java @@ -62,9 +62,6 @@ public static VariableFEDInstruction parseInstruction(VariableCPInstruction cpIn public void processInstruction(ExecutionContext ec) { VariableOperationCode opcode = _in.getVariableOpcode(); switch(opcode) { - case Write: - processWriteInstruction(ec); - break; case CastAsMatrixVariable: processCastAsMatrixVariableInstruction(ec); break; @@ -76,11 +73,6 @@ public void processInstruction(ExecutionContext ec) { } } - private void processWriteInstruction(ExecutionContext ec) { - LOG.warn("Processing write command federated"); - _in.processInstruction(ec); - } - private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { FrameObject mo1 = ec.getFrameObject(_in.getInput1()); From e1ee92de8da68099d366d0341d64b821df0c7e55 Mon Sep 17 00:00:00 2001 From: Kevin Innerebner Date: Sat, 30 Jul 2022 15:45:55 +0200 Subject: [PATCH 4/6] Rework federated write and read Instead of having a separate format `format=federated`, we write federated if the object is federated. We also do not create a JSON file with the federated addresses and ranges, instead we add this information to the MTD file. Note that we now also aligned the specification of addresses and ranges with the usage in our `federated()` function, such that the syntax is similar. --- .../java/org/apache/sysds/common/Types.java | 1 - .../apache/sysds/parser/DMLTranslator.java | 78 +++++- .../apache/sysds/parser/DataExpression.java | 7 +- .../controlprogram/caching/CacheableData.java | 50 ++-- .../controlprogram/caching/FrameObject.java | 9 +- .../controlprogram/caching/MatrixObject.java | 12 +- .../federated/FederatedData.java | 6 + .../federated/FederationMap.java | 59 +++- .../instructions/fed/FEDInstructionUtils.java | 7 +- .../fed/VariableFEDInstruction.java | 23 ++ .../runtime/io/FileFormatProperties.java | 4 +- .../runtime/io/ReaderWriterFederated.java | 254 ------------------ .../sysds/runtime/io/WriterFederated.java | 133 +++++++++ .../sysds/runtime/meta/MetaDataAll.java | 13 +- .../apache/sysds/runtime/util/HDFSTool.java | 38 ++- .../apache/sysds/test/AutomatedTestBase.java | 27 +- .../java/org/apache/sysds/test/TestUtils.java | 151 ++++++++--- .../federated/io/FederatedReaderCSV.java | 6 +- .../federated/io/FederatedReaderTest.java | 4 +- .../federated/io/FederatedSSLTest.java | 6 +- .../federated/io/FederatedTimeoutTest.java | 4 +- .../federated/io/FederatedWriterTest.java | 4 +- .../primitives/FederatedWriteTest.java | 18 +- .../federated/FederatedWriteFrameTest.dml | 2 +- .../federated/FederatedWriteMatrixTest.dml | 2 +- .../FederatedWriteMatrixTestReference.dml | 2 +- .../federated/FederatedWriteTestRead.dml | 3 +- .../io/FederatedReaderTestCreate.dml | 2 +- 28 files changed, 528 insertions(+), 397 deletions(-) delete mode 100644 src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java create mode 100644 src/main/java/org/apache/sysds/runtime/io/WriterFederated.java diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index a7cfa823aa7..ed2de75b902 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -542,7 +542,6 @@ public enum FileFormat { LIBSVM, // text libsvm sparse row representation JSONL, // text nested JSON (Line) representation BINARY, // binary block representation (dense/sparse/ultra-sparse) - FEDERATED, // A federated matrix PROTO, // protocol buffer representation HDF5; // Hierarchical Data Format (HDF) diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c196f780c33..99691d41ca3 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -28,6 +28,7 @@ import java.util.Set; import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; @@ -88,8 +89,14 @@ import org.apache.sysds.runtime.controlprogram.Program; import org.apache.sysds.runtime.controlprogram.ProgramBlock; import org.apache.sysds.runtime.controlprogram.WhileProgramBlock; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; +import org.apache.wink.json4j.JSONArray; +import org.apache.wink.json4j.JSONException; +import org.apache.wink.json4j.JSONObject; public class DMLTranslator { @@ -1060,9 +1067,6 @@ public void constructHops(StatementBlock sb) { // write output in binary block format ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ae.getBlocksize()); break; - case FEDERATED: - ae.setOutputParams(ae.getDim1(), ae.getDim2(), -1, ae.getUpdateType(), -1); - break; case HDF5: // write output in HDF5 format ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1); @@ -2102,9 +2106,16 @@ private Hop processDataExpression(DataExpression source, DataIdentifier target, if (target == null) { target = createTarget(source); } - + + Expression.DataOp opCode = source.getOpCode(); + if (source.isRead() && paramHops.containsKey("federated")) { + // federated read + paramHops = createFederatedParamHopsFromFederatedReadParamHops(paramHops); + opCode = Expression.DataOp.FEDERATED; + } + // construct hop based on opcode - switch(source.getOpCode()) { + switch(opCode) { case READ: currBuiltinOp = new DataOp(target.getName(), target.getDataType(), target.getValueType(), OpOpData.PERSISTENTREAD, paramHops); ((DataOp)currBuiltinOp).setFileName(((StringIdentifier)source.getVarParam(DataExpression.IO_FILENAME)).getValue()); @@ -2155,14 +2166,14 @@ private Hop processDataExpression(DataExpression source, DataIdentifier target, throw new ParseException(source.printErrorLocation() + "processDataExpression():: Unknown operation: " + source.getOpCode()); } - + //set identifier meta data (incl dimensions and blocksizes) setIdentifierParams(currBuiltinOp, source.getOutput()); - if( source.getOpCode()==DataExpression.DataOp.READ ) - ((DataOp)currBuiltinOp).setInputBlocksize(target.getBlocksize()); - else if ( source.getOpCode() == DataExpression.DataOp.WRITE ) { - ((DataOp)currBuiltinOp).setPrivacy(hops.get(target.getName()).getPrivacy()); - if( source.getVarParam(DataExpression.ROWBLOCKCOUNTPARAM) != null ) + if (opCode == DataExpression.DataOp.READ) + ((DataOp) currBuiltinOp).setInputBlocksize(target.getBlocksize()); + else if (opCode == DataExpression.DataOp.WRITE) { + ((DataOp) currBuiltinOp).setPrivacy(hops.get(target.getName()).getPrivacy()); + if (source.getVarParam(DataExpression.ROWBLOCKCOUNTPARAM) != null) currBuiltinOp.setBlocksize(Integer.parseInt( source.getVarParam(DataExpression.ROWBLOCKCOUNTPARAM).toString())); } @@ -2171,6 +2182,51 @@ else if ( source.getOpCode() == DataExpression.DataOp.WRITE ) { return currBuiltinOp; } + private HashMap createFederatedParamHopsFromFederatedReadParamHops(HashMap paramHops) { + Hop dataType = paramHops.get("data_type"); + Hop federatedString = paramHops.get("federated"); + if (!(federatedString instanceof LiteralOp) || federatedString.getValueType() != ValueType.STRING) { + throw new DMLRuntimeException("`federated` argument of read has to be a string literal"); + } + paramHops = convertFederatedStringToParamHops(((LiteralOp) federatedString).getStringValue()); + paramHops.put(DataExpression.FED_TYPE, dataType); + return paramHops; + } + + private HashMap convertFederatedStringToParamHops(String federated) { + HashMap paramHops = new HashMap<>(); + try { + JSONObject federatedJson = new JSONObject(federated); + // temporary `FederationMap` -> DataType does not matter + FederationMap federationMap = FederationMap.fromJson(federatedJson, DataType.MATRIX); + List> mapList = federationMap.getMap(); + + Hop[] federatedAddressesList = new Hop[mapList.size()]; + Hop[] federatedRangesList = new Hop[mapList.size() * 2]; + for (int i = 0 ; i < federatedAddressesList.length; ++i) { + Pair fedPair = mapList.get(i); + FederatedRange range = fedPair.getLeft(); + FederatedData data = fedPair.getRight(); + long[] beginDims = range.getBeginDims(); + long[] endDims = range.getEndDims(); + + federatedAddressesList[i] = new LiteralOp(data.getCompleteAddressPath()); + federatedRangesList[i * 2] = new NaryOp(Expression.getTempName(), DataType.LIST, ValueType.INT64, OpOpN.LIST, + new LiteralOp(beginDims[0]), new LiteralOp(beginDims[1])); + federatedRangesList[i * 2 + 1] = new NaryOp(Expression.getTempName(), DataType.LIST, ValueType.INT64, OpOpN.LIST, + new LiteralOp(endDims[0]), new LiteralOp(endDims[1])); + } + + paramHops.put(DataExpression.FED_ADDRESSES, new NaryOp(Expression.getTempName(), DataType.LIST, + ValueType.STRING, OpOpN.LIST, federatedAddressesList)); + paramHops.put(DataExpression.FED_RANGES, new NaryOp(Expression.getTempName(), DataType.LIST, + ValueType.UNKNOWN, OpOpN.LIST, federatedRangesList)); + } catch (JSONException e) { + throw new RuntimeException(e); + } + return paramHops; + } + /** * Construct HOps from parse tree: process BuiltinFunction Expressions in * MultiAssignment Statements. For all other builtin function expressions, diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java index e2e3996cea4..b15310dca2a 100644 --- a/src/main/java/org/apache/sysds/parser/DataExpression.java +++ b/src/main/java/org/apache/sysds/parser/DataExpression.java @@ -103,6 +103,8 @@ public class DataExpression extends DataIdentifier public static final String PRIVACY = "privacy"; public static final String FINE_GRAINED_PRIVACY = "fine_grained_privacy"; + public static final String FEDERATED = "federated"; + // Parameter names relevant to reading/writing delimited/csv files public static final String DELIM_DELIMITER = "sep"; public static final String DELIM_HAS_HEADER_ROW = "header"; @@ -147,7 +149,9 @@ public class DataExpression extends DataIdentifier //Parameters related to dataset name/HDF4 files. HDF5_DATASET_NAME, // Parameters related to privacy - PRIVACY, FINE_GRAINED_PRIVACY)); + PRIVACY, FINE_GRAINED_PRIVACY, + // Parameter for federated data + FEDERATED)); /** Valid parameter names in arguments to read instruction */ public static final Set READ_VALID_PARAM_NAMES = new HashSet<>( @@ -972,6 +976,7 @@ public void validateExpression(HashMap ids, HashMap { + public void toJson(JSONObject mtd) throws JSONException { + JSONArray addressesJson = new JSONArray(); + JSONArray rangesJson = new JSONArray(); + + for (Pair entry : _fedMap) { + FederatedRange range = entry.getLeft(); + FederatedData data = entry.getRight(); + + addressesJson.add(data.getCompleteAddressPath()); + + JSONArray jsonBegin = new JSONArray(); + for (long dim : range.getBeginDims()) + jsonBegin.add(dim); + rangesJson.add(jsonBegin); + + JSONArray jsonEnd = new JSONArray(); + for (long dim : range.getEndDims()) + jsonEnd.add(dim); + rangesJson.add(jsonEnd); + } + JSONObject federatedJson = new JSONObject(); + federatedJson.put(DataExpression.FED_ADDRESSES, addressesJson); + federatedJson.put(DataExpression.FED_RANGES, rangesJson); + mtd.put(DataExpression.FEDERATED, federatedJson); + } + + public static FederationMap fromJson(JSONObject federatedJson, DataType dataType) throws JSONException { + JSONArray addressesJson = federatedJson.getJSONArray(DataExpression.FED_ADDRESSES); + JSONArray rangesJson = federatedJson.getJSONArray(DataExpression.FED_RANGES); + + List> fedMap = new ArrayList<>(); + for (int i = 0 ; i < addressesJson.size(); ++i) { + String[] parsedValues = parseURL(addressesJson.getString(i)); + FederatedData federatedData = new FederatedData(dataType, new InetSocketAddress(parsedValues[0], Integer.parseInt(parsedValues[1])), parsedValues[2]); + + JSONArray beginJson = rangesJson.getJSONArray(i * 2); + long[] begin = new long[beginJson.size()]; + for (int j = 0; j < begin.length; ++j) + begin[j] = beginJson.getLong(j); + + JSONArray endJson = rangesJson.getJSONArray(i * 2 + 1); + long[] end = new long[endJson.size()]; + for (int j = 0; j < end.length; ++j) + end[j] = endJson.getLong(j); + FederatedRange federatedRange = new FederatedRange(begin, end); + fedMap.add(Pair.of(federatedRange, federatedData)); + } + return new FederationMap(fedMap); + } + + private static class MappingTask implements Callable { private final FederatedRange _range; private final FederatedData _data; private final BiFunction _mappingFunction; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java index 4f9ab9972c9..82c624ac8c1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java @@ -232,7 +232,12 @@ else if((tinst.input1.isMatrix() && ec.getCacheableData(tinst.input1).isFederate } else if(inst instanceof VariableCPInstruction ){ VariableCPInstruction ins = (VariableCPInstruction) inst; - if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable + if(ins.getVariableOpcode() == VariableOperationCode.Write + && (ins.getInput1().isMatrix() || ins.getInput1().isFrame()) + && ec.getCacheableData(ins.getInput1()).isFederated()){ + fedinst = VariableFEDInstruction.parseInstruction(ins); + } + else if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable && ins.getInput1().isMatrix() && ec.getCacheableData(ins.getInput1()).isFederatedExcept(FType.BROADCAST)){ fedinst = VariableFEDInstruction.parseInstruction(ins); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java index f62fef63fb1..f3f317cfd2c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java @@ -30,6 +30,7 @@ import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.caching.FrameObject; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; @@ -41,6 +42,8 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode; +import org.apache.sysds.runtime.io.FileFormatProperties; +import org.apache.sysds.runtime.io.WriterFederated; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageTraceable; @@ -62,6 +65,9 @@ public static VariableFEDInstruction parseInstruction(VariableCPInstruction cpIn public void processInstruction(ExecutionContext ec) { VariableOperationCode opcode = _in.getVariableOpcode(); switch(opcode) { + case Write: + processWriteInstruction(ec); + break; case CastAsMatrixVariable: processCastAsMatrixVariableInstruction(ec); break; @@ -73,6 +79,23 @@ public void processInstruction(ExecutionContext ec) { } } + private void processWriteInstruction(ExecutionContext ec) { + CacheableData cd = ec.getCacheableData(_in.getInput1()); + if(!cd.isFederated()) + throw new DMLRuntimeException( + "Federated Write: " + "Federated input expected, but invoked w/ non-federated input"); + + String fname = ec.getScalarInput(_in.getInput2().getName(), ValueType.STRING, _in.getInput2().isLiteral()).getStringValue(); + String fmtStr = _in.getInput3().getName(); + Types.FileFormat fmt = Types.FileFormat.safeValueOf(fmtStr); + FileFormatProperties formatProperties = _in.getFormatProperties(); + if( fmt != Types.FileFormat.LIBSVM && fmt != Types.FileFormat.HDF5) { + String desc = ec.getScalarInput(_in.getInput4().getName(), ValueType.STRING, _in.getInput4().isLiteral()).getStringValue(); + formatProperties.setDescription(desc); + } + WriterFederated.write(fname, cd, fmtStr, formatProperties); + } + private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { FrameObject mo1 = ec.getFrameObject(_in.getInput1()); diff --git a/src/main/java/org/apache/sysds/runtime/io/FileFormatProperties.java b/src/main/java/org/apache/sysds/runtime/io/FileFormatProperties.java index 178f5718e9f..604b80fb6bd 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FileFormatProperties.java +++ b/src/main/java/org/apache/sysds/runtime/io/FileFormatProperties.java @@ -19,7 +19,9 @@ package org.apache.sysds.runtime.io; -public class FileFormatProperties +import java.io.Serializable; + +public class FileFormatProperties implements Serializable { private String description; private final int _blen; diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java deleted file mode 100644 index 246cf8e7030..00000000000 --- a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java +++ /dev/null @@ -1,254 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.sysds.runtime.io; - -import java.io.BufferedWriter; -import java.io.DataOutputStream; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.net.InetSocketAddress; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.fs.FSDataInputStream; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.mapred.JobConf; -import org.apache.sysds.common.Types; -import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.conf.DMLConfig; -import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.caching.CacheableData; -import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.federated.*; -import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; -import org.apache.sysds.runtime.instructions.cp.Data; -import org.apache.sysds.runtime.lineage.LineageItem; -import org.apache.sysds.runtime.meta.DataCharacteristics; - -/** - * This class serves as the reader for federated objects. To read the files a mdt file is required. The reader is - * different from the other readers in the since that it does not return a MatrixBlock but a Matrix Object wrapper, - * containing the federated Mapping. - * - * On the Matrix Object the function isFederated() will if called read in the federated locations and instantiate the - * map. The reading is done through this code. - * - * This means in practice that it circumvent the other reading code. See more in: - * - * org.apache.sysds.runtime.controlprogram.caching.MatrixObject.readBlobFromHDFS() - * org.apache.sysds.runtime.controlprogram.caching.CacheableData.isFederated() - * - */ -public class ReaderWriterFederated { - private static final Log LOG = LogFactory.getLog(ReaderWriterFederated.class.getName()); - private static final IDSequence siteUniqueCounter = new IDSequence(true); - - /** - * Read a federated map from disk, It is not initialized before it is used in: - * - * org.apache.sysds.runtime.instructions.fed.InitFEDInstruction - * - * @param file The file to read (defaults to HDFS) - * @param mc The data characteristics of the file, that can be read from the mtd file. - * @return A List of federatedRanges and Federated Data - */ - public static List> read(String file, DataCharacteristics mc) { - LOG.debug("Reading federated map from " + file); - try { - JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); - Path path = new Path(file); - FileSystem fs = IOUtilFunctions.getFileSystem(path, job); - FSDataInputStream data = fs.open(path); - ObjectMapper mapper = new ObjectMapper(); - List obj = mapper.readValue(data, new TypeReference>() { - }); - return obj.stream().map(x -> x.convert()).collect(Collectors.toList()); - } - catch(Exception e) { - throw new DMLRuntimeException("Unable to read federated matrix (" + file + ")", e); - } - } - - /** - * TODO add writing to each of the federated locations so that they also save their matrices. - * - * Currently this would write the federated matrix to disk only locally. - * - * @param file The file to save to, (defaults to HDFS paths) - * @param fedMap The federated map to save. - */ - public static void write(String file, FederationMap fedMap) { - LOG.debug("Writing federated map to " + file); - try { - JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); - Path path = new Path(file); - FileSystem fs = IOUtilFunctions.getFileSystem(path, job); - - fedMap.forEachParallel((range, data) -> { - String siteFilename = Long.toString(siteUniqueCounter.getNextID()) + '_' + path.getName(); - try { - FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, - data.getVarID(), new WriteAtSiteUDF(data.getVarID(), siteFilename))).get(); - data.setFilepath((String) response.getData()[0]); - } catch (Exception e) { - throw new DMLRuntimeException(e); - } - return null; - }); - - DataOutputStream out = fs.create(path, true); - ObjectMapper mapper = new ObjectMapper(); - FederatedDataAddress[] outObjects = parseMap(fedMap.getMap()); - try(BufferedWriter pw = new BufferedWriter(new OutputStreamWriter(out))) { - mapper.writeValue(pw, outObjects); - } - - IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fs, path); - } - catch(IOException e) { - throw new DMLRuntimeException("Unable to write test federated matrix to (" + file + "): " + e.getMessage()); - } - } - - private static FederatedDataAddress[] parseMap(List> map) { - return map.stream() - .map(e -> new FederatedDataAddress(e.getKey(), e.getValue())) - .toArray(FederatedDataAddress[]::new); - } - - /** - * This class is used for easy serialization from json using Jackson. The warnings are suppressed because the - * setters and getters only is used inside Jackson. - */ - @SuppressWarnings("unused") - private static class FederatedDataAddress { - private Types.DataType _dataType; - private InetSocketAddress _address; - private String _filepath; - private long[] _begin; - private long[] _end; - - public FederatedDataAddress() { - } - - protected FederatedDataAddress(FederatedRange fr, FederatedData fd) { - _dataType = fd.getDataType(); - _address = fd.getAddress(); - _filepath = fd.getFilepath(); - _begin = fr.getBeginDims(); - _end = fr.getEndDims(); - } - - protected Pair convert() { - FederatedRange fr = new FederatedRange(_begin, _end); - FederatedData fd = new FederatedData(_dataType, _address, _filepath); - return new ImmutablePair<>(fr, fd); - } - - public String getFilepath() { - return _filepath; - } - - public void setFilepath(String filePath) { - _filepath = filePath; - } - - public Types.DataType getDataType() { - return _dataType; - } - - public void setDataType(Types.DataType dataType) { - _dataType = dataType; - } - - public InetSocketAddress getAddress() { - return _address; - } - - public void setAddress(InetSocketAddress address) { - _address = address; - } - - public long[] getBegin() { - return _begin; - } - - public void setBegin(long[] begin) { - _begin = begin; - } - - public long[] getEnd() { - return _end; - } - - public void setEnd(long[] end) { - _end = end; - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append(_dataType); - sb.append(" "); - sb.append(_address); - sb.append(" "); - sb.append(_filepath); - sb.append(" "); - sb.append(Arrays.toString(_begin)); - sb.append(" "); - sb.append(Arrays.toString(_end)); - return sb.toString(); - } - } - - private static class WriteAtSiteUDF extends FederatedUDF { - private static final long serialVersionUID = -6645546954618784216L; - - private final String _filename; - - public WriteAtSiteUDF(long input, String filename) { - super(new long[] {input}); - _filename = filename; - } - - @Override - public FederatedResponse execute(ExecutionContext ec, Data... data) { - CacheableData cd = (CacheableData) data[0]; - String tmpDir = ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.LOCAL_TMP_DIR); - Path p = new Path(tmpDir + '/' + _filename); - cd.exportData(p.toString(), null); - return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, p.toString()); - } - - @Override - public Pair getLineageItem(ExecutionContext ec) { - return null; - } - } -} diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterFederated.java b/src/main/java/org/apache/sysds/runtime/io/WriterFederated.java new file mode 100644 index 00000000000..d9b7c0ab11a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/io/WriterFederated.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.runtime.io; + +import java.io.IOException; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapred.JobConf; +import org.apache.sysds.common.Types; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.caching.FrameObject; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.*; +import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.runtime.util.HDFSTool; + +/** + * This class serves as the writer for federated objects. The workers are tasked to write their part locally, while + * the CP writes a mtd file locally containing the addresses and ranges of the workers, enabling `read()` initialization + * of federated objects. + * + * This method, in comparison to other workers, also directly writes the MTD file, this is because it is important + * that the mtd file is written *AFTER* the workers are finished writing, because their local paths depend on their + * local configuration. They write into their specified tmp directory. + */ +public class WriterFederated { + private static final Log LOG = LogFactory.getLog(WriterFederated.class.getName()); + private static final IDSequence siteUniqueCounter = new IDSequence(true); + + /** + * Write the federated partitions on the workers and create a mtd file locally to be used to re-read the federate + * object. + * + * @param file The file to save to, (defaults to HDFS paths) + * @param cd The federated object to save. + * @param outputFormat The output format of the file + * @param fileFormatProperties The file format properties + */ + public static void write(String file, CacheableData cd, String outputFormat, FileFormatProperties fileFormatProperties) { + LOG.debug("Writing federated map to " + file); + try { + JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); + Path path = new Path(file); + + FederationMap newFedMap = cd.getFedMapping().mapParallel(cd.getFedMapping().getID(), (range, data) -> { + String siteFilename = Long.toString(siteUniqueCounter.getNextID()) + '_' + path.getName(); + try { + FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, + data.getVarID(), new WriteAtSiteUDF(data.getVarID(), siteFilename, outputFormat, fileFormatProperties))).get(); + data.setFilepath((String) response.getData()[0]); + } catch (Exception e) { + throw new DMLRuntimeException(e); + } + return null; + }); + // updated filepath + cd.setFedMapping(newFedMap); + + FileSystem fs = IOUtilFunctions.getFileSystem(path, job); + IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fs, path); + if (cd instanceof MatrixObject) { + HDFSTool.writeMetaDataFile(file + ".mtd", cd.getValueType(), cd.getDataCharacteristics(), + Types.FileFormat.safeValueOf(outputFormat), cd.getPrivacyConstraint(), cd.getFedMapping()); + } else if (cd instanceof FrameObject) { + HDFSTool.writeMetaDataFile(file + ".mtd", null, ((FrameObject) cd).getSchema(), + cd.getDataType(), cd.getDataCharacteristics(), Types.FileFormat.safeValueOf(outputFormat), + cd.getPrivacyConstraint(), cd.getFedMapping()); + } else { + throw new DMLRuntimeException("TensorObject not yet supported by " + WriterFederated.class.getSimpleName()); + } + } + catch(IOException e) { + throw new DMLRuntimeException("Unable to write test federated matrix to (" + file + "): " + e.getMessage()); + } + } + + private static class WriteAtSiteUDF extends FederatedUDF { + private static final long serialVersionUID = -6645546954618784216L; + + private final String _filename; + private final String _outputFormat; + private final FileFormatProperties _fileFormatProperties; + + public WriteAtSiteUDF(long input, String filename, String outputFormat, FileFormatProperties fileFormatProperties) { + super(new long[] {input}); + _filename = filename; + _outputFormat = outputFormat; + _fileFormatProperties = fileFormatProperties; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + CacheableData cd = (CacheableData) data[0]; + // Write the file to the local tmp + String tmpDir = ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.LOCAL_TMP_DIR); + Path p = new Path(tmpDir + '/' + _filename); + cd.exportData(p.toString(), _outputFormat, _fileFormatProperties); + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, p.toString()); + } + + @Override + public Pair getLineageItem(ExecutionContext ec) { + return null; + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java index cb007d27e9a..5803b0b4746 100644 --- a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java +++ b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java @@ -61,6 +61,8 @@ public class MetaDataAll extends DataIdentifier { protected String _delim = DataExpression.DEFAULT_DELIM_DELIMITER; protected boolean _hasHeader = false; protected boolean _sparseDelim = DataExpression.DEFAULT_DELIM_SPARSE; + // Federation map for when federated object is read from disk + protected String _federatedString = null; public MetaDataAll() { // do nothing @@ -184,7 +186,8 @@ private void parseMetaDataParam(Object key, Object val) } else setHasHeader(false); - case DataExpression.DELIM_SPARSE: setSparseDelim((boolean) val); + case DataExpression.DELIM_SPARSE: setSparseDelim((boolean) val); break; + case DataExpression.FEDERATED: setFederatedString(val.toString()); break; } } @@ -224,6 +227,14 @@ public void setSparseDelim(boolean sparseDelim) { _sparseDelim = sparseDelim; } + public void setFederatedString(String federatedString) { + _federatedString = federatedString; + } + + public String getFederatedString() { + return _federatedString; + } + public void setHasHeader(boolean hasHeader) { _hasHeader = hasHeader; } diff --git a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java index a16f3237338..6e459bfc363 100644 --- a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java +++ b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java @@ -33,6 +33,7 @@ import org.apache.hadoop.fs.permission.FsPermission; import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.mapred.JobConf; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.OrderedJSONObject; @@ -391,7 +392,12 @@ public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacter public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics mc, FileFormat fmt, PrivacyConstraint privacyConstraint) throws IOException { - writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, mc, fmt, null, privacyConstraint); + writeMetaDataFile(mtdfile, vt, mc, fmt, privacyConstraint, null); + } + + public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics mc, FileFormat fmt, PrivacyConstraint privacyConstraint, FederationMap federationMap) + throws IOException { + writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, mc, fmt, null, privacyConstraint, federationMap); } public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics mc, FileFormat fmt) @@ -401,7 +407,12 @@ public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] s public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics mc, FileFormat fmt, PrivacyConstraint privacyConstraint) throws IOException { - writeMetaDataFile(mtdfile, vt, schema, dt, mc, fmt, null, privacyConstraint); + writeMetaDataFile(mtdfile, vt, schema, dt, mc, fmt, null, privacyConstraint, null); + } + + public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics mc, FileFormat fmt, PrivacyConstraint privacyConstraint, FederationMap federationMap) + throws IOException { + writeMetaDataFile(mtdfile, vt, schema, dt, mc, fmt, null, privacyConstraint, federationMap); } public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics dc, FileFormat fmt, FileFormatProperties formatProperties) @@ -411,24 +422,30 @@ public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacter public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics dc, FileFormat fmt, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint) throws IOException { - writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, dc, fmt, formatProperties, privacyConstraint); + writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, dc, fmt, formatProperties, privacyConstraint, null); } public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc, FileFormat fmt, FileFormatProperties formatProperties) throws IOException { - writeMetaDataFile(mtdfile, vt, schema, dt, dc, fmt, formatProperties, null); + writeMetaDataFile(mtdfile, vt, schema, dt, dc, fmt, formatProperties, null, null); } public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc, - FileFormat fmt, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint) + FileFormat fmt, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint) + throws IOException { + writeMetaDataFile(mtdfile, vt, schema, dt, dc, fmt, formatProperties, privacyConstraint, null); + } + + public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc, + FileFormat fmt, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint, FederationMap federationMap) throws IOException { Path path = new Path(mtdfile); FileSystem fs = IOUtilFunctions.getFileSystem(path); try( BufferedWriter br = new BufferedWriter(new OutputStreamWriter(fs.create(path,true))) ) { - String mtd = metaDataToString(vt, schema, dt, dc, fmt, formatProperties, privacyConstraint); + String mtd = metaDataToString(vt, schema, dt, dc, fmt, formatProperties, privacyConstraint, federationMap); br.write(mtd); } catch (Exception e) { throw new IOException("Error creating and writing metadata JSON file", e); @@ -447,7 +464,7 @@ public static void writeScalarMetaDataFile(String mtdfile, ValueType vt, Privacy Path path = new Path(mtdfile); FileSystem fs = IOUtilFunctions.getFileSystem(path); try( BufferedWriter br = new BufferedWriter(new OutputStreamWriter(fs.create(path,true))) ) { - String mtd = metaDataToString(vt, null, DataType.SCALAR, null, FileFormat.TEXT, null, privacyConstraint); + String mtd = metaDataToString(vt, null, DataType.SCALAR, null, FileFormat.TEXT, null, privacyConstraint, null); br.write(mtd); } catch (Exception e) { @@ -456,7 +473,8 @@ public static void writeScalarMetaDataFile(String mtdfile, ValueType vt, Privacy } public static String metaDataToString(ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc, - FileFormat fmt, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint) throws JSONException, DMLRuntimeException + FileFormat fmt, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint, + FederationMap federationMap) throws JSONException, DMLRuntimeException { OrderedJSONObject mtd = new OrderedJSONObject(); // maintain order in output file @@ -524,6 +542,10 @@ public static String metaDataToString(ValueType vt, ValueType[] schema, DataType privacyConstraint.toJson(mtd); } + if ( federationMap != null ){ + federationMap.toJson(mtd); + } + return mtd.toString(4); // indent with 4 spaces } diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index c5f7d1a54bb..ddda5202096 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -76,7 +76,7 @@ import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.sysds.runtime.io.FrameReader; import org.apache.sysds.runtime.io.FrameReaderFactory; -import org.apache.sysds.runtime.io.ReaderWriterFederated; +import org.apache.sysds.runtime.io.WriterFederated; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixValue; @@ -608,21 +608,7 @@ protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boo return matrix; } - protected void writeInputFederatedWithMTD(String name, MatrixObject fm, PrivacyConstraint privacyConstraint){ - writeFederatedInputMatrix(name, fm.getFedMapping()); - MatrixCharacteristics mc = (MatrixCharacteristics)fm.getDataCharacteristics(); - try { - String completeMTDPath = baseDirectory + INPUT_DIR + name + ".mtd"; - HDFSTool.writeMetaDataFile(completeMTDPath, ValueType.FP64, mc, FileFormat.FEDERATED, privacyConstraint); - } - catch(IOException e) { - e.printStackTrace(); - throw new RuntimeException(e); - } - - } - - protected void writeFederatedInputMatrix(String name, FederationMap fedMap){ + protected void writeInputFederated(String name, MatrixObject fm, PrivacyConstraint privacyConstraint){ String completePath = baseDirectory + INPUT_DIR + name; try { cleanupExistingData(baseDirectory + INPUT_DIR + name, false); @@ -632,7 +618,12 @@ protected void writeFederatedInputMatrix(String name, FederationMap fedMap){ throw new RuntimeException(e); } - ReaderWriterFederated.write(completePath, fedMap); + // privacy constraint are read automatically by write from federated object -> temporarily replace + PrivacyConstraint tmp = fm.getPrivacyConstraint(); + fm.setPrivacyConstraints(privacyConstraint); + WriterFederated.write(completePath, fm, FileFormat.defaultFormatString(), null); + fm.setPrivacyConstraints(tmp); + inputDirectories.add(baseDirectory + INPUT_DIR + name); } @@ -694,7 +685,7 @@ false, new MatrixCharacteristics((long) examplesForWorkerI, ncol, federatedMatrixObject.setFedMapping(new FederationMap(FederationUtils.getNextFedDataID(), fedHashMap)); federatedMatrixObject.getFedMapping().setType(FType.ROW); - writeInputFederatedWithMTD(name, federatedMatrixObject, privacyConstraint); + writeInputFederated(name, federatedMatrixObject, privacyConstraint); } protected double[][] generateBalancedFederatedRowRanges(int numFederatedWorkers, int dataSetSize) { diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 1096a5147e8..0aa86e935fa 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -62,9 +62,11 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.SequenceFile.Writer; +import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.functionobjects.Builtin; @@ -79,8 +81,10 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataAll; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONObject; import org.junit.Assert; //import jcuda.runtime.JCuda; @@ -328,40 +332,91 @@ private static void readActualAndExpectedFile(ValueType[] schema, String expecte Path outDirectory = new Path(actualDir); Path compareFile = new Path(expectedFile); FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf); - FSDataInputStream fsin = fs.open(compareFile); + Map readExpected = readJIVFile(compareFile, fs, schema); + expectedValues.putAll(readExpected); + Map readActual = readJIVActualFile(outDirectory, schema); + actualValues.putAll(readActual); + } + catch(IOException e) { + fail("unable to read file: " + e.getMessage()); + } + } - try(BufferedReader compareIn = new BufferedReader(new InputStreamReader(fsin))) { - String line; - while((line = compareIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - ValueType vt = (schema != null) ? schema[j - 1] : ValueType.FP64; - Object obj = UtilFunctions.stringToObject(vt, st.nextToken()); - expectedValues.put(new CellIndex(i, j), obj); + private static Map readJIVActualFile(Path actualPath, ValueType[] schema) { + try { + FileSystem fs = IOUtilFunctions.getFileSystem(actualPath, conf); + Map values = tryReadFederatedJIVActualFile(actualPath, schema); + if (values == null) { + values = new HashMap<>(); + FileStatus[] outFiles = fs.listStatus(actualPath); + + for (FileStatus file : outFiles) { + Map partValues = readJIVFile(file.getPath(), fs, schema); + values.putAll(partValues); } } + return values; + } catch (IOException e) { + fail("unable to read file: " + e.getMessage()); + return null; + } + } - FileStatus[] outFiles = fs.listStatus(outDirectory); + private static Map tryReadFederatedJIVActualFile(Path actualPath, ValueType[] schema) { + Map values = new HashMap<>(); + try { + MetaDataAll mtd = new MetaDataAll(actualPath + ".mtd", false, true); + if (mtd.getFederatedString() == null) + return null; - for(FileStatus file : outFiles) { - FSDataInputStream fsout = fs.open(file.getPath()); - try(BufferedReader outIn = new BufferedReader(new InputStreamReader(fsout))) { - String line; - while((line = outIn.readLine()) != null) { - StringTokenizer st = new StringTokenizer(line, " "); - int i = Integer.parseInt(st.nextToken()); - int j = Integer.parseInt(st.nextToken()); - ValueType vt = (schema != null) ? schema[j - 1] : ValueType.FP64; - Object obj = UtilFunctions.stringToObject(vt, st.nextToken()); - actualValues.put(new CellIndex(i, j), obj); + JSONObject json = new JSONObject(mtd.getFederatedString()); + FederationMap map = FederationMap.fromJson(json, Types.DataType.MATRIX); + + FileSystem fs = IOUtilFunctions.getFileSystem(actualPath, conf); + map.forEachParallel((range, data) -> { + int[] beginDims = range.getBeginDimsInt(); + int beginRow = beginDims[0]; + int beginCol = beginDims[1]; + int endCol = range.getEndDimsInt()[1]; + + try { + ValueType[] partSchema = schema == null ? null : Arrays.copyOfRange(schema, beginCol, endCol); + Map partMap = readJIVFile(new Path(data.getFilepath()), fs, partSchema); + + // speed is not relevant + synchronized (values) { + for (Map.Entry entry : partMap.entrySet()) { + CellIndex key = entry.getKey(); + Object value = entry.getValue(); + values.put(new CellIndex(key.row + beginRow, key.column + beginCol), value); + } } + } catch (IOException e) { + throw new RuntimeException(e); } - } + return null; + }); + } catch (Exception ex) { + return null; } - catch(IOException e) { - fail("unable to read file: " + e.getMessage()); + return values; + } + + private static Map readJIVFile(Path filename, FileSystem fs, ValueType[] schema) throws IOException { + HashMap values = new HashMap<>(); + FSDataInputStream fsin = fs.open(filename); + try (BufferedReader compareIn = new BufferedReader(new InputStreamReader(fsin))) { + String line; + while ((line = compareIn.readLine()) != null) { + StringTokenizer st = new StringTokenizer(line, " "); + int i = Integer.parseInt(st.nextToken()); + int j = Integer.parseInt(st.nextToken()); + ValueType vt = (schema != null) ? schema[j - 1] : ValueType.FP64; + Object obj = UtilFunctions.stringToObject(vt, st.nextToken()); + values.put(new CellIndex(i, j), obj); + } } + return values; } /** @@ -479,21 +534,55 @@ public static HashMap readDMLMatrixFromHDFS(String filePath) try { Path outDirectory = new Path(filePath); - FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf); + if (tryReadFederatedMatrix(filePath, expectedValues)) + return expectedValues; + else { + FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf); - FileStatus[] outFiles = fs.listStatus(outDirectory); - for (FileStatus file : outFiles) { - FSDataInputStream outIn = fs.open(file.getPath()); - readValuesFromFileStream(outIn, expectedValues); + FileStatus[] outFiles = fs.listStatus(outDirectory); + for (FileStatus file : outFiles) { + FSDataInputStream outIn = fs.open(file.getPath()); + readValuesFromFileStream(outIn, expectedValues); + } } } catch (IOException e) { - assertTrue("could not read from file " + filePath+": "+e.getMessage(), false); + fail("could not read from file " + filePath + ": " + e.getMessage()); } return expectedValues; } + private static boolean tryReadFederatedMatrix(String filePath, HashMap values) { + try { + MetaDataAll mtd = new MetaDataAll(filePath + ".mtd", false, true); + if (mtd.getFederatedString() == null) + return false; + JSONObject json = new JSONObject(mtd.getFederatedString()); + FederationMap map = FederationMap.fromJson(json, Types.DataType.MATRIX); + map.forEachParallel((range, data) -> { + HashMap partMap = readDMLMatrixFromHDFS(data.getFilepath()); + int[] beginDims = range.getBeginDimsInt(); + int beginRow = beginDims[0]; + int beginCol = beginDims[1]; + // speed is not relevant + synchronized (values) { + for (Map.Entry entry : partMap.entrySet()) { + CellIndex key = entry.getKey(); + double value = entry.getValue(); + values.put(new CellIndex(key.row + beginRow, key.column + beginCol), value); + } + } + return null; + }); + } + catch (Exception ex) { + values.clear(); + return false; + } + return true; + } + /** * Reads values from a matrix file in OS's FS in R format * diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderCSV.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderCSV.java index e8e1b31a4a4..d9358fed8f7 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderCSV.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderCSV.java @@ -80,8 +80,8 @@ public void federatedRead( boolean header) { // Thread.sleep(10000); MatrixObject fed = FederatedTestObjectConstructor.constructFederatedInput(dim, dim, blocksize, host, begins, - ends, new int[] {port1}, new String[] {input("X1")}, input("X.json")); - writeInputFederatedWithMTD("X.json", fed, null); + ends, new int[] {port1}, new String[] {input("X1")}, input("X")); + writeInputFederated("X", fed, null); // Run reference dml script with normal matrix @@ -94,7 +94,7 @@ public void federatedRead( boolean header) { // Run federated fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + ".dml"; - programArgs = new String[] {"-stats", "-args", input("X.json")}; + programArgs = new String[] {"-stats", "-args", input("X")}; String out = runTest(null).toString(); Assert.assertTrue(heavyHittersContainsString("fed_uak+")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java index ff68c8328e0..8c8c239a159 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java @@ -107,7 +107,7 @@ public void federatedRead(Types.ExecMode execMode, int workerCount) { workerCount == 2 ? new int[] {port1, port2} : new int[] {port1}, workerCount == 2 ? new String[] {input("X1"), input("X2")} : new String[] {input("X1")}, input("X.json")); - writeInputFederatedWithMTD("X.json", fed, null); + writeInputFederated("X", fed, null); // Run reference dml script with normal matrix if(workerCount == 1) { @@ -126,7 +126,7 @@ public void federatedRead(Types.ExecMode execMode, int workerCount) { // Run federated fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + ".dml"; - programArgs = new String[] {"-stats", "-args", input("X.json")}; + programArgs = new String[] {"-stats", "-args", input("X")}; String out = runTest(null).toString(); Assert.assertTrue(heavyHittersContainsString("fed_uak+")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java index cce7a5f4c73..093d6bd46bd 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java @@ -104,12 +104,12 @@ public void federatedRead(Types.ExecMode execMode) { try { MatrixObject fed = FederatedTestObjectConstructor.constructFederatedInput( rows, cols, blocksize, host, begins, ends, new int[] {port1, port2}, - new String[] {input("X1"), input("X2")}, input("X.json")); + new String[] {input("X1"), input("X2")}, input("X")); //FIXME: reset avoids deadlock on reference script //(because federated matrix creation added to federated sites - blocks on clear) //However, there seems to be a regression regarding the SSL handling in general FederatedData.resetFederatedSites(); - writeInputFederatedWithMTD("X.json", fed, null); + writeInputFederated("X", fed, null); // Run reference dml script with normal matrix fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + (rowPartitioned ? "Row" : "Col") + "2Reference.dml"; @@ -118,7 +118,7 @@ public void federatedRead(Types.ExecMode execMode) { // Run federated fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + ".dml"; - programArgs = new String[] {"-stats", "-args", input("X.json")}; + programArgs = new String[] {"-stats", "-args", input("X")}; String out = runTest(null).toString(); Assert.assertTrue(heavyHittersContainsString("fed_uak+")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedTimeoutTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedTimeoutTest.java index e8c1ed74e15..288aabb04ff 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedTimeoutTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedTimeoutTest.java @@ -109,8 +109,8 @@ public void federatedSinglenodeRead() { ends, new int[] {port1, port2}, new String[] {input("X1"), input("X2")}, - input("X.json")); - writeInputFederatedWithMTD("X.json", fed, null); + input("X")); + writeInputFederated("X", fed, null); } catch(DMLRuntimeException e) { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java index d8bb7431472..61d293501a8 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java @@ -91,12 +91,12 @@ public void federatedWrite(ExecMode execMode) { // Run reader and write a federated json to enable the rest of the test fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/FederatedReaderTestCreate.dml"; programArgs = new String[] {"-stats", "-explain", "-args", input("X1"), input("X2"), port1 + "", port2 + "", - input("X.json")}; + input("X")}; runTest(null); // Run reference dml script with normal matrix fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/FederatedReaderTest.dml"; - programArgs = new String[] {"-stats", "-args", input("X.json")}; + programArgs = new String[] {"-stats", "-args", input("X")}; String out = runTest(null).toString(); Assert.assertTrue(heavyHittersContainsString("fed_uak+")); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java index f887acda45c..5beed7c1a45 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWriteTest.java @@ -20,6 +20,8 @@ package org.apache.sysds.test.functions.federated.primitives; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -28,16 +30,23 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataAll; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.apache.wink.json4j.JSONArray; +import org.apache.wink.json4j.JSONException; +import org.apache.wink.json4j.JSONObject; import org.junit.Assert; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import static org.apache.sysds.runtime.instructions.fed.InitFEDInstruction.parseURL; + @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe public class FederatedWriteTest extends AutomatedTestBase { @@ -138,15 +147,18 @@ public void federatedWrite(Types.ExecMode execMode, Types.ValueType[] schema) { DMLScript.USE_LOCAL_SPARK_CONFIG = true; } fullDMLScriptName = HOME + testName + ".dml"; - programArgs = new String[] {"-explain", "-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), + programArgs = new String[] {"-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows, "cols=" + cols, "tmp=" + output("T")}; runTest(true, false, null, -1); - Assert.assertSame(getMetaData("T").getFileFormat(), Types.FileFormat.FEDERATED); + Assert.assertTrue(Files.notExists(Path.of(output("T")))); + Assert.assertTrue(Files.exists(Path.of(output("T.mtd")))); + Assert.assertNotNull(getMetaData("T").getFederatedString()); fullDMLScriptName = HOME + TEST_NAME + "Read.dml"; - programArgs = new String[] {"-explain", "-nvargs", "tmp=" + output("T"), "out=" + output("B")}; + programArgs = new String[]{"-stats", "-nvargs", "tmp=" + output("T"), "out=" + output("B")}; runTest(true, false, null, -1); + Assert.assertTrue(heavyHittersContainsString("fed_fedinit", "fed_uack+")); // compare via files if(schema != null) compareResults(schema); diff --git a/src/test/scripts/functions/federated/FederatedWriteFrameTest.dml b/src/test/scripts/functions/federated/FederatedWriteFrameTest.dml index 489fd6c2410..3960b3f6198 100644 --- a/src/test/scripts/functions/federated/FederatedWriteFrameTest.dml +++ b/src/test/scripts/functions/federated/FederatedWriteFrameTest.dml @@ -21,4 +21,4 @@ A = federated(type="Frame", addresses=list($in, $in), ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols))) -write(A, $tmp, format="federated") +write(A, $tmp) diff --git a/src/test/scripts/functions/federated/FederatedWriteMatrixTest.dml b/src/test/scripts/functions/federated/FederatedWriteMatrixTest.dml index e803a630ad5..e1a4420f729 100644 --- a/src/test/scripts/functions/federated/FederatedWriteMatrixTest.dml +++ b/src/test/scripts/functions/federated/FederatedWriteMatrixTest.dml @@ -22,4 +22,4 @@ A = federated(addresses=list($in, $in), ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols))) B = A * 42.0 -write(B, $tmp, format="federated") +write(B, $tmp) diff --git a/src/test/scripts/functions/federated/FederatedWriteMatrixTestReference.dml b/src/test/scripts/functions/federated/FederatedWriteMatrixTestReference.dml index f3839f9c9d7..1bb92259f91 100644 --- a/src/test/scripts/functions/federated/FederatedWriteMatrixTestReference.dml +++ b/src/test/scripts/functions/federated/FederatedWriteMatrixTestReference.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- A = rbind(read($in), read($in)) -B = A * 42.0 +B = colSums(A * 42.0) write(B, $out) diff --git a/src/test/scripts/functions/federated/FederatedWriteTestRead.dml b/src/test/scripts/functions/federated/FederatedWriteTestRead.dml index c05fc7d99ef..298a489fd72 100644 --- a/src/test/scripts/functions/federated/FederatedWriteTestRead.dml +++ b/src/test/scripts/functions/federated/FederatedWriteTestRead.dml @@ -19,5 +19,6 @@ # #------------------------------------------------------------- -A = read($tmp) +# Combine federated objects +A = colSums(read($tmp)) write(A, $out) diff --git a/src/test/scripts/functions/federated/io/FederatedReaderTestCreate.dml b/src/test/scripts/functions/federated/io/FederatedReaderTestCreate.dml index d2e8a471abd..858ba1971e6 100644 --- a/src/test/scripts/functions/federated/io/FederatedReaderTestCreate.dml +++ b/src/test/scripts/functions/federated/io/FederatedReaderTestCreate.dml @@ -23,4 +23,4 @@ X1 = read($1) X2 = read($2) X = federated(addresses=list("LocalHost:" +$3 + "/" +$1, "LocalHost:" +$4+ "/" +$2), ranges=list(list(0, 0), list(nrow(X1), ncol(X1)), list(nrow(X1), 0), list(nrow(X1) + nrow(X2), ncol(X1)))) -write(X, $5, format="federated") +write(X, $5) From 4d1cb504b7dbaf127279618b7a66d98db64bfe32 Mon Sep 17 00:00:00 2001 From: Kevin Innerebner Date: Sat, 30 Jul 2022 16:49:39 +0200 Subject: [PATCH 5/6] Some minor cleanups and fixes --- .../controlprogram/caching/CacheableData.java | 9 ++++---- .../federated/FederationMap.java | 4 ++-- .../apache/sysds/runtime/util/HDFSTool.java | 2 +- .../federated/test_federated_matrix_mult.py | 10 ++++----- .../tests/federated/test_federated_read.py | 6 ++--- .../java/org/apache/sysds/test/TestUtils.java | 22 +++++++++---------- 6 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index 43a4027c616..a95161cde06 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -889,11 +889,12 @@ else if(!federatedWrite) throw new DMLRuntimeException("Reading of " + _hdfsFileName + " ("+hashCode()+") failed.", e); } } - //get object from cache - if(federatedWrite) { + + if (federatedWrite) { + // b) write the matrix WriterFederated.write(fName, this, outputFormat, formatProperties); - } - else { + } else { + //get object from cache if (_data == null) getCache(); acquire(false, _data == null); //incl. read matrix if evicted diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 7e2c9d281a1..cfb12589889 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -694,7 +694,7 @@ public void reverseFedMap() { } } - public void toJson(JSONObject mtd) throws JSONException { + public JSONObject toJson() throws JSONException { JSONArray addressesJson = new JSONArray(); JSONArray rangesJson = new JSONArray(); @@ -717,7 +717,7 @@ public void toJson(JSONObject mtd) throws JSONException { JSONObject federatedJson = new JSONObject(); federatedJson.put(DataExpression.FED_ADDRESSES, addressesJson); federatedJson.put(DataExpression.FED_RANGES, rangesJson); - mtd.put(DataExpression.FEDERATED, federatedJson); + return federatedJson; } public static FederationMap fromJson(JSONObject federatedJson, DataType dataType) throws JSONException { diff --git a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java index 6e459bfc363..5220a2a479f 100644 --- a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java +++ b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java @@ -543,7 +543,7 @@ public static String metaDataToString(ValueType vt, ValueType[] schema, DataType } if ( federationMap != null ){ - federationMap.toJson(mtd); + mtd.put(DataExpression.FEDERATED, federationMap.toJson()); } return mtd.toString(4); // indent with 4 spaces diff --git a/src/main/python/tests/federated/test_federated_matrix_mult.py b/src/main/python/tests/federated/test_federated_matrix_mult.py index 6551e11356b..54011c558cd 100644 --- a/src/main/python/tests/federated/test_federated_matrix_mult.py +++ b/src/main/python/tests/federated/test_federated_matrix_mult.py @@ -71,21 +71,21 @@ class TestFederatedAggFn(unittest.TestCase): def setUpClass(cls): cls.sds = SystemDSContext() cls.sds.federated([fed1], [([0, 0], [dim, dim])] - ).write(fed1_file, format="federated").compute() + ).write(fed1_file).compute() cls.sds.federated([fed1, fed2], [ ([0, 0], [dim, dim]), - ([0, dim], [dim, dim*2])]).write(fed_c2_file, format="federated").compute() + ([0, dim], [dim, dim*2])]).write(fed_c2_file).compute() cls.sds.federated([fed1, fed2, fed3], [ ([0, 0], [dim, dim]), ([0, dim], [dim, dim*2]), - ([0, dim*2], [dim, dim*3])]).write(fed_c3_file, format="federated").compute() + ([0, dim*2], [dim, dim*3])]).write(fed_c3_file).compute() cls.sds.federated([fed1, fed2], [ ([0, 0], [dim, dim]), - ([dim, 0], [dim*2, dim])]).write(fed_r2_file, format="federated").compute() + ([dim, 0], [dim*2, dim])]).write(fed_r2_file).compute() cls.sds.federated([fed1, fed2, fed3], [ ([0, 0], [dim, dim]), ([dim, 0], [dim*2, dim]), - ([dim*2, 0], [dim*3, dim])]).write(fed_r3_file, format="federated").compute() + ([dim*2, 0], [dim*3, dim])]).write(fed_r3_file).compute() @classmethod def tearDownClass(cls): diff --git a/src/main/python/tests/federated/test_federated_read.py b/src/main/python/tests/federated/test_federated_read.py index 6a3c28c28fa..21ef89fd54d 100644 --- a/src/main/python/tests/federated/test_federated_read.py +++ b/src/main/python/tests/federated/test_federated_read.py @@ -66,14 +66,14 @@ class TestFederatedAggFn(unittest.TestCase): def setUpClass(cls): cls.sds = SystemDSContext() cls.sds.federated([fed1], [ - ([0, 0], [dim, dim])]).write(fed1_file, format="federated").compute() + ([0, 0], [dim, dim])]).write(fed1_file).compute() cls.sds.federated([fed1, fed2], [ ([0, 0], [dim, dim]), - ([0, dim], [dim, dim*2])]).write(fed2_file, format="federated").compute() + ([0, dim], [dim, dim*2])]).write(fed2_file).compute() cls.sds.federated([fed1, fed2, fed3], [ ([0, 0], [dim, dim]), ([0, dim], [dim, dim*2]), - ([0, dim*2], [dim, dim*3])]).write(fed3_file, format="federated").compute() + ([0, dim*2], [dim, dim*3])]).write(fed3_file).compute() @classmethod def tearDownClass(cls): diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 0aa86e935fa..cdbe4263b79 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -332,12 +332,13 @@ private static void readActualAndExpectedFile(ValueType[] schema, String expecte Path outDirectory = new Path(actualDir); Path compareFile = new Path(expectedFile); FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf); - Map readExpected = readJIVFile(compareFile, fs, schema); + Map readExpected = tryReadFederatedJIVFile(compareFile, schema); + if (readExpected == null) + readExpected = readJIVFile(compareFile, fs, schema); expectedValues.putAll(readExpected); Map readActual = readJIVActualFile(outDirectory, schema); actualValues.putAll(readActual); - } - catch(IOException e) { + } catch (IOException e) { fail("unable to read file: " + e.getMessage()); } } @@ -345,7 +346,7 @@ private static void readActualAndExpectedFile(ValueType[] schema, String expecte private static Map readJIVActualFile(Path actualPath, ValueType[] schema) { try { FileSystem fs = IOUtilFunctions.getFileSystem(actualPath, conf); - Map values = tryReadFederatedJIVActualFile(actualPath, schema); + Map values = tryReadFederatedJIVFile(actualPath, schema); if (values == null) { values = new HashMap<>(); FileStatus[] outFiles = fs.listStatus(actualPath); @@ -362,7 +363,7 @@ private static Map readJIVActualFile(Path actualPath, ValueTy } } - private static Map tryReadFederatedJIVActualFile(Path actualPath, ValueType[] schema) { + private static Map tryReadFederatedJIVFile(Path actualPath, ValueType[] schema) { Map values = new HashMap<>(); try { MetaDataAll mtd = new MetaDataAll(actualPath + ".mtd", false, true); @@ -531,8 +532,7 @@ public static HashMap readDMLMatrixFromHDFS(String filePath) { HashMap expectedValues = new HashMap<>(); - try - { + try { Path outDirectory = new Path(filePath); if (tryReadFederatedMatrix(filePath, expectedValues)) return expectedValues; @@ -545,8 +545,7 @@ public static HashMap readDMLMatrixFromHDFS(String filePath) readValuesFromFileStream(outIn, expectedValues); } } - } - catch (IOException e) { + } catch (IOException e) { fail("could not read from file " + filePath + ": " + e.getMessage()); } @@ -575,8 +574,7 @@ private static boolean tryReadFederatedMatrix(String filePath, HashMap Date: Sun, 31 Jul 2022 20:31:00 +0200 Subject: [PATCH 6/6] Make filenames more unique --- .../java/org/apache/sysds/runtime/io/WriterFederated.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterFederated.java b/src/main/java/org/apache/sysds/runtime/io/WriterFederated.java index d9b7c0ab11a..b228b5cdd89 100644 --- a/src/main/java/org/apache/sysds/runtime/io/WriterFederated.java +++ b/src/main/java/org/apache/sysds/runtime/io/WriterFederated.java @@ -69,8 +69,9 @@ public static void write(String file, CacheableData cd, String outputFormat, JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); Path path = new Path(file); - FederationMap newFedMap = cd.getFedMapping().mapParallel(cd.getFedMapping().getID(), (range, data) -> { - String siteFilename = Long.toString(siteUniqueCounter.getNextID()) + '_' + path.getName(); + long id = cd.getFedMapping().getID(); + FederationMap newFedMap = cd.getFedMapping().mapParallel(id, (range, data) -> { + String siteFilename = Long.toString(id) + '_' + siteUniqueCounter.getNextID() + '_' + path.getName(); try { FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, data.getVarID(), new WriteAtSiteUDF(data.getVarID(), siteFilename, outputFormat, fileFormatProperties))).get();