Skip to content

Commit 0ab14fa

Browse files
authored
Merge pull request #82 from AArnott/fix80
Avoid using == operator on structs that do not define them
2 parents a9b40bc + 1be96b4 commit 0ab14fa

6 files changed

Lines changed: 172 additions & 16 deletions

File tree

src/ImmutableObjectGraph.Generation.Tests/ImmutableObjectGraph.Generation.Tests.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@
6666
<Generator>MSBuild:GenerateCodeFromAttributes</Generator>
6767
</Compile>
6868
<Compile Include="TestSources\ImmutableDictionaryHelpers.Tests.cs" />
69+
<Compile Include="TestSources\ImmutableWithComplexStructField.cs">
70+
<Generator>MSBuild:GenerateCodeFromAttributes</Generator>
71+
</Compile>
72+
<Compile Include="TestSources\ImmutableWithComplexStructField.Tests.cs" />
6973
<Compile Include="TestSources\MSBuild.cs">
7074
<Generator>MSBuild:GenerateCodeFromAttributes</Generator>
7175
</Compile>
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
namespace ImmutableObjectGraph.Generation.Tests.TestSources
2+
{
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using System.Text;
7+
using System.Threading.Tasks;
8+
using Xunit;
9+
10+
public class ImmutableWithComplexStructFieldTests
11+
{
12+
[Fact]
13+
public void StructWithoutOperatorsAlwaysRecreatesObjectWithChangedValue()
14+
{
15+
var originalStruct = new SomeStructWithMultipleFields(5, 0);
16+
var structWithModifiedSecondField = new SomeStructWithMultipleFields(5, 1);
17+
var v1 = ImmutableWithComplexStructField.Create(someStructField: originalStruct);
18+
var v2 = v1.With(someStructField: structWithModifiedSecondField);
19+
Assert.Equal(structWithModifiedSecondField.Field2, v2.SomeStructField.Field2);
20+
}
21+
22+
[Fact]
23+
public void StructWithoutOperatorsAlwaysRecreatesObjectWithSameValue()
24+
{
25+
var s1 = new SomeStructWithMultipleFields(1, 2);
26+
var v1 = ImmutableWithComplexStructField.Create(someStructField: s1);
27+
var v2 = v1.With(someStructField: s1);
28+
29+
// The object should have been recreated since equality between the two struct values
30+
// cannot be determined without their operator defined.
31+
Assert.NotSame(v1, v2);
32+
Assert.Equal(s1.Field1, v2.SomeStructField.Field1);
33+
}
34+
35+
[Fact(Skip = "Not yet passing (#81).")]
36+
public void StructWithoutOperatorsAlwaysRecreatesObjectWithoutValue()
37+
{
38+
var s1 = new SomeStructWithMultipleFields(1, 2);
39+
var v1 = ImmutableWithComplexStructField.Create(someStructField: s1);
40+
var v2 = v1.With(); // omit the struct value altogether
41+
Assert.Same(v1, v2);
42+
}
43+
44+
[Fact]
45+
public void StructWithOperatorsRecreatesObjectWithChangedValue()
46+
{
47+
var s12 = new SomeStructWithMultipleFieldsAndOperator(1, 2);
48+
var s13 = new SomeStructWithMultipleFieldsAndOperator(1, 3);
49+
var v1 = ImmutableWithComplexStructField.Create(someStructFieldWithOperator: s12);
50+
var v2 = v1.With(someStructFieldWithOperator: s13);
51+
Assert.NotSame(v1, v2);
52+
Assert.Equal(s13.Field2, v2.SomeStructFieldWithOperator.Field2);
53+
}
54+
55+
[Fact(Skip = "Not yet passing (#81).")]
56+
public void StructWithOperatorsRecyclesObjectWithSameValue()
57+
{
58+
var s12 = new SomeStructWithMultipleFieldsAndOperator(1, 2);
59+
var v1 = ImmutableWithComplexStructField.Create(someStructFieldWithOperator: s12);
60+
var v2 = v1.With(someStructFieldWithOperator: s12);
61+
Assert.Same(v1, v2);
62+
}
63+
}
64+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
namespace ImmutableObjectGraph.Generation.Tests.TestSources
2+
{
3+
using System;
4+
5+
[GenerateImmutable]
6+
partial class ImmutableWithComplexStructField
7+
{
8+
readonly SomeStructWithMultipleFields someStructField;
9+
10+
readonly SomeStructWithMultipleFieldsAndOperator someStructFieldWithOperator;
11+
}
12+
13+
struct SomeStructWithMultipleFields
14+
{
15+
internal SomeStructWithMultipleFields(int field1, int field2)
16+
{
17+
this.Field1 = field1;
18+
this.Field2 = field2;
19+
}
20+
21+
internal int Field1 { get; }
22+
23+
internal int Field2 { get; }
24+
}
25+
26+
struct SomeStructWithMultipleFieldsAndOperator
27+
{
28+
internal SomeStructWithMultipleFieldsAndOperator(int field1, int field2)
29+
{
30+
this.Field1 = field1;
31+
this.Field2 = field2;
32+
}
33+
34+
internal int Field1 { get; }
35+
36+
internal int Field2 { get; }
37+
38+
public static bool operator ==(SomeStructWithMultipleFieldsAndOperator one, SomeStructWithMultipleFieldsAndOperator two)
39+
{
40+
return one.Field1 == two.Field1 && one.Field2 == two.Field2;
41+
}
42+
43+
public static bool operator !=(SomeStructWithMultipleFieldsAndOperator one, SomeStructWithMultipleFieldsAndOperator two)
44+
{
45+
return !(one == two);
46+
}
47+
48+
public override bool Equals(object obj)
49+
{
50+
throw new NotImplementedException();
51+
}
52+
53+
public override int GetHashCode()
54+
{
55+
throw new NotImplementedException();
56+
}
57+
}
58+
}

src/ImmutableObjectGraph.Generation/CodeGen.cs

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,35 @@ private static IdentifierNameSyntax GetGenerationalMethodName(IdentifierNameSynt
210210
return SyntaxFactory.IdentifierName(baseName.Identifier.ValueText + generation.ToString(CultureInfo.InvariantCulture));
211211
}
212212

213+
/// <summary>
214+
/// Checks whether a type defines equality operators for itself.
215+
/// </summary>
216+
/// <param name="symbol">The type to check.</param>
217+
/// <returns><c>true</c> if the == and != operators are defined on the type.</returns>
218+
private static bool HasEqualityOperators(ITypeSymbol symbol)
219+
{
220+
Requires.NotNull(symbol, nameof(symbol));
221+
222+
if (symbol.IsReferenceType)
223+
{
224+
// Reference types inherit their equality operators from System.Object.
225+
return true;
226+
}
227+
228+
if (symbol.SpecialType != SpecialType.None)
229+
{
230+
// C# knows how to run equality checks for special (built-in) types like int.
231+
return true;
232+
}
233+
234+
var equalityOperators = from method in symbol.GetMembers().OfType<IMethodSymbol>()
235+
where method.MethodKind == MethodKind.BuiltinOperator || method.MethodKind == MethodKind.UserDefinedOperator
236+
where method.Parameters.Length == 2 && method.Parameters.All(p => p.Type == symbol)
237+
where method.Name == "op_Equality"
238+
select method;
239+
return equalityOperators.Any();
240+
}
241+
213242
private void ReportDiagnostic(string id, SyntaxNode blamedSyntax, params string[] formattingArgs)
214243
{
215244
Requires.NotNull(blamedSyntax, nameof(blamedSyntax));
@@ -533,18 +562,20 @@ private IEnumerable<MethodDeclarationSyntax> CreateWithCoreMethods()
533562
private MemberDeclarationSyntax CreateWithFactoryMethod()
534563
{
535564
// (field.IsDefined && field.Value != this.field)
536-
Func<IdentifierNameSyntax, IdentifierNameSyntax, ExpressionSyntax> isChangedByNames = (propertyName, fieldName) =>
537-
SyntaxFactory.ParenthesizedExpression(
538-
SyntaxFactory.BinaryExpression(
539-
SyntaxKind.LogicalAndExpression,
540-
Syntax.OptionalIsDefined(fieldName),
565+
Func<IdentifierNameSyntax, IdentifierNameSyntax, ITypeSymbol, ExpressionSyntax> isChangedByNames = (propertyName, fieldName, fieldType) =>
566+
fieldType == null || HasEqualityOperators(fieldType) ?
567+
(ExpressionSyntax)SyntaxFactory.ParenthesizedExpression(
541568
SyntaxFactory.BinaryExpression(
542-
SyntaxKind.NotEqualsExpression,
543-
Syntax.OptionalValue(fieldName),
544-
Syntax.ThisDot(propertyName))));
545-
Func<MetaField, ExpressionSyntax> isChanged = v => isChangedByNames(v.NameAsProperty, v.NameAsField);
569+
SyntaxKind.LogicalAndExpression,
570+
Syntax.OptionalIsDefined(fieldName),
571+
SyntaxFactory.BinaryExpression(
572+
SyntaxKind.NotEqualsExpression,
573+
Syntax.OptionalValue(fieldName),
574+
Syntax.ThisDot(propertyName)))) :
575+
Syntax.OptionalIsDefined(fieldName);
576+
Func<MetaField, ExpressionSyntax> isChanged = v => isChangedByNames(v.NameAsProperty, v.NameAsField, v.Symbol.Type);
546577
var anyChangesExpression =
547-
new ExpressionSyntax[] { isChangedByNames(IdentityPropertyName, IdentityParameterName) }.Concat(
578+
new ExpressionSyntax[] { isChangedByNames(IdentityPropertyName, IdentityParameterName, null) }.Concat(
548579
this.applyToMetaType.AllFields.Select(isChanged))
549580
.ChainBinaryExpressions(SyntaxKind.LogicalOrExpression);
550581

src/ImmutableObjectGraph/Optional.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
namespace ImmutableObjectGraph {
2-
using System;
3-
using System.Collections.Generic;
4-
using System.Linq;
5-
using System.Text;
6-
using System.Threading.Tasks;
1+
namespace ImmutableObjectGraph
2+
{
3+
using System.Diagnostics;
74

85
public static class Optional {
6+
[DebuggerStepThrough]
97
public static Optional<T> For<T>(T value) {
108
return value;
119
}

src/ImmutableObjectGraph/Optional`1.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public static implicit operator Optional<T>(T value) {
5757
/// </summary>
5858
/// <param name="defaultValue">The default value to use if a value was not specified.</param>
5959
/// <returns>The value.</returns>
60+
[DebuggerStepThrough]
6061
public T GetValueOrDefault(T defaultValue) {
6162
return this.IsDefined ? this.value : defaultValue;
6263
}

0 commit comments

Comments
 (0)