diff --git a/ruby/red-arrow-format/lib/arrow-format/array.rb b/ruby/red-arrow-format/lib/arrow-format/array.rb index 6474869c30c..be30bffe48e 100644 --- a/ruby/red-arrow-format/lib/arrow-format/array.rb +++ b/ruby/red-arrow-format/lib/arrow-format/array.rb @@ -79,54 +79,34 @@ def initialize(type, size, validity_buffer, values_buffer) super(type, size, validity_buffer) @values_buffer = values_buffer end - end - class Int8Array < IntArray def to_a - apply_validity(@values_buffer.values(:S8, 0, @size)) + apply_validity(@values_buffer.values(@type.buffer_type, 0, @size)) end end + class Int8Array < IntArray + end + class UInt8Array < IntArray - def to_a - apply_validity(@values_buffer.values(:U8, 0, @size)) - end end class Int16Array < IntArray - def to_a - apply_validity(@values_buffer.values(:s16, 0, @size)) - end end class UInt16Array < IntArray - def to_a - apply_validity(@values_buffer.values(:u16, 0, @size)) - end end class Int32Array < IntArray - def to_a - apply_validity(@values_buffer.values(:s32, 0, @size)) - end end class UInt32Array < IntArray - def to_a - apply_validity(@values_buffer.values(:u32, 0, @size)) - end end class Int64Array < IntArray - def to_a - apply_validity(@values_buffer.values(:s64, 0, @size)) - end end class UInt64Array < IntArray - def to_a - apply_validity(@values_buffer.values(:u64, 0, @size)) - end end class FloatingPointArray < Array @@ -393,6 +373,27 @@ def to_a end end + class MapArray < VariableSizeListArray + def to_a + super.collect do |entries| + if entries.nil? + entries + else + hash = {} + entries.each do |key, value| + hash[key] = value + end + hash + end + end + end + + private + def offset_type + :s32 # TODO: big endian support + end + end + class UnionArray < Array def initialize(type, size, types_buffer, children) super(type, size, nil) @@ -432,24 +433,27 @@ def to_a end end - class MapArray < VariableSizeListArray + class DictionaryArray < Array + def initialize(type, size, validity_buffer, indices_buffer, dictionary) + super(type, size, validity_buffer) + @indices_buffer = indices_buffer + @dictionary = dictionary + end + def to_a - super.collect do |entries| - if entries.nil? - entries + values = [] + @dictionary.each do |dictionary_chunk| + values.concat(dictionary_chunk.to_a) + end + buffer_type = @type.index_type.buffer_type + indices = apply_validity(@indices_buffer.values(buffer_type, 0, @size)) + indices.collect do |index| + if index.nil? + nil else - hash = {} - entries.each do |key, value| - hash[key] = value - end - hash + values[index] end end end - - private - def offset_type - :s32 # TODO: big endian support - end end end diff --git a/ruby/red-arrow-format/lib/arrow-format/field.rb b/ruby/red-arrow-format/lib/arrow-format/field.rb index ac531750f76..090113cfe6b 100644 --- a/ruby/red-arrow-format/lib/arrow-format/field.rb +++ b/ruby/red-arrow-format/lib/arrow-format/field.rb @@ -18,10 +18,12 @@ module ArrowFormat class Field attr_reader :name attr_reader :type - def initialize(name, type, nullable) + attr_reader :dictionary_id + def initialize(name, type, nullable, dictionary_id) @name = name @type = type @nullable = nullable + @dictionary_id = dictionary_id end def nullable? diff --git a/ruby/red-arrow-format/lib/arrow-format/file-reader.rb b/ruby/red-arrow-format/lib/arrow-format/file-reader.rb index bf50bfd1cd3..545638ca902 100644 --- a/ruby/red-arrow-format/lib/arrow-format/file-reader.rb +++ b/ruby/red-arrow-format/lib/arrow-format/file-reader.rb @@ -49,17 +49,65 @@ def initialize(input) validate @footer = read_footer - @record_batches = @footer.record_batches + @record_batch_blocks = @footer.record_batches @schema = read_schema(@footer.schema) + @dictionaries = read_dictionaries end def n_record_batches - @record_batches.size + @record_batch_blocks.size end def read(i) - block = @record_batches[i] + fb_message, body = read_block(@record_batch_blocks[i]) + fb_header = fb_message.header + unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch) + raise FileReadError.new(@buffer, + "Not a record batch message: #{i}: " + + fb_header.class.name) + end + read_record_batch(fb_header, @schema, body) + end + + def each + return to_enum(__method__) {n_record_batches} unless block_given? + + @record_batch_blocks.size.times do |i| + yield(read(i)) + end + end + + private + def validate + minimum_size = STREAMING_FORMAT_START_OFFSET + + FOOTER_SIZE_SIZE + + END_MARKER_SIZE + if @buffer.size < minimum_size + raise FileReadError.new(@buffer, + "Input must be larger than or equal to " + + "#{minimum_size}: #{@buffer.size}") + end + + start_marker = @buffer.slice(0, START_MARKER_SIZE) + if start_marker != MAGIC_BUFFER + raise FileReadError.new(@buffer, "No start marker") + end + end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE, + END_MARKER_SIZE) + if end_marker != MAGIC_BUFFER + raise FileReadError.new(@buffer, "No end marker") + end + end + + def read_footer + footer_size_offset = @buffer.size - END_MARKER_SIZE - FOOTER_SIZE_SIZE + footer_size = @buffer.get_value(FOOTER_SIZE_FORMAT, footer_size_offset) + footer_data = @buffer.slice(footer_size_offset - footer_size, + footer_size) + Org::Apache::Arrow::Flatbuf::Footer.new(footer_data) + end + def read_block(block) offset = block.offset # If we can report property error information, we can use @@ -101,54 +149,65 @@ def read(i) metadata = @buffer.slice(offset, metadata_length) fb_message = Org::Apache::Arrow::Flatbuf::Message.new(metadata) - fb_header = fb_message.header - unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch) - raise FileReadError.new(@buffer, - "Not a record batch message: #{i}: " + - fb_header.class.name) - end offset += metadata_length body = @buffer.slice(offset, block.body_length) - read_record_batch(fb_header, @schema, body) - end - def each - return to_enum(__method__) {n_record_batches} unless block_given? - - @record_batches.size.times do |i| - yield(read(i)) - end + [fb_message, body] end - private - def validate - minimum_size = STREAMING_FORMAT_START_OFFSET + - FOOTER_SIZE_SIZE + - END_MARKER_SIZE - if @buffer.size < minimum_size - raise FileReadError.new(@buffer, - "Input must be larger than or equal to " + - "#{minimum_size}: #{@buffer.size}") - end + def read_dictionaries + dictionary_blocks = @footer.dictionaries + return nil if dictionary_blocks.nil? - start_marker = @buffer.slice(0, START_MARKER_SIZE) - if start_marker != MAGIC_BUFFER - raise FileReadError.new(@buffer, "No start marker") + dictionary_fields = {} + @schema.fields.each do |field| + next unless field.type.is_a?(DictionaryType) + dictionary_fields[field.dictionary_id] = field end - end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE, - END_MARKER_SIZE) - if end_marker != MAGIC_BUFFER - raise FileReadError.new(@buffer, "No end marker") + + dictionaries = {} + dictionary_blocks.each do |block| + fb_message, body = read_block(block) + fb_header = fb_message.header + unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::DictionaryBatch) + raise FileReadError.new(@buffer, + "Not a dictionary batch message: " + + fb_header.inspect) + end + + id = fb_header.id + if fb_header.delta? + unless dictionaries.key?(id) + raise FileReadError.new(@buffer, + "A delta dictionary batch message " + + "must exist after a non delta " + + "dictionary batch message: " + + fb_header.inspect) + end + else + if dictionaries.key?(id) + raise FileReadError.new(@buffer, + "Multiple non delta dictionary batch " + + "messages for the same ID is invalid: " + + fb_header.inspect) + end + end + + value_type = dictionary_fields[id].type.value_type + schema = Schema.new([Field.new("dummy", value_type, true, nil)]) + record_batch = read_record_batch(fb_header.data, schema, body) + if fb_header.delta? + dictionaries[id] << record_batch.columns[0] + else + dictionaries[id] = [record_batch.columns[0]] + end end + dictionaries end - def read_footer - footer_size_offset = @buffer.size - END_MARKER_SIZE - FOOTER_SIZE_SIZE - footer_size = @buffer.get_value(FOOTER_SIZE_FORMAT, footer_size_offset) - footer_data = @buffer.slice(footer_size_offset - footer_size, - footer_size) - Org::Apache::Arrow::Flatbuf::Footer.new(footer_data) + def find_dictionary(id) + @dictionaries[id] end end end diff --git a/ruby/red-arrow-format/lib/arrow-format/readable.rb b/ruby/red-arrow-format/lib/arrow-format/readable.rb index 5a247c822a4..11db4685d90 100644 --- a/ruby/red-arrow-format/lib/arrow-format/readable.rb +++ b/ruby/red-arrow-format/lib/arrow-format/readable.rb @@ -26,6 +26,8 @@ require_relative "org/apache/arrow/flatbuf/date" require_relative "org/apache/arrow/flatbuf/date_unit" require_relative "org/apache/arrow/flatbuf/decimal" +require_relative "org/apache/arrow/flatbuf/dictionary_encoding" +require_relative "org/apache/arrow/flatbuf/dictionary_batch" require_relative "org/apache/arrow/flatbuf/duration" require_relative "org/apache/arrow/flatbuf/fixed_size_binary" require_relative "org/apache/arrow/flatbuf/floating_point" @@ -40,11 +42,12 @@ require_relative "org/apache/arrow/flatbuf/message" require_relative "org/apache/arrow/flatbuf/null" require_relative "org/apache/arrow/flatbuf/precision" +require_relative "org/apache/arrow/flatbuf/record_batch" require_relative "org/apache/arrow/flatbuf/schema" require_relative "org/apache/arrow/flatbuf/struct_" require_relative "org/apache/arrow/flatbuf/time" -require_relative "org/apache/arrow/flatbuf/timestamp" require_relative "org/apache/arrow/flatbuf/time_unit" +require_relative "org/apache/arrow/flatbuf/timestamp" require_relative "org/apache/arrow/flatbuf/union" require_relative "org/apache/arrow/flatbuf/union_mode" require_relative "org/apache/arrow/flatbuf/utf8" @@ -67,32 +70,7 @@ def read_field(fb_field) when Org::Apache::Arrow::Flatbuf::Bool type = BooleanType.singleton when Org::Apache::Arrow::Flatbuf::Int - case fb_type.bit_width - when 8 - if fb_type.signed? - type = Int8Type.singleton - else - type = UInt8Type.singleton - end - when 16 - if fb_type.signed? - type = Int16Type.singleton - else - type = UInt16Type.singleton - end - when 32 - if fb_type.signed? - type = Int32Type.singleton - else - type = UInt32Type.singleton - end - when 64 - if fb_type.signed? - type = Int64Type.singleton - else - type = UInt64Type.singleton - end - end + type = read_type_int(fb_type) when Org::Apache::Arrow::Flatbuf::FloatingPoint case fb_type.precision when Org::Apache::Arrow::Flatbuf::Precision::SINGLE @@ -173,14 +151,52 @@ def read_field(fb_field) type = Decimal128Type.new(fb_type.precision, fb_type.scale) end end - Field.new(fb_field.name, type, fb_field.nullable?) + + dictionary = fb_field.dictionary + if dictionary + dictionary_id = dictionary.id + index_type = read_type_int(dictionary.index_type) + type = DictionaryType.new(index_type, type, dictionary.ordered?) + else + dictionary_id = nil + end + Field.new(fb_field.name, type, fb_field.nullable?, dictionary_id) + end + + def read_type_int(fb_type) + case fb_type.bit_width + when 8 + if fb_type.signed? + Int8Type.singleton + else + UInt8Type.singleton + end + when 16 + if fb_type.signed? + Int16Type.singleton + else + UInt16Type.singleton + end + when 32 + if fb_type.signed? + Int32Type.singleton + else + UInt32Type.singleton + end + when 64 + if fb_type.signed? + Int64Type.singleton + else + UInt64Type.singleton + end + end end def read_record_batch(fb_record_batch, schema, body) n_rows = fb_record_batch.length nodes = fb_record_batch.nodes buffers = fb_record_batch.buffers - columns = @schema.fields.collect do |field| + columns = schema.fields.collect do |field| read_column(field, nodes, buffers, body) end RecordBatch.new(schema, n_rows, columns) @@ -242,6 +258,11 @@ def read_column(field, nodes, buffers, body) read_column(child, nodes, buffers, body) end field.type.build_array(length, types, children) + when DictionaryType + indices_buffer = buffers.shift + indices = body.slice(indices_buffer.offset, indices_buffer.length) + dictionary = find_dictionary(field.dictionary_id) + field.type.build_array(length, validity, indices, dictionary) end end end diff --git a/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb b/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb index ae231fccbc6..8682f3e826b 100644 --- a/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb +++ b/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb @@ -151,6 +151,8 @@ def initialize(&on_read) end @state = :schema @schema = nil + @dictionaries = nil + @dictionary_fields = nil end def next_required_size @@ -170,8 +172,23 @@ def process_message(message, body) case @state when :schema process_schema_message(message, body) - when :record_batch - process_record_batch_message(message, body) + when :initial_dictionaries + header = message.header + unless header.is_a?(Org::Apache::Arrow::Flatbuf::DictionaryBatch) + raise ReadError.new("Not a dictionary batch message: " + + header.inspect) + end + process_dictionary_batch_message(message, body) + if @dictionaries.size == @dictionary_fields.size + @state = :data + end + when :data + case message.header + when Org::Apache::Arrow::Flatbuf::DictionaryBatch + process_dictionary_batch_message(message, body) + when Org::Apache::Arrow::Flatbuf::RecordBatch + process_record_batch_message(message, body) + end end end @@ -183,17 +200,43 @@ def process_schema_message(message, body) end @schema = read_schema(header) - # TODO: initial dictionaries support - @state = :record_batch + @dictionaries = {} + @dictionary_fields = {} + @schema.fields.each do |field| + next unless field.type.is_a?(DictionaryType) + @dictionary_fields[field.dictionary_id] = field + end + if @dictionaries.size < @dictionary_fields.size + @state = :initial_dictionaries + else + @state = :data + end end - def process_record_batch_message(message, body) + def process_dictionary_batch_message(message, body) header = message.header - unless header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch) - raise ReadError.new("Not a record batch message: " + + if @state == :initial_dictionaries and header.delta? + raise ReadError.new("An initial dictionary batch message must be " + + "a non delta dictionary batch message: " + header.inspect) end + field = @dictionary_fields[header.id] + value_type = field.type.value_type + schema = Schema.new([Field.new("dummy", value_type, true, nil)]) + record_batch = read_record_batch(header.data, schema, body) + if header.delta? + @dictionaries[header.id] << record_batch.columns[0] + else + @dictionaries[header.id] = [record_batch.columns[0]] + end + end + def find_dictionary(id) + @dictionaries[id] + end + + def process_record_batch_message(message, body) + header = message.header @on_read.call(read_record_batch(header, @schema, body)) end end diff --git a/ruby/red-arrow-format/lib/arrow-format/type.rb b/ruby/red-arrow-format/lib/arrow-format/type.rb index d6d8b7bb81a..a8a73f8ea79 100644 --- a/ruby/red-arrow-format/lib/arrow-format/type.rb +++ b/ruby/red-arrow-format/lib/arrow-format/type.rb @@ -78,6 +78,10 @@ def name "Int8" end + def buffer_type + :S8 + end + def build_array(size, validity_buffer, values_buffer) Int8Array.new(self, size, validity_buffer, values_buffer) end @@ -98,6 +102,10 @@ def name "UInt8" end + def buffer_type + :U8 + end + def build_array(size, validity_buffer, values_buffer) UInt8Array.new(self, size, validity_buffer, values_buffer) end @@ -118,6 +126,10 @@ def name "Int16" end + def buffer_type + :s16 + end + def build_array(size, validity_buffer, values_buffer) Int16Array.new(self, size, validity_buffer, values_buffer) end @@ -138,6 +150,10 @@ def name "UInt16" end + def buffer_type + :u16 + end + def build_array(size, validity_buffer, values_buffer) UInt16Array.new(self, size, validity_buffer, values_buffer) end @@ -158,6 +174,10 @@ def name "Int32" end + def buffer_type + :s32 + end + def build_array(size, validity_buffer, values_buffer) Int32Array.new(self, size, validity_buffer, values_buffer) end @@ -178,6 +198,10 @@ def name "UInt32" end + def buffer_type + :u32 + end + def build_array(size, validity_buffer, values_buffer) UInt32Array.new(self, size, validity_buffer, values_buffer) end @@ -198,6 +222,10 @@ def name "Int64" end + def buffer_type + :s64 + end + def build_array(size, validity_buffer, values_buffer) Int64Array.new(self, size, validity_buffer, values_buffer) end @@ -218,6 +246,10 @@ def name "UInt64" end + def buffer_type + :u64 + end + def build_array(size, validity_buffer, values_buffer) UInt64Array.new(self, size, validity_buffer, values_buffer) end @@ -631,4 +663,28 @@ def build_array(size, types_buffer, children) SparseUnionArray.new(self, size, types_buffer, children) end end + + class DictionaryType < Type + attr_reader :index_type + attr_reader :value_type + attr_reader :ordered + def initialize(index_type, value_type, ordered) + super() + @index_type = index_type + @value_type = value_type + @ordered = ordered + end + + def name + "Dictionary" + end + + def build_array(size, validity_buffer, indices_buffer, dictionary) + DictionaryArray.new(self, + size, + validity_buffer, + indices_buffer, + dictionary) + end + end end diff --git a/ruby/red-arrow-format/test/test-reader.rb b/ruby/red-arrow-format/test/test-reader.rb index cddcea484fb..654158f0478 100644 --- a/ruby/red-arrow-format/test/test-reader.rb +++ b/ruby/red-arrow-format/test/test-reader.rb @@ -824,6 +824,19 @@ def test_read read) end end + + sub_test_case("Dictionary") do + def build_array + values = ["a", "b", "c", nil, "a"] + string_array = Arrow::StringArray.new(values) + string_array.dictionary_encode + end + + def test_read + assert_equal([{"value" => ["a", "b", "c", nil, "a"]}], + read) + end + end end end end