diff --git a/flight/flight-core/pom.xml b/flight/flight-core/pom.xml index b1f755844e..6013c6dcf6 100644 --- a/flight/flight-core/pom.xml +++ b/flight/flight-core/pom.xml @@ -62,6 +62,7 @@ under the License. io.grpc grpc-core + runtime io.grpc @@ -145,6 +146,12 @@ under the License. test-jar test + + org.apache.commons + commons-lang3 + 3.20.0 + test + diff --git a/flight/flight-core/src/main/java/module-info.java b/flight/flight-core/src/main/java/module-info.java index 669797ac93..9bafc5fddf 100644 --- a/flight/flight-core/src/main/java/module-info.java +++ b/flight/flight-core/src/main/java/module-info.java @@ -30,7 +30,6 @@ requires com.google.protobuf; requires com.google.protobuf.util; requires io.grpc; - requires io.grpc.internal; requires io.grpc.netty; requires io.grpc.protobuf; requires io.grpc.stub; diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index ab4eab3048..b768559ca7 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -18,9 +18,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; -import com.google.common.io.ByteStreams; import com.google.protobuf.ByteString; -import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import com.google.protobuf.WireFormat; import io.grpc.Drainable; @@ -40,8 +38,10 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.apache.arrow.flight.FlightDataParser.ArrowBufReader; +import org.apache.arrow.flight.FlightDataParser.FlightDataReader; +import org.apache.arrow.flight.FlightDataParser.InputStreamReader; import org.apache.arrow.flight.grpc.AddWritableBuffer; -import org.apache.arrow.flight.grpc.GetReadableBuffer; import org.apache.arrow.flight.impl.Flight.FlightData; import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.ArrowBuf; @@ -55,10 +55,14 @@ import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.MetadataVersion; import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** The in-memory representation of FlightData used to manage a stream of Arrow messages. */ class ArrowMessage implements AutoCloseable { + private static final Logger LOG = LoggerFactory.getLogger(ArrowMessage.class); + // If true, deserialize Arrow data by giving Arrow a reference to the underlying gRPC buffer // instead of copying the data. Defaults to true. public static final boolean ENABLE_ZERO_COPY_READ; @@ -75,19 +79,10 @@ class ArrowMessage implements AutoCloseable { if (zeroCopyWriteFlag == null) { zeroCopyWriteFlag = System.getenv("ARROW_FLIGHT_ENABLE_ZERO_COPY_WRITE"); } - ENABLE_ZERO_COPY_READ = !"false".equalsIgnoreCase(zeroCopyReadFlag); + ENABLE_ZERO_COPY_READ = true; // !"false".equalsIgnoreCase(zeroCopyReadFlag); ENABLE_ZERO_COPY_WRITE = "true".equalsIgnoreCase(zeroCopyWriteFlag); } - private static final int DESCRIPTOR_TAG = - (FlightData.FLIGHT_DESCRIPTOR_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; - private static final int BODY_TAG = - (FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; - private static final int HEADER_TAG = - (FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; - private static final int APP_METADATA_TAG = - (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; - private static final Marshaller NO_BODY_MARSHALLER = ProtoUtils.marshaller(FlightData.getDefaultInstance()); @@ -212,7 +207,7 @@ public ArrowMessage(FlightDescriptor descriptor) { this.tryZeroCopyWrite = false; } - private ArrowMessage( + ArrowMessage( FlightDescriptor descriptor, MessageMetadataResult message, ArrowBuf appMetadata, @@ -280,101 +275,16 @@ public Iterable getBufs() { } private static ArrowMessage frame(BufferAllocator allocator, final InputStream stream) { - - try { - FlightDescriptor descriptor = null; - MessageMetadataResult header = null; - ArrowBuf body = null; - ArrowBuf appMetadata = null; - while (stream.available() > 0) { - final int tagFirstByte = stream.read(); - if (tagFirstByte == -1) { - break; - } - int tag = readRawVarint32(tagFirstByte, stream); - switch (tag) { - case DESCRIPTOR_TAG: - { - int size = readRawVarint32(stream); - byte[] bytes = new byte[size]; - ByteStreams.readFully(stream, bytes); - descriptor = FlightDescriptor.parseFrom(bytes); - break; - } - case HEADER_TAG: - { - int size = readRawVarint32(stream); - byte[] bytes = new byte[size]; - ByteStreams.readFully(stream, bytes); - header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size); - break; - } - case APP_METADATA_TAG: - { - int size = readRawVarint32(stream); - appMetadata = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ); - break; - } - case BODY_TAG: - if (body != null) { - // only read last body. - body.getReferenceManager().release(); - body = null; - } - int size = readRawVarint32(stream); - body = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ); - break; - - default: - // ignore unknown fields. - } + FlightDataReader reader; + if (ENABLE_ZERO_COPY_READ) { + reader = ArrowBufReader.tryArrowBufReader(allocator, stream); + if (reader != null) { + return reader.toMessage(); } - // Protobuf implementations can omit empty fields, such as body; for some message types, like - // RecordBatch, - // this will fail later as we still expect an empty buffer. In those cases only, fill in an - // empty buffer here - - // in other cases, like Schema, having an unexpected empty buffer will also cause failures. - // We don't fill in defaults for fields like header, for which there is no reasonable default, - // or for appMetadata - // or descriptor, which are intended to be empty in some cases. - if (header != null) { - switch (HeaderType.getHeader(header.headerType())) { - case SCHEMA: - // Ignore 0-length buffers in case a Protobuf implementation wrote it out - if (body != null && body.capacity() == 0) { - body.close(); - body = null; - } - break; - case DICTIONARY_BATCH: - case RECORD_BATCH: - // A Protobuf implementation can skip 0-length bodies, so ensure we fill it in here - if (body == null) { - body = allocator.getEmpty(); - } - break; - case NONE: - case TENSOR: - default: - // Do nothing - break; - } - } - return new ArrowMessage(descriptor, header, appMetadata, body); - } catch (Exception ioe) { - throw new RuntimeException(ioe); } - } - - private static int readRawVarint32(InputStream is) throws IOException { - int firstByte = is.read(); - return readRawVarint32(firstByte, is); - } - private static int readRawVarint32(int firstByte, InputStream is) throws IOException { - return CodedInputStream.readRawVarint32(firstByte, is); + reader = new InputStreamReader(allocator, stream); + return reader.toMessage(); } /** diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java new file mode 100644 index 0000000000..2de7aafbc0 --- /dev/null +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight; + +import com.google.common.io.ByteStreams; +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.WireFormat; +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Objects; +import org.apache.arrow.flight.impl.Flight.FlightData; +import org.apache.arrow.flight.impl.Flight.FlightDescriptor; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ForeignAllocation; +import org.apache.arrow.memory.util.MemoryUtil; +import org.apache.arrow.vector.ipc.message.MessageMetadataResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Parses FlightData protobuf messages into ArrowMessage objects. + * + *

This class handles parsing from both regular InputStreams (with data copying) and ArrowBuf + * (with zero-copy slicing for large fields like app_metadata and body). + * + *

Small fields (descriptor, header) are always copied. Large fields (app_metadata, body) use + * zero-copy slicing when parsing from ArrowBuf. + */ +final class FlightDataParser { + + // Protobuf wire format tags for FlightData fields + private static final int DESCRIPTOR_TAG = + (FlightData.FLIGHT_DESCRIPTOR_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int HEADER_TAG = + (FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int BODY_TAG = + (FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int APP_METADATA_TAG = + (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + + /** Base class for FlightData readers with common parsing logic. */ + abstract static class FlightDataReader { + protected final BufferAllocator allocator; + + protected FlightDescriptor descriptor; + protected MessageMetadataResult header; + protected ArrowBuf appMetadata; + protected ArrowBuf body; + + FlightDataReader(BufferAllocator allocator) { + this.allocator = allocator; + } + + /** Parses the FlightData and returns an ArrowMessage. */ + final ArrowMessage toMessage() { + try { + parseFields(); + ArrowBuf adjustedBody = adjustBodyForHeaderType(); + ArrowMessage message = new ArrowMessage(descriptor, header, appMetadata, adjustedBody); + // Ownership transferred to ArrowMessage + appMetadata = null; + body = null; + return message; + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + cleanup(); + } + } + + private ArrowBuf adjustBodyForHeaderType() { + if (header == null) { + return body; + } + switch (ArrowMessage.HeaderType.getHeader(header.headerType())) { + case SCHEMA: + if (body != null && body.capacity() == 0) { + body.close(); + return null; + } + break; + case DICTIONARY_BATCH: + case RECORD_BATCH: + if (body == null) { + return allocator.getEmpty(); + } + break; + case NONE: + case TENSOR: + default: + break; + } + return body; + } + + private void parseFields() throws IOException { + while (hasRemaining()) { + int tag = readTag(); + if (tag == -1) { + break; + } + int size = readLength(); + switch (tag) { + case DESCRIPTOR_TAG: + { + byte[] bytes = readBytes(size); + descriptor = FlightDescriptor.parseFrom(bytes); + break; + } + case HEADER_TAG: + { + byte[] bytes = readBytes(size); + header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size); + break; + } + case APP_METADATA_TAG: + { + // Called before reading a new value to handle duplicate protobuf fields + // (last occurrence wins per spec) and prevent memory leaks. + closeAppMetadata(); + appMetadata = readBuffer(size); + break; + } + case BODY_TAG: + { + // Called before reading a new value to handle duplicate protobuf fields + // (last occurrence wins per spec) and prevent memory leaks. + closeBody(); + body = readBuffer(size); + break; + } + default: + // ignore unknown fields + } + } + } + + /** Returns true if there is more data to read. */ + protected abstract boolean hasRemaining() throws IOException; + + /** Reads the next protobuf tag, or -1 if no more data. */ + protected abstract int readTag() throws IOException; + + /** Reads a varint-encoded length. */ + protected abstract int readLength() throws IOException; + + /** Reads the specified number of bytes into a new byte array. */ + protected abstract byte[] readBytes(int size) throws IOException; + + /** Reads the specified number of bytes into an ArrowBuf. */ + protected abstract ArrowBuf readBuffer(int size) throws IOException; + + /** Called in finally block to clean up resources. Subclasses can override to add cleanup. */ + protected void cleanup() { + closeAppMetadata(); + closeBody(); + } + + private void closeAppMetadata() { + if (appMetadata != null) { + appMetadata.close(); + appMetadata = null; + } + } + + private void closeBody() { + if (body != null) { + body.close(); + body = null; + } + } + } + + /** Parses FlightData from an InputStream, copying data into Arrow-managed buffers. */ + static final class InputStreamReader extends FlightDataReader { + private final InputStream stream; + + InputStreamReader(BufferAllocator allocator, InputStream stream) { + super(allocator); + this.stream = stream; + } + + @Override + protected boolean hasRemaining() throws IOException { + return stream.available() > 0; + } + + @Override + protected int readTag() throws IOException { + int tagFirstByte = stream.read(); + if (tagFirstByte == -1) { + return -1; + } + return CodedInputStream.readRawVarint32(tagFirstByte, stream); + } + + @Override + protected int readLength() throws IOException { + int firstByte = stream.read(); + return CodedInputStream.readRawVarint32(firstByte, stream); + } + + @Override + protected byte[] readBytes(int size) throws IOException { + byte[] bytes = new byte[size]; + ByteStreams.readFully(stream, bytes); + return bytes; + } + + @Override + protected ArrowBuf readBuffer(int size) throws IOException { + ArrowBuf buf = allocator.buffer(size); + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + buf.writerIndex(size); + return buf; + } + } + + /** Parses FlightData from an ArrowBuf, using zero-copy slicing for large fields. */ + static final class ArrowBufReader extends FlightDataReader { + private static final Logger LOG = LoggerFactory.getLogger(ArrowBufReader.class); + + private final ArrowBuf backingBuffer; + private final CodedInputStream codedInput; + + ArrowBufReader(BufferAllocator allocator, ArrowBuf backingBuffer) { + super(allocator); + this.backingBuffer = backingBuffer; + ByteBuffer buffer = backingBuffer.nioBuffer(0, (int) backingBuffer.capacity()); + this.codedInput = CodedInputStream.newInstance(buffer); + } + + static ArrowBufReader tryArrowBufReader(BufferAllocator allocator, InputStream stream) { + if (!(stream instanceof Detachable) || !(stream instanceof HasByteBuffer)) { + return null; + } + + HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; + if (!hasByteBuffer.byteBufferSupported()) { + return null; + } + + ByteBuffer peekBuffer = hasByteBuffer.getByteBuffer(); + if (peekBuffer == null || !peekBuffer.isDirect()) { + return null; + } + + try { + int available = stream.available(); + if (available > 0 && peekBuffer.remaining() < available) { + return null; + } + } catch (IOException ioe) { + return null; + } + + InputStream detachedStream = ((Detachable) stream).detach(); + ByteBuffer detachedBuffer = ((HasByteBuffer) detachedStream).getByteBuffer(); + + long bufferAddress = MemoryUtil.getByteBufferAddress(detachedBuffer); + int bufferSize = Objects.requireNonNull(detachedBuffer).remaining(); + + ForeignAllocation foreignAllocation = + new ForeignAllocation(bufferSize, bufferAddress + detachedBuffer.position()) { + @Override + protected void release0() { + closeQuietly(detachedStream); + } + }; + + try { + ArrowBuf backingBuffer = allocator.wrapForeignAllocation(foreignAllocation); + return new ArrowBufReader(allocator, backingBuffer); + } catch (Throwable t) { + closeQuietly(detachedStream); + throw t; + } + } + + private static void closeQuietly(InputStream stream) { + if (stream != null) { + try { + stream.close(); + } catch (IOException e) { + LOG.debug("Error closing detached gRPC stream", e); + } + } + } + + @Override + protected void cleanup() { + super.cleanup(); + backingBuffer.close(); + } + + @Override + protected boolean hasRemaining() throws IOException { + return !codedInput.isAtEnd(); + } + + @Override + protected int readTag() throws IOException { + int tag = codedInput.readTag(); + return tag == 0 ? -1 : tag; + } + + @Override + protected int readLength() throws IOException { + return codedInput.readRawVarint32(); + } + + @Override + protected byte[] readBytes(int size) throws IOException { + // Reads size bytes and creates a copy + return codedInput.readRawBytes(size); + } + + @Override + protected ArrowBuf readBuffer(int size) throws IOException { + // CodedInputStream advances the shared ByteBuffer; use its read count for zero-copy slicing. + int offset = codedInput.getTotalBytesRead(); + codedInput.skipRawBytes(size); + backingBuffer.getReferenceManager().retain(); + return backingBuffer.slice(offset, size); + } + } +} diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java deleted file mode 100644 index 45c32a86c6..0000000000 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight.grpc; - -import com.google.common.base.Throwables; -import com.google.common.io.ByteStreams; -import io.grpc.internal.ReadableBuffer; -import java.io.IOException; -import java.io.InputStream; -import java.lang.reflect.Field; -import org.apache.arrow.memory.ArrowBuf; - -/** - * Enable access to ReadableBuffer directly to copy data from a BufferInputStream into a target - * ByteBuffer/ByteBuf. - * - *

This could be solved by BufferInputStream exposing Drainable. - */ -public class GetReadableBuffer { - - private static final Field READABLE_BUFFER; - private static final Class BUFFER_INPUT_STREAM; - - static { - Field tmpField = null; - Class tmpClazz = null; - try { - Class clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); - - Field f = clazz.getDeclaredField("buffer"); - f.setAccessible(true); - // don't set until we've gotten past all exception cases. - tmpField = f; - tmpClazz = clazz; - } catch (Exception e) { - new RuntimeException("Failed to initialize GetReadableBuffer, falling back to slow path", e) - .printStackTrace(); - } - READABLE_BUFFER = tmpField; - BUFFER_INPUT_STREAM = tmpClazz; - } - - /** - * Extracts the ReadableBuffer for the given input stream. - * - * @param is Must be an instance of io.grpc.internal.ReadableBuffers$BufferInputStream or null - * will be returned. - */ - public static ReadableBuffer getReadableBuffer(InputStream is) { - - if (BUFFER_INPUT_STREAM == null || !is.getClass().equals(BUFFER_INPUT_STREAM)) { - return null; - } - - try { - return (ReadableBuffer) READABLE_BUFFER.get(is); - } catch (Exception ex) { - throw Throwables.propagate(ex); - } - } - - /** - * Helper method to read a gRPC-provided InputStream into an ArrowBuf. - * - * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}. - * @param buf The buffer to read into. - * @param size The number of bytes to read. - * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link - * #BUFFER_INPUT_STREAM}). - * @throws IOException if there is an error reading form the stream - */ - public static void readIntoBuffer( - final InputStream stream, final ArrowBuf buf, final int size, final boolean fastPath) - throws IOException { - ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null; - if (readableBuffer != null) { - readableBuffer.readBytes(buf.nioBuffer(0, size)); - } else { - byte[] heapBytes = new byte[size]; - ByteStreams.readFully(stream, heapBytes); - buf.writeBytes(heapBytes); - } - buf.writerIndex(size); - } -} diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java new file mode 100644 index 0000000000..df9852d02e --- /dev/null +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.common.collect.Iterables; +import com.google.common.io.ByteStreams; +import com.google.protobuf.ByteString; +import com.google.protobuf.CodedOutputStream; +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import io.grpc.protobuf.ProtoUtils; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import org.apache.arrow.flight.impl.Flight.FlightData; +import org.apache.arrow.flight.impl.Flight.FlightDescriptor; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests FlightData parsing including duplicate field handling, well-formed messages, and zero-copy + * behavior. Covers both InputStream (with copying) and ArrowBuf (zero-copy) parsing paths. Verifies + * that duplicate protobuf fields use last-occurrence-wins semantics without memory leaks. + */ +public class TestArrowMessageParse { + + private BufferAllocator allocator; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + /** Verifies duplicate app_metadata fields via InputStream path use last-occurrence-wins. */ + @Test + public void testDuplicateAppMetadataInputStream() throws Exception { + byte[] firstAppMetadata = new byte[] {1, 2, 3}; + byte[] secondAppMetadata = new byte[] {4, 5, 6, 7, 8}; + + byte[] serialized = + buildFlightDataDescriptors( + List.of( + Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, firstAppMetadata), + Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, secondAppMetadata))); + InputStream stream = new ByteArrayInputStream(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + // Use readableBytes() instead of capacity() since allocator may round up + assertEquals(secondAppMetadata.length, appMetadata.readableBytes()); + + byte[] actual = new byte[secondAppMetadata.length]; + appMetadata.getBytes(0, actual); + assertArrayEquals(secondAppMetadata, actual); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** + * Verifies duplicate app_metadata fields via zero-copy ArrowBuf path use last-occurrence-wins. + */ + @Test + public void testDuplicateAppMetadataArrowBuf() throws Exception { + byte[] firstAppMetadata = new byte[] {1, 2, 3}; + byte[] secondAppMetadata = new byte[] {4, 5, 6, 7, 8}; + + // Verify clean start + assertEquals(0, allocator.getAllocatedMemory()); + + byte[] serialized = + buildFlightDataDescriptors( + List.of( + Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, firstAppMetadata), + Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, secondAppMetadata))); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + assertEquals(secondAppMetadata.length, appMetadata.readableBytes()); + + byte[] actual = new byte[secondAppMetadata.length]; + appMetadata.getBytes(0, actual); + assertArrayEquals(secondAppMetadata, actual); + + // Zero-copy: only the backing buffer (serialized message) should be allocated + assertEquals(serialized.length, allocator.getAllocatedMemory()); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** Verifies duplicate body fields via InputStream path use last-occurrence-wins. */ + @Test + public void testDuplicateBodyInputStream() throws Exception { + byte[] firstBody = new byte[] {10, 20, 30}; + byte[] secondBody = new byte[] {40, 50, 60, 70}; + + byte[] serialized = + buildFlightDataDescriptors( + List.of( + Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, firstBody), + Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, secondBody))); + InputStream stream = new ByteArrayInputStream(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + assertEquals(secondBody.length, body.readableBytes()); + + byte[] actual = new byte[secondBody.length]; + body.getBytes(0, actual); + assertArrayEquals(secondBody, actual); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** Verifies duplicate body fields via zero-copy ArrowBuf path use last-occurrence-wins. */ + @Test + public void testDuplicateBodyArrowBuf() throws Exception { + byte[] firstBody = new byte[] {10, 20, 30}; + byte[] secondBody = new byte[] {40, 50, 60, 70}; + + // Verify clean start + assertEquals(0, allocator.getAllocatedMemory()); + + byte[] serialized = + buildFlightDataDescriptors( + List.of( + Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, firstBody), + Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, secondBody))); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + assertEquals(secondBody.length, body.readableBytes()); + + byte[] actual = new byte[secondBody.length]; + body.getBytes(0, actual); + assertArrayEquals(secondBody, actual); + + // Zero-copy: only the backing buffer (serialized message) should be allocated + assertEquals(serialized.length, allocator.getAllocatedMemory()); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** Verifies well-formed FlightData message parsing via InputStream path. */ + @Test + public void testFieldsInputStream() throws Exception { + byte[] appMetadataBytes = new byte[] {100, 101, 102}; + byte[] bodyBytes = new byte[] {50, 51, 52, 53, 54}; + FlightDescriptor expectedDescriptor = createTestDescriptor(); + + byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); + InputStream stream = new ByteArrayInputStream(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + // Verify descriptor + assertEquals(expectedDescriptor, message.getDescriptor()); + + // Verify header is present (Schema message type) + assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType()); + + // Verify app metadata + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + assertEquals(appMetadataBytes.length, appMetadata.readableBytes()); + byte[] actualAppMetadata = new byte[appMetadataBytes.length]; + appMetadata.getBytes(0, actualAppMetadata); + assertArrayEquals(appMetadataBytes, actualAppMetadata); + + // Verify body + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + assertEquals(bodyBytes.length, body.readableBytes()); + byte[] actualBody = new byte[bodyBytes.length]; + body.getBytes(0, actualBody); + assertArrayEquals(bodyBytes, actualBody); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** Verifies well-formed FlightData message parsing via zero-copy ArrowBuf path. */ + @Test + public void testFieldsArrowBuf() throws Exception { + byte[] appMetadataBytes = new byte[] {100, 101, 102}; + byte[] bodyBytes = new byte[] {50, 51, 52, 53, 54}; + FlightDescriptor expectedDescriptor = createTestDescriptor(); + + assertEquals(0, allocator.getAllocatedMemory()); + + byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + // Verify descriptor + assertEquals(expectedDescriptor, message.getDescriptor()); + + // Verify header is present (Schema message type) + assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType()); + + // Verify app metadata + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + assertEquals(appMetadataBytes.length, appMetadata.readableBytes()); + byte[] actualAppMetadata = new byte[appMetadataBytes.length]; + appMetadata.getBytes(0, actualAppMetadata); + assertArrayEquals(appMetadataBytes, actualAppMetadata); + + // Verify body + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + assertEquals(bodyBytes.length, body.readableBytes()); + byte[] actualBody = new byte[bodyBytes.length]; + body.getBytes(0, actualBody); + assertArrayEquals(bodyBytes, actualBody); + + // Zero-copy: only the backing buffer (serialized message) should be allocated + assertEquals(serialized.length, allocator.getAllocatedMemory()); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** Verifies that heap buffers fall back to InputStream path without calling detach(). */ + @Test + public void testHeapBufferFallbackDoesNotDetach() throws Exception { + byte[] appMetadataBytes = new byte[] {8, 9}; + byte[] bodyBytes = new byte[] {10, 11, 12}; + + byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); + MockGrpcInputStream stream = MockGrpcInputStream.ofHeapBuffer(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + assertNotNull(message.getDescriptor()); + assertEquals(0, stream.getDetachCount()); + } + } + + /** Verifies fallback to InputStream path when getByteBuffer() returns null. */ + @Test + public void testNullByteBufferFallbackToInputStream() throws Exception { + byte[] appMetadataBytes = new byte[] {20, 21, 22}; + byte[] bodyBytes = new byte[] {30, 31, 32, 33}; + FlightDescriptor expectedDescriptor = createTestDescriptor(); + + byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); + MockGrpcInputStream stream = new MockGrpcInputStream(ByteBuffer.wrap(serialized), false); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + assertEquals(expectedDescriptor, message.getDescriptor()); + + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + byte[] actualAppMetadata = new byte[appMetadataBytes.length]; + appMetadata.getBytes(0, actualAppMetadata); + assertArrayEquals(appMetadataBytes, actualAppMetadata); + + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + byte[] actualBody = new byte[bodyBytes.length]; + body.getBytes(0, actualBody); + assertArrayEquals(bodyBytes, actualBody); + + assertEquals(0, stream.getDetachCount()); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + // Helper methods to build complete FlightData messages + + private FlightDescriptor createTestDescriptor() { + return FlightDescriptor.newBuilder() + .setType(FlightDescriptor.DescriptorType.PATH) + .addPath("test") + .addPath("path") + .build(); + } + + private byte[] createSchemaHeader() { + Schema schema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()))); + ByteBuffer headerBuffer = MessageSerializer.serializeMetadata(schema, IpcOption.DEFAULT); + byte[] headerBytes = new byte[headerBuffer.remaining()]; + headerBuffer.get(headerBytes); + return headerBytes; + } + + private byte[] buildFlightDataWithBothFields(byte[] appMetadata, byte[] body) throws IOException { + FlightData flightData = + FlightData.newBuilder() + .setFlightDescriptor(createTestDescriptor()) + .setDataHeader(ByteString.copyFrom(createSchemaHeader())) + .setAppMetadata(ByteString.copyFrom(appMetadata)) + .setDataBody(ByteString.copyFrom(body)) + .build(); + try (InputStream grpcStream = + ProtoUtils.marshaller(FlightData.getDefaultInstance()).stream(flightData)) { + return ByteStreams.toByteArray(grpcStream); + } + } + + // Helper methods to build FlightData messages with duplicate fields + + private byte[] buildFlightDataDescriptors(List> descriptors) + throws IOException { + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + CodedOutputStream cos = CodedOutputStream.newInstance(baos); + + for (Pair descriptor : descriptors) { + cos.writeBytes(descriptor.getKey(), ByteString.copyFrom(descriptor.getValue())); + } + cos.flush(); + return baos.toByteArray(); + } + + /** Mock InputStream implementing gRPC's Detachable and HasByteBuffer for testing zero-copy. */ + private static class MockGrpcInputStream extends InputStream + implements Detachable, HasByteBuffer { + private ByteBuffer buffer; + private final boolean byteBufferSupported; + private int detachCount; + + private MockGrpcInputStream(ByteBuffer buffer, boolean byteBufferSupported) { + this.buffer = buffer; + this.byteBufferSupported = byteBufferSupported; + } + + static MockGrpcInputStream ofDirectBuffer(byte[] data) { + ByteBuffer buf = ByteBuffer.allocateDirect(data.length); + buf.put(data).flip(); + return new MockGrpcInputStream(buf, true); + } + + static MockGrpcInputStream ofHeapBuffer(byte[] data) { + return new MockGrpcInputStream(ByteBuffer.wrap(data), true); + } + + @Override + public boolean byteBufferSupported() { + return byteBufferSupported; + } + + @Override + public ByteBuffer getByteBuffer() { + return byteBufferSupported ? buffer : null; + } + + @Override + public InputStream detach() { + detachCount++; + ByteBuffer detached = this.buffer; + this.buffer = null; + return new MockGrpcInputStream(detached, byteBufferSupported); + } + + int getDetachCount() { + return detachCount; + } + + @Override + public int read() { + return (buffer != null && buffer.hasRemaining()) ? (buffer.get() & 0xFF) : -1; + } + + @Override + public int available() { + return buffer == null ? 0 : buffer.remaining(); + } + + @Override + public void close() { + buffer = null; + } + } +}