Skip to content

Commit deece7b

Browse files
committed
Fix aggregation validation for builtins and callables
1 parent b4e5aae commit deece7b

2 files changed

Lines changed: 52 additions & 8 deletions

File tree

neat/aggregations.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
and code for adding new user-defined ones.
44
"""
55

6+
import inspect
67
import types
78
import warnings
89
from functools import reduce
@@ -49,14 +50,27 @@ class InvalidAggregationFunction(TypeError):
4950

5051

5152
def validate_aggregation(function): # TODO: Recognize when need `reduce`
52-
if not isinstance(function,
53-
(types.BuiltinFunctionType,
54-
types.FunctionType,
55-
types.LambdaType)):
56-
raise InvalidAggregationFunction("A function object is required.")
57-
58-
if not (function.__code__.co_argcount >= 1):
59-
raise InvalidAggregationFunction("A function taking at least one argument is required")
53+
if not callable(function):
54+
raise InvalidAggregationFunction("A callable object is required.")
55+
56+
try:
57+
signature = inspect.signature(function)
58+
except (TypeError, ValueError) as exc:
59+
if isinstance(function, types.BuiltinFunctionType):
60+
return
61+
raise InvalidAggregationFunction("Unable to inspect aggregation callable signature.") from exc
62+
63+
accepts_positional = any(
64+
parameter.kind in (
65+
inspect.Parameter.POSITIONAL_ONLY,
66+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
67+
inspect.Parameter.VAR_POSITIONAL,
68+
)
69+
for parameter in signature.parameters.values()
70+
)
71+
72+
if not accepts_positional:
73+
raise InvalidAggregationFunction("A function taking at least one positional argument is required")
6074

6175

6276
class AggregationFunctionSet:

tests/test_aggregation.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,25 @@ def test_add_minabs():
7272
assert config.genome_config.aggregation_function_defs.is_valid('minabs')
7373

7474

75+
def test_add_builtin_max():
76+
local_dir = os.path.dirname(__file__)
77+
config_path = os.path.join(local_dir, 'test_configuration')
78+
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
79+
neat.DefaultSpeciesSet, neat.DefaultStagnation,
80+
config_path)
81+
82+
config.genome_config.add_aggregation('builtin_max', max)
83+
assert config.genome_config.aggregation_function_defs.get('builtin_max') is max
84+
85+
7586
def dud_function():
7687
return 0.0
7788

7889

90+
def keyword_only_function(*, items):
91+
return sum(items)
92+
93+
7994
def test_function_set():
8095
s = aggregations.AggregationFunctionSet()
8196
assert s.get('sum') is not None
@@ -135,6 +150,21 @@ def test_bad_add2():
135150
raise Exception("Should have had a TypeError/derived for dud_function")
136151

137152

153+
def test_bad_add3():
154+
local_dir = os.path.dirname(__file__)
155+
config_path = os.path.join(local_dir, 'test_configuration')
156+
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
157+
neat.DefaultSpeciesSet, neat.DefaultStagnation,
158+
config_path)
159+
160+
try:
161+
config.genome_config.add_aggregation('keyword_only_function', keyword_only_function)
162+
except TypeError:
163+
pass
164+
else:
165+
raise Exception("Should have had a TypeError/derived for keyword_only_function")
166+
167+
138168
if __name__ == '__main__':
139169
test_sum()
140170
test_product()

0 commit comments

Comments
 (0)