Skip to content

Commit 24d0514

Browse files
authored
Add Decimal128 support. (returnString#297)
1 parent e8df7f7 commit 24d0514

7 files changed

Lines changed: 40 additions & 12 deletions

File tree

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
/target
22
Cargo.lock
33
.idea/
4-

convergence-arrow/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,5 @@ async-trait = "0.1"
1313
datafusion = "47"
1414
convergence = { path = "../convergence", version = "0.16.0" }
1515
chrono = "0.4"
16-
17-
[dev-dependencies]
1816
tokio-postgres = { version = "0.7", features = [ "with-chrono-0_4" ] }
17+
rust_decimal = { version = "1.37.1", features = ["default", "db-postgres"] }

convergence-arrow/src/table.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlState};
44
use convergence::protocol_ext::DataRowBatch;
55
use datafusion::arrow::array::{
6-
BooleanArray, Date32Array, Date64Array, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array,
7-
Int64Array, Int8Array, StringArray, StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
8-
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
6+
BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array,
7+
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
8+
StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
9+
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array
910
};
1011
use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit};
1112
use datafusion::arrow::record_batch::RecordBatch;
@@ -47,6 +48,7 @@ pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBat
4748
DataType::Float16 => row.write_float4(array_val!(Float16Array, col, row_idx).to_f32()),
4849
DataType::Float32 => row.write_float4(array_val!(Float32Array, col, row_idx)),
4950
DataType::Float64 => row.write_float8(array_val!(Float64Array, col, row_idx)),
51+
DataType::Decimal128(p, s) => row.write_numeric_16(array_val!(Decimal128Array, col, row_idx), p, s),
5052
DataType::Utf8 => row.write_string(array_val!(StringArray, col, row_idx)),
5153
DataType::Utf8View => row.write_string(array_val!(StringViewArray, col, row_idx)),
5254
DataType::Date32 => {
@@ -103,6 +105,7 @@ pub fn data_type_to_oid(ty: &DataType) -> Result<DataTypeOid, ErrorResponse> {
103105
DataType::UInt64 => DataTypeOid::Int8,
104106
DataType::Float16 | DataType::Float32 => DataTypeOid::Float4,
105107
DataType::Float64 => DataTypeOid::Float8,
108+
DataType::Decimal128(_, _) => DataTypeOid::Numeric,
106109
DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text,
107110
DataType::Date32 | DataType::Date64 => DataTypeOid::Date,
108111
DataType::Timestamp(_, None) => DataTypeOid::Timestamp,

convergence-arrow/tests/test_arrow.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use convergence::protocol_ext::DataRowBatch;
66
use convergence::server::{self, BindOptions};
77
use convergence::sqlparser::ast::Statement;
88
use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc};
9-
use datafusion::arrow::array::{ArrayRef, Date32Array, Float32Array, Int32Array, StringArray, StringViewArray, TimestampSecondArray};
9+
use datafusion::arrow::array::{ArrayRef, Date32Array, Decimal128Array, Float32Array, Int32Array, StringArray, StringViewArray, TimestampSecondArray};
1010
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
1111
use datafusion::arrow::record_batch::RecordBatch;
1212
use std::sync::Arc;
13+
use rust_decimal::Decimal;
1314
use tokio_postgres::{connect, NoTls};
1415

1516
struct ArrowPortal {
@@ -31,6 +32,7 @@ impl ArrowEngine {
3132
fn new() -> Self {
3233
let int_col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
3334
let float_col = Arc::new(Float32Array::from(vec![1.5, 2.5, 3.5])) as ArrayRef;
35+
let decimal_col = Arc::new(Decimal128Array::from(vec![11, 22, 33]).with_precision_and_scale(2, 0).unwrap()) as ArrayRef;
3436
let string_col = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
3537
let string_view_col = Arc::new(StringViewArray::from(vec!["aa", "bb", "cc"])) as ArrayRef;
3638
let ts_col = Arc::new(TimestampSecondArray::from(vec![1577836800, 1580515200, 1583020800])) as ArrayRef;
@@ -39,14 +41,15 @@ impl ArrowEngine {
3941
let schema = Schema::new(vec![
4042
Field::new("int_col", DataType::Int32, true),
4143
Field::new("float_col", DataType::Float32, true),
44+
Field::new("decimal_col", DataType::Decimal128(2, 0), true),
4245
Field::new("string_col", DataType::Utf8, true),
4346
Field::new("string_view_col", DataType::Utf8View, true),
4447
Field::new("ts_col", DataType::Timestamp(TimeUnit::Second, None), true),
4548
Field::new("date_col", DataType::Date32, true),
4649
]);
4750

4851
Self {
49-
batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, string_col, string_view_col, ts_col, date_col])
52+
batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, decimal_col, string_col, string_view_col, ts_col, date_col])
5053
.expect("failed to create batch"),
5154
}
5255
}
@@ -91,8 +94,8 @@ async fn basic_data_types() {
9194
let rows = client.query("select 1", &[]).await.unwrap();
9295
let get_row = |idx: usize| {
9396
let row = &rows[idx];
94-
let cols: (i32, f32, &str, &str, NaiveDateTime, NaiveDate) =
95-
(row.get(0), row.get(1), row.get(2), row.get(3), row.get(4), row.get(5));
97+
let cols: (i32, f32, Decimal, &str, &str, NaiveDateTime, NaiveDate) =
98+
(row.get(0), row.get(1), row.get(2), row.get(3), row.get(4), row.get(5), row.get(6));
9699
cols
97100
};
98101

@@ -101,6 +104,7 @@ async fn basic_data_types() {
101104
(
102105
1,
103106
1.5,
107+
Decimal::from(11),
104108
"a",
105109
"aa",
106110
NaiveDate::from_ymd_opt(2020, 1, 1)
@@ -115,6 +119,7 @@ async fn basic_data_types() {
115119
(
116120
2,
117121
2.5,
122+
Decimal::from(22),
118123
"b",
119124
"bb",
120125
NaiveDate::from_ymd_opt(2020, 2, 1)
@@ -129,6 +134,7 @@ async fn basic_data_types() {
129134
(
130135
3,
131136
3.5,
137+
Decimal::from(33),
132138
"c",
133139
"cc",
134140
NaiveDate::from_ymd_opt(2020, 3, 1)

convergence/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,5 @@ futures = "0.3"
1616
sqlparser = "0.46"
1717
async-trait = "0.1"
1818
chrono = "0.4"
19-
20-
[dev-dependencies]
19+
rust_decimal = { version = "1.37.1", features = ["default", "db-postgres"] }
2120
tokio-postgres = "0.7"

convergence/src/protocol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ data_types! {
7575
Float4 = 700, 4
7676
Float8 = 701, 8
7777

78+
Numeric = 1700, -1
79+
7880
Date = 1082, 4
7981
Timestamp = 1114, 8
8082

convergence/src/protocol_ext.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
use crate::protocol::{ConnectionCodec, FormatCode, ProtocolError, RowDescription};
44
use bytes::{BufMut, BytesMut};
55
use chrono::{NaiveDate, NaiveDateTime};
6+
use rust_decimal::Decimal;
7+
use tokio_postgres::types::{ToSql, Type};
68
use tokio_util::codec::Encoder;
79

810
/// Supports batched rows for e.g. returning portal result sets.
@@ -131,6 +133,24 @@ impl<'a> DataRowWriter<'a> {
131133
}
132134
}
133135

136+
/// Writes a numeric value for the next column.
137+
pub fn write_numeric_16(&mut self, val: i128, _p: &u8, s: &i8) {
138+
let decimal = Decimal::from_i128_with_scale(val, *s as u32);
139+
match self.parent.format_code {
140+
FormatCode::Text => {
141+
self.write_string(&decimal.to_string())
142+
}
143+
FormatCode::Binary => {
144+
let numeric_type = Type::from_oid(1700).expect("failed to create numeric type");
145+
let mut buf = BytesMut::new();
146+
decimal.to_sql(&numeric_type, &mut buf)
147+
.expect("failed to write numeric");
148+
149+
self.write_value(&buf.freeze())
150+
}
151+
};
152+
}
153+
134154
primitive_write!(write_int2, i16);
135155
primitive_write!(write_int4, i32);
136156
primitive_write!(write_int8, i64);

0 commit comments

Comments
 (0)