Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/transformers/gpt_oss/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import ast
import argparse
import os

from transformers import AutoTokenizer

import mindspore

from mindone.transformers import pipeline, AutoModelForCausalLM


def generate(args):

pipe = pipeline("text-generation", model=args.model_name, mindspore_dtype=mindspore.bfloat16, model_kwargs={'attn_implementation': 'eager'})

messages = [
{"role": "user", "content": args.prompt},
]

outputs = pipe(
messages,
max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1])


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="qwen3 demo.")
parser.add_argument("--prompt", type=str, default="Explain quantum mechanics clearly and concisely.")
parser.add_argument("--model_name", type=str, default="openai/gpt-oss-20b", help="Path to the pre-trained model.")
parser.add_argument("--enable_synchronize", type=ast.literal_eval, default=True)
parser.add_argument("--enable_offload", type=ast.literal_eval, default=False)
parser.add_argument("--dry_run", type=str, default="")
args = parser.parse_args()


if args.dry_run:
os.environ["MS_SIMULATION_LEVEL"] = args.dry_run
if int(args.dry_run) >= 1:
os.environ["GLOG_v"] = "1"
if int(args.dry_run) >= 2:
os.environ["MS_ALLOC_CONF"] = "memory_tracker:True"
if args.enable_synchronize:
mindspore.launch_blocking()
if args.enable_offload:
mindspore.set_offload_context(offload_config={"offload_param":"cpu", "offload_path": "./offload", "offload_cpu_size":"512GB", "hbm_ratio":0.9})

generate(args)
2 changes: 1 addition & 1 deletion mindone/safetensors/mindspore.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, ms.Tensor]:

def _np2ms(np_dict: Dict[str, np.ndarray]) -> Dict[str, ms.Tensor]:
for k, v in np_dict.items():
np_dict[k] = ms.Parameter(v, name=k)
np_dict[k] = ms.tensor(v)
return np_dict


Expand Down
3 changes: 3 additions & 0 deletions mindone/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,3 +643,6 @@
MiniMaxPreTrainedModel,
)
from .models.vjepa2 import VJEPA2ForVideoClassification, VJEPA2Model, VJEPA2PreTrainedModel

if version.parse(transformers.__version__) >= version.parse("4.55.0"):
from .models.gpt_oss import *
146 changes: 146 additions & 0 deletions mindone/transformers/integrations/mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import math

import mindspore
from mindspore import nn

from ..utils import logging


logger = logging.get_logger(__name__)

FP4_VALUES = [
+0.0,
+0.5,
+1.0,
+1.5,
+2.0,
+3.0,
+4.0,
+6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]


# Copied from https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/weights.py#L68
def convert_moe_packed_tensors(
blocks,
scales,
*,
dtype: mindspore.Type = mindspore.bfloat16,
rows_per_chunk: int = 32768 * 1024,
) -> mindspore.Tensor:
scales = scales.to(mindspore.int32) - 127

assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}"

lut = mindspore.tensor(FP4_VALUES, dtype=dtype)

*prefix_shape, G, B = blocks.shape
rows_total = math.prod(prefix_shape) * G

blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)

out = mindspore.mint.empty(rows_total, B * 2, dtype=dtype)

for r0 in range(0, rows_total, rows_per_chunk):
r1 = min(r0 + rows_per_chunk, rows_total)

blk = blocks[r0:r1]
exp = scales[r0:r1]

# nibble indices -> int64
# idx_lo = (blk & 0x0F).to(mindspore.int64)
# idx_hi = (blk >> 4).to(mindspore.int64)
idx_lo = mindspore.tensor(blk.numpy() & 0x0F, mindspore.int64)
idx_hi = mindspore.tensor(blk.numpy() >> 4, mindspore.int64)
Comment on lines +77 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using .numpy() for bitwise operations can be inefficient as it involves data transfer between the device (e.g., GPU/NPU) and CPU, and then back to a tensor. For better performance, you should use MindSpore's device-side bitwise operations.

