Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pyproject.toml with build dependencies #958

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

dhellmann
Copy link

@dhellmann dhellmann commented May 17, 2024

Adding this metadata means it is possible to pip install flash-attn
without pre-installing the packages imported in setup.py. It is still
possible to follow the existing manual instructions, too, but by not
requiring pre-installation it also makes it easier for someone to
build from the source dist.

Potentially related to #875, #833, #453, #416, #253, #246, #188

dhellmann added 2 commits May 17, 2024 09:02
Adding this metadata means it is possible to `pip install flash-attn`
without pre-installing the packages imported in setup.py. It is still
possible to follow the existing manual instructions, too, but by not
requiring pre-installation it _also_ makes it easier for someone to
build from the source dist.

Signed-off-by: Doug Hellmann <[email protected]>
The build package provides a front-end to the standard APIs for
generating build artifacts like sdists and wheels from source
code. Using build means the pyproject.toml file is processed, which
means the build dependencies imported by setup.py can be installed
before setup.py itself is processed.

Signed-off-by: Doug Hellmann <[email protected]>
@Skylion007
Copy link

Unfortunately, this also prevent src builds from compiling against the right version of PyTorch (due to build-isolation) and there is no standard way to turn it off. The only way to do it is to use a pre-pyproject.toml package and rely on legacy support for setup.py.

@dhellmann
Copy link
Author

Unfortunately, this also prevent src builds from compiling against the right version of PyTorch (due to build-isolation) and there is no standard way to turn it off. The only way to do it is to use a pre-pyproject.toml package and rely on legacy support for setup.py.

This isn't correct. I still install torch ahead of time, and build without isolation using the pip wheel command and the --no-build-isolation option. setup.py is still run, and performs all of the validation it is doing now.

Ideally, when we have something like wheel selectors implemented, that will also help with this situation.

@technillogue
Copy link

please use something like this, your package is impossible to use otherwise. does anyone have a fork that declares build dependencies correctly?

@bionicles
Copy link

hi, could we take a fresh look at setup.py to get py311+ working?

image

I think the issue is python stdlib deleted "packaging"
luckily these are pretty replaceable: we could recapitulate parse with a one liner, and the Version thing isn't strictly necessary

def tuple_from_version(v: str) -> Tuple[int, ...]:
    return tuple(int(v_i) for v_i in v.split("."))

apache 2.0 on that, i'm guessing parse returns Version? no type hints makes it hard to know what you expect to return from the get_cuda_bare_metal_version function

image

image

dunno what to do about this tho
image

@bionicles
Copy link

ok, i believe packaging is not the issue, as i have made these changes and the compilation still hangs forever here

image

# Copyright (c) 2023, Tri Dao.

import sys
import warnings
import os
import re
import ast
from pathlib import Path
import platform


# from packaging.version import parse, Version

from setuptools import setup, find_packages
import subprocess

import urllib.request
import urllib.error
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

import torch
from torch.utils.cpp_extension import (
    BuildExtension,
    CppExtension,
    CUDAExtension,
    CUDA_HOME,
)

from typing import Tuple


def tuple_from_version(v: str) -> Tuple[int, ...]:
    return tuple(int(v_i) for v_i in v.split("."))


with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()


# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

PACKAGE_NAME = "flash_attn"

BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"

# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"


def get_platform():
    """
    Returns the platform name as used in wheel filenames.
    """
    if sys.platform.startswith("linux"):
        return f"linux_{platform.uname().machine}"
    elif sys.platform == "darwin":
        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
        return f"macosx_{mac_version}_x86_64"
    elif sys.platform == "win32":
        return "win_amd64"
    else:
        raise ValueError("Unsupported platform: {}".format(sys.platform))


