-
Notifications
You must be signed in to change notification settings - Fork 89
add gpt-oss #1209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zhanghuiyao
wants to merge
1
commit into
mindspore-lab:master
Choose a base branch
from
zhanghuiyao:gpt_oss
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
add gpt-oss #1209
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import ast | ||
| import argparse | ||
| import os | ||
|
|
||
| from transformers import AutoTokenizer | ||
|
|
||
| import mindspore | ||
|
|
||
| from mindone.transformers import pipeline, AutoModelForCausalLM | ||
|
|
||
|
|
||
| def generate(args): | ||
|
|
||
| pipe = pipeline("text-generation", model=args.model_name, mindspore_dtype=mindspore.bfloat16, model_kwargs={'attn_implementation': 'eager'}) | ||
|
|
||
| messages = [ | ||
| {"role": "user", "content": args.prompt}, | ||
| ] | ||
|
|
||
| outputs = pipe( | ||
| messages, | ||
| max_new_tokens=256, | ||
| ) | ||
| print(outputs[0]["generated_text"][-1]) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="qwen3 demo.") | ||
| parser.add_argument("--prompt", type=str, default="Explain quantum mechanics clearly and concisely.") | ||
| parser.add_argument("--model_name", type=str, default="openai/gpt-oss-20b", help="Path to the pre-trained model.") | ||
| parser.add_argument("--enable_synchronize", type=ast.literal_eval, default=True) | ||
| parser.add_argument("--enable_offload", type=ast.literal_eval, default=False) | ||
| parser.add_argument("--dry_run", type=str, default="") | ||
| args = parser.parse_args() | ||
|
|
||
|
|
||
| if args.dry_run: | ||
| os.environ["MS_SIMULATION_LEVEL"] = args.dry_run | ||
| if int(args.dry_run) >= 1: | ||
| os.environ["GLOG_v"] = "1" | ||
| if int(args.dry_run) >= 2: | ||
| os.environ["MS_ALLOC_CONF"] = "memory_tracker:True" | ||
| if args.enable_synchronize: | ||
| mindspore.launch_blocking() | ||
| if args.enable_offload: | ||
| mindspore.set_offload_context(offload_config={"offload_param":"cpu", "offload_path": "./offload", "offload_cpu_size":"512GB", "hbm_ratio":0.9}) | ||
|
|
||
| generate(args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,146 @@ | ||||||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| import re | ||||||
| import math | ||||||
|
|
||||||
| import mindspore | ||||||
| from mindspore import nn | ||||||
|
|
||||||
| from ..utils import logging | ||||||
|
|
||||||
|
|
||||||
| logger = logging.get_logger(__name__) | ||||||
|
|
||||||
| FP4_VALUES = [ | ||||||
| +0.0, | ||||||
| +0.5, | ||||||
| +1.0, | ||||||
| +1.5, | ||||||
| +2.0, | ||||||
| +3.0, | ||||||
| +4.0, | ||||||
| +6.0, | ||||||
| -0.0, | ||||||
| -0.5, | ||||||
| -1.0, | ||||||
| -1.5, | ||||||
| -2.0, | ||||||
| -3.0, | ||||||
| -4.0, | ||||||
| -6.0, | ||||||
| ] | ||||||
|
|
||||||
|
|
||||||
| # Copied from https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/weights.py#L68 | ||||||
| def convert_moe_packed_tensors( | ||||||
| blocks, | ||||||
| scales, | ||||||
| *, | ||||||
| dtype: mindspore.Type = mindspore.bfloat16, | ||||||
| rows_per_chunk: int = 32768 * 1024, | ||||||
| ) -> mindspore.Tensor: | ||||||
| scales = scales.to(mindspore.int32) - 127 | ||||||
|
|
||||||
| assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" | ||||||
|
|
||||||
| lut = mindspore.tensor(FP4_VALUES, dtype=dtype) | ||||||
|
|
||||||
| *prefix_shape, G, B = blocks.shape | ||||||
| rows_total = math.prod(prefix_shape) * G | ||||||
|
|
||||||
| blocks = blocks.reshape(rows_total, B) | ||||||
| scales = scales.reshape(rows_total, 1) | ||||||
|
|
||||||
| out = mindspore.mint.empty(rows_total, B * 2, dtype=dtype) | ||||||
|
|
||||||
| for r0 in range(0, rows_total, rows_per_chunk): | ||||||
| r1 = min(r0 + rows_per_chunk, rows_total) | ||||||
|
|
||||||
| blk = blocks[r0:r1] | ||||||
| exp = scales[r0:r1] | ||||||
|
|
||||||
| # nibble indices -> int64 | ||||||
| # idx_lo = (blk & 0x0F).to(mindspore.int64) | ||||||
| # idx_hi = (blk >> 4).to(mindspore.int64) | ||||||
| idx_lo = mindspore.tensor(blk.numpy() & 0x0F, mindspore.int64) | ||||||
| idx_hi = mindspore.tensor(blk.numpy() >> 4, mindspore.int64) | ||||||
|
|
||||||
| sub = out[r0:r1] | ||||||
| sub[:, 0::2] = lut[idx_lo] | ||||||
| sub[:, 1::2] = lut[idx_hi] | ||||||
|
|
||||||
| out[r0:r1] = mindspore.ops.ldexp(sub, exp) | ||||||
| del idx_lo, idx_hi, blk, exp | ||||||
|
|
||||||
| return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) | ||||||
|
|
||||||
|
|
||||||
| class Mxfp4GptOssExperts(nn.Cell): | ||||||
| def __init__(self, config): | ||||||
| super().__init__() | ||||||
|
|
||||||
| self.num_experts = config.num_local_experts | ||||||
| self.intermediate_size = config.intermediate_size | ||||||
| self.hidden_size = config.hidden_size | ||||||
|
|
||||||
| self.gate_up_proj_blocks = mindspore.Parameter( | ||||||
| mindspore.mint.zeros((self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16), dtype=mindspore.uint8), | ||||||
| requires_grad=False, | ||||||
| ) | ||||||
| self.gate_up_proj_scales = mindspore.Parameter( | ||||||
| mindspore.mint.zeros((self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32), dtype=mindspore.uint8), | ||||||
| requires_grad=False, | ||||||
| ) | ||||||
| # self.gate_up_proj_bias = mindspore.Parameter( | ||||||
| # mindspore.mint.zeros((self.num_experts, 2 * self.intermediate_size), dtype=mindspore.float32), requires_grad=False | ||||||
| # ) | ||||||
|
|
||||||
| self.down_proj_blocks = mindspore.Parameter( | ||||||
| mindspore.mint.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=mindspore.uint8), | ||||||
| requires_grad=False, | ||||||
| ) | ||||||
| self.down_proj_scales = mindspore.Parameter( | ||||||
| mindspore.mint.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32), dtype=mindspore.uint8), | ||||||
| requires_grad=False, | ||||||
| ) | ||||||
| # self.down_proj_bias = mindspore.Parameter( | ||||||
| # mindspore.mint.zeros((self.num_experts, self.hidden_size), dtype=mindspore.float32), requires_grad=False | ||||||
| # ) | ||||||
| self.alpha = 1.702 | ||||||
|
|
||||||
| self.gate_up_proj_precision_config = None | ||||||
| self.down_proj_precision_config = None | ||||||
|
|
||||||
| self.is_dequantized = False | ||||||
|
|
||||||
| def construct(self, hidden_states: mindspore.Tensor, routing_data, gather_idx, scatter_idx) -> mindspore.Tensor: | ||||||
| raise NotImplementedError | ||||||
|
|
||||||
| # FIXME: Temporary for support MXFP4 inference | ||||||
| def dequantize(self): | ||||||
|
|
||||||
| if self.is_dequantized: | ||||||
| return | ||||||
|
|
||||||
| for proj in ["gate_up_proj", "down_proj"]: | ||||||
| blocks_attr = f"{proj}_blocks" | ||||||
| scales_attr = f"{proj}_scales" | ||||||
| dequantized = convert_moe_packed_tensors(getattr(self, blocks_attr), getattr(self, scales_attr)) | ||||||
| dequantized = dequantized.transpose(1, 2).contiguous() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| setattr(self, proj, mindspore.Parameter(dequantized)) | ||||||
| delattr(self, blocks_attr) | ||||||
| delattr(self, scales_attr) | ||||||
|
|
||||||
| self.is_dequantized = True | ||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from mindspore import nn | ||
|
|
||
| from .utils import logging | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class GradientCheckpointingLayer(nn.Cell): | ||
| """Base class for layers with gradient checkpointing. | ||
|
|
||
| This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled | ||
| (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is | ||
| enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`. | ||
|
|
||
| Important: | ||
|
|
||
| When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states) | ||
| must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients. | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| >>> # Correct - hidden_states passed as positional arg | ||
| >>> out = self.layer(hidden_states, attention_mask=attention_mask) | ||
|
|
||
| >>> # Incorrect - hidden_states passed as keyword arg | ||
| >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask) | ||
| ``` | ||
| """ | ||
|
|
||
| gradient_checkpointing = False | ||
|
|
||
| def __call__(self, *args, **kwargs): | ||
| if self.gradient_checkpointing and self.training: | ||
| raise NotImplementedError | ||
| return super().__call__(*args, **kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
.numpy()for bitwise operations can be inefficient as it involves data transfer between the device (e.g., GPU/NPU) and CPU, and then back to a tensor. For better performance, you should use MindSpore's device-side bitwise operations.