Skip to content

Commit fc764f3

Browse files
committed
fix: clean up unnecessary changes
1 parent 8c79f57 commit fc764f3

File tree

8 files changed

+12
-81
lines changed

8 files changed

+12
-81
lines changed

fastvideo/distributed/device_communicators/pyhccl.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,6 @@
2929

3030
logger = init_logger(__name__)
3131

32-
# from vllm.distributed.utils import StatelessProcessGroup
33-
# from vllm.logger import logger
34-
35-
# from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
36-
# HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
37-
# hcclRedOpTypeEnum, hcclUniqueId)
38-
# from vllm_ascend.utils import current_stream
39-
40-
4132
class PyHcclCommunicator:
4233

4334
def __init__(
@@ -173,4 +164,4 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
173164
buffer = buffer_type(tensor.data_ptr())
174165
self.hccl.hcclBroadcast(buffer, tensor.numel(),
175166
hcclDataTypeEnum.from_torch(tensor.dtype), src,
176-
self.comm, aclrtStream_t(stream.npu_stream))
167+
self.comm, aclrtStream_t(stream.npu_stream))

fastvideo/distributed/device_communicators/pyhccl_wrapper.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
from fastvideo.utils import find_hccl_library
2727

2828
logger = init_logger(__name__)
29-
# from vllm.logger import logger
30-
31-
# from vllm_ascend.utils import find_hccl_library
3229

3330
# export types and functions from hccl to Python ===
3431
# for the original hccl definition, please check
@@ -133,10 +130,6 @@ class HCCLLibrary:
133130
ctypes.POINTER(hcclComm_t),
134131
]),
135132

