-
Notifications
You must be signed in to change notification settings - Fork 303
Gemma3 text keras hf checkpoint conversion #2433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,10 @@ | |
Usage: | ||
```shell | ||
cd tools/checkpoint_conversion | ||
python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b | ||
python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b | ||
python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b \ | ||
--export_safetensors | ||
python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b \ | ||
--export_safetensors | ||
``` | ||
""" | ||
|
||
|
@@ -21,16 +23,209 @@ | |
# No GPU for conversion, makes memory management easier. | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | ||
|
||
import json | ||
import shutil | ||
|
||
import keras # noqa: E402 | ||
import numpy as np | ||
import tensorflow_datasets as tfds # noqa: E402 | ||
import torch | ||
import transformers | ||
from absl import app # noqa: E402 | ||
from absl import flags # noqa: E402 | ||
from checkpoint_conversion_utils import download_gcs_file | ||
from gemma import gm # noqa: E402 | ||
from keras import ops # noqa: E402 | ||
from safetensors.torch import save_file | ||
from transformers import AutoModelForCausalLM | ||
from transformers import AutoTokenizer | ||
|
||
import keras_hub # noqa: E402 | ||
|
||
|
||
def convert_to_hf_config(keras_config): | ||
"""Convert Keras Gemma config to Hugging Face GemmaConfig.""" | ||
hf_config = transformers.Gemma3TextConfig( | ||
vocab_size=keras_config.vocabulary_size, | ||
num_hidden_layers=keras_config.num_layers, | ||
num_attention_heads=keras_config.num_query_heads, | ||
num_key_value_heads=keras_config.num_key_value_heads, | ||
hidden_size=keras_config.hidden_dim, | ||
intermediate_size=keras_config.intermediate_dim, | ||
head_dim=keras_config.head_dim, | ||
max_position_embeddings=32768, | ||
) | ||
return hf_config | ||
|
||
|
||
def export_to_hf(backbone, keras_tokenizer, path): | ||
"""Convert a Keras Gemma model to Hugging Face format and save to path.""" | ||
Comment on lines
+61
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring for this function is missing def export_to_hf(backbone, keras_tokenizer, path):
"""Convert a Keras Gemma model to Hugging Face format and save to path.
Args:
backbone: A `keras_hub.models.Gemma3Backbone` instance.
keras_tokenizer: A `keras_hub.models.Gemma3Tokenizer` instance.
path: str. The path to save the Hugging Face model to.
""" Style Guide ReferencesFootnotes
|
||
|
||
hf_config = convert_to_hf_config(backbone) | ||
weights_dict = {} | ||
|
||
# Helper function to convert bfloat16 weights to torch tensors | ||
def to_torch(weight): | ||
# Convert bfloat16 to float32 first, then to torch, then to bfloat16 | ||
if hasattr(weight.dtype, "name") and "bfloat16" in str(weight.dtype): | ||
weight = np.array(weight, dtype=np.float32) | ||
return torch.from_numpy(weight).to(torch.bfloat16) | ||
Comment on lines
+68
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This helper function can be simplified to robustly handle various array types (like JAX arrays) and then used consistently throughout Currently, the conversion logic With the suggested change, you can then refactor the rest of the function, for example: q_kernel = block.attention.query_dense.get_weights()[0]
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = (
to_torch(q_kernel)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
) def to_torch(weight):
# Convert array-like weights (e.g., from JAX) to a float32 NumPy
# array before creating a bfloat16 torch tensor for compatibility.
np_weight = np.array(weight, dtype=np.float32)
return torch.from_numpy(np_weight).to(torch.bfloat16) |
||
|
||
# Token embeddings | ||
token_embedding = backbone.get_layer("token_embedding").get_weights()[0] | ||
weights_dict["model.embed_tokens.weight"] = to_torch(token_embedding) | ||
|
||
for i in range(backbone.num_layers): | ||
block = backbone.get_layer(f"decoder_block_{i}") | ||
q_kernel = block.attention.query_dense.get_weights()[0] | ||
q_kernel = ( | ||
torch.from_numpy(np.array(q_kernel, dtype=np.float32)) | ||
.to(torch.bfloat16) | ||
.permute(1, 0, 2) | ||
.reshape(backbone.hidden_dim, -1) | ||
.T | ||
) | ||
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = q_kernel | ||
|
||
k_kernel = block.attention.key_dense.get_weights()[0] | ||
k_kernel = ( | ||
torch.from_numpy(np.array(k_kernel, dtype=np.float32)) | ||
.to(torch.bfloat16) | ||
.permute(1, 0, 2) | ||
.reshape(backbone.hidden_dim, -1) | ||
.T | ||
) | ||
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = k_kernel | ||
|
||
v_kernel = block.attention.value_dense.get_weights()[0] | ||
v_kernel = ( | ||
torch.from_numpy(np.array(v_kernel, dtype=np.float32)) | ||
.to(torch.bfloat16) | ||
.permute(1, 0, 2) | ||
.reshape(backbone.hidden_dim, -1) | ||
.T | ||
) | ||
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = v_kernel | ||
|
||
o_kernel = block.attention.output_dense.get_weights()[0] | ||
o_kernel = ( | ||
torch.from_numpy(np.array(o_kernel, dtype=np.float32)) | ||
.to(torch.bfloat16) | ||
.permute(2, 0, 1) | ||
.reshape(backbone.hidden_dim, -1) | ||
) | ||
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = o_kernel | ||
|
||
q_norm = block.attention.query_norm.get_weights()[0] | ||
weights_dict[f"model.layers.{i}.self_attn.q_norm.weight"] = to_torch( | ||
q_norm | ||
) | ||
|
||
k_norm = block.attention.key_norm.get_weights()[0] | ||
weights_dict[f"model.layers.{i}.self_attn.k_norm.weight"] = to_torch( | ||
k_norm | ||
) | ||
|
||
gate_kernel = block.gating_ffw.get_weights()[0] | ||
gate_kernel = ( | ||
torch.from_numpy(np.array(gate_kernel, dtype=np.float32)) | ||
.to(torch.bfloat16) | ||
.T | ||
) | ||
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel | ||
|
||
up_kernel = block.gating_ffw_2.get_weights()[0] | ||
up_kernel = ( | ||
torch.from_numpy(np.array(up_kernel, dtype=np.float32)) | ||
.to(torch.bfloat16) | ||
.T | ||
) | ||
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel | ||
|
||
down_kernel = block.ffw_linear.get_weights()[0] | ||
down_kernel = ( | ||
torch.from_numpy(np.array(down_kernel, dtype=np.float32)) | ||
.to(torch.bfloat16) | ||
.T | ||
) | ||
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel | ||
|
||
input_layer_norm = block.pre_attention_norm.get_weights()[0] | ||
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = to_torch( | ||
input_layer_norm | ||
) | ||
|
||
post_attn_norm = block.post_attention_norm.get_weights()[0] | ||
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( | ||
to_torch(post_attn_norm) | ||
) | ||
|
||
pre_feedforward_layernorm_weight = block.pre_ffw_norm.get_weights()[0] | ||
weights_dict[f"model.layers.{i}.pre_feedforward_layernorm.weight"] = ( | ||
to_torch(pre_feedforward_layernorm_weight) | ||
) | ||
|
||
post_feedforward_layernorm_weight = block.post_ffw_norm.get_weights()[0] | ||
weights_dict[f"model.layers.{i}.post_feedforward_layernorm.weight"] = ( | ||
to_torch(post_feedforward_layernorm_weight) | ||
) | ||
|
||
final_norm = backbone.get_layer("final_normalization").get_weights()[0] | ||
weights_dict["model.norm.weight"] = to_torch(final_norm) | ||
weights_dict["lm_head.weight"] = weights_dict[ | ||
"model.embed_tokens.weight" | ||
].clone() | ||
|
||
os.makedirs(path, exist_ok=True) | ||
with open(os.path.join(path, "config.json"), "w") as f: | ||
json.dump(hf_config.to_dict(), f) | ||
weights_dict = {k: v.contiguous() for k, v in weights_dict.items()} | ||
save_file(weights_dict, os.path.join(path, "model.safetensors")) | ||
keras_tokenizer.save_assets(path) | ||
vocab_spm = os.path.join(path, "vocabulary.spm") | ||
tokenizer_model = os.path.join(path, "tokenizer.model") | ||
if os.path.exists(vocab_spm): | ||
shutil.move(vocab_spm, tokenizer_model) | ||
print("Export complete! Files saved in:", path) | ||
|
||
|
||
def load_hf_model(model_name, device): | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForCausalLM.from_pretrained(model_name) | ||
model.to(device) | ||
model.eval() | ||
return model, tokenizer | ||
|
||
|
||
def infer( | ||
model, | ||
tokenizer, | ||
prompt, | ||
device, | ||
max_new_tokens=30, | ||
temperature=1.0, | ||
top_k=50, | ||
top_p=1.0, | ||
): | ||
# Tokenize inpu | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
inputs = tokenizer(prompt, return_tensors="pt").to(device) | ||
|
||
# Generate output | ||
with torch.no_grad(): | ||
outputs = model.generate( | ||
**inputs, | ||
max_new_tokens=max_new_tokens, | ||
temperature=temperature, | ||
top_k=top_k, | ||
top_p=top_p, | ||
do_sample=True, | ||
) | ||
Comment on lines
+215
to
+222
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=False,
) |
||
|
||
# Decode generated tokens | ||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False) | ||
return generated_text | ||
|
||
|
||
keras.utils.set_random_seed(42) | ||
|
||
FLAGS = flags.FLAGS | ||
|
@@ -126,6 +321,13 @@ | |
required=True, | ||
) | ||
|
||
flags.DEFINE_bool( | ||
"export_safetensors", | ||
False, | ||
"Export model to Safetensors format (HuggingFace-compatible). \ | ||
Only for text-only models.", | ||
Comment on lines
+327
to
+328
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The help string for the "Export model to Safetensors format (HuggingFace-compatible). "
"Only for text-only models.", |
||
) | ||
|
||
|
||
def convert_model(flax_config, text_only): | ||
vision_encoder = None | ||
|
@@ -558,6 +760,35 @@ def main(_): | |
keras_tokenizer.save_to_preset(preset) | ||
print(f"🏁 Preset saved to ./{preset}") | ||
|
||
if FLAGS.export_safetensors and text_only: | ||
export_dir = f"./{preset}_safetensors_export" | ||
print( | ||
f"🏃 Exporting to Safetensors (HuggingFace format) at {export_dir}" | ||
) | ||
export_to_hf(keras_model, keras_tokenizer, export_dir) | ||
print(f"🏁 Safetensors export complete: {export_dir}") | ||
|
||
local_hf_model, local_hf_tokenizer = load_hf_model( | ||
export_dir, device="cpu" | ||
) | ||
print("Local Hugging Face model loaded successfully!") | ||
|
||
print( | ||
"🔶 Safetensors output:", | ||
infer( | ||
local_hf_model, | ||
local_hf_tokenizer, | ||
"What is Keras?", | ||
"cpu", | ||
max_new_tokens=100, | ||
), | ||
) | ||
elif FLAGS.export_safetensors: | ||
print( | ||
"⚠️ Safetensors export is only supported for text-only models. \ | ||
Skipping export." | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for this function is missing
Args
andReturns
sections, which is inconsistent with the repository's style guide. Providing detailed docstrings improves code clarity and maintainability.1Style Guide References
Footnotes
The style guide requires all public functions to have Google-style docstrings, including comprehensive documentation for all parameters and return values. ↩