11import contextlib
2- from functools import wraps
2+ from functools import partial , wraps
33from typing import Any , Callable , ClassVar , Optional , Set , Union
44
55import torch
6+ from compressed_tensors .modeling import (
7+ register_key_hook ,
8+ register_query_hook ,
9+ register_value_hook ,
10+ )
611from loguru import logger
712from pydantic import BaseModel
813from torch .utils .hooks import RemovableHandle
@@ -92,7 +97,7 @@ def wrapped_hook(*args, **kwargs):
9297
9398 return hook (* args , ** kwargs )
9499
95- register_function = getattr (target , f"register_ { hook_type } _hook" )
100+ register_function = self . _get_register_function (target , hook_type )
96101 handle = register_function (wrapped_hook , ** kwargs )
97102 self ._hooks .add (handle )
98103 logger .debug (f"{ self } added { handle } " )
@@ -113,3 +118,15 @@ def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
113118 hook .remove ()
114119
115120 self ._hooks -= handles
121+
122+ def _get_register_function (
123+ self , target : torch .nn .Module , hook_type : str
124+ ) -> Callable :
125+ if hook_type == "query" :
126+ return partial (register_query_hook , target )
127+ elif hook_type == "key" :
128+ return partial (register_key_hook , target )
129+ elif hook_type == "value" :
130+ return partial (register_value_hook , target )
131+ else :
132+ return getattr (target , f"register_{ hook_type } _hook" )
0 commit comments