def version_from_cuda_dir(cuda_dir: str) -> Tuple[int, ...]:
    raw_output = subprocess.check_output(
        [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
    )
    output = raw_output.split()
    release_idx = output.index("release") + 1
    # print("release_idx", release_idx)
    target_output = output[release_idx]
    # print("target_output", target_output)
    cuda_bare_metal_version = tuple_from_version(target_output.rstrip(","))
    print("cuda_bare_metal_version", cuda_bare_metal_version)
    return cuda_bare_metal_version


def check_if_cuda_home_none(global_option: str) -> None:
    if CUDA_HOME is not None:
        return
    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
    # in that case.
    warnings.warn(
        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
        "only images whose names contain 'devel' will provide nvcc."
    )


def append_nvcc_threads(nvcc_extra_args):
    nvcc_threads = os.getenv("NVCC_THREADS") or "4"
    return nvcc_extra_args + ["--threads", nvcc_threads]


cmdclass = {}
ext_modules = []

# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])

if not SKIP_CUDA_BUILD:
    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
    # See https://github.com/pytorch/pytorch/pull/70650
    generator_flag = []
    torch_dir = torch.__path__[0]
    if os.path.exists(
        os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")
    ):
        generator_flag = ["-DOLD_GENERATOR_PATH"]

    check_if_cuda_home_none("flash_attn")
    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
    bare_metal_version = None
    if CUDA_HOME is not None:
        bare_metal_version = version_from_cuda_dir(CUDA_HOME)
        if bare_metal_version < (11, 6):
            raise RuntimeError(
                "FlashAttention is only supported on CUDA 11.6 and above.  "
                "Note: make sure nvcc has a supported version by running nvcc -V."
            )
    # cc_flag.append("-gencode")
    # cc_flag.append("arch=compute_75,code=sm_75")
    cc_flag.append("-gencode")
    cc_flag.append("arch=compute_80,code=sm_80")
    if CUDA_HOME is not None and isinstance(bare_metal_version, tuple):
        if bare_metal_version >= (11, 8):
            cc_flag.append("-gencode")
            cc_flag.append("arch=compute_90,code=sm_90")

    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True
    ext_modules.append(
        CUDAExtension(
            name="flash_attn_2_cuda",
            sources=[
                "csrc/flash_attn/flash_api.cpp",
                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
            ],
            extra_compile_args={
                "cxx": ["-O3", "-std=c++17"] + generator_flag,
                "nvcc": append_nvcc_threads(
                    [
                        "-O3",
                        "-std=c++17",
                        "-U__CUDA_NO_HALF_OPERATORS__",
                        "-U__CUDA_NO_HALF_CONVERSIONS__",
                        "-U__CUDA_NO_HALF2_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                        "--expt-relaxed-constexpr",
                        "--expt-extended-lambda",
                        "--use_fast_math",
                        # "--ptxas-options=-v",
                        # "--ptxas-options=-O2",
                        # "-lineinfo",
                        # "-DFLASHATTENTION_DISABLE_BACKWARD",
                        # "-DFLASHATTENTION_DISABLE_DROPOUT",
                        # "-DFLASHATTENTION_DISABLE_ALIBI",
                        # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
                        # "-DFLASHATTENTION_DISABLE_LOCAL",
                    ]
                    + generator_flag
                    + cc_flag
                ),
            },
            include_dirs=[
                Path(this_dir) / "csrc" / "flash_attn",
                Path(this_dir) / "csrc" / "flash_attn" / "src",
                Path(this_dir) / "csrc" / "cutlass" / "include",
            ],
        )
    )


def get_package_version():
    this_path = Path(this_dir) / "flash_attn" / "__init__.py"
    with open(this_path, "r") as f:
        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
    assert version_match is not None, "no package version match"
    public_version = ast.literal_eval(version_match.group(1))
    local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
    if local_version:
        return f"{public_version}+{local_version}"
    else:
        return str(public_version)


def get_wheel_url():
    # Determine the version numbers that will be used to determine the correct wheel
    # We're using the CUDA version used to build torch, not the one currently installed
    # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
    # torch_cuda_version = parse(torch.version.cuda)
    # torch_version_raw = parse(torch.__version__)
    torch_cuda_version = tuple_from_version(torch.version.cuda)
    torch_version_raw = tuple_from_version(torch.__version__)
    # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
    # to save CI time. Minor versions should be compatible.
    torch_cuda_version = (
        tuple_from_version("11.8")
        if torch_cuda_version[0] == 11
        else tuple_from_version("12.2")
    )
    python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
    platform_name = get_platform()
    flash_version = get_package_version()
    # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
    cuda_version = f"{torch_cuda_version[0]}{torch_cuda_version[1]}"
    torch_version = f"{torch_version_raw[0]}.{torch_version_raw[1]}"
    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()

    # Determine wheel URL based on CUDA version, torch version, python version and OS
    wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
    wheel_url = BASE_WHEEL_URL.format(
        tag_name=f"v{flash_version}", wheel_name=wheel_filename
    )
    return wheel_url, wheel_filename


class CachedWheelsCommand(_bdist_wheel):
    """
    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
    find an existing wheel (which is currently the case for all flash attention installs). We use
    the environment parameters to detect whether there is already a pre-built version of a compatible
    wheel available and short-circuits the standard full build pipeline.
    """

    def run(self):
        if FORCE_BUILD:
            return super().run()

        wheel_url, wheel_filename = get_wheel_url()
        print("Guessing wheel URL: ", wheel_url)
        try:
            urllib.request.urlretrieve(wheel_url, wheel_filename)

            assert self.dist_dir is not None, "cannot cache without a dist_dir"
            # Make the archive
            # Lifted from the root wheel processing command
            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
            if not os.path.exists(self.dist_dir):
                os.makedirs(self.dist_dir)

            impl_tag, abi_tag, plat_tag = self.get_tag()
            archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"

            wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
            print("Raw wheel path", wheel_path)
            os.rename(wheel_filename, wheel_path)
        except (urllib.error.HTTPError, urllib.error.URLError):
            print("Precompiled wheel not found. Building from source...")
            # If the wheel could not be downloaded, build from source
            super().run()


class NinjaBuildExtension(BuildExtension):
    def __init__(self, *args, **kwargs) -> None:
        # do not override env MAX_JOBS if already exists
        if not os.environ.get("MAX_JOBS"):
            import psutil

            # calculate the maximum allowed NUM_JOBS based on cores
            n_cpus = os.cpu_count()
            assert n_cpus is not None, "no cpus?!"
            max_num_jobs_cores = max(1, n_cpus // 2)

            # calculate the maximum allowed NUM_JOBS based on free memory
            free_memory_gb = psutil.virtual_memory().available / (
                1024**3
            )  # free memory in GB
            max_num_jobs_memory = int(
                free_memory_gb / 9
            )  # each JOB peak memory cost is ~8-9GB when threads = 4

            # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
            max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
            os.environ["MAX_JOBS"] = str(max_jobs)

        super().__init__(*args, **kwargs)


setup(
    name=PACKAGE_NAME,
    version=get_package_version(),
    packages=find_packages(
        exclude=(
            "build",
            "csrc",
            "include",
            "tests",
            "dist",
            "docs",
            "benchmarks",
            "flash_attn.egg-info",
        )
    ),
    author="Tri Dao",
    author_email="[email protected]",
    description="Flash Attention: Fast and Memory-Efficient Exact Attention",
    long_description=long_description,
    long_description_content_type="text/markdown",
    url="https://github.com/Dao-AILab/flash-attention",
    classifiers=[
        "Programming Language :: Python :: 3",
        "License :: OSI Approved :: BSD License",
        "Operating System :: Unix",
    ],
    ext_modules=ext_modules,
    cmdclass=(
        {"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
        if ext_modules
        else {
            "bdist_wheel": CachedWheelsCommand,
        }
    ),
    python_requires=">=3.7",
    install_requires=[
        "torch",
        "einops",
    ],
    setup_requires=[
        "packaging",
        "psutil",
        "ninja",
    ],
)

@collinsleewyatt
Copy link

Unfortunately, this also prevent src builds from compiling against the right version of PyTorch (due to build-isolation) and there is no standard way to turn it off. The only way to do it is to use a pre-pyproject.toml package and rely on legacy support for setup.py.

Does this mean that under the pyproject.toml implementation, flash-attention would always pull the latest version of torch and ignore any currently installed?

@dhellmann
Copy link
Author

Unfortunately, this also prevent src builds from compiling against the right version of PyTorch (due to build-isolation) and there is no standard way to turn it off. The only way to do it is to use a pre-pyproject.toml package and rely on legacy support for setup.py.

Does this mean that under the pyproject.toml implementation, flash-attention would always pull the latest version of torch and ignore any currently installed?

It depends on how the build is run.

In our build, we

  1. Apply this patch.
  2. Create a virtualenv with the version of torch we want along with the other build dependencies.
  3. Run pip wheel with the --no-build-isolation flag.

With --no-build-isolation, pip runs the build step using the environment where pip is installed. That gives us control over creating that virtualenv, including the torch version.

If --no-build-isolation is not used, then pip will create a virtualenv and install the build dependencies. With this patch, and without the flag, pip will pull in the latest version of torch.

@collinsleewyatt
Copy link

Are there other projects which target different pytorch versions that have gotten around the build isolation limitation?

The problem with setup.py is that it attempts to import torch and wheel, before calling setup() with torch and wheel in install_requires. Not sure, but maybe it's possible to modify setup.py to use dynamic imports or call setup() before said imports? Any thoughts on the feasibility of the former solution? Happy to work on a patch for this.

@dhellmann
Copy link
Author

For our project, we are also building

  • deepspeed
  • torchvision
  • triton
  • vllm
  • vllm_flash_attn
  • xformers

All of which depend on torch for their builds. They all include either just the package or the package with a version range (like >=2.3.1) in their pyproject.toml. We take the same approach with those of constructing the virtualenv using the version of torch we are using before building the package.

None of these packages are perfect. We're carrying some patches for all of them to fix up the builds, in some cases similar to the patch I've proposed here and in other cases to remove a pinned version. For example, vllm pins torch in their requirements list.

@mgorny
Copy link
Contributor

mgorny commented Jan 15, 2025

Does this mean that under the pyproject.toml implementation, flash-attention would always pull the latest version of torch and ignore any currently installed?

I think there may be some confusion here.

  1. If you call setup.py directly, it just uses the installed version, and the pull request doesn't change anything in that scenario.
  2. If you install it from source via pip install ..., it will normally pull the newest version to build the wheel. However, without the pull request, pip install ... doesn't work at all — it fails because pip does build isolation anyway, and tries to build it in a virtual environment without pytorch at all.
  3. If you install it from source via pip install --no-build-isolation ..., then the current version is used for the build — the pull request doesn't change anything in that scenario.
  4. When installing the built wheel (either from PyPI or after pip install), the current version of PyTorch is also used (provided it satisfies install_requires) — so again, nothing changes.

So unless I'm missing something, this pull request fixes scenario 2. that currently doesn't work at all, and doesn't cause any regressions in the remaining scenarios.

@rgommers
Copy link

rgommers commented Jan 15, 2025

So unless I'm missing something, this pull request fixes scenario 2.

There are two flavors of that, with pip install . or pip install flash-attn being run in an environment that:

  1. Doesn't have torch installed yet
  2. Does have a version of torch installed already

(1) will indeed be fixed with this PR.
(2) will work if the already-installed torch version is the latest one. If it's older, then due to build isolation, flash-attn will be built against latest torch and then find an older torch at runtime in the environment. Because this package builds PyTorch extensions with torch.utils.cpp_extension and PyTorch does not have a stable ABI, this will result in bad things (crashes probably, perhaps incorrect results). I suspect that @Skylion007's comment was hinting at the problem with the latter scenario, and why the pyproject.toml that was present before was moved to flash_attn/pyproject.toml 1.5 years ago in commit 73bd3f3.

That ABI issue is indeed bad (crashes are worse than no build starting). The good news is that this is easily fixable: in setup.py, query the version of PyTorch used at build time, and then dynamically update install_requires to contain torch==version_at_build_time. That will fix the metadata in the wheel produced by pip install, and ensure a new torch version is pulled in.

If that were added to this PR, I'd say build isolation is no longer a problem, that's simply a feature of how Python packaging works. And all four of the scenarios @mgorny outlined will work as they should.

Given the number of related issues and PRs on this repo, as well as incoming links, it looks like a fair amount of people would benefit from having this PR merged. @tridao with that change added, would you be willing to merge this PR? Or is there anything else you'd like to see or are worried about?

Copy link
Contributor

@weiji14 weiji14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The good news is that this is easily fixable: in setup.py, query the version of PyTorch used at build time, and then dynamically update install_requires to contain torch==version_at_build_time. That will fix the metadata in the wheel produced by pip install, and ensure a new torch version is pulled in.\n\nIf that were added to this PR, I'd say build isolation is no longer a problem

This exactly 👆🏻. We are effectively doing something similar for the conda-forge flash-attn package which has a constraint to install the same (minor) version of pytorch at runtime as that used at build. E.g. flash-attn=2.7.2 can only be installed with pytorch>=2.5.1,<2.6.0a0. I'm not sure about Pytorch's ABI guarantees, but constraining to allow any patch version within a minor version seems reasonable I suppose.

We've actually created a minimal pyproject.toml file for the conda-forge recipe (paired with a simplified setup.py) that has worked nicely since May 2024. Would like it if this pyproject.toml file is just available upstream here.

@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools", "packaging", "psutil", "ninja", "torch", "wheel"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've been able to get away with just setuptools, torch and ninja at https://github.com/conda-forge/flash-attn-feedstock/blob/9e6cd984d2d3018ff92b3ad4cbb867641638ba39/recipe/pyproject.toml#L2 when building the flash-attn package for conda-forge.

Suggested change
requires = ["setuptools", "packaging", "psutil", "ninja", "torch", "wheel"]
requires = ["setuptools", "ninja", "torch"]

Does packaging, psutil and wheel need to be declared here? They are declared in setup_requires since #937 FWIW.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

packaging is used directly in setup.py, so it must be declared here.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with psutil and wheel. The wheel package is imported in a top-level import. psutil is imported in a local import unless the user sets MAX_JOBS env var. build-system requires must list all non-stdlib packages that are needed by the build backend or setup.py.

@rgommers
Copy link

I'm not sure about Pytorch's ABI guarantees, but constraining to allow any patch version within a minor version seems reasonable I suppose.

Technically PyTorch does not offer any ABI stability at all even for micro/bugfix releases, and exactly matching versions are needed. In the past, bugfix releases have tended to be mostly compatible, however not always.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants