Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bigquery/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version="1.32", features=["macros"] }
time = { version = "0.3", features = ["std", "macros", "formatting", "parsing", "serde"] }
arrow = { version = "56.1.0", default-features = false, features = ["ipc"] }
arrow = { version = "58.1.0", default-features = false, features = ["ipc"] }
base64 = "0.22"
bigdecimal = { version="0.4", features=["serde"] }
num-bigint = "0.4"
Expand Down
146 changes: 131 additions & 15 deletions bigquery/src/storage_write/mod.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,76 @@
use std::collections::HashMap;

use arrow::error::ArrowError;
use arrow::ipc::writer::{
write_message, CompressionContext, DictionaryTracker, EncodedData, IpcDataGenerator, IpcWriteOptions,
};
use arrow::record_batch::RecordBatch;
use google_cloud_gax::grpc::codegen::tokio_stream::Stream;
use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::{ProtoData, Rows};
use google_cloud_googleapis::cloud::bigquery::storage::v1::{AppendRowsRequest, ProtoRows, ProtoSchema};
use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::{ArrowData, ProtoData, Rows};
use google_cloud_googleapis::cloud::bigquery::storage::v1::{
AppendRowsRequest, ArrowRecordBatch, ArrowSchema, ProtoRows, ProtoSchema,
};
use prost_types::DescriptorProto;
use std::collections::HashMap;

mod flow;
pub mod stream;

enum Payload {
Proto {
schema: DescriptorProto,
rows: Vec<Vec<u8>>,
},
Arrow {
serialized_schema: Vec<u8>,
serialized_record_batch: Vec<u8>,
},
}

pub struct AppendRowsRequestBuilder {
offset: Option<i64>,
trace_id: Option<String>,
missing_value_interpretations: Option<HashMap<String, i32>>,
default_missing_value_interpretation: Option<i32>,
data: Vec<Vec<u8>>,
schema: DescriptorProto,
payload: Payload,
}

impl AppendRowsRequestBuilder {
pub fn new(schema: DescriptorProto, data: Vec<Vec<u8>>) -> Self {
Self::with_payload(Payload::Proto { schema, rows: data })
}

pub fn new_arrow(serialized_schema: Vec<u8>, serialized_record_batch: Vec<u8>) -> Self {
Self::with_payload(Payload::Arrow {
serialized_schema,
serialized_record_batch,
})
}

pub fn from_record_batch(batch: &RecordBatch) -> Result<Self, ArrowError> {
let options = IpcWriteOptions::default();
let generator = IpcDataGenerator::default();
let mut dict_tracker = DictionaryTracker::new(true);
let mut compression = CompressionContext::default();

let schema_encoded =
generator.schema_to_bytes_with_dictionary_tracker(&batch.schema(), &mut dict_tracker, &options);
let serialized_schema = encoded_to_bytes(vec![schema_encoded], &options)?;

let (dict_encoded, batch_encoded) = generator.encode(batch, &mut dict_tracker, &options, &mut compression)?;
let mut encoded = dict_encoded;
encoded.push(batch_encoded);
let serialized_record_batch = encoded_to_bytes(encoded, &options)?;

Ok(Self::new_arrow(serialized_schema, serialized_record_batch))
}

fn with_payload(payload: Payload) -> Self {
Self {
offset: None,
trace_id: None,
missing_value_interpretations: None,
default_missing_value_interpretation: None,
data,
schema,
payload,
}
}

Expand All @@ -49,28 +95,98 @@ impl AppendRowsRequestBuilder {
}

pub(crate) fn build(self, stream: &str) -> AppendRowsRequest {
let rows = match self.payload {
Payload::Proto { schema, rows } => Rows::ProtoRows(ProtoData {
writer_schema: Some(ProtoSchema {
proto_descriptor: Some(schema),
}),
rows: Some(ProtoRows { serialized_rows: rows }),
}),
Payload::Arrow {
serialized_schema,
serialized_record_batch,
} => Rows::ArrowRows(ArrowData {
writer_schema: Some(ArrowSchema { serialized_schema }),
rows: Some(ArrowRecordBatch {
serialized_record_batch,
#[allow(deprecated)]
row_count: 0,
}),
}),
};
AppendRowsRequest {
write_stream: stream.to_string(),
offset: self.offset,
trace_id: self.trace_id.unwrap_or_default(),
missing_value_interpretations: self.missing_value_interpretations.unwrap_or_default(),
default_missing_value_interpretation: self.default_missing_value_interpretation.unwrap_or(0),
rows: Some(Rows::ProtoRows(ProtoData {
writer_schema: Some(ProtoSchema {
proto_descriptor: Some(self.schema),
}),
rows: Some(ProtoRows {
serialized_rows: self.data,
}),
})),
rows: Some(rows),
}
}
}

fn encoded_to_bytes(messages: Vec<EncodedData>, options: &IpcWriteOptions) -> Result<Vec<u8>, ArrowError> {
let mut buf = Vec::new();
for message in messages {
write_message(&mut buf, message, options)?;
}
Ok(buf)
}

pub fn into_streaming_request(rows: Vec<AppendRowsRequest>) -> impl Stream<Item = AppendRowsRequest> {
async_stream::stream! {
for row in rows {
yield row;
}
}
}

#[cfg(test)]
mod tests {
use std::io::{BufReader, Cursor};
use std::sync::Arc;

use arrow::array::{Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::ipc::reader::StreamReader;
use arrow::record_batch::RecordBatch;
use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::Rows;

use super::AppendRowsRequestBuilder;

fn sample_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let ids = Arc::new(Int64Array::from(vec![1, 2, 3]));
let names = Arc::new(StringArray::from(vec!["a", "b", "c"]));
RecordBatch::try_new(schema, vec![ids, names]).unwrap()
}

#[test]
fn from_record_batch_emits_arrow_rows_and_round_trips() {
let batch = sample_batch();
let expected_rows = batch.num_rows();

let builder = AppendRowsRequestBuilder::from_record_batch(&batch).unwrap();
let request = builder.build("projects/p/datasets/d/tables/t/streams/_default");

let Rows::ArrowRows(arrow_data) = request.rows.expect("rows set") else {
panic!("expected Arrow rows variant");
};
let schema_bytes = arrow_data.writer_schema.expect("writer_schema").serialized_schema;
let batch_bytes = arrow_data.rows.expect("rows").serialized_record_batch;
assert!(!schema_bytes.is_empty());
assert!(!batch_bytes.is_empty());

// Mirror storage.rs: concat schema + batch and decode with StreamReader.
let mut combined = schema_bytes;
combined.extend_from_slice(&batch_bytes);
let reader = StreamReader::try_new(BufReader::new(Cursor::new(combined)), None).unwrap();
let decoded: Vec<RecordBatch> = reader.collect::<Result<_, _>>().unwrap();
assert_eq!(decoded.len(), 1);
assert_eq!(decoded[0].num_rows(), expected_rows);
assert_eq!(decoded[0].num_columns(), 2);
}
}
44 changes: 43 additions & 1 deletion bigquery/src/storage_write/stream/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl AsStream for DefaultStream {
#[cfg(test)]
mod tests {
use crate::client::{Client, ClientConfig};
use crate::storage_write::stream::tests::{create_append_rows_request, TestData};
use crate::storage_write::stream::tests::{create_append_rows_request, create_arrow_append_rows_request, TestData};
use futures_util::StreamExt;
use google_cloud_gax::grpc::Status;
use prost::Message;
Expand Down Expand Up @@ -117,4 +117,46 @@ mod tests {
task.await.unwrap().unwrap();
}
}

#[serial_test::serial]
#[tokio::test]
async fn test_storage_write_arrow() {
let (config, project_id) = ClientConfig::new_with_auth().await.unwrap();
let project_id = project_id.unwrap();
let client = Client::new(config).await.unwrap();
let tables = ["write_test", "write_test_1"];
let writer = client.default_storage_writer();

let mut streams = vec![];
for i in 0..2 {
let table = format!(
"projects/{}/datasets/gcrbq_storage/tables/{}",
&project_id,
tables[i % tables.len()]
);
let stream = writer.create_write_stream(&table).await.unwrap();
streams.push(stream);
}

let mut tasks: Vec<JoinHandle<Result<(), Status>>> = vec![];
for (i, stream) in streams.into_iter().enumerate() {
tasks.push(tokio::spawn(async move {
let mut rows = vec![];
for j in 0..5 {
let values = (0..3).map(|k| format!("arrow_default_{i}_{j}_{k}")).collect();
rows.push(create_arrow_append_rows_request(values));
}
let mut result = stream.append_rows(rows).await.unwrap();
while let Some(res) = result.next().await {
let res = res?;
tracing::info!("append row errors = {:?}", res.row_errors.len());
}
Ok(())
}));
}

for task in tasks {
task.await.unwrap().unwrap();
}
}
}
15 changes: 14 additions & 1 deletion bigquery/src/storage_write/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,15 @@ impl DisposableStreamDelegate {

#[cfg(test)]
pub(crate) mod tests {
use crate::storage_write::AppendRowsRequestBuilder;
use std::sync::Arc;

use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use prost_types::{field_descriptor_proto, DescriptorProto, FieldDescriptorProto};

use crate::storage_write::AppendRowsRequestBuilder;

#[derive(Clone, PartialEq, ::prost::Message)]
pub(crate) struct TestData {
#[prost(string, tag = "1")]
Expand Down Expand Up @@ -150,4 +156,11 @@ pub(crate) mod tests {
};
AppendRowsRequestBuilder::new(proto, buf)
}

pub(crate) fn create_arrow_append_rows_request(values: Vec<String>) -> AppendRowsRequestBuilder {
let schema = Arc::new(Schema::new(vec![Field::new("col_string", DataType::Utf8, false)]));
let col = Arc::new(StringArray::from(values));
let batch = RecordBatch::try_new(schema, vec![col]).unwrap();
AppendRowsRequestBuilder::from_record_batch(&batch).unwrap()
}
}
Loading