Skip to content

Commit 2117a33

Browse files
committed
Consolidate special case regexp_match logic
1 parent 114eec6 commit 2117a33

1 file changed

Lines changed: 100 additions & 82 deletions

File tree

datafusion/functions/src/regex/regexpreplace.rs

Lines changed: 100 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ use datafusion_expr::{
4242
Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
4343
};
4444
use datafusion_macros::user_doc;
45-
use regex::Regex;
45+
use regex::{CaptureLocations, Regex};
46+
use std::borrow::Cow;
4647
use std::collections::HashMap;
4748
use std::sync::{Arc, LazyLock};
4849

@@ -201,6 +202,80 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
201202
.into_owned()
202203
}
203204

205+
struct ShortRegex {
206+
/// Shortened anchored regex used to extract capture group 1 directly.
207+
/// See [`try_build_short_extract_regex`] for details.
208+
short_re: Regex,
209+
/// Reusable capture locations for `short_re` to avoid per-row allocation.
210+
locs: CaptureLocations,
211+
}
212+
213+
/// Holds the normal compiled regex together with the optional fast path used
214+
/// for `regexp_replace(str, '^...(capture)...*$', '\1')`.
215+
struct OptimizedRegex {
216+
/// Full regex used for the normal replacement path and as a correctness fallback.
217+
re: Regex,
218+
/// Precomputed state for the direct-extraction fast path, when applicable.
219+
short_re: Option<ShortRegex>,
220+
}
221+
222+
impl OptimizedRegex {
223+
/// Builds any reusable state needed by the extraction fast path.
224+
///
225+
/// The fast path is only enabled for single replacements where the pattern
226+
/// and replacement satisfy [`try_build_short_extract_regex`].
227+
fn new(re: Regex, limit: usize, pattern: &str, replacement: &str) -> Self {
228+
let short_re = if limit == 1 {
229+
try_build_short_extract_regex(pattern, replacement)
230+
} else {
231+
None
232+
};
233+
234+
let short_re = short_re.map(|short_re| {
235+
let locs = short_re.capture_locations();
236+
ShortRegex { short_re, locs }
237+
});
238+
239+
Self { re, short_re }
240+
}
241+
242+
/// Applies the direct-extraction fast path when it preserves the result of
243+
/// `Regex::replacen`; otherwise falls back to the full regex replacement.
244+
fn replacen<'a>(
245+
&mut self,
246+
val: &'a str,
247+
limit: usize,
248+
replacement: &str,
249+
) -> Cow<'a, str> {
250+
// If this pattern is not eligible for direct extraction, use the full regex.
251+
let Some(ShortRegex { short_re, locs }) = self.short_re.as_mut() else {
252+
return self.re.replacen(val, limit, replacement);
253+
};
254+
255+
// If the shortened regex does not match, the original anchored regex would
256+
// also leave the input unchanged.
257+
if short_re.captures_read(locs, val).is_none() {
258+
return Cow::Borrowed(val);
259+
};
260+
261+
// `captures_read` succeeded, so the overall shortened match is present.
262+
let match_end = locs.get(0).unwrap().1;
263+
if memchr(b'\n', &val.as_bytes()[match_end..]).is_some() {
264+
// If there is a newline after the match, we can't use the short
265+
// regex since it won't match across lines. Fall back to the full
266+
// regex replacement.
267+
return self.re.replacen(val, limit, replacement);
268+
};
269+
// The fast path only applies to `${1}` replacements, so the result is
270+
// either capture group 1 or the empty string if that group did not match.
271+
if let Some((start, end)) = locs.get(1) {
272+
Cow::Borrowed(&val[start..end])
273+
} else {
274+
Cow::Borrowed("")
275+
}
276+
}
277+
}
278+
204279
/// For anchored patterns like `^...(capture)....*$` where the replacement
205280
/// is `\1`, build a shorter regex (stripping trailing `.*$`) and use
206281
/// `captures_read` with `CaptureLocations` for direct extraction — no
@@ -440,7 +515,7 @@ macro_rules! fetch_string_arg {
440515
/// hold a single Regex object for the replace operation. This also speeds
441516
/// up the pre-processing time of the replacement string, since it only
442517
/// needs to processed once.
443-
fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
518+
fn regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
444519
args: &[ArrayRef],
445520
) -> Result<ArrayRef> {
446521
let array_size = args[0].len();
@@ -475,13 +550,7 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
475550
// with rust ones.
476551
let replacement = regex_replace_posix_groups(replacement);
477552

478-
// For anchored patterns like ^...(capture)....*$, build a shorter
479-
// regex and use captures_read for direct extraction.
480-
let short_re = if limit == 1 {
481-
try_build_short_extract_regex(&pattern, &replacement)
482-
} else {
483-
None
484-
};
553+
let mut opt_re = OptimizedRegex::new(re, limit, &pattern, &replacement);
485554

486555
let string_array_type = args[0].data_type();
487556
match string_array_type {
@@ -499,37 +568,13 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
499568
let mut new_offsets = BufferBuilder::<T>::new(string_array.len() + 1);
500569
new_offsets.append(T::zero());
501570

502-
if let Some(ref short_re) = short_re {
503-
let mut locs = short_re.capture_locations();
504-
string_array.iter().for_each(|val| {
505-
if let Some(val) = val {
506-
if short_re.captures_read(&mut locs, val).is_some() {
507-
let match_end = locs.get(0).unwrap().1;
508-
if memchr(b'\n', &val.as_bytes()[match_end..]).is_none() {
509-
if let Some((start, end)) = locs.get(1) {
510-
vals.append_slice(&val.as_bytes()[start..end]);
511-
}
512-
} else {
513-
// Newline in remainder: .*$ wouldn't match without 's' flag
514-
let result =
515-
re.replacen(val, limit, replacement.as_str());
516-
vals.append_slice(result.as_bytes());
517-
}
518-
} else {
519-
vals.append_slice(val.as_bytes());
520-
}
521-
}
522-
new_offsets.append(T::from_usize(vals.len()).unwrap());
523-
});
524-
} else {
525-
string_array.iter().for_each(|val| {
526-
if let Some(val) = val {
527-
let result = re.replacen(val, limit, replacement.as_str());
528-
vals.append_slice(result.as_bytes());
529-
}
530-
new_offsets.append(T::from_usize(vals.len()).unwrap());
531-
});
532-
}
571+
string_array.iter().for_each(|val| {
572+
if let Some(val) = val {
573+
let result = opt_re.replacen(val, limit, replacement.as_str());
574+
vals.append_slice(result.as_bytes());
575+
}
576+
new_offsets.append(T::from_usize(vals.len()).unwrap());
577+
});
533578

534579
let data = ArrayDataBuilder::new(GenericStringArray::<T>::DATA_TYPE)
535580
.len(string_array.len())
@@ -544,39 +589,12 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
544589

545590
let mut builder = StringViewBuilder::with_capacity(string_view_array.len());
546591

547-
if let Some(ref short_re) = short_re {
548-
let mut locs = short_re.capture_locations();
549-
for val in string_view_array.iter() {
550-
if let Some(val) = val {
551-
if short_re.captures_read(&mut locs, val).is_some() {
552-
let match_end = locs.get(0).unwrap().1;
553-
if memchr(b'\n', &val.as_bytes()[match_end..]).is_none() {
554-
if let Some((start, end)) = locs.get(1) {
555-
builder.append_value(&val[start..end]);
556-
} else {
557-
builder.append_value("");
558-
}
559-
} else {
560-
// Newline in remainder: .*$ wouldn't match without 's' flag
561-
let result =
562-
re.replacen(val, limit, replacement.as_str());
563-
builder.append_value(result);
564-
}
565-
} else {
566-
builder.append_value(val);
567-
}
568-
} else {
569-
builder.append_null();
570-
}
571-
}
572-
} else {
573-
for val in string_view_array.iter() {
574-
if let Some(val) = val {
575-
let result = re.replacen(val, limit, replacement.as_str());
576-
builder.append_value(result);
577-
} else {
578-
builder.append_null();
579-
}
592+
for val in string_view_array.iter() {
593+
if let Some(val) = val {
594+
let result = opt_re.replacen(val, limit, replacement.as_str());
595+
builder.append_value(result.as_ref());
596+
} else {
597+
builder.append_null();
580598
}
581599
}
582600

@@ -653,7 +671,7 @@ fn specialize_regexp_replace<T: OffsetSizeTrait>(
653671
arg.to_array(expansion_len)
654672
})
655673
.collect::<Result<Vec<_>>>()?;
656-
_regexp_replace_static_pattern_replace::<T>(&args)
674+
regexp_replace_static_pattern_replace::<T>(&args)
657675
}
658676

659677
// If there are no specialized implementations, we'll fall back to the
@@ -787,7 +805,7 @@ mod tests {
787805
let replacements = <$T>::from(replacement);
788806
let expected = <$T>::from(expected);
789807

790-
let re = _regexp_replace_static_pattern_replace::<$O>(&[
808+
let re = regexp_replace_static_pattern_replace::<$O>(&[
791809
Arc::new(values),
792810
Arc::new(patterns),
793811
Arc::new(replacements),
@@ -832,7 +850,7 @@ mod tests {
832850
let flags = StringArray::from(vec!["i"; 5]);
833851
let expected = <$T>::from(expected);
834852

835-
let re = _regexp_replace_static_pattern_replace::<$O>(&[
853+
let re = regexp_replace_static_pattern_replace::<$O>(&[
836854
Arc::new(values),
837855
Arc::new(patterns),
838856
Arc::new(replacements),
@@ -864,7 +882,7 @@ mod tests {
864882
let replacements = StringArray::from(vec!["foo"; 5]);
865883
let expected = StringArray::from(vec![None::<&str>; 5]);
866884

867-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
885+
let re = regexp_replace_static_pattern_replace::<i32>(&[
868886
Arc::new(values),
869887
Arc::new(patterns),
870888
Arc::new(replacements),
@@ -881,7 +899,7 @@ mod tests {
881899
let replacements = StringArray::from(Vec::<Option<&str>>::new());
882900
let expected = StringArray::from(Vec::<Option<&str>>::new());
883901

884-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
902+
let re = regexp_replace_static_pattern_replace::<i32>(&[
885903
Arc::new(values),
886904
Arc::new(patterns),
887905
Arc::new(replacements),
@@ -899,7 +917,7 @@ mod tests {
899917
let flags = StringArray::from(vec![None::<&str>; 5]);
900918
let expected = StringArray::from(vec![None::<&str>; 5]);
901919

902-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
920+
let re = regexp_replace_static_pattern_replace::<i32>(&[
903921
Arc::new(values),
904922
Arc::new(patterns),
905923
Arc::new(replacements),
@@ -918,7 +936,7 @@ mod tests {
918936
let patterns = StringArray::from(vec!["["; 5]);
919937
let replacements = StringArray::from(vec!["foo"; 5]);
920938

921-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
939+
let re = regexp_replace_static_pattern_replace::<i32>(&[
922940
Arc::new(values),
923941
Arc::new(patterns),
924942
Arc::new(replacements),
@@ -955,7 +973,7 @@ mod tests {
955973
Some("c"),
956974
]);
957975

958-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
976+
let re = regexp_replace_static_pattern_replace::<i32>(&[
959977
Arc::new(values),
960978
Arc::new(patterns),
961979
Arc::new(replacements),
@@ -983,7 +1001,7 @@ mod tests {
9831001
let replacements = StringArray::from(vec!["foo"; 1]);
9841002
let expected = StringArray::from(vec![Some("b"), None, Some("foo"), None, None]);
9851003

986-
let re = _regexp_replace_static_pattern_replace::<i32>(&[
1004+
let re = regexp_replace_static_pattern_replace::<i32>(&[
9871005
Arc::new(values),
9881006
Arc::new(patterns),
9891007
Arc::new(replacements),

0 commit comments

Comments
 (0)