136-
# HcclResult HcclAllReduce(
137-
# void *sendBuf, void *recvBuf, uint64_t count,
138-
# HcclDataType dataType, HcclReduceOp op, HcclComm comm,
139-
# aclrtStream stream);
140133
Function("HcclAllReduce", hcclResult_t, [
141134
buffer_type,
142135
buffer_type,

fastvideo/models/loader/component_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def load_model(self,
273273

274274
# Explicitly move model to target device after loading weights
275275
model = model.to(target_device)
276+
276277
if use_cpu_offload:
277278
# Disable FSDP for MPS as it's not compatible
278279
if current_platform.is_mps():

fastvideo/models/loader/fsdp_load.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def maybe_load_fsdp_model(
104104
if not training_mode and not fsdp_inference:
105105
hsdp_replicate_dim = world_size
106106
hsdp_shard_dim = 1
107+
107108
if current_platform.is_npu():
108109
with torch.device("cpu"):
109110
device_mesh = init_device_mesh(
@@ -118,7 +119,7 @@ def maybe_load_fsdp_model(
118119
# (Replicate(), Shard(dim=0))
119120
mesh_shape=(hsdp_replicate_dim, hsdp_shard_dim),
120121
mesh_dim_names=("replicate", "shard"),
121-
)
122+
)
122123
shard_model(model,
123124
cpu_offload=cpu_offload,
124125
reshard_after_forward=True,

fastvideo/pipelines/composed_pipeline_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self,
6565
if self._required_config_modules is None:
6666
raise NotImplementedError(
6767
"Subclass must set _required_config_modules")
68+
6869
maybe_init_distributed_environment_and_model_parallel(
6970
fastvideo_args.tp_size, fastvideo_args.sp_size)
7071

@@ -151,6 +152,7 @@ def from_pretrained(cls,
151152
assert fastvideo_args.pipeline_config.dit_precision == 'fp32', 'only fp32 is supported for training'
152153

153154
logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)
155+
154156
pipe = cls(model_path,
155157
fastvideo_args,
156158
required_config_modules=required_config_modules,

fastvideo/platforms/npu.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,11 @@
1-
#
2-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3-
#
4-
# Licensed under the Apache License, Version 2.0 (the "License");
5-
# you may not use this file except in compliance with the License.
6-
# You may obtain a copy of the License at
7-
#
8-
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
10-
# Unless required by applicable law or agreed to in writing, software
11-
# distributed under the License is distributed on an "AS IS" BASIS,
12-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# See the License for the specific language governing permissions and
14-
# limitations under the License.
15-
# This file is a part of the vllm-ascend project.
16-
#
17-
181
import gc
192
import os
203
from datetime import timedelta
214
from typing import TYPE_CHECKING, Optional, Tuple
225

236
import torch
24-
# import vllm.envs as envs
257
from torch.distributed import ProcessGroup
268
from torch.distributed.distributed_c10d import PrefixStore
27-
# from vllm.logger import logger
28-
# from vllm.platforms import Platform, PlatformEnum
299

3010
import os
3111
from collections.abc import Callable
@@ -42,21 +22,6 @@
4222
PlatformEnum)
4323
from fastvideo.utils import import_pynvml
4424

45-
# import vllm_ascend.envs as envs_ascend
46-
# from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
47-
# from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD,
48-
# check_torchair_cache_exist,
49-
# delete_torchair_cache_file,
50-
# update_aclgraph_sizes)
51-
52-
# if TYPE_CHECKING:
53-
# from vllm.config import ModelConfig, VllmConfig
54-
# from vllm.utils import FlexibleArgumentParser
55-
# else:
56-
# ModelConfig = None
57-
# VllmConfig = None
58-
# FlexibleArgumentParser = None
59-
6025
logger = init_logger(__name__)
6126

6227
class NPUPlatform(Platform):
@@ -115,8 +80,8 @@ def clear_npu_memory(cls):
11580
@classmethod
11681
def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None,
11782
head_size: int, dtype: torch.dtype) -> str:
118-
# TODO(will): maybe come up with a more general interface for local attention
119-
# if distributed is False, we always try to use Flash attn
83+
# the NPU only supports Flash Attention
84+
# TODO(will): Other tasks will be synchronized in subsequent updates.
12085

12186
logger.info("Trying FASTVIDEO_ATTENTION_BACKEND=%s",
12287
envs.FASTVIDEO_ATTENTION_BACKEND)
@@ -216,9 +181,6 @@ def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None,
216181

217182
return "fastvideo.attention.backends.flash_attn.FlashAttentionBackend"
218183

219-
@classmethod
220-
def get_punica_wrapper(cls) -> str:
221-
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
222184

223185
@classmethod
224186
def get_current_memory_usage(cls,
@@ -235,19 +197,6 @@ def get_device_communicator_cls(cls) -> str:
235197
def is_pin_memory_available(cls):
236198
return True
237199

238-
# @classmethod
239-
# def supports_v1(cls, model_config: ModelConfig) -> bool:
240-
# """Returns whether the current platform can support v1 for the supplied
241-
# model configuration.
242-
# """
243-
# return True
244-
245-
# @classmethod
246-
# def get_piecewise_backend_cls(cls) -> str:
247-
# """
248-
# Get piecewise backend class for piecewise graph.
249-
# """
250-
# return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
251200

252201
@classmethod
253202
def stateless_init_device_torch_dist_pg(
@@ -276,12 +225,8 @@ def stateless_init_device_torch_dist_pg(
276225
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
277226
backend_options)
278227
device = torch.device("npu")
279-
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
280-
# implemented in the 2.5.1 version of PyTorch. But we need to set it
281-
# after the latest version is released.
282-
# pg._set_default_backend(backend_type)
283228
backend_class._set_sequence_number_for_group()
284229
backend_type = ProcessGroup.BackendType.CUSTOM
285230

286231
pg._register_backend(device, backend_type, backend_class)
287-
return pg
232+
return pg

fastvideo/training/training_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def __init__(
7676
raise ValueError("lora rank must be set when using lora training")
7777

7878
set_random_seed(fastvideo_args.seed) # for lora param init
79-
breakpoint()
8079
super().__init__(model_path, fastvideo_args, required_config_modules,
8180
loaded_modules) # type: ignore
8281

@@ -396,7 +395,6 @@ def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch:
396395
num_latent_t=self.training_args.num_latent_t)
397396

398397
training_batch = self._build_attention_metadata(training_batch)
399-
400398
training_batch = self._build_input_kwargs(training_batch)
401399
training_batch = self._transformer_forward_and_compute_loss(
402400
training_batch)
@@ -488,6 +486,7 @@ def train(self) -> None:
488486
training_batch.current_timestep = step
489487
training_batch.current_vsa_sparsity = current_vsa_sparsity
490488
training_batch = self.train_one_step(training_batch)
489+
491490
loss = training_batch.total_loss
492491
grad_norm = training_batch.grad_norm
493492

@@ -528,6 +527,7 @@ def train(self) -> None:
528527
logger.info(
529528
"GPU memory usage after validation: %s MB, trainable params: %sB",
530529
gpu_memory_usage, trainable_params)
530+
531531
wandb.finish()
532532
save_checkpoint(self.transformer, self.global_rank,
533533
self.training_args.output_dir,

fastvideo/training/wan_training_pipeline.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
if current_platform.is_npu():
1414
import torch_npu
1515
from torch_npu.contrib import transfer_to_npu
16-
from msprobe.pytorch import PrecisionDebugger, seed_all
1716

1817
vsa_available = is_vsa_available()
1918

@@ -59,7 +58,7 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs):
5958

6059
def main(args) -> None:
6160
logger.info("Starting training pipeline...")
62-
breakpoint()
61+
6362
pipeline = WanTrainingPipeline.from_pretrained(
6463
args.pretrained_model_name_or_path, args=args)
6564
args = pipeline.training_args
@@ -69,7 +68,6 @@ def main(args) -> None:
6968

7069
if __name__ == "__main__":
7170
argv = sys.argv
72-
seed_all(seed=42, mode=True)
7371
from fastvideo.fastvideo_args import TrainingArgs
7472
from fastvideo.utils import FlexibleArgumentParser
7573
parser = FlexibleArgumentParser()

0 commit comments

Comments
 (0)