Skip to content

Commit bac2ab7

Browse files
committed
Fixed #306 and added unit-tests.
- Added NumberInfo class that can get `MaxValue` or `MinValue` based on `NPTypeCode`
1 parent c81942e commit bac2ab7

3 files changed

Lines changed: 174 additions & 27 deletions

File tree

src/NumSharp.Core/Shape.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,15 @@ public Shape Slice(Slice[] slices, bool reduce = false)
220220

221221
public static bool operator ==(Shape a, Shape b)
222222
{
223-
if (b is null) return false;
223+
if (a is null && b is null)
224+
return true;
225+
226+
if (a is null || b is null)
227+
return false;
228+
229+
if (ReferenceEquals(a, b))
230+
return true;
231+
224232
return Equals(a, b);
225233
}
226234

@@ -302,5 +310,16 @@ public Shape Clone()
302310
{
303311
return new Shape(this.dimensions) {layout = this.layout};
304312
}
313+
314+
public Shape SubShape(int dim)
315+
{
316+
var arr = new int[dimensions.Length-dim];
317+
for (int i = 0; i < arr.Length; i++)
318+
{
319+
arr[i] = dimensions[dim + i];
320+
}
321+
322+
return arr;
323+
}
305324
}
306325
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
using System;
2+
using System.Numerics;
3+
using NumSharp.Backends;
4+
5+
namespace NumSharp.Utilities
6+
{
7+
public static class NumberInfo
8+
{
9+
/// <summary>
10+
/// Get the min value of given <see cref="NPTypeCode"/>.
11+
/// </summary>
12+
public static object MaxValue(this NPTypeCode typeCode)
13+
{
14+
switch (typeCode)
15+
{
16+
case NPTypeCode.Complex:
17+
return new Complex(double.MaxValue, double.MaxValue);
18+
case NPTypeCode.Boolean:
19+
return (byte)1;
20+
#if _REGEN
21+
%foreach except(supported_primitives, "Boolean", "String")%
22+
case NPTypeCode.#1:
23+
return #1.MaxValue;
24+
%
25+
#else
26+
case NPTypeCode.Byte:
27+
return Byte.MaxValue;
28+
case NPTypeCode.Int16:
29+
return Int16.MaxValue;
30+
case NPTypeCode.UInt16:
31+
return UInt16.MaxValue;
32+
case NPTypeCode.Int32:
33+
return Int32.MaxValue;
34+
case NPTypeCode.UInt32:
35+
return UInt32.MaxValue;
36+
case NPTypeCode.Int64:
37+
return Int64.MaxValue;
38+
case NPTypeCode.UInt64:
39+
return UInt64.MaxValue;
40+
case NPTypeCode.Char:
41+
return Char.MaxValue;
42+
case NPTypeCode.Double:
43+
return Double.MaxValue;
44+
case NPTypeCode.Single:
45+
return Single.MaxValue;
46+
case NPTypeCode.Decimal:
47+
return Decimal.MaxValue;
48+
#endif
49+
default:
50+
throw new ArgumentOutOfRangeException(nameof(typeCode), typeCode, null);
51+
}
52+
}
53+
54+
/// <summary>
55+
/// Get the min value of given <see cref="NPTypeCode"/>.
56+
/// </summary>
57+
public static object MinValue(this NPTypeCode typeCode)
58+
{
59+
switch (typeCode)
60+
{
61+
case NPTypeCode.Complex:
62+
return new Complex(double.MinValue, double.MinValue);
63+
case NPTypeCode.Boolean:
64+
return (byte)0;
65+
#if _REGEN
66+
%foreach except(supported_primitives, "Boolean", "String")%
67+
case NPTypeCode.#1:
68+
return #1.MinValue;
69+
%
70+
#else
71+
case NPTypeCode.Byte:
72+
return Byte.MinValue;
73+
case NPTypeCode.Int16:
74+
return Int16.MinValue;
75+
case NPTypeCode.UInt16:
76+
return UInt16.MinValue;
77+
case NPTypeCode.Int32:
78+
return Int32.MinValue;
79+
case NPTypeCode.UInt32:
80+
return UInt32.MinValue;
81+
case NPTypeCode.Int64:
82+
return Int64.MinValue;
83+
case NPTypeCode.UInt64:
84+
return UInt64.MinValue;
85+
case NPTypeCode.Char:
86+
return Char.MinValue;
87+
case NPTypeCode.Double:
88+
return Double.MinValue;
89+
case NPTypeCode.Single:
90+
return Single.MinValue;
91+
case NPTypeCode.Decimal:
92+
return Decimal.MinValue;
93+
#endif
94+
default:
95+
throw new ArgumentOutOfRangeException(nameof(typeCode), typeCode, null);
96+
}
97+
}
98+
}
99+
}

