Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 233 additions & 2 deletions tools/checkpoint_conversion/convert_gemma3_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
Usage:
```shell
cd tools/checkpoint_conversion
python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b
python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b
python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b \
--export_safetensors
python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b \
--export_safetensors
```
"""

Expand All @@ -21,16 +23,209 @@
# No GPU for conversion, makes memory management easier.
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import json
import shutil

import keras # noqa: E402
import numpy as np
import tensorflow_datasets as tfds # noqa: E402
import torch
import transformers
from absl import app # noqa: E402
from absl import flags # noqa: E402
from checkpoint_conversion_utils import download_gcs_file
from gemma import gm # noqa: E402
from keras import ops # noqa: E402
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

import keras_hub # noqa: E402


def convert_to_hf_config(keras_config):
"""Convert Keras Gemma config to Hugging Face GemmaConfig."""
Comment on lines +46 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this function is missing Args and Returns sections, which is inconsistent with the repository's style guide. Providing detailed docstrings improves code clarity and maintainability.1

def convert_to_hf_config(keras_config):
    """Convert Keras Gemma config to Hugging Face GemmaConfig.

    Args:
        keras_config: A Keras Gemma3 config object from the backbone.

    Returns:
        A `transformers.Gemma3TextConfig` instance.
    """

Style Guide References

Footnotes

  1. The style guide requires all public functions to have Google-style docstrings, including comprehensive documentation for all parameters and return values.

hf_config = transformers.Gemma3TextConfig(
vocab_size=keras_config.vocabulary_size,
num_hidden_layers=keras_config.num_layers,
num_attention_heads=keras_config.num_query_heads,
num_key_value_heads=keras_config.num_key_value_heads,
hidden_size=keras_config.hidden_dim,
intermediate_size=keras_config.intermediate_dim,
head_dim=keras_config.head_dim,
max_position_embeddings=32768,
)
return hf_config


def export_to_hf(backbone, keras_tokenizer, path):
"""Convert a Keras Gemma model to Hugging Face format and save to path."""
Comment on lines +61 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this function is missing Args and Returns sections, which is inconsistent with the repository's style guide. Providing detailed docstrings improves code clarity and maintainability.1

def export_to_hf(backbone, keras_tokenizer, path):
    """Convert a Keras Gemma model to Hugging Face format and save to path.

    Args:
        backbone: A `keras_hub.models.Gemma3Backbone` instance.
        keras_tokenizer: A `keras_hub.models.Gemma3Tokenizer` instance.
        path: str. The path to save the Hugging Face model to.
    """

Style Guide References

Footnotes

  1. The style guide requires all public functions to have Google-style docstrings, including comprehensive documentation for all parameters and return values.


hf_config = convert_to_hf_config(backbone)
weights_dict = {}

# Helper function to convert bfloat16 weights to torch tensors
def to_torch(weight):
# Convert bfloat16 to float32 first, then to torch, then to bfloat16
if hasattr(weight.dtype, "name") and "bfloat16" in str(weight.dtype):
weight = np.array(weight, dtype=np.float32)
return torch.from_numpy(weight).to(torch.bfloat16)
Comment on lines +68 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This helper function can be simplified to robustly handle various array types (like JAX arrays) and then used consistently throughout export_to_hf to reduce code duplication.

Currently, the conversion logic torch.from_numpy(np.array(weight, dtype=np.float32)).to(torch.bfloat16) is repeated for many weights. You can simplify to_torch to encapsulate this logic and improve maintainability.

With the suggested change, you can then refactor the rest of the function, for example:

q_kernel = block.attention.query_dense.get_weights()[0]
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = (
    to_torch(q_kernel)
    .permute(1, 0, 2)
    .reshape(backbone.hidden_dim, -1)
    .T
)
    def to_torch(weight):
        # Convert array-like weights (e.g., from JAX) to a float32 NumPy
        # array before creating a bfloat16 torch tensor for compatibility.
        np_weight = np.array(weight, dtype=np.float32)
        return torch.from_numpy(np_weight).to(torch.bfloat16)


# Token embeddings
token_embedding = backbone.get_layer("token_embedding").get_weights()[0]
weights_dict["model.embed_tokens.weight"] = to_torch(token_embedding)

for i in range(backbone.num_layers):
block = backbone.get_layer(f"decoder_block_{i}")
q_kernel = block.attention.query_dense.get_weights()[0]
q_kernel = (
torch.from_numpy(np.array(q_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
)
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = q_kernel

k_kernel = block.attention.key_dense.get_weights()[0]
k_kernel = (
torch.from_numpy(np.array(k_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
)
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = k_kernel

v_kernel = block.attention.value_dense.get_weights()[0]
v_kernel = (
torch.from_numpy(np.array(v_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
)
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = v_kernel

o_kernel = block.attention.output_dense.get_weights()[0]
o_kernel = (
torch.from_numpy(np.array(o_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(2, 0, 1)
.reshape(backbone.hidden_dim, -1)
)
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = o_kernel

q_norm = block.attention.query_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.self_attn.q_norm.weight"] = to_torch(
q_norm
)

k_norm = block.attention.key_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.self_attn.k_norm.weight"] = to_torch(
k_norm
)

gate_kernel = block.gating_ffw.get_weights()[0]
gate_kernel = (
torch.from_numpy(np.array(gate_kernel, dtype=np.float32))
.to(torch.bfloat16)
.T
)
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel

up_kernel = block.gating_ffw_2.get_weights()[0]
up_kernel = (
torch.from_numpy(np.array(up_kernel, dtype=np.float32))
.to(torch.bfloat16)
.T
)
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel

down_kernel = block.ffw_linear.get_weights()[0]
down_kernel = (
torch.from_numpy(np.array(down_kernel, dtype=np.float32))
.to(torch.bfloat16)
.T
)
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel

input_layer_norm = block.pre_attention_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = to_torch(
input_layer_norm
)

post_attn_norm = block.post_attention_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
to_torch(post_attn_norm)
)

pre_feedforward_layernorm_weight = block.pre_ffw_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.pre_feedforward_layernorm.weight"] = (
to_torch(pre_feedforward_layernorm_weight)
)

post_feedforward_layernorm_weight = block.post_ffw_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.post_feedforward_layernorm.weight"] = (
to_torch(post_feedforward_layernorm_weight)
)

final_norm = backbone.get_layer("final_normalization").get_weights()[0]
weights_dict["model.norm.weight"] = to_torch(final_norm)
weights_dict["lm_head.weight"] = weights_dict[
"model.embed_tokens.weight"
].clone()

os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.json"), "w") as f:
json.dump(hf_config.to_dict(), f)
weights_dict = {k: v.contiguous() for k, v in weights_dict.items()}
save_file(weights_dict, os.path.join(path, "model.safetensors"))
keras_tokenizer.save_assets(path)
vocab_spm = os.path.join(path, "vocabulary.spm")
tokenizer_model = os.path.join(path, "tokenizer.model")
if os.path.exists(vocab_spm):
shutil.move(vocab_spm, tokenizer_model)
print("Export complete! Files saved in:", path)


def load_hf_model(model_name, device):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
model.eval()
return model, tokenizer


def infer(
model,
tokenizer,
prompt,
device,
max_new_tokens=30,
temperature=1.0,
top_k=50,
top_p=1.0,
):
# Tokenize inpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in this comment.

    # Tokenize input

inputs = tokenizer(prompt, return_tensors="pt").to(device)

# Generate output
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
)
Comment on lines +215 to +222
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The infer function is used for a sanity check after converting the model. Currently, it uses sampling (do_sample=True), which makes the output non-deterministic. For a validation or sanity check step in a conversion script, it's better to use a deterministic generation strategy like greedy search to ensure consistent and predictable outputs.

        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            do_sample=False,
        )


# Decode generated tokens
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
return generated_text


keras.utils.set_random_seed(42)

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -126,6 +321,13 @@
required=True,
)

flags.DEFINE_bool(
"export_safetensors",
False,
"Export model to Safetensors format (HuggingFace-compatible). \
Only for text-only models.",
Comment on lines +327 to +328
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The help string for the export_safetensors flag contains a backslash and extra whitespace, which will be literally included in the help message displayed to the user. It's better to use implicit string concatenation for multiline help messages to ensure clean formatting.

    "Export model to Safetensors format (HuggingFace-compatible). "
    "Only for text-only models.",

)


def convert_model(flax_config, text_only):
vision_encoder = None
Expand Down Expand Up @@ -558,6 +760,35 @@ def main(_):
keras_tokenizer.save_to_preset(preset)
print(f"🏁 Preset saved to ./{preset}")

if FLAGS.export_safetensors and text_only:
export_dir = f"./{preset}_safetensors_export"
print(
f"🏃 Exporting to Safetensors (HuggingFace format) at {export_dir}"
)
export_to_hf(keras_model, keras_tokenizer, export_dir)
print(f"🏁 Safetensors export complete: {export_dir}")

local_hf_model, local_hf_tokenizer = load_hf_model(
export_dir, device="cpu"
)
print("Local Hugging Face model loaded successfully!")

print(
"🔶 Safetensors output:",
infer(
local_hf_model,
local_hf_tokenizer,
"What is Keras?",
"cpu",
max_new_tokens=100,
),
)
elif FLAGS.export_safetensors:
print(
"⚠️ Safetensors export is only supported for text-only models. \
Skipping export."
)


if __name__ == "__main__":
app.run(main)
Loading