Skip to content

Commit 7a9c595

Browse files
committed
add time event support
1 parent 0d2d924 commit 7a9c595

File tree

2 files changed

+115
-47
lines changed

2 files changed

+115
-47
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 91 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -252,25 +252,41 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(
252252
uint64_t seq,
253253
bool isP2P,
254254
const char* profilingTitle,
255-
const std::optional<std::vector<at::Tensor>>& inputs)
255+
const std::optional<std::vector<at::Tensor>>& inputs,
256+
bool enableTiming,
257+
bool xpuEventCacheEnabled)
256258
: Work(rank, opType, profilingTitle, inputs),
257259
device_(device),
258260
workStartTime_(std::chrono::steady_clock::now()),
259261
seq_(seq),
260-
isP2P_(isP2P) {
261-
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>();
262+
isP2P_(isP2P),
263+
timingEnabled_(enableTiming) {
264+
if (xpuEventCacheEnabled) {
265+
xcclStartEvent_ = enableTiming
266+
? XPUEventCache::get(device.index())->create(enableTiming)
267+
: nullptr;
268+
xcclEndEvent_ = XPUEventCache::get(device.index())->create(enableTiming);
269+
} else {
270+
xcclStartEvent_ = enableTiming
271+
? std::make_shared<at::xpu::XPUEvent>(xpuEventDefault)
272+
: nullptr;
273+
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>(
274+
enableTiming ? xpuEventDefault : xpuEventDisableTiming);
275+
}
262276
stashed_for_allocator_safety_ = std::make_shared<TensorShelf>();
263277
}
264278

265279
ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w)
266280
: Work(w.rank_, w.opType_),
267281
device_(w.device_),
282+
xcclStartEvent_(w.xcclStartEvent_),
268283
xcclEndEvent_(w.xcclEndEvent_),
269284
blockingWait_(w.blockingWait_),
270285
workStartTime_(w.workStartTime_),
271286
seq_(w.seq_),
272287
isP2P_(w.isP2P_),
273-
stashed_for_allocator_safety_(w.stashed_for_allocator_safety_) {}
288+
stashed_for_allocator_safety_(w.stashed_for_allocator_safety_),
289+
timingEnabled_(w.timingEnabled_) {}
274290

275291
ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default;
276292

@@ -486,7 +502,7 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
486502
profilingTitle ? profilingTitle : "",
487503
inputs,
488504
outputs,
489-
nullptr,
505+
r->xcclStartEvent_.get(),
490506
r->xcclEndEvent_.get(),
491507
options_->timeout,
492508
pgStatus_,
@@ -495,6 +511,17 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
495511
return r;
496512
}
497513

514+
float ProcessGroupXCCL::WorkXCCL::getDuration() const {
515+
TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled");
516+
TORCH_CHECK(
517+
xcclStartEvent_,
518+
"getDuration only works if xcclStartEvents_ is populated, true if timing enabled");
519+
TORCH_CHECK(
520+
xcclEndEvent_,
521+
"getDuration only works if xcclEndEvents_ is populated, which should always be true");
522+
return xcclStartEvent_->elapsed_time(*xcclEndEvent_);
523+
}
524+
498525
std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
499526
const std::string& deviceKey,
500527
at::Device& device,
@@ -643,6 +670,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {
643670

644671
work->stashed_for_allocator_safety_->stash(coalescedTensors_);
645672

673+
if (work->timingEnabled_) {
674+
work->xcclStartEvent_->record(stream);
675+
}
676+
646677
groupEnd();
647678

648679
work->xcclEndEvent_->record(stream);
@@ -773,6 +804,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
773804

774805
pre(stream, work);
775806

807+
if (work->timingEnabled_ && !coalescing_state_) {
808+
work->xcclStartEvent_->record(stream);
809+
}
810+
776811
for (const auto i : c10::irange(inputs.size())) {
777812
fn(inputs[i], outputs[i], *comm, stream, *cclstream);
778813
}
@@ -810,12 +845,14 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
810845
return asyncOp ? work : nullptr;
811846
}
812847

813-
template <typename Fn>
848+
template <typename Fn, typename PreProcess, typename PostProcess>
814849
c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
815850
at::Tensor& tensor,
816851
Fn fn,
817852
int peer,
818853
OpType opType,
854+
PreProcess pre,
855+
PostProcess post,
819856
const char* profilingTitle) {
820857
auto device = tensor.device();
821858
std::string key;
@@ -865,13 +902,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
865902
auto cclstream = xcclStreamsMap_.at(key).second;
866903
syncStream(device, xcclEventsMap_[key], stream);
867904

868-
if (enableNanCheck_ && opType == OpType::SEND) {
869-
checkForNan(tensor, stream);
870-
}
871-
905+
c10::intrusive_ptr<ProcessGroupNCCL::WorkXCCL> work;
872906
if (!coalescing_state_) {
873-
auto work =
874-
initWork(device, rank_, opType, true, profilingTitle, {tensor}, {});
907+
work = initWork(device, rank_, opType, true, profilingTitle, {tensor}, {});
875908
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
876909
work->outputs_->push_back(tensor);
877910

@@ -884,37 +917,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
884917
profilingTitle,
885918
{tensor},
886919
{tensor},
887-
nullptr,
920+
work->xcclStartEvent_.get(),
888921
work->xcclEndEvent_.get(),
889922
options_->timeout,
890923
pgStatus_,
891924
true);
892925

893-
c10::OptionalDeviceGuard gpuGuard(device);
894-
895-
c10::xpu::XPUCachingAllocator::recordStream(
896-
tensor.storage().data_ptr(), stream);
897-
898-
fn(tensor, *comm, stream, cclstream, p2pTargetRank);
899-
900-
work->xcclEndEvent_->record(stream);
901-
work->blockingWait_ = blockingWait_;
902-
std::vector<c10::Stream> streams = {stream.unwrap()};
903-
c10::MultiStreamGuard streamGuard(streams);
904-
std::vector<at::Device> devices{device};
905-
work->future_ = c10::make_intrusive<at::ivalue::Future>(
906-
c10::ListType::create(c10::TensorType::get()), devices);
907-
work->future_->markCompleted(at::IValue(*work->outputs_));
908-
auto id = work->trace_id_;
909-
work->future_->addCallback(
910-
[id](at::ivalue::Future&) {
911-
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
912-
},
913-
/*use_future*/ false);
914-
915-
work->numelIn_ = work->numelOut_ = tensor.numel();
916-
setEnqueuedPgStatus(work);
917-
return work;
918926
} else {
919927
FlightRecorderXCCL::get()->record(
920928
local_id_,
@@ -930,15 +938,53 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
930938
options_->timeout,
931939
pgStatus_,
932940
true);
933-
c10::OptionalDeviceGuard gpuGuard(device);
941+
}
934942

