|
6 | 6 | ) |
7 | 7 | from keras_hub.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE |
8 | 8 | from keras_hub.src.utils.preset_utils import builtin_presets |
9 | | -from keras_hub.src.utils.preset_utils import find_subclass |
10 | 9 | from keras_hub.src.utils.preset_utils import get_preset_loader |
11 | 10 | from keras_hub.src.utils.preset_utils import get_preset_saver |
12 | 11 | from keras_hub.src.utils.python_utils import classproperty |
@@ -171,43 +170,38 @@ def from_preset( |
171 | 170 | ) |
172 | 171 | ``` |
173 | 172 | """ |
174 | | - if cls == Preprocessor: |
| 173 | + if cls is Preprocessor: |
175 | 174 | raise ValueError( |
176 | 175 | "Do not call `Preprocessor.from_preset()` directly. Instead " |
177 | 176 | "choose a particular task preprocessing class, e.g. " |
178 | 177 | "`keras_hub.models.TextClassifierPreprocessor.from_preset()`." |
179 | 178 | ) |
180 | 179 |
|
181 | 180 | loader = get_preset_loader(preset) |
182 | | - backbone_cls = loader.check_backbone_class() |
183 | | - # Detect the correct subclass if we need to. |
184 | | - if cls.backbone_cls != backbone_cls: |
185 | | - cls = find_subclass(preset, cls, backbone_cls) |
186 | | - return loader.load_preprocessor(cls, config_file, **kwargs) |
| 181 | + return loader.load_preprocessor( |
| 182 | + cls=cls, config_file=config_file, kwargs=kwargs |
| 183 | + ) |
187 | 184 |
|
188 | 185 | @classmethod |
189 | | - def _add_missing_kwargs(cls, loader, kwargs): |
190 | | - """Fill in required kwargs when loading from preset. |
191 | | -
|
192 | | - This is a private method hit when loading a preprocessing layer that |
193 | | - was not directly saved in the preset. This method should fill in |
194 | | - all required kwargs required to call the class constructor. For almost, |
195 | | - all preprocessors, the only required args are `tokenizer`, |
196 | | - `image_converter`, and `audio_converter`, but this can be overridden, |
197 | | - e.g. for a preprocessor with multiple tokenizers for different |
198 | | - encoders. |
| 186 | + def _from_defaults(cls, loader, kwargs): |
| 187 | + """Load a preprocessor from default values. |
| 188 | +
|
| 189 | + This is a private method hit for loading a preprocessing layer that was |
| 190 | + not directly saved in the preset. Usually this means loading a |
| 191 | + tokenizer, image_converter and/or audio_converter and calling the |
| 192 | + constructor. But this can be overridden by subclasses as needed. |
199 | 193 | """ |
| 194 | + defaults = {} |
| 195 | + # Allow loading any tokenizer, image_converter or audio_converter config |
| 196 | + # we find on disk. We allow mixing a matching tokenizers and |
| 197 | + # preprocessing layers (though this is usually not a good idea). |
200 | 198 | if "tokenizer" not in kwargs and cls.tokenizer_cls: |
201 | | - kwargs["tokenizer"] = loader.load_tokenizer(cls.tokenizer_cls) |
| 199 | + defaults["tokenizer"] = loader.load_tokenizer() |
202 | 200 | if "audio_converter" not in kwargs and cls.audio_converter_cls: |
203 | | - kwargs["audio_converter"] = loader.load_audio_converter( |
204 | | - cls.audio_converter_cls |
205 | | - ) |
| 201 | + defaults["audio_converter"] = loader.load_audio_converter() |
206 | 202 | if "image_converter" not in kwargs and cls.image_converter_cls: |
207 | | - kwargs["image_converter"] = loader.load_image_converter( |
208 | | - cls.image_converter_cls |
209 | | - ) |
210 | | - return kwargs |
| 203 | + defaults["image_converter"] = loader.load_image_converter() |
| 204 | + return cls(**{**defaults, **kwargs}) |
211 | 205 |
|
212 | 206 | def load_preset_assets(self, preset): |
213 | 207 | """Load all static assets needed by the preprocessing layer. |
|
0 commit comments