diff --git a/README-zh.md b/README-zh.md index b50d541..fe4ce52 100644 --- a/README-zh.md +++ b/README-zh.md @@ -84,13 +84,50 @@ DataFlex 与 LLaMA-Factory 无缝集成,为研究人员和开发者提供更 请使用以下命令进行环境配置与安装👇 +使用 `uv`: + +```bash +git clone https://github.com/OpenDCAI/DataFlex.git +cd DataFlex +uv pip install -r requirements.txt +uv pip install -e .[torch] +uv pip install llamafactory==0.9.3 +``` + +使用 `pip`: + ```bash git clone https://github.com/OpenDCAI/DataFlex.git cd DataFlex -pip install -e . +pip install -r requirements.txt +pip install -e .[torch] pip install llamafactory==0.9.3 ``` +对于首次使用的用户,建议先完成上面的基础安装。`LESS` 和 `NICE` 依赖可选包 `traker`,安装时可能会下载较大的 PyTorch/CUDA wheel。 + +使用 `uv`: + +```bash +uv pip install -e .[less] +# 或 +uv pip install -e .[nice] +``` + +使用 `pip`: + +```bash +pip install -e .[less] +# 或 +pip install -e .[nice] +``` + +如果网络不稳定,可以适当提高超时时间: + +```bash +UV_HTTP_TIMEOUT=300 uv pip install -e .[less] +``` + 启动命令与 [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) 类似。 下面给出一个使用 [LESS](https://arxiv.org/abs/2402.04333) 的示例: diff --git a/README.md b/README.md index a3613a1..b5a0c22 100644 --- a/README.md +++ b/README.md @@ -86,13 +86,50 @@ We summarize repositories related to Data Selection, Data Mixture, and Data Rewe Please use the following commands for environment setup and installation👇 +Using `uv`: + +```bash +git clone https://github.com/OpenDCAI/DataFlex.git +cd DataFlex +uv pip install -r requirements.txt +uv pip install -e .[torch] +uv pip install llamafactory==0.9.3 +``` + +Using `pip`: + ```bash git clone https://github.com/OpenDCAI/DataFlex.git cd DataFlex -pip install -e . +pip install -r requirements.txt +pip install -e .[torch] pip install llamafactory==0.9.3 ``` +For first-time users, start with the base install above. `LESS` and `NICE` require the optional `traker` dependency, which may download large PyTorch/CUDA wheels. + +Using `uv`: + +```bash +uv pip install -e .[less] +# or +uv pip install -e .[nice] +``` + +Using `pip`: + +```bash +pip install -e .[less] +# or +pip install -e .[nice] +``` + +If your network is unstable, increase the timeout: + +```bash +UV_HTTP_TIMEOUT=300 uv pip install -e .[less] +``` + The launch command is similar to [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory). Below is an example using [LESS](https://arxiv.org/abs/2402.04333) : diff --git a/pyproject.toml b/pyproject.toml index 80d1a70..10fe657 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,9 @@ minicpm_v = [ modelscope = ["modelscope"] openmind = ["openmind"] swanlab = ["swanlab"] +traker = ["traker"] +less = ["traker"] +nice = ["traker"] dev = ["pre-commit", "ruff", "pytest", "build"] [tool.setuptools] diff --git a/requirements.txt b/requirements.txt index a284568..4630c4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,5 +25,4 @@ av librosa tyro<0.9.0 omegaconf -traker #faiss diff --git a/src/dataflex/train/selector/less_selector.py b/src/dataflex/train/selector/less_selector.py index 9fd9890..d0808d4 100644 --- a/src/dataflex/train/selector/less_selector.py +++ b/src/dataflex/train/selector/less_selector.py @@ -7,7 +7,6 @@ import torch.distributed as dist from tqdm import tqdm from torch.utils.data import DataLoader, Dataset -from trak.projectors import BasicProjector, CudaProjector, ProjectionType import json import os import glob # 用于文件查找 @@ -142,6 +141,14 @@ def _obtain_gradients(self, model, batch, gradient_type, m: Optional[torch.Tenso def _get_trak_projector(self): """获取 TRAK projector,优先使用 CUDA 版本。""" + try: + from trak.projectors import BasicProjector, CudaProjector + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "LESS selector requires the optional dependency `traker`. " + "Install it with `uv pip install -e .[less]` or `uv pip install traker`." + ) from exc + try: import fast_jl num_sms = torch.cuda.get_device_properties(self.device.index).multi_processor_count @@ -180,6 +187,14 @@ def _collect_and_save_projected_gradients(self, model, save_dir, dataset_to_use, """ 核心函数:每个进程独立计算梯度、投影,并保存带有索引的分块文件。 """ + try: + from trak.projectors import ProjectionType + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "LESS selector requires the optional dependency `traker`. " + "Install it with `uv pip install -e .[less]` or `uv pip install traker`." + ) from exc + # 1) 初始化 Projector (每个进程都需要一个) num_params = self._get_number_of_params(model) projector_class = self._get_trak_projector() @@ -386,4 +401,4 @@ def select(self, model, step_id: int, num_samples: int, **kwargs) -> List[int]: dist.broadcast_object_list(obj_list, src=0) selected_indices = obj_list[0] - return selected_indices \ No newline at end of file + return selected_indices diff --git a/src/dataflex/train/selector/nice_selector.py b/src/dataflex/train/selector/nice_selector.py index 0f7b9f9..4975c51 100644 --- a/src/dataflex/train/selector/nice_selector.py +++ b/src/dataflex/train/selector/nice_selector.py @@ -7,7 +7,6 @@ import torch.distributed as dist from tqdm import tqdm from torch.utils.data import DataLoader, Dataset -from trak.projectors import BasicProjector, CudaProjector, ProjectionType import os import glob @@ -225,6 +224,14 @@ def _obtain_gradients(self, model, batch, gradient_type: str, *, m: Optional[tor def _get_trak_projector(self): """获取 TRAK projector,优先使用 CUDA 版本。""" + try: + from trak.projectors import BasicProjector, CudaProjector + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "NICE selector requires the optional dependency `traker`. " + "Install it with `uv pip install -e .[nice]` or `uv pip install traker`." + ) from exc + try: import fast_jl num_sms = torch.cuda.get_device_properties(self.device.index).multi_processor_count @@ -425,6 +432,14 @@ def _collect_and_save_projected_gradients(self, rl_mode: bool = False, optimizer=None): """统一采集梯度、执行投影并保存,rl_mode 控制是否启用蒙特卡洛采样。""" + try: + from trak.projectors import ProjectionType + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "NICE selector requires the optional dependency `traker`. " + "Install it with `uv pip install -e .[nice]` or `uv pip install traker`." + ) from exc + # 1) 初始化 Projector (每个进程都需要一个) num_params = self._get_number_of_params(model) projector_class = self._get_trak_projector() @@ -655,4 +670,4 @@ def select(self, model, step_id: int, num_samples: int, **kwargs) -> List[int]: dist.broadcast_object_list(obj_list, src=0) selected_indices = obj_list[0] - return selected_indices \ No newline at end of file + return selected_indices