Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
09b425f
Initial native function use
rnett Mar 5, 2021
278d7cc
Allow body constants
rnett Mar 5, 2021
c2b0b60
Fix body forbids
rnett Mar 5, 2021
dec04e6
Use default eager session for tensor calls
rnett Mar 5, 2021
047243b
Use default eager for single tensor call too
rnett Mar 5, 2021
696ef67
Get functions from graph
rnett Mar 5, 2021
71b8fab
Start of saver support
rnett Mar 5, 2021
3697122
Update loading, detect statefulness, use PartitionedCall
rnett Mar 6, 2021
abf0ed1
Start of dependencies
rnett Mar 6, 2021
ba65103
Support dependencies
rnett Mar 6, 2021
bafcec2
Remove unwrapping
rnett Mar 6, 2021
bb641e3
Proper attribute setters
rnett Mar 7, 2021
9e686c4
Add ignored gradient test
rnett Mar 7, 2021
b4bf605
Rebase fix
rnett Mar 9, 2021
b7ab76c
Op generation for functions
rnett Mar 13, 2021
6d8308e
Rebase fix
rnett Mar 17, 2021
eaeaf6e
SavedFunction for running functions from SavedModelBundles
rnett Apr 11, 2021
f32fbf2
Review fixes
rnett Apr 17, 2021
f892c54
Generation and better javadoc
rnett Apr 17, 2021
f485ccf
Rework pointer scopes
rnett Apr 19, 2021
5977c40
SessionFunction instead of SavedModelBundle specific class
rnett May 15, 2021
12af327
Add CallableFunction javadoc
rnett May 15, 2021
1b4bf59
Remove obsolete test
rnett May 15, 2021
54e8855
Rebase fix
rnett May 15, 2021
117b391
Formatting fixes and nits
rnett May 21, 2021
3df55b9
Add session function test, Signature.builder with name
rnett May 21, 2021
99a3403
Remove extra synchronization
rnett May 21, 2021
cecf71d
Formatting
rnett May 28, 2021
61e5ffd
New names
rnett May 30, 2021
1258276
Note on SavedModel functions
rnett May 30, 2021
b1b378e
Fix tests
rnett May 30, 2021
cba0fea
Rename name method
rnett May 30, 2021
e67ae1e
Re-add tests w/ SessionFunction
rnett May 30, 2021
c09385d
Helper methods for saving
rnett May 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Review fixes
Signed-off-by: Ryan Nett <JNett96@gmail.com>
  • Loading branch information
rnett committed May 28, 2021
commit f32fbf2afa5d231414491d7ac4696765fc3ed422
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
Expand Down Expand Up @@ -212,21 +213,36 @@ public String toString() {
return signature.toString();
}

//TODO migrate to the actual ops once they are generated
public static final String CALL_OP = "PartitionedCall";
//TODO migrate to the actual ops once they are generated
public static final String STATEFUL_CALL_OP = "StatefulPartitionedCall";


/**
* Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
* inputs and outputs are keyed by the names set in the {@code Signature}.
*
* @param scope the scope to call the function in
* @param arguments the arguments to the call
* @return the outputs of the function
*/
public Map<String, Operand<?>> call(Scope scope,
Map<String, Operand<?>> arguments) {
List<Operand<?>> inputList = new ArrayList<>();

Output<?>[] inputs = new Output<?>[signature().inputNames().size()];

int i = 0;
for (String inputName : signature().inputNames()) {
Operand<?> input = arguments.get(inputName);
if (input == null) {
throw new IllegalArgumentException(
"Function " + signature().methodName() + " has parameter \"" + inputName
+ "\", but no argument was passed for it.");
}
inputList.add(input);
inputs[i] = input.asOutput();
i++;
}

scope.env().attachFunction(this);
Expand All @@ -237,26 +253,26 @@ public Map<String, Operand<?>> call(Scope scope,
OperationBuilder opBuilder = scope.env()
.opBuilder(isStateful() ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName));

opBuilder.addInputList(inputList.stream().map(Operand::asOutput).toArray(Output[]::new));
opBuilder.addInputList(inputs);

opBuilder.setAttr("f", this);
opBuilder.setAttr("Tin", inputList.stream().map(x -> x.asOutput().dataType()).toArray(DataType[]::new));
opBuilder.setAttr("Tout", signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new));
opBuilder.setAttr("Tin", inputDtypes);
opBuilder.setAttr("Tout", outputDtypes);

opBuilder = scope.apply(opBuilder);
Operation op = opBuilder.build();

int numOutputs1 = op.numOutputs();
List<Operand<?>> outputList = new ArrayList<>(signature().outputNames().size());

for (int i = 0; i < numOutputs1; i++) {
for (i = 0; i < numOutputs1; i++) {
outputList.add(op.output(i));
}

Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature().outputNames().size());