Suggested change
idx_lo = mindspore.tensor(blk.numpy() & 0x0F, mindspore.int64)
idx_hi = mindspore.tensor(blk.numpy() >> 4, mindspore.int64)
idx_lo = mindspore.ops.bitwise_and(blk, 0x0F).to(mindspore.int64)
idx_hi = mindspore.ops.right_shift(blk, 4).to(mindspore.int64)


sub = out[r0:r1]
sub[:, 0::2] = lut[idx_lo]
sub[:, 1::2] = lut[idx_hi]

out[r0:r1] = mindspore.ops.ldexp(sub, exp)
del idx_lo, idx_hi, blk, exp

return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)


class Mxfp4GptOssExperts(nn.Cell):
def __init__(self, config):
super().__init__()

self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size

self.gate_up_proj_blocks = mindspore.Parameter(
mindspore.mint.zeros((self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16), dtype=mindspore.uint8),
requires_grad=False,
)
self.gate_up_proj_scales = mindspore.Parameter(
mindspore.mint.zeros((self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32), dtype=mindspore.uint8),
requires_grad=False,
)
# self.gate_up_proj_bias = mindspore.Parameter(
# mindspore.mint.zeros((self.num_experts, 2 * self.intermediate_size), dtype=mindspore.float32), requires_grad=False
# )

self.down_proj_blocks = mindspore.Parameter(
mindspore.mint.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=mindspore.uint8),
requires_grad=False,
)
self.down_proj_scales = mindspore.Parameter(
mindspore.mint.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32), dtype=mindspore.uint8),
requires_grad=False,
)
# self.down_proj_bias = mindspore.Parameter(
# mindspore.mint.zeros((self.num_experts, self.hidden_size), dtype=mindspore.float32), requires_grad=False
# )
self.alpha = 1.702

self.gate_up_proj_precision_config = None
self.down_proj_precision_config = None

self.is_dequantized = False

def construct(self, hidden_states: mindspore.Tensor, routing_data, gather_idx, scatter_idx) -> mindspore.Tensor:
raise NotImplementedError

# FIXME: Temporary for support MXFP4 inference
def dequantize(self):

if self.is_dequantized:
return

for proj in ["gate_up_proj", "down_proj"]:
blocks_attr = f"{proj}_blocks"
scales_attr = f"{proj}_scales"
dequantized = convert_moe_packed_tensors(getattr(self, blocks_attr), getattr(self, scales_attr))
dequantized = dequantized.transpose(1, 2).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The .contiguous() call is likely a remnant from PyTorch and is often not necessary in MindSpore. The transpose operation in MindSpore returns a new contiguous tensor. Removing this call would make the code cleaner and avoid potential confusion.

Suggested change
dequantized = dequantized.transpose(1, 2).contiguous()
dequantized = dequantized.transpose(1, 2)

setattr(self, proj, mindspore.Parameter(dequantized))
delattr(self, blocks_attr)
delattr(self, scales_attr)

self.is_dequantized = True
51 changes: 51 additions & 0 deletions mindone/transformers/modeling_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from mindspore import nn

from .utils import logging


logger = logging.get_logger(__name__)


class GradientCheckpointingLayer(nn.Cell):
"""Base class for layers with gradient checkpointing.

This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
(`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.

Important:

When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.

Example:

```python
>>> # Correct - hidden_states passed as positional arg
>>> out = self.layer(hidden_states, attention_mask=attention_mask)

>>> # Incorrect - hidden_states passed as keyword arg
>>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
```
"""

gradient_checkpointing = False

def __call__(self, *args, **kwargs):
if self.gradient_checkpointing and self.training:
raise NotImplementedError
return super().__call__(*args, **kwargs)
26 changes: 18 additions & 8 deletions mindone/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import mindspore as ms
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.nn import CrossEntropyLoss, Identity
from mindspore.nn.utils import no_init_parameters

