Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/nvidia/fuser into hopper_ma…
Browse files Browse the repository at this point in the history
…tmul_cta_k_fix
  • Loading branch information
rdspring1 committed Jan 8, 2025
2 parents 77982fb + 9ce2112 commit dd4d385
Show file tree
Hide file tree
Showing 68 changed files with 4,393 additions and 1,281 deletions.
22 changes: 15 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp
${NVFUSER_SRCS_DIR}/rng.cpp
${NVFUSER_SRCS_DIR}/runtime/allocations.cpp
${NVFUSER_SRCS_DIR}/runtime/executor.cpp
Expand Down Expand Up @@ -296,13 +297,18 @@ endif()
add_library(codegen_internal OBJECT ${NVFUSER_SRCS})

if(NOT MSVC)
# -Werror is not enabled, because of gcc 12.2 used in manylinux image.
# consider enable this when we upgrade. linking comment:
# https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266
target_compile_options(codegen_internal PRIVATE
-Wall -Wno-unused-function
# -Werror
)
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
target_compile_options(codegen_internal PRIVATE
-Wall -Wno-unused-function -Werror
# These warnings are not treated as errors because of gcc 12.2 used in
# manylinux image. consider enable this when we upgrade.
# linking comment:
# https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266
-Wno-error=restrict -Wno-error=stringop-overflow)
else()
target_compile_options(codegen_internal PRIVATE
-Wall -Wno-unused-function -Werror)
endif()
endif()

target_compile_definitions(codegen_internal PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB")
Expand Down Expand Up @@ -442,6 +448,7 @@ if(BUILD_PYTHON)
list(APPEND NVFUSER_PYTHON_SRCS
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp
${NVFUSER_SRCS_DIR}/python_frontend/schedule_bindings.cpp
)

add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS})
Expand Down Expand Up @@ -575,6 +582,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_resharding.cpp
${NVFUSER_ROOT}/tests/cpp/test_resize.cpp
${NVFUSER_ROOT}/tests/cpp/test_reduction_pointwise.cpp
${NVFUSER_ROOT}/tests/cpp/test_rope.cpp
${NVFUSER_ROOT}/tests/cpp/test_scalar_hoisting.cpp
${NVFUSER_ROOT}/tests/cpp/test_scatter_gather.cpp
${NVFUSER_ROOT}/tests/cpp/test_sdpa_node.cpp
Expand Down
71 changes: 71 additions & 0 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,77 @@ class BFS {
Direction allowed_direction_ = Direction::Undefined;
};

// Unlike the default BFS behavior, Expr is considered ready to
// visit as long as one of the inputs or outputs has any of its dependencies met
template <
typename ExprT,
typename ValT,
typename DefinitionT,
typename UsesT,
typename InputsT,
typename OutputsT>
class BFSWithPermissiveDependence
: public BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT> {
public:
using NodeType =
typename BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>::
NodeType;

BFSWithPermissiveDependence(
DefinitionT definition,
UsesT uses,
InputsT inputs,
OutputsT outputs,
std::vector<NodeType> from,
std::vector<NodeType> to,
bool require_all_to_visited = true,
Direction allowed_direction = Direction::Undefined)
: BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>(
definition,
uses,
inputs,
outputs,
std::move(from),
std::move(to),
require_all_to_visited,
allowed_direction) {}

std::optional<std::pair<Direction, std::vector<NodeType>>> isReady(
const ExprT& expr) const override {
// Either any inputs or any outputs must have been visited
decltype(auto) inputs = this->inputs_(expr);
if (!inputs.empty() && this->allowed_direction_ != Direction::Backward &&
std::any_of(
inputs.begin(), inputs.end(), [&](const ValT& input) -> bool {
return this->isDependencySatisfied(input);
})) {
std::vector<NodeType> prev_nodes;
std::copy_if(
inputs.begin(),
inputs.end(),
std::back_inserter(prev_nodes),
[&](const ValT& input) -> bool { return this->isVisited(input); });
return std::make_pair(Direction::Forward, prev_nodes);
}

decltype(auto) outputs = this->outputs_(expr);
if (!outputs.empty() && this->allowed_direction_ != Direction::Forward &&
std::any_of(
outputs.begin(), outputs.end(), [&](const ValT& output) -> bool {
return this->isDependencySatisfied(output);
})) {
std::vector<NodeType> prev_nodes;
std::copy_if(
outputs.begin(),
outputs.end(),
std::back_inserter(prev_nodes),
[&](const ValT& output) -> bool { return this->isVisited(output); });
return std::make_pair(Direction::Backward, prev_nodes);
}
return std::nullopt;
}
};

// Find the shortest path from the from vals to the to
// vals. Dependency between vals and exprs must be satisfied.
// It is an error if no valid path is found unless
Expand Down
25 changes: 15 additions & 10 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3028,17 +3028,22 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
} else {
step_code << gen_index << " += " << gen_step;
}
if (loop->isUnrolled()) {
indent() << "#pragma unroll\n";
} else if (
loop->circularBufferLoopStage() == CircularBufferLoopStage::Epilog) {
indent() << "#pragma unroll " << loop->circularBufferLoopStageDepth() - 1
<< "\n";
} else if (
loop->circularBufferLoopStage() !=
if (loop->circularBufferLoopStage() !=
CircularBufferLoopStage::NotApplicable) {
indent() << "#pragma unroll " << loop->circularBufferLoopStageDepth()
<< "\n";
// NOTE: requireUnroll is sometimes called on a circular-buffered matmul
// loops when static shapes are used. To avoid hinting that the compiler
// should maximally unroll such loops leading to very long compiles, we
// handle that case explicitly here and ignore loop->isUnrolled().
//
// Unroll "prefetch" many circular buffered loops regardless of buffer
// stage (prologue, main, or epilogue)
int64_t prefetch = kernel_->summary()
.circular_buffer_info
.getCircularBufferOptionsFor(loop->iter_domain())
.prefetch;
indent() << "#pragma unroll " << prefetch << "\n";
} else if (loop->isUnrolled()) {
indent() << "#pragma unroll\n";
} else {
indent() << "#pragma unroll 1\n";
}
Expand Down
6 changes: 4 additions & 2 deletions csrc/device_lower/analysis/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,11 @@ IterDomain* CircularBufferInfo::getCircularBufferAxis(

const CircularBufferOptions& CircularBufferInfo::getCircularBufferOptionsFor(
IterDomain* circular_buffer_axis) const {
auto concrete_id = lower_utils::getConcreteLoopID(circular_buffer_axis);
if (GpuLower::hasCurrent()) {
circular_buffer_axis = lower_utils::getConcreteLoopID(circular_buffer_axis);
}

auto maybe_depth_it = circular_buffer_options_.find(concrete_id);
auto maybe_depth_it = circular_buffer_options_.find(circular_buffer_axis);

NVF_ERROR(
maybe_depth_it != circular_buffer_options_.end(),
Expand Down
15 changes: 15 additions & 0 deletions csrc/device_lower/analysis/device_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <cuda.h>

#include <device_lower/analysis/device_version.h>
#include <device_lower/lower2device.h>
#include <mma_type.h>
Expand All @@ -19,9 +21,22 @@ void MinimumDeviceVersion::dispatch(Val* val) {
}
if (val->dtype() == DataType::Float8_e4m3fn ||
val->dtype() == DataType::Float8_e5m2) {
// See release note
// https://docs.nvidia.com/cuda/archive/12.1.0/parallel-thread-execution/index.html#ptx-isa-version-8-1
#if (CUDA_VERSION >= 12010)
ensureVersion(
{8, 9},
"Fusion contains Float8_xxx values which was introduced in Ada (8.9)");
// See release note
// https://docs.nvidia.com/cuda/archive/11.8.0/parallel-thread-execution/index.html#ptx-isa-version-7-8
#elif (CUDA_VERSION >= 11080)
ensureVersion(
{9, 0},
"Fusion contains Float8_xxx values which was introduced in Hopper (9.0)");
#else
NVF_ERROR(
"Fusion contains Float8_xxx values which was not supported in given CUDA version");
#endif // (CUDA_VERSION >= 12010)
}
IterVisitor::dispatch(val);
}
Expand Down
8 changes: 4 additions & 4 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1686,8 +1686,8 @@ Val* hardCodedIndexGenerationForStMatrix(
Val* out_index = nullptr;

NVF_ERROR(
ldst->out()->dtype() == DataType::Half,
"we only support half type in stmatrix");
dataTypeSize(ldst->out()->dtype()) == 2,
"we only support 16-bit types in stmatrix");

NVF_ERROR(ldst->out()->isA<TensorView>());
TensorView* out_tv = ldst->out()->as<TensorView>();
Expand Down Expand Up @@ -1959,8 +1959,8 @@ Val* hardCodedIndexGenerationForStMatrixSwizzle(
"size not currently supported for stmatrix");

NVF_ERROR(
ldst->out()->dtype() == DataType::Half,
"we only support half type in stmatrix");
dataTypeSize(ldst->out()->dtype()) == 2,
"we only support 16-bit types in stmatrix");

NVF_ERROR(ldst->out()->isA<TensorView>());
TensorView* out_tv = ldst->out()->as<TensorView>();
Expand Down
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class Val;
f(HostUnit); \
f(PostOnStream); \
f(SetCurrentStream); \
f(GetCurrentStream); \
f(Wait); \
f(Synchronize); \
f(StartCoalescing); \
Expand Down
24 changes: 22 additions & 2 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ HostIrEvaluator::HostIrEvaluator(
HostIrEvaluatorParams params)
: container_(std::move(container)),
communicator_(communicator),
params_(params) {
params_(params),
my_device_index_(communicator_ ? communicator_->deviceId() : 0) {
const DeviceIdxType device_index =
(communicator_ != nullptr && communicator_->is_available())
? communicator_->deviceId()
Expand Down Expand Up @@ -273,8 +274,27 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) {
setCurrentCUDAStream(getCUDAStream(set_current_stream->stream()));
}

void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) {
streams_.insert(
{get_current_stream->stream(),
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))});
}

void HostIrEvaluator::handle(Synchronize* synchronize) {
getCUDAStream(synchronize->stream()).synchronize();
cudaStream_t current_stream =
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))
.stream();
cudaStream_t stream_to_sync = getCUDAStream(synchronize->stream()).stream();

cudaEvent_t event = {};
NVFUSER_CUDA_RT_SAFE_CALL(
cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(event, stream_to_sync));
NVFUSER_CUDA_RT_SAFE_CALL(
cudaStreamWaitEvent(current_stream, event, cudaEventWaitDefault));
NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event));
}

void HostIrEvaluator::handle(PostOnStream* post_ir) {
Expand Down
2 changes: 2 additions & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class HostIrEvaluator final : public OptOutDispatch {
private:
using OptOutDispatch::handle;
void handle(SetCurrentStream* set_current_stream) override;
void handle(GetCurrentStream* get_current_stream) override;
void handle(Synchronize* synchronize) override;
void handle(PostOnStream* post_ir) override;
void handle(Communication* communication) override;
Expand All @@ -138,6 +139,7 @@ class HostIrEvaluator final : public OptOutDispatch {
using StreamKey = std::variant<int64_t, Stream*>;
std::unordered_map<StreamKey, c10::cuda::CUDAStream> streams_;
std::unordered_map<Expr*, c10::intrusive_ptr<c10d::Work>> works_;
const int64_t my_device_index_;
};

} // namespace hir
Expand Down
16 changes: 16 additions & 0 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ bool SetCurrentStream::sameAs(const Statement* other) const {
return false;
}

GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
auto stream = IrBuilder::createInContainer<Stream>(passkey.ir_container_);
addAttribute(stream);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream)

std::string GetCurrentStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString()
<< std::endl;
return ss.str();
}

Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
: Expr(passkey, {}, {}, {expr}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
Expand Down
22 changes: 22 additions & 0 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,28 @@ class SetCurrentStream : public Expr {
}
};

class GetCurrentStream : public Expr {
public:
using Expr::Expr;
GetCurrentStream(IrBuilderPasskey passkey);

GetCurrentStream(const GetCurrentStream& other) = delete;
GetCurrentStream& operator=(const GetCurrentStream& other) = delete;
GetCurrentStream(GetCurrentStream&& other) = delete;
GetCurrentStream& operator=(GetCurrentStream&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::GetCurrentStream";
}

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
}
};

class Wait : public Expr {
public:
using Expr::Expr;
Expand Down
Loading

0 comments on commit dd4d385

Please sign in to comment.