List<String> outputNames = new ArrayList<>(signature().outputNames());
for (int i = 0; i < outputNames.size(); i++) {
for (i = 0; i < outputNames.size(); i++) {
String outputName = outputNames.get(i);

if (i > outputList.size()) {
Expand Down Expand Up @@ -378,61 +394,60 @@ public void save(String exportDir) throws IOException {
SavedModelBundle.exporter(exportDir).withFunction(this).export();
}

/**
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
* JIT is extremely non-obvious.
*
* Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id:
* 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
*/
private void makeJit() {
try (PointerScope scope = new PointerScope()) {
byte[] bytes = AttrValue.newBuilder().setB(true).build().toByteArray();
BytePointer trueValue = new BytePointer(bytes);

TF_Status status1 = TF_Status.newStatus();
TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, status1);
status1.throwExceptionIfNotOK();

TF_Status status2 = TF_Status.newStatus();
TF_FunctionSetAttrValueProto(nativeHandle(), "_noinline", trueValue, bytes.length, status2);
status2.throwExceptionIfNotOK();
}
}

TF_Function nativeHandle() {
if (nativeFunction.getNativeHandle().isNull()) {
throw new IllegalStateException("Function has been closed");
}
return nativeFunction.getNativeHandle();
}

private final Signature signature;
private final NativeFunction nativeFunction;
private final PointerScope scope;
private final Set<TF_Function> dependencies;

ConcreteFunction(Signature signature, NativeFunction nativeFunction, Collection<NativeFunction> availableFunctions) {
this(signature, nativeFunction, nativeFunction.getAllDependencies(availableFunctions));
}

private static boolean dataTypesMatch(List<DataType> a, List<DataType> b) {
if (a.size() != b.size()) {
return false;
/**
* Detects the signature from the handle
*/
static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
Collection<NativeFunction> availableFunctions) {

Signature.Builder builder = Signature.builder().methodName(nativeFunction.getFunctionDef().getSignature().getName())
.key(nativeFunction.getName());

for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) {
TensorInfo info = TensorInfo.newBuilder()
.setDtype(input.getType())
.setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
.setName(input.getName())
.build();

builder.input(input.getName(), info);
}

for (int i = 0; i < a.size(); i++) {
DataType aType = a.get(i);
DataType bType = b.get(i);
for (ArgDef outputDef : nativeFunction.getFunctionDef().getSignature().getOutputArgList()) {
TensorInfo info = TensorInfo.newBuilder()
.setDtype(outputDef.getType())
.setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
.setName(outputDef.getName())
.build();

if (aType != DataType.DT_INVALID && bType != DataType.DT_INVALID && !a.equals(b)) {
return false;
}
builder.output(outputDef.getName(), info);
}

return true;
return new ConcreteFunction(
builder.build(),
nativeFunction,
availableFunctions
);
}

private final Signature signature;
private final NativeFunction nativeFunction;
private final PointerScope scope;
private final Set<TF_Function> dependencies;
private final DataType[] inputDtypes;
private final DataType[] outputDtypes;

private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set<TF_Function> dependencies) {
this.signature = signature;
this.nativeFunction = nativeFunction;
Expand All @@ -452,8 +467,10 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set
+ this.signature.getOutputs().size());
}

List<DataType> inputs = this.signature.getInputs().values().stream().map(x -> x.dataType)
.collect(Collectors.toList());
inputDtypes = this.signature.getInputs().values().stream().map(x -> x.dataType)
.toArray(DataType[]::new);

List<DataType> inputs = Arrays.asList(inputDtypes);
List<DataType> nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList().stream()
.map(ArgDef::getType)
.collect(Collectors.toList());
Expand All @@ -464,8 +481,9 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set
+ nativeInputs + ", got " + inputs);
}

List<DataType> outputs = this.signature.getOutputs().values().stream().map(x -> x.dataType)
.collect(Collectors.toList());
outputDtypes = signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);

List<DataType> outputs = Arrays.asList(outputDtypes);
List<DataType> nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream()
.map(ArgDef::getType)
.collect(Collectors.toList());
Expand All @@ -486,39 +504,42 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set
}

/**
* Detects the signature from the handle
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
* JIT is extremely non-obvious.
*
* Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id:
* 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
*/
static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
Collection<NativeFunction> availableFunctions) {
private void makeJit() {
try (PointerScope scope = new PointerScope()) {
byte[] bytes = AttrValue.newBuilder().setB(true).build().toByteArray();
BytePointer trueValue = new BytePointer(bytes);

Signature.Builder builder = Signature.builder().methodName(nativeFunction.getFunctionDef().getSignature().getName())
.key(nativeFunction.getName());
TF_Status status1 = TF_Status.newStatus();
TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, status1);
status1.throwExceptionIfNotOK();

for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) {
TensorInfo info = TensorInfo.newBuilder()
.setDtype(input.getType())
.setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
.setName(input.getName())
.build();
TF_Status status2 = TF_Status.newStatus();
TF_FunctionSetAttrValueProto(nativeHandle(), "_noinline", trueValue, bytes.length, status2);
status2.throwExceptionIfNotOK();
}
}

