Skip to content

Commit

Permalink
Copy tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Sep 26, 2023
1 parent d05478c commit f9a4d51
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions bakllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def finalize(self):
json.dump({"metadata": {}, "weight_map": self.weight_map}, file)


def process(config: BakllamaConfig, out_path: str, clone_tensors: bool = False):
def process(
config: BakllamaConfig,
out_path: str,
clone_tensors: bool = False,
copy_tokenizer: bool = True,
):
if config.embedding_source is None:
config.embedding_source = config.layer_slices[0].model

Expand Down Expand Up @@ -179,6 +184,11 @@ def process(config: BakllamaConfig, out_path: str, clone_tensors: bool = False):
)
writer.finalize()

if copy_tokenizer:
transformers.AutoTokenizer.from_pretrained(layer_sources[0][0]).save_pretrained(
out_path, safe_serialization=True
)


def main(
config_path: str,
Expand All @@ -189,11 +199,14 @@ def main(
help="Clone tensors before saving, to allow multiple occurrences of the same layer"
),
] = False,
copy_tokenizer: bool = True,
):
with open(config_path, "r", encoding="utf-8") as file:
config = BakllamaConfig(**yaml.safe_load(file))

process(config, out_path, clone_tensors=clone_tensors)
process(
config, out_path, clone_tensors=clone_tensors, copy_tokenizer=copy_tokenizer
)


if __name__ == "__main__":
Expand Down

0 comments on commit f9a4d51

Please sign in to comment.