| 
 | 1 | +# Prediction interface for Cog ⚙️  | 
 | 2 | +# https://github.com/replicate/cog/blob/main/docs/python.md  | 
 | 3 | + | 
 | 4 | +from typing import List, Optional  | 
 | 5 | +from cog import BasePredictor, Input, Path  | 
 | 6 | +import data  | 
 | 7 | +import torch  | 
 | 8 | +from models import imagebind_model  | 
 | 9 | +from models.imagebind_model import ModalityType  | 
 | 10 | + | 
 | 11 | +MODALITY_TO_PREPROCESSING = {  | 
 | 12 | +    ModalityType.TEXT: data.load_and_transform_text,  | 
 | 13 | +    ModalityType.VISION: data.load_and_transform_vision_data,  | 
 | 14 | +    ModalityType.AUDIO: data.load_and_transform_audio_data,  | 
 | 15 | +}  | 
 | 16 | + | 
 | 17 | + | 
 | 18 | +class Predictor(BasePredictor):  | 
 | 19 | +    def setup(self):  | 
 | 20 | +        """Load the model into memory to make running multiple predictions efficient"""  | 
 | 21 | +        model = imagebind_model.imagebind_huge(pretrained=True)  | 
 | 22 | +        model.eval()  | 
 | 23 | +        self.model = model.to("cuda")  | 
 | 24 | + | 
 | 25 | +    def predict(  | 
 | 26 | +        self,  | 
 | 27 | +        input: Path = Input(  | 
 | 28 | +            description="file that you want to embed. Needs to be text, vision, or audio.",  | 
 | 29 | +            default=None,  | 
 | 30 | +        ),  | 
 | 31 | +        text_input: str = Input(  | 
 | 32 | +            description="text that you want to embed. Provide a string here instead of a text file to input if you'd like.",  | 
 | 33 | +            default=None,  | 
 | 34 | +        ),  | 
 | 35 | +        modality: str = Input(  | 
 | 36 | +            description="modality of the input you'd like to embed",  | 
 | 37 | +            choices=list(MODALITY_TO_PREPROCESSING.keys()),  | 
 | 38 | +            default=ModalityType.VISION,  | 
 | 39 | +        ),  | 
 | 40 | +    ) -> List[float]:  | 
 | 41 | +        """Infer a single embedding with the model"""  | 
 | 42 | + | 
 | 43 | +        if not input and not text_input:  | 
 | 44 | +            raise Exception(  | 
 | 45 | +                "Neither input nor text_input were provided! Provide one in order to generate an embedding"  | 
 | 46 | +            )  | 
 | 47 | + | 
 | 48 | +        modality_function = MODALITY_TO_PREPROCESSING[modality]  | 
 | 49 | + | 
 | 50 | +        if modality == "text":  | 
 | 51 | +            if input and text_input:  | 
 | 52 | +                raise Exception(  | 
 | 53 | +                    f"Input and text_input were both provided! Only provide one to generate an embedding.\nInput provided: {input}\nText Input provided: {text_input}"  | 
 | 54 | +                )  | 
 | 55 | +            if text_input:  | 
 | 56 | +                input = text_input  | 
 | 57 | +            else:  | 
 | 58 | +                with open(input, "r") as f:  | 
 | 59 | +                    text_input = f.readlines()  | 
 | 60 | +                input = text_input  | 
 | 61 | + | 
 | 62 | +        device = "cuda"  | 
 | 63 | +        model_input = {modality: modality_function([input], device)}  | 
 | 64 | + | 
 | 65 | +        with torch.no_grad():  | 
 | 66 | +            embeddings = self.model(model_input)  | 
 | 67 | +        # print(type(embeddings))  | 
 | 68 | +        emb = embeddings[modality]  | 
 | 69 | +        return emb.cpu().squeeze().tolist()  | 
0 commit comments