935-
c10::xpu::XPUCachingAllocator::recordStream(
936-
tensor.storage().data_ptr(), stream);
943+
if (enableNanCheck_ && opType == OpType::SEND) {
944+
checkForNan(tensor, stream);
945+
}
946+
if (!coalescing_state_) {
947+
// Start event should only be recorded before the ncclGroupStart()
948+
if (work->timingEnabled_) {
949+
work->ncclStartEvent_->record(stream);
950+
}
937951

938-
fn(tensor, *comm, stream, cclstream, p2pTargetRank);
952+
pre(stream, work);
953+
}
954+
c10::OptionalDeviceGuard gpuGuard(device);
939955

940-
return nullptr;
956+
c10::xpu::XPUCachingAllocator::recordStream(
957+
tensor.storage().data_ptr(), stream);
958+
959+
xcclGroupStart();
960+
fn(tensor, *comm, stream, cclstream, p2pTargetRank);
961+
xcclGroupEnd();
962+
963+
if (!coalescing_state_) {
964+
post(stream);
965+
966+
work->xcclEndEvent_->record(stream);
967+
work->blockingWait_ = blockingWait_;
968+
work->numelIn_ = work->numelOut_ = tensor.numel();
969+
{
970+
std::vector<c10::Stream> streams = {stream.unwrap()};
971+
c10::MultiStreamGuard streamGuard(streams);
972+
std::vector<at::Device> devices{device};
973+
work->future_ = c10::make_intrusive<at::ivalue::Future>(
974+
c10::ListType::create(c10::TensorType::get()), devices);
975+
work->future_->markCompleted(at::IValue(*work->outputs_));
976+
}
977+
978+
auto id = work->trace_id_;
979+
work->future_->addCallback(
980+
[id](at::ivalue::Future&) {
981+
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
982+
},
983+
/*use_future*/ false);
984+
setEnqueuedPgStatus(work);
941985
}
986+
987+
return work;
942988
}
943989

944990
c10::intrusive_ptr<Work> ProcessGroupXCCL::send(
@@ -2043,8 +2089,8 @@ c10::DeviceIndex ProcessGroupXCCL::guessDeviceId() const {
20432089
} else if (!usedDeviceIdxs_.empty()) {
20442090
return *usedDeviceIdxs_.begin();
20452091
}
2046-
int devIdx =
2047-
static_cast<int16_t>(globalRank() % at::detail::getXPUHooks().getNumGPUs());
2092+
int devIdx = static_cast<int16_t>(
2093+
globalRank() % at::detail::getXPUHooks().getNumGPUs());
20482094
LOG(WARNING)
20492095
<< logPrefix()
20502096
<< c10::str(

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
103103

104104
protected:
105105
at::Device device_;
106+
std::shared_ptr<at::xpu::XPUEvent> xcclStartEvent_;
106107
std::shared_ptr<at::xpu::XPUEvent> xcclEndEvent_;
107108
bool isBarrierOp_{false};
108109
bool blockingWait_{false};
@@ -117,6 +118,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
117118
std::shared_ptr<std::vector<at::Tensor>> outputs_;
118119
std::shared_ptr<TensorShelf> stashed_for_allocator_safety_;
119120
c10::intrusive_ptr<at::ivalue::Future> future_;
121+
bool timingEnabled_;
120122
friend class ProcessGroupXCCL;
121123
};
122124

@@ -306,13 +308,33 @@ class TORCH_API ProcessGroupXCCL : public Backend {
306308
/*nanCheck =*/false);
307309
}
308310

309-
template <typename Fn>
311+
template <typename Fn>
312+
c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
313+
at::Tensor& tensor,
314+
Fn fn,
315+
int peer,
316+
OpType opType,
317+
const char* profilingTitle) {
318+
return pointToPoint(
319+
tensor,
320+
fn,
321+
peer,
322+
opType,
323+
[](at::xpu::XPUStream&,
324+
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>& work) {},
325+
[](at::xpu::XPUStream&) {},
326+
profilingTitle);
327+
}
328+
329+
template <typename Fn, typename PreProcess, typename PostProcess>
310330
c10::intrusive_ptr<Work> pointToPoint(
311331
at::Tensor& tensor,
312332
Fn fn,
313333
int peer,
314334
OpType opType,
315-
const char* profilingTitle = nullptr);
335+
PreProcess pre,
336+
PostProcess post,
337+
const char* profilingTitle);
316338

317339
c10::intrusive_ptr<Work> allreduce_impl(
318340
at::Tensor& tensor,

0 commit comments

Comments
 (0)