Skip to content

Commit 2b369fb

Browse files
committed
Additonal test for "as" in Linq queries
1 parent a458baf commit 2b369fb

1 file changed

Lines changed: 323 additions & 0 deletions

File tree

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
// Copyright (C) 2019 Xtensive LLC.
2+
// This code is distributed under MIT license terms.
3+
// See the License.txt file in the project root for more information.
4+
// Created by: Denis Kudelin
5+
// Created: 2018.01.16
6+
7+
using System;
8+
using System.Collections;
9+
using System.Collections.Generic;
10+
using System.Linq;
11+
using System.Linq.Expressions;
12+
using NUnit.Framework;
13+
using Xtensive.Core;
14+
using Xtensive.Orm.Configuration;
15+
using Xtensive.Orm.Tests.Linq.TypeAsTranslationTestModels;
16+
17+
namespace Xtensive.Orm.Tests.Linq
18+
{
19+
public class TypeAsTranslationTest : AutoBuildTest
20+
{
21+
#region Nested types
22+
23+
public class CustomExpressionReplacer : Xtensive.Linq.ExpressionVisitor
24+
{
25+
private readonly Func<Expression, Func<Expression, Expression>, Expression> visit;
26+
27+
public static Expression Visit(Expression expression, Func<Expression, Func<Expression, Expression>, Expression> visit) =>
28+
new CustomExpressionReplacer(visit).Visit(expression);
29+
30+
protected override Expression Visit(Expression exp) => visit(exp, base.Visit);
31+
32+
private CustomExpressionReplacer(Func<Expression, Func<Expression, Expression>, Expression> visit)
33+
{
34+
this.visit = visit;
35+
}
36+
}
37+
38+
private sealed class ComparisonComparer<T> : Comparer<T>
39+
{
40+
private readonly Comparison<T> comparison;
41+
42+
public static new Comparer<T> Create(Comparison<T> comparison) =>
43+
comparison == null ? throw new ArgumentNullException("comparison") : new ComparisonComparer<T>(comparison);
44+
45+
public override int Compare(T x, T y) => comparison(x, y);
46+
47+
private ComparisonComparer(Comparison<T> comparison)
48+
{
49+
this.comparison = comparison;
50+
}
51+
}
52+
#endregion
53+
54+
[Test]
55+
public void Test1()
56+
{
57+
using(var session = Domain.OpenSession())
58+
using (var tx = session.OpenTransaction()) {
59+
QueryExpressionTest(
60+
() => session.Query.All<TestEntity1>().SelectMany(
61+
x => x.EntitySet.SelectMany(
62+
y => y.EntitySet.Select(
63+
z => (x.Value1 as TestEntity3).EntitySet.Any()))));
64+
}
65+
}
66+
67+
[Test]
68+
public void Test2()
69+
{
70+
using (var session = Domain.OpenSession())
71+
using (var tx = session.OpenTransaction()) {
72+
QueryExpressionTest(
73+
() => session.Query.All<TestEntity1>().SelectMany(
74+
x => x.EntitySet.SelectMany(
75+
y => y.EntitySet.Select(
76+
z => (y.Value1 as TestEntity3).EntitySet.Any()))));
77+
}
78+
}
79+
80+
[Test]
81+
public void Test3()
82+
{
83+
using (var session = Domain.OpenSession())
84+
using (var tx = session.OpenTransaction()) {
85+
QueryExpressionTest(
86+
() =>
87+
session.Query.All<TestEntity1>().SelectMany(
88+
x => x.EntitySet.SelectMany(y => y.EntitySet.Select(z => (z.Value1 as TestEntity3).EntitySet.Any()))));
89+
}
90+
}
91+
92+
[Test]
93+
public void Test4()
94+
{
95+
using (var session = Domain.OpenSession())
96+
using (var tx = session.OpenTransaction()) {
97+
QueryExpressionTest(
98+
() =>
99+
session.Query.All<TestEntity2>()
100+
.Where(x => ((x.Value1 as TestEntity3).Value1 as TestEntity3).EntitySet.Any()).Select(x => x.Id2));
101+
}
102+
}
103+
104+
[Test]
105+
public void Test5()
106+
{
107+
using (var session = Domain.OpenSession())
108+
using (var tx = session.OpenTransaction()) {
109+
QueryExpressionTest(
110+
() => session.Query.All<TestEntity2>().Where(
111+
x => x.EntitySet.Any(y => ((x.Value1 as TestEntity3).Value1 as TestEntity3).EntitySet.Any())));
112+
}
113+
}
114+
115+
private void QueryExpressionTest<TResult>(Expression<Func<TResult>> queryExpression)
116+
{
117+
var result1 = RewriteQueryExpressionAndInvoke(queryExpression, false);
118+
var result2 = RewriteQueryExpressionAndInvoke(queryExpression, true);
119+
120+
if (!(result1 is IEnumerable) || !(result2 is IEnumerable)) {
121+
Assert.That(result1, Is.EqualTo(result2));
122+
return;
123+
}
124+
125+
var result1Array = ((IEnumerable) result1).Cast<object>().ToArray();
126+
var result2Array = ((IEnumerable) result2).Cast<object>().ToArray();
127+
128+
Assert.That(result1Array.SequenceEqual(result2Array));
129+
}
130+
131+
private TResult RewriteQueryExpressionAndInvoke<TResult>(Expression<Func<TResult>> expression, bool asEnumerable)
132+
{
133+
var orderByMethod = (asEnumerable ? typeof(Enumerable) : typeof(Queryable)).GetMethods()
134+
.Single(x => x.Name == "OrderBy" && x.GetParameters().Length == (asEnumerable ? 3 : 2));
135+
var toArrayMethod = typeof(Enumerable).GetMethod("ToArray");
136+
var asQueryableMethod = typeof(Queryable).GetMethods().Single(x => x.Name == "AsQueryable" && x.IsGenericMethod);
137+
var keyPropertyInfo = typeof(IEntity).GetProperty("Key");
138+
139+
expression = ((Expression<Func<TResult>>) CustomExpressionReplacer.Visit(
140+
expression,
141+
(e, visit) => {
142+
var result = (e = visit(e));
143+
144+
if (result != null && typeof(IQueryable<IEntity>).IsAssignableFrom(result.Type)) {
145+
var isOrderedQueryable = typeof(IOrderedQueryable).IsAssignableFrom(result.Type);
146+
var entityType = result.Type.GetGenericArguments().Single();
147+
148+
if (asEnumerable)
149+
result = Expression.Call(toArrayMethod.MakeGenericMethod(entityType), result);
150+
151+
if (!isOrderedQueryable) {
152+
var keyParameter = Expression.Parameter(entityType);
153+
var keyProperty = Expression.Property(keyParameter, keyPropertyInfo);
154+
var orderByMethodGeneric = orderByMethod.MakeGenericMethod(entityType, keyPropertyInfo.PropertyType);
155+
var parameters = orderByMethodGeneric.GetParameters();
156+
var keySelectorType = asEnumerable
157+
? parameters[1].ParameterType
158+
: parameters[1].ParameterType.GetGenericArguments().Single();
159+
var keySelector = (Expression) Expression.Lambda(keySelectorType, keyProperty, keyParameter);
160+
161+
if (asEnumerable) {
162+
keySelector = Expression.Constant(((LambdaExpression) keySelector).Compile());
163+
var comparer = Expression.Constant(
164+
ComparisonComparer<Key>.Create(
165+
(k1, k2) => Comparer.Default.Compare(k1.Value.GetValue(0), k2.Value.GetValue(0))));
166+
result = Expression.Call(orderByMethodGeneric, result, keySelector, comparer);
167+
result = Expression.Call(toArrayMethod.MakeGenericMethod(entityType), result);
168+
}
169+
else {
170+
result = Expression.Call(orderByMethodGeneric, result, keySelector);
171+
}
172+
}
173+
174+
if (asEnumerable) {
175+
result = Expression.Call(asQueryableMethod.MakeGenericMethod(entityType), result);
176+
}
177+
}
178+
179+
if (asEnumerable && (e is MemberExpression || e is MethodCallExpression)) {
180+
Expression obj;
181+
182+
var methodCall = e as MethodCallExpression;
183+
if (methodCall != null) {
184+
obj = methodCall.Object ?? methodCall.Arguments.FirstOrDefault();
185+
}
186+
else {
187+
obj = ((MemberExpression) e).Expression;
188+
}
189+
190+
if (obj != null && (typeof(IQueryable<IEntity>).IsAssignableFrom(obj.Type)
191+
|| typeof(IEntity).IsAssignableFrom(obj.Type)
192+
|| typeof(Structure).IsAssignableFrom(obj.Type))) {
193+
result = Expression.Condition(
194+
Expression.Equal(obj, Expression.Constant(null, obj.Type)),
195+
Expression.Default(result.Type),
196+
result);
197+
return result;
198+
}
199+
}
200+
201+
return result;
202+
}));
203+
204+
return expression.Compile()();
205+
}
206+
207+
protected override void PopulateData()
208+
{
209+
using (var session = Domain.OpenSession())
210+
using (var tx = session.OpenTransaction()) {
211+
var entity1a = new TestEntity1(1);
212+
213+
var entity2a = new TestEntity2(2);
214+
var entity2b = new TestEntity2(3);
215+
var entity2c = new TestEntity2(4);
216+
var entity2d = new TestEntity2(5);
217+
218+
var entity3a = new TestEntity3(6);
219+
var entity3b = new TestEntity3(7);
220+
var entity3c = new TestEntity3(8);
221+
222+
_ = entity1a.EntitySet.Add(entity2a);
223+
_ = entity1a.EntitySet.Add(entity2b);
224+
_ = entity1a.EntitySet.Add(entity2c);
225+
_ = entity1a.EntitySet.Add(entity2d);
226+
entity1a.Value1 = entity3a;
227+
228+
_ = entity2a.EntitySet.Add(entity3a);
229+
entity2a.Value1 = entity3a;
230+
_ = entity2b.EntitySet.Add(entity3b);
231+
entity2b.Value1 = entity3b;
232+
_ = entity2c.EntitySet.Add(entity3c);
233+
entity2d.Value1 = entity3c;
234+
235+
_ = entity3a.EntitySet.Add(entity1a);
236+
entity3a.Value1 = entity3c;
237+
_ = entity3b.EntitySet.Add(entity2a);
238+
entity3b.Value1 = entity3a;
239+
entity3c.Value1 = entity3b;
240+
241+
tx.Complete();
242+
}
243+
}
244+
245+
protected override DomainConfiguration BuildConfiguration()
246+
{
247+
var config = base.BuildConfiguration();
248+
config.Types.Register(typeof(ITestEntity).Assembly, typeof(ITestEntity).Namespace);
249+
return config;
250+
}
251+
}
252+
}
253+
254+
namespace Xtensive.Orm.Tests.Linq.TypeAsTranslationTestModels
255+
{
256+
public interface ITestEntity : IEntity
257+
{
258+
259+
}
260+
261+
[HierarchyRoot]
262+
public class TestEntity1 : Entity, ITestEntity
263+
{
264+
[Field]
265+
public EntitySet<TestEntity2> EntitySet { get; set; }
266+
267+
[Field]
268+
public ITestEntity Value1 { get; set; }
269+
270+
[Field, Key]
271+
public int Id { get; set; }
272+
273+
[Field]
274+
public int Id2 { get; set; }
275+
276+
public TestEntity1(int id2)
277+
{
278+
Id2 = id2;
279+
}
280+
}
281+
282+
[HierarchyRoot]
283+
public class TestEntity2 : Entity, ITestEntity
284+
{
285+
[Field]
286+
public EntitySet<TestEntity3> EntitySet { get; set; }
287+
288+
[Field, Key]
289+
public int Id { get; set; }
290+
291+
[Field]
292+
public ITestEntity Value1 { get; set; }
293+
294+
[Field]
295+
public int Id2 { get; set; }
296+
297+
public TestEntity2(int id2)
298+
{
299+
Id2 = id2;
300+
}
301+
}
302+
303+
[HierarchyRoot]
304+
public class TestEntity3 : Entity, ITestEntity
305+
{
306+
[Field, Key]
307+
public int Id { get; set; }
308+
309+
[Field]
310+
public EntitySet<ITestEntity> EntitySet { get; set; }
311+
312+
[Field]
313+
public ITestEntity Value1 { get; set; }
314+
315+
[Field]
316+
public int Id2 { get; set; }
317+
318+
public TestEntity3(int id2)
319+
{
320+
Id2 = id2;
321+
}
322+
}
323+
}

0 commit comments

Comments
 (0)