Skip to content
Open
1 change: 1 addition & 0 deletions mindone/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@
EsmModel,
EsmPreTrainedModel,
)
from .models.evolla import EvollaForProteinText2Text, EvollaModel, EvollaPreTrainedModel, EvollaProcessor
from .models.falcon import (
FalconForCausalLM,
FalconForQuestionAnswering,
Expand Down
1 change: 1 addition & 0 deletions mindone/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
encoder_decoder,
ernie,
esm,
evolla,
falcon,
falcon_mamba,
fastspeech2_conformer,
Expand Down
1 change: 1 addition & 0 deletions mindone/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
("chameleon", "ChameleonProcessor"),
("chinese_clip", "ChineseCLIPProcessor"),
("colpali", "ColPaliProcessor"),
("evolla", "EvollaProcessor"),
("flava", "FlavaProcessor"),
("idefics", "IdeficsProcessor"),
("instructblip", "InstructBlipProcessor"),
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/evolla/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling_evolla import EvollaForProteinText2Text, EvollaModel, EvollaPreTrainedModel
from .processing_evolla import EvollaProcessor
1,747 changes: 1,747 additions & 0 deletions mindone/transformers/models/evolla/modeling_evolla.py

Large diffs are not rendered by default.

265 changes: 265 additions & 0 deletions mindone/transformers/models/evolla/processing_evolla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for EVOLLA.
"""

import os
from typing import Optional, Union

from transformers.models.auto import AutoTokenizer

import mindspore as ms

from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin

PROTEIN_VALID_KEYS = ["aa_seq", "foldseek", "msa"]


class EvollaProcessor(ProcessorMixin):
r"""
Constructs a EVOLLA processor which wraps a LLama tokenizer and SaProt tokenizer (EsmTokenizer) into a single processor.

[`EvollaProcessor`] offers all the functionalities of [`EsmTokenizer`] and [`LlamaTokenizerFast`]. See the
docstring of [`~EvollaProcessor.__call__`] and [`~EvollaProcessor.decode`] for more information.

Args:
protein_tokenizer (`EsmTokenizer`):
An instance of [`EsmTokenizer`]. The protein tokenizer is a required input.
tokenizer (`LlamaTokenizerFast`, *optional*):
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
protein_max_length (`int`, *optional*, defaults to 1024):
The maximum length of the sequence to be generated.
text_max_length (`int`, *optional*, defaults to 512):
The maximum length of the text to be generated.
"""

attributes = ["protein_tokenizer", "tokenizer"]
valid_kwargs = ["sequence_max_length"]
# protein_tokenizer_class = "EsmTokenizer"
# tokenizer_class = "LlamaTokenizerFast"
protein_tokenizer_class = "AutoTokenizer"
tokenizer_class = "AutoTokenizer"
protein_tokenizer_dir_name = "protein_tokenizer"
# tokenizer_dir_name = "text_tokenizer"

def __init__(self, protein_tokenizer, tokenizer=None, protein_max_length=1024, text_max_length=512, **kwargs):
if protein_tokenizer is None:
raise ValueError("You need to specify an `protein_tokenizer`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")

super().__init__(protein_tokenizer, tokenizer)

self.tokenizer.pad_token = "<|reserved_special_token_0|>"
self.protein_max_length = protein_max_length
self.text_max_length = text_max_length

def process_proteins(self, proteins, protein_max_length=1024):
sa_sequences = []
for protein in proteins:
aa_seq = protein.get("aa_seq")
foldseek = protein.get("foldseek")
sa_sequence = "".join([s.upper() + f.lower() for s, f in zip(aa_seq, foldseek)])
sa_sequences.append(sa_sequence)

sa_tokens = self.protein_tokenizer.batch_encode_plus(
sa_sequences, return_tensors="np", truncation=True, max_length=protein_max_length, padding=True
)
return sa_tokens

def process_text(
self,
texts,
text_max_length: int = 512,
):
prompts = []
for messages in texts:
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prompts.append(prompt)

prompt_inputs = self.tokenizer(
prompts,
add_special_tokens=False,
return_tensors="np",
padding="longest",
truncation=True,
max_length=text_max_length,
)
return prompt_inputs

def __call__(
self,
proteins: Optional[Union[list[dict], dict]] = None,
messages_list: Optional[Union[list[list[dict]], list[dict]]] = None,
protein_max_length: Optional[int] = None,
text_max_length: Optional[int] = None,
**kwargs,
):
r"""This method takes batched or non-batched proteins and messages_list and converts them into format that can be used by
the model.

Args:
proteins (`Union[List[dict], dict]`):
A list of dictionaries or a single dictionary containing the following keys:
- `"aa_seq"` (`str`) -- The amino acid sequence of the protein.
- `"foldseek"` (`str`) -- The foldseek string of the protein.
messages_list (`Union[List[List[dict]], List[dict]]`):
A list of lists of dictionaries or a list of dictionaries containing the following keys:
- `"role"` (`str`) -- The role of the message.
- `"content"` (`str`) -- The content of the message.
protein_max_length (`int`, *optional*, defaults to 1024):
The maximum length of the sequence to be generated.
text_max_length (`int`, *optional*, defaults to 512):
The maximum length of the text.

Return:
a dict with following keys:
- `protein_input_ids` (`mindspore.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the protein sequence.
- `protein_attention_mask` (`mindspore.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the protein sequence.
- `text_input_ids` (`mindspore.Tensor` of shape `(batch_size, sequence_length)`) -- The input IDs for the text sequence.
- `text_attention_mask` (`mindspore.Tensor` of shape `(batch_size, sequence_length)`) -- The attention mask for the text sequence.
"""
# proteins and messages_list should be provided
if proteins is None or messages_list is None:
raise ValueError("You need to specify `messages_list` and `proteins`.")

protein_max_length = protein_max_length if protein_max_length is not None else self.protein_max_length
text_max_length = text_max_length if text_max_length is not None else self.text_max_length

# proteins should be List[dict]
if isinstance(proteins, dict):
proteins = [proteins]
# messages_list should be List[List[dict]]
if isinstance(messages_list, (list, tuple)) and not isinstance(messages_list[0], (list, tuple)):
messages_list = [messages_list]
# Check if batched proteins are in the correct format
if isinstance(proteins, (list, tuple)) and not all(isinstance(p, dict) for p in proteins):
raise ValueError("The proteins should be a list of dictionaries, but not all elements are dictionaries.")
if isinstance(proteins, (list, tuple)) and not all(
all(k in PROTEIN_VALID_KEYS for k in p.keys()) for p in proteins
):
raise ValueError(
"There should be a list of dictionaries with keys: "
f"{', '.join(PROTEIN_VALID_KEYS)} for each protein."
f"But got: {proteins}"
)
# Check if batched messages_list is in the correct format
if isinstance(messages_list, (list, tuple)):
for messages in messages_list:
if not isinstance(messages, (list, tuple)):
raise TypeError(f"Each messages in messages_list should be a list instead of {type(messages)}.")
if not all(isinstance(m, dict) for m in messages):
raise ValueError(
"Each message in messages_list should be a list of dictionaries, but not all elements are dictionaries."
)
if any(len(m.keys()) != 2 for m in messages) or any(
set(m.keys()) != {"role", "content"} for m in messages
):
raise ValueError(
"Each message in messages_list should be a list of dictionaries with two keys: 'role' and 'content'."
f"But got: {messages}"
)
else:
raise ValueError(
f"The messages_list should be a list of lists of dictionaries, but it's {type(messages_list)}."
)
sa_tokens = self.process_proteins(proteins, protein_max_length)

text_tokens = self.process_text(messages_list, text_max_length)

return_tensors = kwargs.get("return_tensors", "np")

if return_tensors == "np":
return BatchFeature(
data={
"protein_input_ids": sa_tokens["input_ids"],
"protein_attention_mask": sa_tokens["attention_mask"],
"input_ids": text_tokens["input_ids"],
"attention_mask": text_tokens["attention_mask"],
}
)
elif return_tensors == "ms":
return BatchFeature(
data={
"protein_input_ids": ms.Tensor(sa_tokens["input_ids"]),
"protein_attention_mask": ms.Tensor(sa_tokens["attention_mask"]),
"input_ids": ms.Tensor(text_tokens["input_ids"]),
"attention_mask": ms.Tensor(text_tokens["attention_mask"]),
}
)
else:
raise ValueError(f"Invalid return_tensors: {return_tensors}")

def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)

def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)

def protein_batch_decode(self, *args, **kwargs):
return self.protein_tokenizer.batch_decode(*args, **kwargs)

def protein_decode(self, *args, **kwargs):
return self.protein_tokenizer.decode(*args, **kwargs)

# overwrite to save the protein tokenizer in a separate folder
# Adapted from instructblip.processing_instructblip.py
# (https://github.com/huggingface/transformers/blob/9b479a245b793cac2a8b2e87c6d8e81bb24e20c4/src/\
# transformers/models/instructblip/processing_instructblip.py#L191-L221)
def save_pretrained(self, save_directory, **kwargs):
# only save the protein tokenizer in sub_dir
self.protein_tokenizer.save_pretrained(os.path.join(save_directory, self.protein_tokenizer_dir_name))

# we modify the attributes so that only the text tokenizer are saved in the main folder
protein_tokenizer_present = "protein_tokenizer" in self.attributes
# find the correct position of it in the attributes list
protein_tokenizer_index = self.attributes.index("protein_tokenizer") if protein_tokenizer_present else None
if protein_tokenizer_present and protein_tokenizer_index is not None:
self.attributes.remove("protein_tokenizer")

outputs = super().save_pretrained(save_directory, **kwargs)

if protein_tokenizer_present and protein_tokenizer_index is not None:
self.attributes.insert(protein_tokenizer_index, "protein_tokenizer")

return outputs

# overwrite to load the protein tokenizer from a separate folder
# Adapted from instructblip.processing_instructblip.py
# (https://github.com/huggingface/transformers/blob/9b479a245b793cac2a8b2e87c6d8e81bb24e20c4/src/transformers/\
# models/instructblip/processing_instructblip.py#L191-L221)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)

# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
if isinstance(processor, tuple):
processor = processor[0]
protein_tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder=cls.protein_tokenizer_dir_name
)

processor.protein_tokenizer = protein_tokenizer

return processor


__all__ = ["EvollaProcessor"]
Empty file.
Loading