Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/javaTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ on:
- '*.html'
- 'src/main/python/**'
- 'dev/**'
- '.github/workflows/python.yml'
branches:
- main
pull_request:
Expand All @@ -38,6 +39,7 @@ on:
- '*.html'
- 'src/main/python/**'
- 'dev/**'
- '.github/workflows/python.yml'
branches:
- main

Expand Down
120 changes: 120 additions & 0 deletions src/main/java/org/apache/sysds/api/PythonDMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,40 @@
import org.apache.log4j.Logger;
import org.apache.sysds.api.jmlc.Connection;

import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UnixPipeUtils;
import py4j.DefaultGatewayServerListener;
import py4j.GatewayServer;
import py4j.Py4JNetworkException;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;


public class PythonDMLScript {

private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName());
final private Connection _connection;
public static GatewayServer GwS;

private static String fromPythonBase = "py2java";
private static String toPythonBase = "java2py";
public HashMap<Integer, BufferedInputStream> fromPython = null;
public HashMap<Integer, BufferedOutputStream> toPython = null;
public String baseDir;
private static int BATCH_SIZE = 32*1024;

/**
* Entry point for Python API.
*
Expand Down Expand Up @@ -78,6 +101,103 @@ public Connection getConnection() {
return _connection;
}


public void openPipes(String path, int num) throws IOException {
fromPython = new HashMap<>(num * 2);
toPython = new HashMap<>(num * 2);
baseDir = path;
for (int i = 0; i < num; i++) {
BufferedInputStream pipe_in = UnixPipeUtils.openInput(path + "/" + fromPythonBase + "-" + i, i);
LOG.debug("PY2JAVA pipe "+i+" is ready!");
fromPython.put(i, pipe_in);

BufferedOutputStream pipe_out = UnixPipeUtils.openOutput(path + "/" + toPythonBase + "-" + i, i);
toPython.put(i, pipe_out);
}
}

public MatrixBlock startReadingMbFromPipe(int id, int rlen, int clen, Types.ValueType type) throws IOException {
long limit = (long) rlen * clen;
LOG.debug("trying to read matrix from "+id+" with "+rlen+" rows and "+clen+" columns. Total size: "+limit);
if(limit > Integer.MAX_VALUE)
throw new DMLRuntimeException("Dense NumPy array of size " + limit +
" cannot be converted to MatrixBlock");
MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
if(fromPython != null){
BufferedInputStream pipe = fromPython.get(id);
double[] denseBlock = new double[(int) limit];
UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, (int) limit, type, denseBlock, 0);
mb.init(denseBlock, rlen, clen);
} else {
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
}
mb.recomputeNonZeros();
mb.examSparsity();
LOG.debug("Reading from Python finished");
return mb;
}

public MatrixBlock startReadingMbFromPipes(int[] blockSizes, int rlen, int clen, Types.ValueType type) throws ExecutionException, InterruptedException {
long limit = (long) rlen * clen;
if(limit > Integer.MAX_VALUE)
throw new DMLRuntimeException("Dense NumPy array of size " + limit +
" cannot be converted to MatrixBlock");
MatrixBlock mb = new MatrixBlock(rlen, clen, false, -1);
if(fromPython != null){
ExecutorService pool = CommonThreadPool.get();
double[] denseBlock = new double[(int) limit];
int offsetOut = 0;
List<Future<Void>> futures = new ArrayList<>();
for (int i = 0; i < blockSizes.length; i++) {
BufferedInputStream pipe = fromPython.get(i);
int id = i, blockSize = blockSizes[i], _offsetOut = offsetOut;
Callable<Void> task = () -> {
UnixPipeUtils.readNumpyArrayInBatches(pipe, id, BATCH_SIZE, blockSize, type, denseBlock, _offsetOut);
return null;
};

futures.add(pool.submit(task));
offsetOut += blockSize;
}
// Wait for all tasks and propagate exceptions
for (Future<Void> f : futures) {
f.get();
}

mb.init(denseBlock, rlen, clen);
} else {
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
}
mb.recomputeNonZeros();
mb.examSparsity();
return mb;
}

public void startWritingMbToPipe(int id, MatrixBlock mb) throws IOException {
if (toPython != null) {
int rlen = mb.getNumRows();
int clen = mb.getNumColumns();
int numElem = rlen * clen;
LOG.debug("Trying to write matrix ["+baseDir + "-"+ id+"] with "+rlen+" rows and "+clen+" columns. Total size: "+numElem*8);

BufferedOutputStream out = toPython.get(id);
long bytes = UnixPipeUtils.writeNumpyArrayInBatches(out, id, BATCH_SIZE, numElem, Types.ValueType.FP64, mb);

LOG.debug("Writing of " + bytes +" Bytes to Python ["+baseDir + "-"+ id+"] finished");
} else {
throw new DMLRuntimeException("FIFO Pipes are not initialized.");
}
}

public void closePipes() throws IOException {
LOG.debug("Closing all pipes in Java");
for (BufferedInputStream pipe : fromPython.values())
pipe.close();
for (BufferedOutputStream pipe : toPython.values())
pipe.close();
LOG.debug("Closed all pipes in Java");
}

protected static class DMLGateWayListener extends DefaultGatewayServerListener {
private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName());

Expand Down
Loading
Loading