24
24
"""
25
25
26
26
from abc import ABC , abstractmethod
27
- from collections .abc import Mapping
27
+ from collections .abc import Callable , Hashable , Iterable , Mapping , Set
28
28
from typing import (
29
29
TYPE_CHECKING ,
30
- AbstractSet ,
31
- Callable ,
30
+ Concatenate ,
32
31
Generic ,
33
- Hashable ,
34
- Iterable ,
32
+ TypeAlias ,
35
33
TypeVar ,
36
34
cast ,
37
35
)
38
36
from warnings import warn
39
37
40
38
from immutabledict import immutabledict
41
- from typing_extensions import Concatenate , ParamSpec , TypeAlias , TypeIs
39
+ from typing_extensions import ParamSpec , TypeIs
42
40
43
41
import pymbolic .primitives as p
44
42
from pymbolic .typing import ArithmeticExpressionT , ExpressionT
@@ -640,7 +638,7 @@ class CachedCombineMapper(CachedMapper, CombineMapper):
640
638
CollectedT = TypeVar ("CollectedT" )
641
639
642
640
643
- class Collector (CombineMapper [AbstractSet [CollectedT ], P ]):
641
+ class Collector (CombineMapper [Set [CollectedT ], P ]):
644
642
"""A subclass of :class:`CombineMapper` for the common purpose of
645
643
collecting data derived from an expression in a set that gets 'unioned'
646
644
across children at each non-leaf node in the expression tree.
@@ -651,34 +649,34 @@ class Collector(CombineMapper[AbstractSet[CollectedT], P]):
651
649
"""
652
650
653
651
def combine (self ,
654
- values : Iterable [AbstractSet [CollectedT ]]
655
- ) -> AbstractSet [CollectedT ]:
652
+ values : Iterable [Set [CollectedT ]]
653
+ ) -> Set [CollectedT ]:
656
654
import operator
657
655
from functools import reduce
658
656
return reduce (operator .or_ , values , set ())
659
657
660
658
def map_constant (self , expr : object ,
661
- * args : P .args , ** kwargs : P .kwargs ) -> AbstractSet [CollectedT ]:
659
+ * args : P .args , ** kwargs : P .kwargs ) -> Set [CollectedT ]:
662
660
return set ()
663
661
664
662
def map_variable (self , expr : p .Variable ,
665
- * args : P .args , ** kwargs : P .kwargs ) -> AbstractSet [CollectedT ]:
663
+ * args : P .args , ** kwargs : P .kwargs ) -> Set [CollectedT ]:
666
664
return set ()
667
665
668
666
def map_wildcard (self , expr : p .Wildcard ,
669
- * args : P .args , ** kwargs : P .kwargs ) -> AbstractSet [CollectedT ]:
667
+ * args : P .args , ** kwargs : P .kwargs ) -> Set [CollectedT ]:
670
668
return set ()
671
669
672
670
def map_dot_wildcard (self , expr : p .DotWildcard ,
673
- * args : P .args , ** kwargs : P .kwargs ) -> AbstractSet [CollectedT ]:
671
+ * args : P .args , ** kwargs : P .kwargs ) -> Set [CollectedT ]:
674
672
return set ()
675
673
676
674
def map_star_wildcard (self , expr : p .StarWildcard ,
677
- * args : P .args , ** kwargs : P .kwargs ) -> AbstractSet [CollectedT ]:
675
+ * args : P .args , ** kwargs : P .kwargs ) -> Set [CollectedT ]:
678
676
return set ()
679
677
680
678
def map_function_symbol (self , expr : p .FunctionSymbol ,
681
- * args : P .args , ** kwargs : P .kwargs ) -> AbstractSet [CollectedT ]:
679
+ * args : P .args , ** kwargs : P .kwargs ) -> Set [CollectedT ]:
682
680
return set ()
683
681
684
682
0 commit comments