from dataclasses import dataclass, field from typing import Dict, List, Optional @dataclass(frozen=True) class TextConfig: dim: int = 2048 n_layers: int = 24 vocab_size: int = 51200 max_context: int = 2048 n_heads: int = 32 prefix_attn: int = 730 @dataclass(frozen=True) class VisionConfig: enc_dim: int = 1152 enc_patch_size: int = 14 enc_n_layers: int = 27 enc_ff_dim: int = 4304 enc_n_heads: int = 16 proj_out_dim: int = 2048 crop_size: int = 378 in_channels: int = 3 max_crops: int = 12 overlap_margin: int = 4 proj_inner_dim: int = 8192 @dataclass(frozen=True) class RegionConfig: dim: int = 2048 coord_feat_dim: int = 256 coord_out_dim: int = 1024 size_feat_dim: int = 512 size_out_dim: int = 2048 inner_dim: int = 8192 @dataclass(frozen=True) class TokenizerConfig: bos_id: int = 50256 eos_id: int = 50256 templates: Dict[str, Optional[Dict[str, List[int]]]] = field( default_factory=lambda: { "caption": { "short": [198, 198, 16438, 8305, 25], "normal": [198, 198, 24334, 1159, 25], }, "query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]}, "detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]}, "point": {"prefix": [198, 198, 12727, 25], "suffix": [628]}, } ) @dataclass(frozen=True) class MoondreamConfig: text: TextConfig = TextConfig() vision: VisionConfig = VisionConfig() region: RegionConfig = RegionConfig() tokenizer: TokenizerConfig = TokenizerConfig() @classmethod def from_dict(cls, config_dict: dict): text_config = TextConfig(**config_dict.get("text", {})) vision_config = VisionConfig(**config_dict.get("vision", {})) region_config = RegionConfig(**config_dict.get("region", {})) tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {})) return cls( text=text_config, vision=vision_config, region=region_config, tokenizer=tokenizer_config, ) def to_dict(self): return { "text": self.text.__dict__, "vision": self.vision.__dict__, "region": self.region.__dict__, "tokenizer": self.tokenizer.__dict__, }