Skip to content

Commit 5533c8c

Browse files
committed
Rewrote np.randint(...) [#292] and added unit tests.
1 parent bac2ab7 commit 5533c8c

4 files changed

Lines changed: 191 additions & 32 deletions

File tree

src/NumSharp.Core/Backends/TypedArrayStorage.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,7 @@ protected static Array _ChangeTypeOfArray(Array sourceArray, Type to_dtype)
908908

909909
protected void _Allocate(Shape shape, Type dtype, Array values)
910910
{
911-
_Shape = shape;
911+
_Shape = shape ?? throw new ArgumentNullException(nameof(shape));
912912

913913
if (dtype != null)
914914
{
@@ -928,6 +928,9 @@ protected void _Allocate(Shape shape, Type dtype, Array values)
928928
/// <param name="shape">The shape of the array.</param>
929929
public void Allocate(Shape shape, Type dtype = null)
930930
{
931+
if (shape == null)
932+
throw new ArgumentNullException(nameof(shape));
933+
931934
dtype = dtype ?? DType;
932935
_Allocate(shape, dtype, Arrays.Create(dtype, new int[] {shape.Size}));
933936
}
@@ -975,9 +978,10 @@ public void Allocate(Array values)
975978
public void Allocate(Array values, Shape shape)
976979
{
977980
if (values == null)
978-
{
979981
throw new ArgumentNullException(nameof(values));
980-
}
982+
983+
if (shape == null)
984+
throw new ArgumentNullException(nameof(shape));
981985

982986
Type elementType = values.GetType();
983987
// ReSharper disable once PossibleNullReferenceException
Lines changed: 134 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,155 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Linq;
4+
using System.Runtime.CompilerServices;
45
using System.Security.Cryptography;
56
using System.Text;
7+
using NumSharp.Backends;
8+
using NumSharp.Utilities;
69

710
namespace NumSharp
811
{
912
public partial class NumPyRandom
1013
{
11-
public NDArray randint(int low, int size = 1)
14+
/// <summary>
15+
/// Return random integers from the “discrete uniform” distribution of the specified dtype in the “half-open” interval [low, high). If high is None (the default), then results are from [0, low).
16+
/// </summary>
17+
/// <param name="low">Lowest (signed) integer to be drawn from the distribution (unless high=-1, in which case this parameter is one above the highest such integer).</param>
18+
/// <param name="high">If provided, one above the largest (signed) integer to be drawn from the distribution (see above for behavior if high=-1).</param>
19+
/// <param name="size">The shape of the array.</param>
20+
/// <param name="dtype">Desired dtype of the result. All dtypes are determined by their name, i.e., ‘int64’, ‘int’, etc, so byteorder is not available and a specific precision may have different C types depending on the platform. The default value is ‘np.int’.</param>
21+
/// <returns></returns>
22+
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.randint.html</remarks>
23+
public NDArray randint(long low, long high = -1, Shape size = null, Type dtype = null)
1224
{
13-
var data = new int[size];
14-
for (int i = 0; i < data.Length; i++)
25+
dtype = dtype ?? np.int32;
26+
var typecode = dtype.GetTypeCode();
27+
if (high == -1)
1528
{
16-
data[i] = randomizer.Next(low, int.MaxValue);
29+
high = low;
30+
low = 0;
1731
}
1832

19-
var np = new NDArray(typeof(int), size);
20-
np.ReplaceData(data);
33+
if (size == null || (size.NDim == 1 && size.Size == 1))
34+
return NDArray.Scalar(randomizer.NextLong(low, high), typecode);
2135

22-
return np;
23-
}
24-
25-
public NDArray randint(int low, int? high = null, Shape shape = null)
26-
{
27-
if(high == null)
28-
{
29-
high = int.MaxValue;
30-
}
31-
if(shape == null)
32-
{
33-
shape = new Shape(high.Value - low);
34-
}
35-
var data = new int[shape.Size];
36-
for(int i = 0; i < data.Length; i++)
36+
var nd = new NDArray(dtype, size); //allocation called inside.
37+
switch (typecode)
3738
{
38-
data[i] = randomizer.Next(low, high.Value);
39-
}
39+
#if _REGEN
40+
%foreach supported_numericals,supported_numericals_lowercase%
41+
case NPTypeCode.#1:
42+
{
43+
var data = (#2[])nd.Array;
44+
for (int i = 0; i < data.Length; i++)
45+
data[i] = Convert.To#1(randomizer.NextLong(low, high));
46+
47+
break;
48+
}
49+
%
50+
#else
51+
case NPTypeCode.Byte:
52+
{
53+
var data = (byte[])nd.Array;
54+
for (int i = 0; i < data.Length; i++)
55+
data[i] = Convert.ToByte(randomizer.NextLong(low, high));
56+
57+
break;
58+
}
59+
60+
case NPTypeCode.Int16:
61+
{
62+
var data = (short[])nd.Array;
63+
for (int i = 0; i < data.Length; i++)
64+
data[i] = Convert.ToInt16(randomizer.NextLong(low, high));
65+
66+
break;
67+
}
68+
69+
case NPTypeCode.UInt16:
70+
{
71+
var data = (ushort[])nd.Array;
72+
for (int i = 0; i < data.Length; i++)
73+
data[i] = Convert.ToUInt16(randomizer.NextLong(low, high));
74+
75+
break;
76+
}
77+
78+
case NPTypeCode.Int32:
79+
{
80+
var data = (int[])nd.Array;
81+
for (int i = 0; i < data.Length; i++)
82+
data[i] = Convert.ToInt32(randomizer.NextLong(low, high));
83+
84+
break;
85+
}
4086

41-
var np = new NDArray(typeof(int), shape.Dimensions.ToArray());
42-
np.ReplaceData(data);
87+
case NPTypeCode.UInt32:
88+
{
89+
var data = (uint[])nd.Array;
90+
for (int i = 0; i < data.Length; i++)
91+
data[i] = Convert.ToUInt32(randomizer.NextLong(low, high));
92+
93+
break;
94+
}
95+
96+
case NPTypeCode.Int64:
97+
{
98+
var data = (long[])nd.Array;
99+
for (int i = 0; i < data.Length; i++)
100+
data[i] = Convert.ToInt64(randomizer.NextLong(low, high));
101+
102+
break;
103+
}
104+
105+
case NPTypeCode.UInt64:
106+
{
107+
var data = (ulong[])nd.Array;
108+
for (int i = 0; i < data.Length; i++)
109+
data[i] = Convert.ToUInt64(randomizer.NextLong(low, high));
110+
111+
break;
112+
}
113+
114+
case NPTypeCode.Char:
115+
{
116+
var data = (char[])nd.Array;
117+
for (int i = 0; i < data.Length; i++)
118+
data[i] = Convert.ToChar(randomizer.NextLong(low, high));
119+
120+
break;
121+
}
122+
123+
case NPTypeCode.Double:
124+
{
125+
var data = (double[])nd.Array;
126+
for (int i = 0; i < data.Length; i++)
127+
data[i] = Convert.ToDouble(randomizer.NextLong(low, high));
128+
129+
break;
130+
}
131+
132+
case NPTypeCode.Single:
133+
{
134+
var data = (float[])nd.Array;
135+
for (int i = 0; i < data.Length; i++)
136+
data[i] = Convert.ToSingle(randomizer.NextLong(low, high));
137+
138+
break;
139+
}
140+
141+
case NPTypeCode.Decimal:
142+
{
143+
var data = (decimal[])nd.Array;
144+
for (int i = 0; i < data.Length; i++)
145+
data[i] = Convert.ToDecimal(randomizer.NextLong(low, high));
146+
147+
break;
148+
}
149+
#endif
150+
}
43151

44-
return np;
152+
return nd;
45153
}
46154
}
47155
}

test/NumSharp.UnitTest/Random/Randomizer.Tests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public void SaveAndRestore()
2222
[TestMethod]
2323
public void CompareRandomizerToRandom()
2424
{
25-
var rnd = new Random(42);
25+
var rnd = new System.Random(42);
2626
var rndizer = new Randomizer(42);
2727

2828
rnd.Next().Should().Be(rndizer.Next());

test/NumSharp.UnitTest/Random/np.random.randint.Test.cs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,64 @@
55
using System.Collections.Generic;
66
using System.Linq;
77
using System.Text;
8+
using FluentAssertions;
89

910
namespace NumSharp.UnitTest.RandomSampling
1011
{
1112
[TestClass]
12-
public class NdArrayRandomRandIntTest
13+
public class NdArrayRandomRandIntTests
1314
{
1415
[TestMethod]
1516
public void randint()
1617
{
17-
var a = np.random.RandomState().randint(low: 0, high: 10, shape: new Shape(5, 5));
18+
var a = np.random.RandomState().randint(low: 0, high: 10, size: new Shape(5, 5));
1819
Assert.IsTrue(a.Data<int>().Count(x => x < 10) == 25);
1920
}
21+
22+
/// <summary>
23+
/// Based on issue https://github.com/SciSharp/NumSharp/issues/292
24+
/// </summary>
25+
[TestMethod]
26+
public void randint_2()
27+
{
28+
for (int i = 0; i < 50; i++)
29+
{
30+
var result_1 = np.random.randint(2, size: 10); // 10 numbers between [2, int.MaxValue)
31+
result_1.Array.As<int[]>().All(v => v >= 0 && v < 2).Should().BeTrue();
32+
result_1.Array.As<int[]>().Should().HaveCount(10);
33+
34+
var result_2 = np.random.randint(5, size: new Shape(2, 4)); // 8 numbers between [5, int.MaxValue)
35+
result_2.Array.As<int[]>().All(v => v >= 0 && v < 5).Should().BeTrue();
36+
result_2.Array.As<int[]>().Should().HaveCount(2 * 4);
37+
38+
var result_3 = np.random.randint(5, size: new Shape(2, 4)); // 2x4 matrix with elements between [5, int.MaxValue)
39+
result_3.Array.As<int[]>().All(v => v >= 0 && v < 5).Should().BeTrue();
40+
result_3.Array.As<int[]>().Should().HaveCount(2 * 4);
41+
42+
var result_4 = np.random.randint(low: 0, high: 5); // throws System.NullReferenceException
43+
result_4.Array.As<int[]>().All(v => v >= 0 && v < 5).Should().BeTrue();
44+
result_4.Array.As<int[]>().Should().HaveCount(1);
45+
46+
var result_5 = np.random.randint(0, 5, null); // throws System.NullReferenceException (equivalent to result_4)
47+
result_5.Array.As<int[]>().All(v => v >= 0 && v < 5).Should().BeTrue();
48+
result_5.Array.As<int[]>().Should().HaveCount(1);
49+
50+
var result_6 = np.random.randint(5); // Does not even compile
51+
result_6.Array.As<int[]>().All(v => v >= 0 && v < 5).Should().BeTrue();
52+
result_6.Array.As<int[]>().Should().HaveCount(1);
53+
54+
var result_7 = np.random.randint(low: 0, high: 10, size: new Shape(2, 2)); // 2x2 matrix with elements between [0, 10)
55+
result_7.Array.As<int[]>().All(v => v >= 0 && v < 10).Should().BeTrue();
56+
result_7.Array.As<int[]>().Should().HaveCount(2 * 2);
57+
58+
var result_8 = np.random.randint(1, 5, 8); // 8 numbers between [1, 5)
59+
result_8.Array.As<int[]>().All(v => v >= 1 && v < 5).Should().BeTrue();
60+
result_8.Array.As<int[]>().Should().HaveCount(8);
61+
62+
var result_9 = np.random.randint(1, 5, new Shape(3, 2)); // 3x2 matrix with elements between [1, 5)
63+
result_9.Array.As<int[]>().All(v => v >= 1 && v < 5).Should().BeTrue();
64+
result_9.Array.As<int[]>().Should().HaveCount(3 * 2);
65+
}
66+
}
2067
}
2168
}

0 commit comments

Comments
 (0)