builder.input(input.getName(), info);
private static boolean dataTypesMatch(List<DataType> a, List<DataType> b) {
if (a.size() != b.size()) {
return false;
}

for (ArgDef outputDef : nativeFunction.getFunctionDef().getSignature().getOutputArgList()) {
TensorInfo info = TensorInfo.newBuilder()
.setDtype(outputDef.getType())
.setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
.setName(outputDef.getName())
.build();
for (int i = 0; i < a.size(); i++) {
DataType aType = a.get(i);
DataType bType = b.get(i);

builder.output(outputDef.getName(), info);
if (aType != DataType.DT_INVALID && bType != DataType.DT_INVALID && !a.equals(b)) {
return false;
}
}

return new ConcreteFunction(
builder.build(),
nativeFunction,
availableFunctions
);
return true;
}


Expand Down Expand Up @@ -616,8 +637,4 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
return new ConcreteFunction(signature, new NativeFunction(handle), graph.getNativeFunctions());
}
}

ConcreteFunction withNewSignature(Signature signature) {
return new ConcreteFunction(signature, nativeFunction, dependencies);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.tensorflow.internal.WeakPointerScope;
import org.tensorflow.internal.c_api.TFE_Context;
import org.tensorflow.internal.c_api.TFE_ContextOptions;
import org.tensorflow.internal.c_api.TF_Function;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.op.Op;
import org.tensorflow.op.Scope;
Expand Down Expand Up @@ -290,15 +289,16 @@ public OperationBuilder opBuilder(String type, String name) {
public void attachFunction(ConcreteFunction function) {
checkSession();
try (PointerScope scope = new PointerScope()) {
attachNativeFunction(function.nativeHandle());
function.getDependencies().forEach(this::attachNativeFunction);
}
}
TF_Status status = TF_Status.newStatus();
TFE_ContextAddFunction(nativeHandle, function.nativeHandle(), status);
status.throwExceptionIfNotOK();

private void attachNativeFunction(TF_Function fn) {
TF_Status status = TF_Status.newStatus();
TFE_ContextAddFunction(nativeHandle, fn, status);
status.throwExceptionIfNotOK();
function.getDependencies().forEach(fn -> {
TF_Status status2 = TF_Status.newStatus();
TFE_ContextAddFunction(nativeHandle, fn, status2);
status2.throwExceptionIfNotOK();
});
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,17 +387,21 @@ public GraphOperationBuilder opBuilder(String type, String name) {
public void attachFunction(ConcreteFunction function) {
try (Reference ref = ref();
PointerScope scope = new PointerScope()) {
attachNativeFunction(ref.nativeHandle(), function.nativeHandle());
function.getDependencies().forEach(x -> attachNativeFunction(ref.nativeHandle(), x));
}
}
TF_Status status = TF_Status.newStatus();
TF_GraphCopyFunction(ref.nativeHandle(), function.nativeHandle(), null, status);
status.throwExceptionIfNotOK();

private void attachNativeFunction(TF_Graph graph, TF_Function fn) {
TF_Status status = TF_Status.newStatus();
TF_GraphCopyFunction(graph, fn, null, status);
status.throwExceptionIfNotOK();
function.getDependencies().forEach(x -> {
TF_Status status2 = TF_Status.newStatus();
TF_GraphCopyFunction(ref.nativeHandle(), x, null, status2);
status2.throwExceptionIfNotOK();
});
}
}

/**
* Get the graph's functions. Deallocating the function pointers is the caller's responsibility.
*/
synchronized List<NativeFunction> getNativeFunctions() {
try (Reference ref = ref();
PointerScope scope = new PointerScope()) {
Expand Down Expand Up @@ -435,17 +439,14 @@ public synchronized ConcreteFunction getFunction(String key) {
// will close unused functions when method ends
funcs.forEach(x -> x.getNativeHandle().withDeallocatorInScope());

ConcreteFunction func = null;
for (NativeFunction f : funcs) {

for (int i = 0; i < funcs.size(); i++) {

if (funcs.get(i).getName().equals(key) && func == null) {
func = ConcreteFunction.fromNativeHandle(funcs.get(i), funcs);
if (f.getName().equals(key)) {
return ConcreteFunction.fromNativeHandle(f, funcs);
}
}

return func;
}
return null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@
* A class holding a native function handle and providing cached access to it's {@link FunctionDef}.
*/
class NativeFunction {

private final TF_Function nativeHandle;

private FunctionDef functionDef = null;
private List<String> dependencies = null;
private Boolean stateful = null;
private String name = null;

public NativeFunction(TF_Function nativeHandle) {
this.nativeHandle = nativeHandle;
}
Expand Down Expand Up @@ -159,4 +151,12 @@ synchronized Set<TF_Function> getAllDependencies(Collection<NativeFunction> avai
.map(NativeFunction::getNativeHandle)
.collect(Collectors.toSet());
}

private final TF_Function nativeHandle;

private FunctionDef functionDef = null;
private List<String> dependencies = null;
private Boolean stateful = null;
private String name = null;

}
Loading