1717Expression = sympy .Basic | np .number | int | float
1818
1919
20- class Set (set [Expression ]):
21- @staticmethod
22- def wrap (obj : Expression ) -> set [Expression ]:
23- return {obj }
24-
25-
2620class List (list [Expression ]):
27- @ staticmethod
28- def wrap ( obj : Expression ) -> list [ Expression ]:
29- return [ obj ]
21+ """
22+ A list that aliases `extend` to `update` to mirror the `set` interface.
23+ """
3024
3125 def update (self , obj : Iterable [Expression ]) -> None :
3226 self .extend (obj )
3327
3428
3529Mode = Literal ['all' , 'unique' ]
36- modes : dict [Mode , type [List ] | type [Set ]] = {
30+ modes : dict [Mode , type [List ] | type [set [ Expression ] ]] = {
3731 'all' : List ,
38- 'unique' : Set
32+ 'unique' : set
3933}
4034
4135
@@ -97,7 +91,7 @@ def search(exprs: Expression | Iterable[Expression],
9791 query : type | Callable [[Any ], bool ],
9892 mode : Mode = 'unique' ,
9993 visit : Literal ['dfs' , 'bfs' , 'bfs_first_hit' ] = 'dfs' ,
100- deep : bool = False ) -> List | Set :
94+ deep : bool = False ) -> List | set [ Expression ] :
10195 """Interface to Search."""
10296
10397 assert mode in ('all' , 'unique' ), "Unknown mode"
@@ -118,7 +112,7 @@ def search(exprs: Expression | Iterable[Expression],
118112 _search = searcher .visit_preorder_first_hit
119113 else :
120114 raise ValueError (f"Unknown visit mode '{ visit } '" )
121-
115+
122116 exprs = filter (lambda e : isinstance (e , sympy .Basic ), as_tuple (exprs ))
123117 found = modes [mode ](chain (* map (_search , exprs )))
124118
0 commit comments