Skip to content

Conversation

@zqiu24
Copy link
Contributor

@zqiu24 zqiu24 commented Sep 29, 2025

Dear All,

I have noticed some minor issues in the forward of OFT, the most important is that the indexing is on cpu (instead of on the device), resulting in slower training.

I have made some minor modifications, I tested on 8B finetuning and it is approximately 30% faster.

Hope for a quick review :) @BenjaminBossan

Thank you so much.

Best,

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR to make OFT faster, I found a few issues, please check my comments.

In my local testing, this PR increased the speed from 5726 tok/sec to 6582 tok/sec using the default MetaMath experiment for OFT (~15% faster).

vec = matrix[:, self.rows, self.cols]
return vec

@torch.compile
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, we don't torch.compile any code but instead leave it to the user to choose if they want to compile or not. Testing on my machine with a 4090, I also don't see a big difference (6582 vs 6532 tokens / sec). Did you find this to help a lot in your setting? If not, I'd suggest to remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick reply! The thing is that I noticed it will speed it up a bit (I deliberately do not add the specific configurations like dynamic etc. to avoid error), but we can leave it out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: if we have custom kernels (for example triton kernel, is it allowed to add?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that I noticed it will speed it up a bit

How much difference does it make for you? If it's not significantly more than I observed, I'd say, let's remove it for now.

General question: if we have custom kernels (for example triton kernel, is it allowed to add?)

Good question. So at HF, there is a general push to make it easier to optionally make use of kernels: https://huggingface.co/blog/hello-hf-kernels. In transformers, there is already an integration, though the API is still early stage (I think there is no option for fine-grained control of what kernels to use).

For PEFT, we plan to add the same functionality but we're still working out how to exactly implement it. Once it's there, we will happily accept kernels for specific PEFT methods (and migrate existing kernels there). Until then, feel free to add your OFT kernel to the HF kernel hub, we could use it for testing the kernels integration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the information :) In another project I am building kernels for this orthgonalization, which makes it faster and more memory-efficient than the current torch.compile, but sure, we can add it later :)
Let's just remove the torch.compile for now.

if required_dtype != self.weight.dtype:
x = x.to(self.weight.dtype)

if self.rows.device != self.weight.device:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, tensor.to(device) is a no-op if the device is already the same, i.e. we cna remove the device check.

Comment on lines 255 to 256
self.rows = self.rows.to(self.weight.device)
self.cols = self.cols.to(self.weight.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder: Should we also move rows and cols to the correct device directly when they're initialized?

self.oft_R[adapter_name] = OFTRotationModule(
r if not block_share else 1,
n_elements,
oft_block_size,
self.in_features,
coft=coft,
eps=eps,
block_share=block_share,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)

As these tensors are quite small, I think the memory impact is negligible but we avoid moving them during each forward call. I would still keep this code here to be safe (as it's a no-op if they're already on the right device).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could also do this, but when it is initialized I think it is on cpu? (I checked the device of the base layer weight, it is on cpu), so how should I know the correct device? Best,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at this point, the weights of the base model should generally already be on the accelerator. E.g. when I run these tests with a GPU:

pytest tests/test_custom_models.py -k "test_forward_float16 and oft and not boft"

and I add this check before OFTRotationModule is assigned:

base_layer_device = self.get_base_layer().device
assert base_layer_device.type == "cuda"

it passes. So we can add the device as input argument to OFTRotationModule and then ensure that it's moved to the right device. WDYT?

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 1, 2025

@BenjaminBossan I just pushed a newer version. Best

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making these changes. There are still some small issues, please check my comments.

In addition, let's add a test for the new warning. It can be added to test_initialization.py. Here is a proposal:

import json

from peft import OFTConfig

[..]

class TestOft:
    torch_device = infer_device()

    def get_model(self, bias=True):
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = nn.Linear(32, 32)

        return MyModule().eval().to(self.torch_device)

    @pytest.mark.parametrize("peft_version", ["0.17.0", "0.18.0", None])
    def test_load_outdated_oft_checkpoint_warns(self, peft_version, tmp_path, recwarn):
        # In PEFT v0.18.0, there was a small change in the OFT implementation with Cayley-Neumann enabled. As the
        # outputs change slightly, users need to be warned about it if the checkpoint stems from a PEFT version below
        # 0.18.0. When the 'peft_version' key is not in the config, it means that the version is below 0.18.0.
        config = OFTConfig(target_modules=["lin"], use_cayley_neumann=True)  # only relevant when using Cayley-Neumann
        model = get_peft_model(self.get_model(), config)
        model.save_pretrained(tmp_path)
        del model

        # overwrite the peft_version
        with open(tmp_path / "adapter_config.json") as f:
            config_json = json.load(f)

        if peft_version is None:
            del config_json["peft_version"]
        else:
            config_json["peft_version"] = peft_version

        with open(tmp_path / "adapter_config.json", "w") as f:
            json.dump(config_json, f)

        msg = "TODO"  # <= replace with final warning message
        PeftModel.from_pretrained(self.get_model(), tmp_path)

        warn_messages = [str(w.message) for w in recwarn.list]
        if peft_version == "0.18.0":
            assert not any(w.startswith(msg) for w in warn_messages)
        else:
            assert any(w.startswith(msg) for w in warn_messages)

In theory, we also need to add a check in the other direction: A user trains an OFT checkpoint with version 0.18.0 (i.e. containing the change from this PR) and then tries to load this checkpoint in 0.17.0. I think this scenario is rare enough that we can ignore it, but LMK if you think it should be added.

min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if (peft_version == "unknown") or (parsed_version < min_version):
msg = "warning message that explains what is happening"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this message was just a placeholder :-D Could you please add a short explanation here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure:)

