Skip to content

Commit 78c9f7e

Browse files
authored
Merge pull request matplotlib#31151 from lzqlzzq/add-mlx-support
Add mlx support
2 parents d3b539c + 9f5d729 commit 78c9f7e

2 files changed

Lines changed: 47 additions & 1 deletion

File tree

lib/matplotlib/cbook.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2411,6 +2411,18 @@ def _is_jax_array(x):
24112411
and isinstance(x, tp))
24122412

24132413

2414+
def _is_mlx_array(x):
2415+
"""Return whether *x* is a MLX Array."""
2416+
try:
2417+
# We're intentionally not attempting to import mlx. If somebody
2418+
# has created a mlx array, mlx should already be in sys.modules.
2419+
tp = sys.modules.get("mlx.core").array
2420+
except AttributeError:
2421+
return False # Module not imported or a nonstandard module with no Array attr.
2422+
return (isinstance(tp, type) # Just in case it's a very nonstandard module.
2423+
and isinstance(x, tp))
2424+
2425+
24142426
def _is_pandas_dataframe(x):
24152427
"""Check if *x* is a Pandas DataFrame."""
24162428
try:
@@ -2454,7 +2466,10 @@ def _unpack_to_numpy(x):
24542466
# so in this case we do not want to return a function
24552467
if isinstance(xtmp, np.ndarray):
24562468
return xtmp
2457-
if _is_torch_array(x) or _is_jax_array(x) or _is_tensorflow_array(x):
2469+
if _is_torch_array(x) \
2470+
or _is_jax_array(x) \
2471+
or _is_tensorflow_array(x) \
2472+
or _is_mlx_array(x):
24582473
# using np.asarray() instead of explicitly __array__(), as the latter is
24592474
# only _one_ of many methods, and it's the last resort, see also
24602475
# https://numpy.org/devdocs/user/basics.interoperability.html#using-arbitrary-objects-in-numpy

lib/matplotlib/tests/test_cbook.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,3 +1077,34 @@ def __array__(self):
10771077
# if not mocked, and the implementation does not guarantee it
10781078
# is the same Python object, just the same values.
10791079
assert_array_equal(result, data)
1080+
1081+
1082+
def test_unpack_to_numpy_from_mlx():
1083+
"""
1084+
Test that mlx arrays are converted to NumPy arrays.
1085+
1086+
We don't want to create a dependency on mlx in the test suite, so we mock it.
1087+
"""
1088+
class Array:
1089+
def __init__(self, data):
1090+
self.data = data
1091+
1092+
def __array__(self):
1093+
return self.data
1094+
1095+
# mlx is something peculiar
1096+
# class `array` is in `mlx.core`
1097+
mlx_core = ModuleType('mlx.core')
1098+
mlx_core.array = Array
1099+
1100+
sys.modules['mlx.core'] = mlx_core
1101+
1102+
data = np.arange(10)
1103+
mlx_array = mlx_core.array(data)
1104+
1105+
result = cbook._unpack_to_numpy(mlx_array)
1106+
assert isinstance(result, np.ndarray)
1107+
# compare results, do not check for identity: the latter would fail
1108+
# if not mocked, and the implementation does not guarantee it
1109+
# is the same Python object, just the same values.
1110+
assert_array_equal(result, data)

0 commit comments

Comments
 (0)