diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/coders/JAXBCoder.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/coders/JAXBCoder.java index f683b3e1eda3..6e2833eff60e 100644 --- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/coders/JAXBCoder.java +++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/coders/JAXBCoder.java @@ -19,6 +19,8 @@ import com.google.cloud.dataflow.sdk.util.CloudObject; import com.google.cloud.dataflow.sdk.util.Structs; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.common.io.ByteStreams; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -71,13 +73,19 @@ public void encode(T value, OutputStream outStream, Context context) JAXBContext jaxbContext = JAXBContext.newInstance(jaxbClass); jaxbMarshaller = jaxbContext.createMarshaller(); } - - jaxbMarshaller.marshal(value, new FilterOutputStream(outStream) { - // JAXB closes the underyling stream so we must filter out those calls. - @Override - public void close() throws IOException { + if (!context.isWholeStream) { + try { + long size = getEncodedElementByteSize(value, Context.OUTER); + // record the number of bytes the XML consists of so when reading we only read the encoded + // value + VarInt.encode(size, outStream); + } catch (Exception e) { + throw new CoderException( + "An Exception occured while trying to get the size of an encoded representation", e); } - }); + } + + jaxbMarshaller.marshal(value, new CloseIgnoringOutputStream(outStream)); } catch (JAXBException e) { throw new CoderException(e); } @@ -91,13 +99,13 @@ public T decode(InputStream inStream, Context context) throws CoderException, IO jaxbUnmarshaller = jaxbContext.createUnmarshaller(); } + InputStream stream = inStream; + if (!context.isWholeStream) { + long limit = VarInt.decodeLong(inStream); + stream = ByteStreams.limit(inStream, limit); + } @SuppressWarnings("unchecked") - T obj = (T) jaxbUnmarshaller.unmarshal(new FilterInputStream(inStream) { - // JAXB closes the underyling stream so we must filter out those calls. - @Override - public void close() throws IOException { - } - }); + T obj = (T) jaxbUnmarshaller.unmarshal(new CloseIgnoringInputStream(stream)); return obj; } catch (JAXBException e) { throw new CoderException(e); @@ -109,6 +117,30 @@ public String getEncodingId() { return getJAXBClass().getName(); } + private static class CloseIgnoringInputStream extends FilterInputStream { + + protected CloseIgnoringInputStream(InputStream in) { + super(in); + } + + @Override + public void close() { + // Do nothing. JAXB closes the underlying stream so we must filter out those calls. + } + } + + private static class CloseIgnoringOutputStream extends FilterOutputStream { + + protected CloseIgnoringOutputStream(OutputStream out) { + super(out); + } + + @Override + public void close() throws IOException { + // JAXB closes the underlying stream so we must filter out those calls. + } + } + //////////////////////////////////////////////////////////////////////////////////// // JSON Serialization details below diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/coders/JAXBCoderTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/coders/JAXBCoderTest.java index ae0919023a1f..26c11985c3f6 100644 --- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/coders/JAXBCoderTest.java +++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/coders/JAXBCoderTest.java @@ -19,11 +19,18 @@ import com.google.cloud.dataflow.sdk.testing.CoderProperties; import com.google.cloud.dataflow.sdk.util.CoderUtils; +import com.google.common.collect.ImmutableList; + import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + import javax.xml.bind.annotation.XmlRootElement; /** Unit tests for {@link JAXBCoder}. */ @@ -79,13 +86,62 @@ public boolean equals(Object obj) { } @Test - public void testEncodeDecode() throws Exception { + public void testEncodeDecodeOuter() throws Exception { JAXBCoder coder = JAXBCoder.of(TestType.class); byte[] encoded = CoderUtils.encodeToByteArray(coder, new TestType("abc", 9999)); Assert.assertEquals(new TestType("abc", 9999), CoderUtils.decodeFromByteArray(coder, encoded)); } + @Test + public void testEncodeDecodeNested() throws Exception { + JAXBCoder jaxbCoder = JAXBCoder.of(TestType.class); + TestCoder nesting = new TestCoder(jaxbCoder); + + byte[] encoded = CoderUtils.encodeToByteArray(nesting, new TestType("abc", 9999)); + Assert.assertEquals( + new TestType("abc", 9999), CoderUtils.decodeFromByteArray(nesting, encoded)); + } + + /** + * A coder that surrounds the value with two values, to demonstrate nesting. + */ + private static class TestCoder extends StandardCoder { + private final JAXBCoder jaxbCoder; + public TestCoder(JAXBCoder jaxbCoder) { + this.jaxbCoder = jaxbCoder; + } + + @Override + public void encode(TestType value, OutputStream outStream, Context context) + throws CoderException, IOException { + Context subContext = context.nested(); + VarIntCoder.of().encode(3, outStream, subContext); + jaxbCoder.encode(value, outStream, subContext); + VarLongCoder.of().encode(22L, outStream, subContext); + } + + @Override + public TestType decode(InputStream inStream, Context context) + throws CoderException, IOException { + Context subContext = context.nested(); + VarIntCoder.of().decode(inStream, subContext); + TestType result = jaxbCoder.decode(inStream, subContext); + VarLongCoder.of().decode(inStream, subContext); + return result; + } + + @Override + public List> getCoderArguments() { + return ImmutableList.of(jaxbCoder); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + jaxbCoder.verifyDeterministic(); + } + } + @Test public void testEncodable() throws Exception { CoderProperties.coderSerializable(JAXBCoder.of(TestType.class));