Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,50 @@ DataFlex 与 LLaMA-Factory 无缝集成,为研究人员和开发者提供更

请使用以下命令进行环境配置与安装👇

使用 `uv`:

```bash
git clone https://github.com/OpenDCAI/DataFlex.git
cd DataFlex
uv pip install -r requirements.txt
uv pip install -e .[torch]
uv pip install llamafactory==0.9.3
```

使用 `pip`:

```bash
git clone https://github.com/OpenDCAI/DataFlex.git
cd DataFlex
pip install -e .
pip install -r requirements.txt
pip install -e .[torch]
pip install llamafactory==0.9.3
```

对于首次使用的用户,建议先完成上面的基础安装。`LESS` 和 `NICE` 依赖可选包 `traker`,安装时可能会下载较大的 PyTorch/CUDA wheel。

使用 `uv`:

```bash
uv pip install -e .[less]
# 或
uv pip install -e .[nice]
```

使用 `pip`:

```bash
pip install -e .[less]
# 或
pip install -e .[nice]
```

如果网络不稳定,可以适当提高超时时间:

```bash
UV_HTTP_TIMEOUT=300 uv pip install -e .[less]
```

启动命令与 [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) 类似。
下面给出一个使用 [LESS](https://arxiv.org/abs/2402.04333) 的示例:

Expand Down
39 changes: 38 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,50 @@ We summarize repositories related to Data Selection, Data Mixture, and Data Rewe

Please use the following commands for environment setup and installation👇

Using `uv`:

```bash
git clone https://github.com/OpenDCAI/DataFlex.git
cd DataFlex
uv pip install -r requirements.txt
uv pip install -e .[torch]
uv pip install llamafactory==0.9.3
```

Using `pip`:

```bash
git clone https://github.com/OpenDCAI/DataFlex.git
cd DataFlex
pip install -e .
pip install -r requirements.txt
pip install -e .[torch]
pip install llamafactory==0.9.3
```

For first-time users, start with the base install above. `LESS` and `NICE` require the optional `traker` dependency, which may download large PyTorch/CUDA wheels.

Using `uv`:

```bash
uv pip install -e .[less]
# or
uv pip install -e .[nice]
```

Using `pip`:

```bash
pip install -e .[less]
# or
pip install -e .[nice]
```

If your network is unstable, increase the timeout:

```bash
UV_HTTP_TIMEOUT=300 uv pip install -e .[less]
```

The launch command is similar to [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory).
Below is an example using [LESS](https://arxiv.org/abs/2402.04333) :

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ minicpm_v = [
modelscope = ["modelscope"]
openmind = ["openmind"]
swanlab = ["swanlab"]
traker = ["traker"]
less = ["traker"]
nice = ["traker"]
dev = ["pre-commit", "ruff", "pytest", "build"]

[tool.setuptools]
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ av
librosa
tyro<0.9.0
omegaconf
traker
#faiss
19 changes: 17 additions & 2 deletions src/dataflex/train/selector/less_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.distributed as dist
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from trak.projectors import BasicProjector, CudaProjector, ProjectionType
import json
import os
import glob # 用于文件查找
Expand Down Expand Up @@ -142,6 +141,14 @@ def _obtain_gradients(self, model, batch, gradient_type, m: Optional[torch.Tenso

def _get_trak_projector(self):
"""获取 TRAK projector,优先使用 CUDA 版本。"""
try:
from trak.projectors import BasicProjector, CudaProjector
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"LESS selector requires the optional dependency `traker`. "
"Install it with `uv pip install -e .[less]` or `uv pip install traker`."
) from exc

try:
import fast_jl
num_sms = torch.cuda.get_device_properties(self.device.index).multi_processor_count
Expand Down Expand Up @@ -180,6 +187,14 @@ def _collect_and_save_projected_gradients(self, model, save_dir, dataset_to_use,
"""
核心函数:每个进程独立计算梯度、投影,并保存带有索引的分块文件。
"""
try:
from trak.projectors import ProjectionType
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"LESS selector requires the optional dependency `traker`. "
"Install it with `uv pip install -e .[less]` or `uv pip install traker`."
) from exc

# 1) 初始化 Projector (每个进程都需要一个)
num_params = self._get_number_of_params(model)
projector_class = self._get_trak_projector()
Expand Down Expand Up @@ -386,4 +401,4 @@ def select(self, model, step_id: int, num_samples: int, **kwargs) -> List[int]:
dist.broadcast_object_list(obj_list, src=0)
selected_indices = obj_list[0]

return selected_indices
return selected_indices
19 changes: 17 additions & 2 deletions src/dataflex/train/selector/nice_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.distributed as dist
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from trak.projectors import BasicProjector, CudaProjector, ProjectionType
import os
import glob

Expand Down Expand Up @@ -225,6 +224,14 @@ def _obtain_gradients(self, model, batch, gradient_type: str, *, m: Optional[tor

def _get_trak_projector(self):
"""获取 TRAK projector,优先使用 CUDA 版本。"""
try:
from trak.projectors import BasicProjector, CudaProjector
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"NICE selector requires the optional dependency `traker`. "
"Install it with `uv pip install -e .[nice]` or `uv pip install traker`."
) from exc

try:
import fast_jl
num_sms = torch.cuda.get_device_properties(self.device.index).multi_processor_count
Expand Down Expand Up @@ -425,6 +432,14 @@ def _collect_and_save_projected_gradients(self,
rl_mode: bool = False,
optimizer=None):
"""统一采集梯度、执行投影并保存,rl_mode 控制是否启用蒙特卡洛采样。"""
try:
from trak.projectors import ProjectionType
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"NICE selector requires the optional dependency `traker`. "
"Install it with `uv pip install -e .[nice]` or `uv pip install traker`."
) from exc

# 1) 初始化 Projector (每个进程都需要一个)
num_params = self._get_number_of_params(model)
projector_class = self._get_trak_projector()
Expand Down Expand Up @@ -655,4 +670,4 @@ def select(self, model, step_id: int, num_samples: int, **kwargs) -> List[int]:
dist.broadcast_object_list(obj_list, src=0)
selected_indices = obj_list[0]

return selected_indices
return selected_indices
Loading