-
Notifications
You must be signed in to change notification settings - Fork 222
Add wan2.1 functionality support for Ascend NPU platform #810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
add: npu platform
| @@ -0,0 +1,74 @@ | |||
| # SPDX-License-Identifier: Apache-2.0 | |||
| # Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/cuda_communicator.py | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this. DeviceCommunicatorBase is also defined here, and your code is based on the NPU implementation of this class.
| @@ -0,0 +1,165 @@ | |||
| # | |||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete these comments.
| @@ -0,0 +1,250 @@ | |||
| # | |||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete
| else: | ||
| backend = "nccl" | ||
| logger.info("Using nccl backend for CUDA platform") | ||
| # if backend == "nccl" and not current_platform.is_cuda_alike(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused code comments. This will not be repeated hereafter. Please conduct a comprehensive check.
| # Use gloo backend for non-CUDA platforms (MPS, CPU) | ||
| backend = "gloo" | ||
| logger.info("Using gloo backend for %s platform", | ||
| if backend == "nccl" or backend == "hccl": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is backend assigned as "hccl"? No possible assignment has been found.
| if current_platform.is_cuda_alike(): | ||
| device = torch.device(f"cuda:{local_rank}") | ||
| torch.cuda.set_device(device) | ||
| if current_platform.is_npu(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be elif
| device = torch.device(f"cuda:{local_rank}") | ||
| torch.cuda.set_device(device) | ||
| if current_platform.is_npu(): | ||
| device = torch.device(f"npu:{local_rank}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is duplicate code. Try another branching approach with less code duplication.
fastvideo/platforms/npu.py
Outdated
| def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, | ||
| head_size: int, dtype: torch.dtype) -> str: | ||
| # the NPU only supports Flash Attention | ||
| # TODO(will): Other tasks will be synchronized in subsequent updates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove todo
| @classmethod | ||
| def is_pin_memory_available(cls): | ||
| return True | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Standardize the number of blank lines.
| from fastvideo.training.training_pipeline import TrainingPipeline | ||
| from fastvideo.utils import is_vsa_available | ||
| from fastvideo.platforms import current_platform | ||
| if current_platform.is_npu(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use this. It's not advisable to replace APIs without careful consideration. We need to analyze the adaptation points one by one and replace them individually.
SolitaryThinker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for your contribution! I've left some comments. Please let me know when this PR is ready for CI tests. Meanwhile you can install/run our pre-commit linters using the following commands:
# Linting, formatting and static type checking
pre-commit install --hook-type pre-commit --hook-type commit-msg
# You can manually run pre-commit with
pre-commit run --all-files
.gitignore
Outdated
| preprocess_output_text/ | ||
| ======= | ||
| log/ | ||
| >>>>>>> 1a6592a4 (add: npu platform) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix please
| "hcclComm_t", | ||
| "aclrtStream_t", | ||
| "buffer_type", | ||
| ] No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add newline character to end of last line please
| torch.npu.reset_peak_memory_stats() | ||
|
|
||
| @classmethod | ||
| def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does ascend NPU support all these attention backends? It would be good to remove any that is not supported yet and fallback to torch sdpa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, currently only SDPA is supported, and modifications have been made.
|
|
||
| def all_reduce(self, input_, op: torch.distributed.ReduceOp | None = None): | ||
| pyhccl_comm = self.pyhccl_comm | ||
| assert pyhccl_comm is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add assert failed message
Thank you for your review and feedback! I'll address all the comments promptly and let you know once the PR is ready for CI tests. |
| try: | ||
| self.hccl = HCCLLibrary(library_path) | ||
| except Exception: | ||
| # disable because of missing HCCL library |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, the error message has been added
| stream = current_stream() | ||
| if src == self.rank: | ||
| buffer = buffer_type(tensor.data_ptr()) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same code in if and else
| collate_fn=passthrough, | ||
| num_workers=num_data_workers, | ||
| pin_memory=True, | ||
| pin_memory_device = current_platform.device_name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no extra spaces on either side of the equals sign
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Linting, formatting and static type checking
pre-commit install --hook-type pre-commit --hook-type commit-msg
# You can manually run pre-commit with
pre-commit run --all-files
| try: | ||
| self.hccl = HCCLLibrary(library_path) | ||
| except Exception: | ||
| print("disable hccl because of missing HCCL library") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logger.error or warning instead
| logger.info("NPU is available") | ||
| else: | ||
| logger.info("NPU is not available") | ||
| except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use detailed exception if possible
fastvideo/platforms/interface.py
Outdated
| raise NotImplementedError | ||
|
|
||
| @classmethod | ||
| def get_torch_device(cls) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we replace this "Any"? If we can't, just not using typing for this function
|
Hi, please run |
) Co-authored-by: kiritorl <[email protected]>
This PR adds Ascend version support for the wan2.1 functionality, with the following key implementations:
NPU Platform Integration:
platforms/directoryCommunicator Enhancement:
End-to-End Functionality:
This implementation allows wan2.1 to run natively on Ascend NPUs