Skip to content

Commit fdc2ae2

Browse files
committed
Handle more than one overload of MemoryExtensions.Contains
1 parent ad1eb4d commit fdc2ae2

3 files changed

Lines changed: 82 additions & 1 deletion

File tree

Orm/Xtensive.Orm/Linq/ExpressionExtensions.cs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
using System;
88
using System.Collections.Concurrent;
9+
using System.Collections.Generic;
910
using System.Linq;
1011
using System.Linq.Expressions;
1112
using System.Reflection;
@@ -27,6 +28,12 @@ public static class ExpressionExtensions
2728

2829
private static readonly Func<Type, MethodInfo> TupleValueAccessorFactory;
2930

31+
private static readonly Type MemoryExtensionsType = typeof(MemoryExtensions);
32+
private static readonly MethodInfo ReadOnlySpanContains2;
33+
private static readonly MethodInfo ReadOnlySpanContains3;
34+
private static readonly MethodInfo SpanContains;
35+
private static readonly MethodInfo EnumerableContains;
36+
3037
///<summary>
3138
/// Makes <see cref="Tuples.Tuple.GetValueOrDefault{T}"/> method call.
3239
///</summary>
@@ -72,6 +79,47 @@ public static Expression LiftToNullable(this Expression expression) =>
7279
/// <returns>Expression tree that wraps <paramref name="expression"/>.</returns>
7380
public static ExpressionTree ToExpressionTree(this Expression expression) => new ExpressionTree(expression);
7481

82+
/// <summary>
83+
/// Transforms <see cref="MemoryExtensions.Contains{T}(ReadOnlySpan{T}, T)"/> applied call into <see cref="Enumerable.Contains{TSource}(IEnumerable{TSource}, TSource)"/>
84+
/// if detected.
85+
/// </summary>
86+
/// <param name="mc">Possible candidate for transformation.</param>
87+
/// <returns>New instance of expression, if transformation was required, otherwise, the same expression.</returns>
88+
public static MethodCallExpression TryTransformToOldFashionContains(this MethodCallExpression mc)
89+
{
90+
if (mc.Method.DeclaringType == MemoryExtensionsType) {
91+
var genericMethod = mc.Method.GetGenericMethodDefinition();
92+
if (genericMethod == ReadOnlySpanContains2 || genericMethod == ReadOnlySpanContains3 || genericMethod == SpanContains) {
93+
var arguments = mc.Arguments;
94+
95+
Type elementType;
96+
Expression[] newArguments;
97+
98+
if (arguments[0] is MethodCallExpression mcInner && mcInner.Method.Name.Equals(WellKnown.Operator.Implicit, StringComparison.Ordinal)) {
99+
var wrappedArray = mcInner.Arguments[0];
100+
elementType = wrappedArray.Type.GetElementType();
101+
newArguments = new[] { wrappedArray, arguments[1] };
102+
}
103+
else if (arguments[0] is UnaryExpression uInner
104+
&& uInner.Method is not null
105+
&& uInner.Method.Name.Equals(WellKnown.Operator.Implicit, StringComparison.Ordinal)) {
106+
107+
elementType = uInner.Operand.Type.GetElementType();
108+
newArguments = new[] { uInner.Operand, arguments[1] };
109+
}
110+
else {
111+
return mc;
112+
}
113+
114+
var genericContains = EnumerableContains.MakeGenericMethod(elementType);
115+
var replacement = Expression.Call(genericContains, newArguments);
116+
return replacement;
117+
}
118+
return mc;
119+
}
120+
return mc;
121+
}
122+
75123

76124
// Type initializer
77125

@@ -80,6 +128,32 @@ static ExpressionExtensions()
80128
var tupleGenericAccessor = WellKnownOrmTypes.Tuple.GetMethods()
81129
.Single(mi => mi.Name == nameof(Tuple.GetValueOrDefault) && mi.IsGenericMethod);
82130
TupleValueAccessorFactory = type => tupleGenericAccessor.CachedMakeGenericMethod(type);
131+
132+
var genericReadOnlySpan = typeof(ReadOnlySpan<>);
133+
var genericSpan = typeof(Span<>);
134+
135+
var filteredByNameItems = MemoryExtensionsType.GetMethods(BindingFlags.Public | BindingFlags.Static)
136+
.Where(m => m.Name.Equals(nameof(System.MemoryExtensions.Contains), StringComparison.OrdinalIgnoreCase));
137+
138+
var spanCandidates = new List<(MethodInfo, int)>();
139+
var readonlyspanCandidates = new List<(MethodInfo, int)>();
140+
141+
foreach (var method in filteredByNameItems) {
142+
var parameters = method.GetParameters();
143+
var firstParameter = parameters[0];
144+
var genericDef = firstParameter.ParameterType.GetGenericTypeDefinition();
145+
if (genericDef == genericReadOnlySpan) {
146+
readonlyspanCandidates.Add((method, parameters.Length));
147+
}
148+
else if (genericDef == genericSpan) {
149+
spanCandidates.Add((method, parameters.Length));
150+
}
151+
}
152+
153+
ReadOnlySpanContains2 = readonlyspanCandidates.Where(c => c.Item2 == 2).Select(c => c.Item1).First();
154+
ReadOnlySpanContains3 = readonlyspanCandidates.Where(c => c.Item2 == 3).Select(c => c.Item1).FirstOrDefault();
155+
SpanContains = spanCandidates.Where(c => c.Item2 == 2).Select(c => c.Item1).First();
156+
EnumerableContains = typeof(System.Linq.Enumerable).GetMethodEx(nameof(System.Linq.Enumerable.Contains), BindingFlags.Public | BindingFlags.Static, new string[1], new object[2]);
83157
}
84158
}
85159
}

Orm/Xtensive.Orm/Orm/Linq/Translator.Expressions.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,10 @@ protected override Expression VisitMethodCall(MethodCallExpression mc)
534534
if (methodDeclaringType == typeof(System.MemoryExtensions)) {
535535
var parameters = method.GetParameters();
536536

537-
if (methodName.Equals(nameof(System.MemoryExtensions.Contains), StringComparison.Ordinal) && parameters.Length == 2){
537+
if (methodName.Equals(nameof(System.MemoryExtensions.Contains), StringComparison.Ordinal)){
538+
// There might be 2 or 3 arguments.
539+
// In case of three, last one is IEqualityComparer<T> which will probably have default value
540+
// Comparer doesn't matter in context of our queries, so we ignore it
538541
return VisitContains(mc.Arguments[0].StripImplicitCast(), mc.Arguments[1], false);
539542
}
540543
}

Orm/Xtensive.Orm/Orm/Providers/Expressions/ExpressionProcessor.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ protected override SqlExpression VisitMethodCall(MethodCallExpression mc)
412412
if (mc.AsTupleAccess(activeParameters) != null)
413413
return VisitTupleAccess(mc);
414414

415+
if (mc.Method.Name.Equals(nameof(Enumerable.Contains), StringComparison.Ordinal)) {
416+
// there might be "innovative" implicit cast to ReadOnlySpan inside, which is not supported
417+
mc = mc.TryTransformToOldFashionContains();
418+
}
415419
var arguments = mc.Arguments.SelectToArray(a => Visit(a));
416420
var mi = mc.Method;
417421

0 commit comments

Comments
 (0)