test/NumSharp.UnitTest/Shape.Test.cs

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,78 +4,107 @@
44
using System.Text;
55
using NumSharp.Extensions;
66
using System.Linq;
7+
using FluentAssertions;
78
using NumSharp;
89

910
namespace NumSharp.UnitTest
1011
{
1112
[TestClass]
1213
public class NDStorageTest
1314
{
14-
//[TestMethod]
15+
//TODO! [TestMethod]
1516
public void Index()
1617
{
17-
var shape0 = new Shape(4,3);
18+
var shape0 = new Shape(4, 3);
1819

1920
int idx0 = shape0.GetIndexInShape(2, 1);
20-
21-
Assert.IsTrue(idx0 == 6);
21+
22+
Assert.IsTrue(idx0 == 4*2+1);
2223
}
23-
//[TestMethod]
24+
25+
//TODO! [TestMethod]
2426
public void CheckIndexing()
2527
{
26-
var shape0 = new Shape(4,3,2);
28+
var shape0 = new Shape(4, 3, 2);
2729

2830
int[] strgDimSize = shape0.Strides;
2931

3032
int index = shape0.GetIndexInShape(1, 2, 1);
3133

32-
Assert.IsTrue(Enumerable.SequenceEqual(shape0.GetDimIndexOutShape(index),new int[]{1,2,1}));
34+
Assert.IsTrue(Enumerable.SequenceEqual(shape0.GetDimIndexOutShape(index), new int[] {1, 2, 1}));
3335

3436
var rnd = new Randomizer();
35-
var randomIndex = new int[]{rnd.Next(0,3),rnd.Next(0,2),rnd.Next(0,1)};
37+
var randomIndex = new int[] {rnd.Next(0, 3), rnd.Next(0, 2), rnd.Next(0, 1)};
3638

3739
int index1 = shape0.GetIndexInShape(randomIndex);
38-
Assert.IsTrue(Enumerable.SequenceEqual(shape0.GetDimIndexOutShape(index1),randomIndex));
40+
Assert.IsTrue(Enumerable.SequenceEqual(shape0.GetDimIndexOutShape(index1), randomIndex));
3941

40-
var shape1 = new Shape(2,3,4);
42+
var shape1 = new Shape(2, 3, 4);
4143

42-
index = shape1.GetIndexInShape(1,2,1);
43-
Assert.IsTrue(Enumerable.SequenceEqual(shape1.GetDimIndexOutShape(index),new int[]{1,2,1}));
44+
index = shape1.GetIndexInShape(1, 2, 1);
45+
Assert.IsTrue(Enumerable.SequenceEqual(shape1.GetDimIndexOutShape(index), new int[] {1, 2, 1}));
4446

45-
randomIndex = new int[]{rnd.Next(0,1),rnd.Next(0,2),rnd.Next(0,3)};
47+
randomIndex = new int[] {rnd.Next(0, 1), rnd.Next(0, 2), rnd.Next(0, 3)};
4648
index = shape1.GetIndexInShape(randomIndex);
47-
Assert.IsTrue(Enumerable.SequenceEqual(shape1.GetDimIndexOutShape(index),randomIndex));
49+
Assert.IsTrue(Enumerable.SequenceEqual(shape1.GetDimIndexOutShape(index), randomIndex));
4850

49-
randomIndex = new int[]{rnd.Next(1,10),rnd.Next(1,10),rnd.Next(1,10)};
51+
randomIndex = new int[] {rnd.Next(1, 10), rnd.Next(1, 10), rnd.Next(1, 10)};
5052

5153
var shape2 = new Shape(randomIndex);
5254

53-
randomIndex = new int[]{rnd.Next(0,shape2.Dimensions[0]),rnd.Next(0,shape2.Dimensions[1]),rnd.Next(0,shape2.Dimensions[2])};
55+
randomIndex = new int[] {rnd.Next(0, shape2.Dimensions[0]), rnd.Next(0, shape2.Dimensions[1]), rnd.Next(0, shape2.Dimensions[2])};
5456

5557
index = shape2.GetIndexInShape(randomIndex);
56-
Assert.IsTrue(Enumerable.SequenceEqual(shape2.GetDimIndexOutShape(index),randomIndex));
58+
Assert.IsTrue(Enumerable.SequenceEqual(shape2.GetDimIndexOutShape(index), randomIndex));
5759
}
58-
//[TestMethod]
60+
61+
//TODO! [TestMethod]
5962
public void CheckColRowSwitch()
6063
{
6164
var shape1 = new Shape(5);
62-
Assert.IsTrue(Enumerable.SequenceEqual(shape1.Strides,new int[]{1}));
65+
Assert.IsTrue(Enumerable.SequenceEqual(shape1.Strides, new int[] {1}));
6366

6467
shape1.ChangeTensorLayout();
65-
Assert.IsTrue(Enumerable.SequenceEqual(shape1.Strides,new int[]{1}));
68+
Assert.IsTrue(Enumerable.SequenceEqual(shape1.Strides, new int[] {1}));
6669

67-
var shape2 = new Shape(4,3);
68-
Assert.IsTrue(Enumerable.SequenceEqual(shape2.Strides,new int[]{1,4}));
70+
var shape2 = new Shape(4, 3);
71+
Assert.IsTrue(Enumerable.SequenceEqual(shape2.Strides, new int[] {1, 4}));
6972

7073
shape2.ChangeTensorLayout();
71-
Assert.IsTrue(Enumerable.SequenceEqual(shape2.Strides,new int[]{3,1}));
74+
Assert.IsTrue(Enumerable.SequenceEqual(shape2.Strides, new int[] {3, 1}));
7275

73-
var shape3 = new Shape(2,3,4);
74-
Assert.IsTrue(Enumerable.SequenceEqual(shape3.Strides,new int[]{1,2,6}));
76+
var shape3 = new Shape(2, 3, 4);
77+
Assert.IsTrue(Enumerable.SequenceEqual(shape3.Strides, new int[] {1, 2, 6}));
7578

7679
shape3.ChangeTensorLayout();
77-
Assert.IsTrue(Enumerable.SequenceEqual(shape3.Strides,new int[]{12,4,1}));
80+
Assert.IsTrue(Enumerable.SequenceEqual(shape3.Strides, new int[] {12, 4, 1}));
81+
}
82+
83+
/// <summary>
84+
/// Based on issue https://github.com/SciSharp/NumSharp/issues/306
85+
/// </summary>
86+
[TestMethod]
87+
public void EqualityComparer()
88+
{
89+
Shape a = null;
90+
Shape b = null;
91+
92+
(a == b).Should().BeTrue();
93+
(a == null).Should().BeTrue();
94+
(null == b).Should().BeTrue();
95+
96+
a = 5;
97+
b = 4;
98+
(a != b).Should().BeTrue();
99+
100+
b = 5;
101+
(a == b).Should().BeTrue();
78102

103+
a = new Shape(1, 2, 3, 4, 5);
104+
b = new Shape(1, 2, 3, 4, 5);
105+
(a == b).Should().BeTrue();
106+
b = new Shape(1, 2, 3, 4);
107+
(a != b).Should().BeTrue();
79108
}
80109
}
81110
}

0 commit comments

Comments
 (0)