from .activations import get_activation
from .generation.utils import GenerationMixin
Expand Down Expand Up @@ -94,9 +95,7 @@ def _get_pt2ms_mappings(m):
mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func)
for name, cell in m.cells_and_names():
if isinstance(cell, (nn.Conv1d, nn.Conv1dTranspose)):
mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ms.Parameter(
ops.expand_dims(x, axis=-2), name=x.name
)
mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ops.expand_dims(x, axis=-2)
if "weight_norm_cell" in name:
ori_name = name.replace(".weight_norm_cell", "")
mappings[f"{ori_name}.weight_g"] = f"{ori_name}.weight_g", lambda x: ms.Parameter(
Expand Down Expand Up @@ -349,9 +348,9 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar
local_state = {v.name: v for k, v in model_to_load.parameters_and_names()}
for k, v in state_dict.items():
if k in local_state:
v.set_dtype(local_state[k].dtype)
state_dict[k] = ms.Parameter(v.to(local_state[k].dtype), name=k)
else:
pass # unexpect key keeps origin dtype
state_dict[k] = ms.Parameter(v, name=k) # unexpect key keeps origin dtype
cm = silence_mindspore_logger() if is_sharded else nullcontext()
with cm:
ms.load_param_into_net(model_to_load, state_dict, strict_load=True)
Expand Down Expand Up @@ -383,7 +382,12 @@ def _get_name(self):

def to(self, dtype: Optional[ms.Type] = None):
for p in self.get_parameters():
p.set_dtype(dtype)
if p.dtype in (ms.uint8,):
logger.warning(
f"Do not convert to {dtype_to_str(dtype)}, param: {p.name}"
)
else:
p.set_dtype(dtype)
return self

def float(self):
Expand Down Expand Up @@ -974,7 +978,8 @@ def _from_config(cls, config, **kwargs):
mindspore_dtype=mindspore_dtype,
)

model = cls(config, **kwargs)
with no_init_parameters():
model = cls(config, **kwargs)

# We cannot set default mindspore dtype. So we need to cast model weights after creating.
if mindspore_dtype is not None:
Expand All @@ -984,6 +989,8 @@ def _from_config(cls, config, **kwargs):
f"convert model:{model.__class__.__name__} parameters to mindspore_dtype {dtype_to_str(mindspore_dtype)}"
)

model.init_parameters_data()

return model

def get_input_embeddings(self) -> nn.Cell:
Expand Down Expand Up @@ -2344,7 +2351,8 @@ def from_pretrained(
config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype
)

model = cls(config, *model_args, **model_kwargs)
with no_init_parameters():
model = cls(config, *model_args, **model_kwargs)

# Make sure to tie the weights correctly
model.tie_weights()
Expand Down Expand Up @@ -2400,6 +2408,8 @@ def from_pretrained(
# make sure token embedding weights are still tied if needed
model.tie_weights()

model.init_parameters_data()

# Set model in evaluation mode to deactivate DropOut modules by default
model.set_train(False)

Expand Down
7 changes: 7 additions & 0 deletions mindone/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
import transformers
from packaging import version

if version.parse(transformers.__version__) >= version.parse("4.55.0"):
from ..utils import LossKwargs
transformers.utils.LossKwargs = LossKwargs

from . import (
albert,
aria,
Expand Down Expand Up @@ -98,3 +102,6 @@

if version.parse(transformers.__version__) >= version.parse("4.53.0"):
from . import glm4v, minimax, vjepa2

if version.parse(transformers.__version__) >= version.parse("4.55.0"):
from . import gpt_oss
1 change: 1 addition & 0 deletions mindone/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"subfolder",
"use_auth_token",
"token",
"mindspore_dtype",
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
code_revision = kwargs.pop("code_revision", None)
Expand Down
4 changes: 4 additions & 0 deletions mindone/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@
CONFIG_MAPPING_NAMES.update({"minimax": "MiniMaxConfig", "vjepa2": "VJEPA2Model"})
MODEL_NAMES_MAPPING.update({"minimax": "MiniMax", "vjepa2": "VJEPA2Model"})

if version.parse(transformers.__version__) >= version.parse("4.55.0"):
CONFIG_MAPPING_NAMES.update({"gpt_oss": "GptOssConfig"})
MODEL_NAMES_MAPPING.update({"gpt_oss": "GptOss"})


def model_type_to_module_name(key):
"""Converts a config key to the corresponding module."""
Expand Down
Loading
Loading