Comment on lines 103 to 104
self.rows.to(device)
self.cols.to(device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As rows and cols are tensors, calling .to() is not in-place operations, you have to re-assign the values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

block_share=block_share,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
device=self.weight.device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly more robust:

Suggested change
device=self.weight.device,
device=self.get_base_layer().weight.device,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

kernel_size=base_layer.kernel_size,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
device=self.weight.device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly more robust:

Suggested change
device=self.weight.device,
device=self.get_base_layer().weight.device,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. "
"Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights."
)
if kwargs["use_caylay_neumann"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this more backwards compatible, also, there was a typo:

Suggested change
if kwargs["use_caylay_neumann"]:
if kwargs.get("use_cayley_neumann", False):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 200 to 204
peft_version = kwargs.get("peft_version", "unknown")
parsed_version = packaging.version.Version(peft_version)
min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if (peft_version == "unknown") or (parsed_version < min_version):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
peft_version = kwargs.get("peft_version", "unknown")
parsed_version = packaging.version.Version(peft_version)
min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if (peft_version == "unknown") or (parsed_version < min_version):
peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version
parsed_version = packaging.version.Version(peft_version)
min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if parsed_version < min_version:

I just tested and parsing with "unknown" throws an error, so let's put a dummy value here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 2, 2025

@BenjaminBossan Just updated the comments. Best,

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 2, 2025

@BenjaminBossan What do these error messages mean? Best,

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do these error messages mean?

I forgot that the stored PEFT version may contain a commit hash, like 0.17.2@123abc. My suggested change should address that.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 2, 2025

@BenjaminBossan Thanks. Just updated.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a small bug with meta devices, check my comment.

block_share=block_share,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
device=self.get_base_layer().weight.device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this doesn't work if the device is meta, as we'll get:

  self.weight = self.weight.to(device)
                  ^^^^^^^^^^^^^^^^^^^^^^

E NotImplementedError: Cannot copy out of meta tensor; no data!

Let's do something like:

device = self.get_base_layer().weight.device
[...]
self.oft_R[adapter_name] = OFTRotationModule(
    ...
    device=device if device.type != "meta" else None,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan How to reproduce this with meta?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan but then it is not correct right? During initialization it is meta, so the variables val / cols will never get shifted to gpu (this is why the code is currently slow)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan I think I can leave out this self.weight = self.weight.to(device), since it had worked correctly before, I just want to make sure that the self.cols and self.rows get shifted correctly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reproduce the error, just run the failing tests:

pytest tests/test_custom_models.py::TestPeftCustomModel::test_load_model_low_cpu_mem_usage -k "oft and not boft"

During initialization it is meta, so the variables val / cols will never get shifted to gpu (this is why the code is currently slow)?

So the situation is the following: A user loads a trained PEFT adapter with low_cpu_mem_usage=True, the weights are at first created on meta device, but then when the actual weights are loaded, they are moved to the correct device. This should happen at the end of update_layer, when self._move_adapter_to_device_of_base_layer(adapter_name) is called. I think this method won't work on rows and cols though, as they are just tensors. I think what can be done:

  1. Make them non-persistent buffers, I believe then they would be automatically be moved when oft_R is moved.
  2. Override _move_adapter_to_device_of_base_layer, call super()._move_adapter_to_device_of_base_layer, then handle rows and cols explicitly.

kernel_size=base_layer.kernel_size,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
device=self.get_base_layer().weight.device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 3, 2025

@BenjaminBossan Ok, the newest should be working. I used the non-persistent buffer. Best,

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the latest updates, going with buffers is okay for me. I found a few small remaining TODOs, after that the PR should be good. Please also run make style before committing.

with open(tmp_path / "adapter_config.json", "w") as f:
json.dump(config_json, f)

msg = "TODO" # <= replace with final warning message
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace it with the correct error message (start of the message is fine).

min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if parsed_version < min_version:
msg = "The cayley-neumann parameterization has been slightly changed to be more numerically stable in PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, downgrade PEFT to version 0.17.0 to use the old parameterization."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
msg = "The cayley-neumann parameterization has been slightly changed to be more numerically stable in PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, downgrade PEFT to version 0.17.0 to use the old parameterization."
msg = (
"The cayley-neumann parameterization has been slightly changed to be more numerically stable in "
"PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, "
"downgrade PEFT to version 0.17.0 to use the old parameterization."
)

Let's add some line breaks to keep <= 120 chars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, done

config_json = json.load(f)

if peft_version is None:
del config_json["peft_version"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
del config_json["peft_version"]
config_json.pop("peft_version", None)

In case the key does not exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 6, 2025

@BenjaminBossan Thanks, just updated again. Best.

@BenjaminBossan
Copy link
Member

I checked the failing CI and most failing tests are unrelated (anything with "TypeError: CLIPTextModel.init() got an unexpected keyword argument 'offload_state_dict'" can be ignored). However, there are 2 actual failures:

FAILED tests/test_custom_models.py::TestPeftCustomModel::test_disable_adapters_with_merging[Conv2d 1 OFT-Conv2d-OFTConfig-config_kwargs107] - assert False
 +  where False = <built-in method allclose of type object at 0x00007FFE3FBAB860>(tensor([[-1.1555e+01, -9.5367e-06],\n        [-4.0945e+01,  0.0000e+00]], grad_fn=<LogSoftmaxBackward0>), tensor([[-1.1542e+01, -9.7751e-06],\n        [-4.0898e+01,  0.0000e+00]]), atol=0.0001, rtol=0.0001)
 +    where <built-in method allclose of type object at 0x00007FFE3FBAB860> = torch.allclose
FAILED tests/test_custom_models.py::TestPeftCustomModel::test_disable_adapters_with_merging[Conv2d 4 OFT-Conv2d-OFTConfig-config_kwargs109] - assert False
 +  where False = <built-in method allclose of type object at 0x00007FFE3FBAB860>(tensor([[-1.1555e+01, -9.5367e-06],\n        [-4.0945e+01,  0.0000e+00]], grad_fn=<LogSoftmaxBackward0>), tensor([[-1.1547e+01, -9.6559e-06],\n        [-4.0910e+01,  0.0000e+00]]), atol=0.0001, rtol=0.0001)
 +    where <built-in method allclose of type object at 0x00007FFE3FBAB860> = torch.allclose

I could reproduce these errors locally. What surprises me is that the model output is quite difference compared to the main branch:

# main
>>> torch.testing.assert_close(outputs_before, outputs_disabled)
*** AssertionError: Tensor-likes are not close!

Mismatched elements: 2 / 4 (50.0%)
Greatest absolute difference: 0.002262115478515625 at index (1, 0) (up to 1e-05 allowed)
Greatest relative difference: 5.868273365194909e-05 at index (0, 0) (up to 1.3e-06 allowed)

# this branch
>>> torch.testing.assert_close(outputs_before, outputs_disabled)
*** AssertionError: Tensor-likes are not close!

Mismatched elements: 2 / 4 (50.0%)
Greatest absolute difference: 0.047229766845703125 at index (1, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.0011932472698390484 at index (0, 0) (up to 1.3e-06 allowed)

So the greatest absolute and relative errors increased by a factor of 20. Of course, we could just increase the tolerance further and the tests would pass, but I wonder if something strange could be going on with the new _cayley_batch calculation and conv layers.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 7, 2025

@BenjaminBossan Thanks for pointing this out. I will look into it. Best,

@BenjaminBossan
Copy link
Member

@zqiu24 Did you make any progress? If not, I think we can still merge with higher tolerance. It could just be that this example is degenerate but there is no issue in general. It would be good to have this PR merged before the next release (current plan: next week).

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 14, 2025

@BenjaminBossan Thanks for this update. I will come back to that soon, I am a bit occupied with some other stuff. I will come back to you end of this week. Best,

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 20, 2025

@BenjaminBossan Hi, Benjamin, I locally trained some models with the new formulation, I did not notice major differences (so no divergence or similar thing), so I think it should be fine. Should I change the tests? Best,

@BenjaminBossan
Copy link
Member

Thanks for checking, then this is just a test artifact. Please increase the tolerance for OFT and add a comment that this works well in practice, it's just the test having issues.

Also, could you please merge with/rebase on main? Then the PR should be good to go.

@BenjaminBossan
Copy link
Member

@zqiu24 I came across an OFT issue with unmerging, maybe we can check this and possibly provide a fix in this PR too.

Here is the reproducer first:

import torch
from peft import OFTConfig, get_peft_model
from transformers import AutoModelForCausalLM

model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
config = OFTConfig(init_weights=False)

model = get_peft_model(model, config)
dummy_input = torch.randint(0, model.config.vocab_size, (1, 16))
model.eval()

logits = model(dummy_input)[0]
model.merge_adapter()
logits_merged = model(dummy_input)[0]
model.unmerge_adapter()
logits_unmerged = model(dummy_input)[0]

atol, rtol = 1e-4, 1e-4
assert torch.allclose(logits, logits_merged, atol=atol, rtol=rtol)
assert torch.allclose(logits, logits_unmerged, atol=atol, rtol=rtol)

This fails for me and I think the issue is this line in the unmerge method.

To me, it's not obvious why taking the transpose of oft_mat would be the inverse operation when using torch.mm. Shouldn't the actual inverse be used, i.e.:

orig_weights = torch.mm(torch.linalg.inv(oft_mat), orig_weights.to(oft_mat.dtype))

I know it's not efficient, but at least with this change, the code runs successfully. WDYT @zqiu24?

If we want to make this change, we need to update the other OFT layers too.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 22, 2025

@zqiu24 I came across an OFT issue with unmerging, maybe we can check this and possibly provide a fix in this PR too.

Here is the reproducer first:

import torch
from peft import OFTConfig, get_peft_model
from transformers import AutoModelForCausalLM

model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
config = OFTConfig(init_weights=False)

model = get_peft_model(model, config)
dummy_input = torch.randint(0, model.config.vocab_size, (1, 16))
model.eval()

logits = model(dummy_input)[0]
model.merge_adapter()
logits_merged = model(dummy_input)[0]
model.unmerge_adapter()
logits_unmerged = model(dummy_input)[0]

atol, rtol = 1e-4, 1e-4
assert torch.allclose(logits, logits_merged, atol=atol, rtol=rtol)
assert torch.allclose(logits, logits_unmerged, atol=atol, rtol=rtol)

This fails for me and I think the issue is this line in the unmerge method.

To me, it's not obvious why taking the transpose of oft_mat would be the inverse operation when using torch.mm. Shouldn't the actual inverse be used, i.e.:

orig_weights = torch.mm(torch.linalg.inv(oft_mat), orig_weights.to(oft_mat.dtype))

I know it's not efficient, but at least with this change, the code runs successfully. WDYT @zqiu24?

If we want to make this change, we need to update the other OFT layers too.

Hi, thanks for this comment. Actually I think you are right, the reason why I do this transpose is because orthogonal matrix's inverse is its transpose, but apperently because of the approximation, it has more approximation error to orthogonality, this is why this unmerge has now bigger error. I agree to change it, since it is not in the training and will only be used quite rare.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 22, 2025

@BenjaminBossan I think the current version is able to pass the test without changing the tolerance? And it also fixes this unmerge issue. Best.

@BenjaminBossan
Copy link
Member

@zqiu24 Unfortunately, GH shows 54 commits and 127 changed files. Could you please try merging with/rebasing on main again, so that I can see the actual diff?

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 23, 2025

@BenjaminBossan Thanks for the comment, does this work? Best,

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, the PR LGTM.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Oct 23, 2025

@BenjaminBossan Thank you so much for the support:)

@BenjaminBossan BenjaminBossan merged commit a18ba67 into huggingface:main Oct 23, 2025
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants