1
1
from transformers import AutoConfig , AutoModel , AutoModelForCausalLM , AutoModelForSeq2SeqLM
2
2
3
3
from .models import (
4
- GPTCrossLayerConfig ,
5
- GPTCrossLayerForCausalLM ,
6
- GPTCrossLayerModel ,
7
4
GPTDolomiteConfig ,
8
5
GPTDolomiteForCausalLM ,
9
- GPTDolomiteForCausalLM_TP ,
10
6
GPTDolomiteModel ,
11
- MoEDolomiteConfig ,
12
- MoEDolomiteForCausalLM ,
13
- MoEDolomiteForCausalLM_TP ,
14
- MoEDolomiteModel ,
15
- RNNDolomiteConfig ,
16
- RNNDolomiteForCausalLM ,
17
- RNNDolomiteModel ,
18
7
)
19
8
20
9
21
10
# (AutoConfig, AutoModel, AutoModelForCausalLM)
22
11
_CUSTOM_MODEL_REGISTRY = [
23
12
(GPTDolomiteConfig , GPTDolomiteModel , GPTDolomiteForCausalLM ),
24
- (MoEDolomiteConfig , MoEDolomiteModel , MoEDolomiteForCausalLM ),
25
- (GPTCrossLayerConfig , GPTCrossLayerModel , GPTCrossLayerForCausalLM ),
26
- (RNNDolomiteConfig , RNNDolomiteModel , RNNDolomiteForCausalLM ),
27
13
]
28
14
_CUSTOM_MODEL_TYPES = []
29
15
_CUSTOM_MODEL_CLASSES = []
@@ -43,16 +29,3 @@ def register_model_classes() -> None:
43
29
44
30
def is_custom_model (model_class : type [AutoModelForCausalLM ] | type [AutoModelForSeq2SeqLM ], model_type : str ) -> bool :
45
31
return model_class .__name__ in _CUSTOM_MODEL_CLASSES or model_type in _CUSTOM_MODEL_TYPES
46
-
47
-
48
- _TENSOR_PARALLEL_CLASS_MAPPING = {
49
- GPTDolomiteConfig .model_type : GPTDolomiteForCausalLM_TP ,
50
- MoEDolomiteConfig .model_type : MoEDolomiteForCausalLM_TP ,
51
- }
52
-
53
-
54
- def get_tensor_parallel_class (model_type : str ) -> AutoModelForCausalLM :
55
- if model_type in _TENSOR_PARALLEL_CLASS_MAPPING :
56
- return _TENSOR_PARALLEL_CLASS_MAPPING [model_type ]
57
-
58
- raise ValueError (f"tensor parallel is not supported with `model_type` ({ model_type } )" )
0 commit comments