Skip to content

Commit 3cdb93e

Browse files
committed
Remove excess register stuff
Signed-off-by: Mustafa Eyceoz <[email protected]>
1 parent b4b7b4b commit 3cdb93e

File tree

1 file changed

+0
-27
lines changed

1 file changed

+0
-27
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,15 @@
11
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
22

33
from .models import (
4-
GPTCrossLayerConfig,
5-
GPTCrossLayerForCausalLM,
6-
GPTCrossLayerModel,
74
GPTDolomiteConfig,
85
GPTDolomiteForCausalLM,
9-
GPTDolomiteForCausalLM_TP,
106
GPTDolomiteModel,
11-
MoEDolomiteConfig,
12-
MoEDolomiteForCausalLM,
13-
MoEDolomiteForCausalLM_TP,
14-
MoEDolomiteModel,
15-
RNNDolomiteConfig,
16-
RNNDolomiteForCausalLM,
17-
RNNDolomiteModel,
187
)
198

209

2110
# (AutoConfig, AutoModel, AutoModelForCausalLM)
2211
_CUSTOM_MODEL_REGISTRY = [
2312
(GPTDolomiteConfig, GPTDolomiteModel, GPTDolomiteForCausalLM),
24-
(MoEDolomiteConfig, MoEDolomiteModel, MoEDolomiteForCausalLM),
25-
(GPTCrossLayerConfig, GPTCrossLayerModel, GPTCrossLayerForCausalLM),
26-
(RNNDolomiteConfig, RNNDolomiteModel, RNNDolomiteForCausalLM),
2713
]
2814
_CUSTOM_MODEL_TYPES = []
2915
_CUSTOM_MODEL_CLASSES = []
@@ -43,16 +29,3 @@ def register_model_classes() -> None:
4329

4430
def is_custom_model(model_class: type[AutoModelForCausalLM] | type[AutoModelForSeq2SeqLM], model_type: str) -> bool:
4531
return model_class.__name__ in _CUSTOM_MODEL_CLASSES or model_type in _CUSTOM_MODEL_TYPES
46-
47-
48-
_TENSOR_PARALLEL_CLASS_MAPPING = {
49-
GPTDolomiteConfig.model_type: GPTDolomiteForCausalLM_TP,
50-
MoEDolomiteConfig.model_type: MoEDolomiteForCausalLM_TP,
51-
}
52-
53-
54-
def get_tensor_parallel_class(model_type: str) -> AutoModelForCausalLM:
55-
if model_type in _TENSOR_PARALLEL_CLASS_MAPPING:
56-
return _TENSOR_PARALLEL_CLASS_MAPPING[model_type]
57-
58-
raise ValueError(f"tensor parallel is not supported with `model_type` ({model_type})")

0 commit comments

Comments
 (0)