@@ -252,25 +252,41 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(
252252 uint64_t seq,
253253 bool isP2P,
254254 const char * profilingTitle,
255- const std::optional<std::vector<at::Tensor>>& inputs)
255+ const std::optional<std::vector<at::Tensor>>& inputs,
256+ bool enableTiming,
257+ bool xpuEventCacheEnabled)
256258 : Work(rank, opType, profilingTitle, inputs),
257259 device_ (device),
258260 workStartTime_(std::chrono::steady_clock::now()),
259261 seq_(seq),
260- isP2P_(isP2P) {
261- xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>();
262+ isP2P_(isP2P),
263+ timingEnabled_(enableTiming) {
264+ if (xpuEventCacheEnabled) {
265+ xcclStartEvent_ = enableTiming
266+ ? XPUEventCache::get (device.index ())->create (enableTiming)
267+ : nullptr ;
268+ xcclEndEvent_ = XPUEventCache::get (device.index ())->create (enableTiming);
269+ } else {
270+ xcclStartEvent_ = enableTiming
271+ ? std::make_shared<at::xpu::XPUEvent>(xpuEventDefault)
272+ : nullptr ;
273+ xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>(
274+ enableTiming ? xpuEventDefault : xpuEventDisableTiming);
275+ }
262276 stashed_for_allocator_safety_ = std::make_shared<TensorShelf>();
263277}
264278
265279ProcessGroupXCCL::WorkXCCL::WorkXCCL (const WorkXCCL& w)
266280 : Work(w.rank_, w.opType_),
267281 device_(w.device_),
282+ xcclStartEvent_(w.xcclStartEvent_),
268283 xcclEndEvent_(w.xcclEndEvent_),
269284 blockingWait_(w.blockingWait_),
270285 workStartTime_(w.workStartTime_),
271286 seq_(w.seq_),
272287 isP2P_(w.isP2P_),
273- stashed_for_allocator_safety_(w.stashed_for_allocator_safety_) {}
288+ stashed_for_allocator_safety_(w.stashed_for_allocator_safety_),
289+ timingEnabled_(w.timingEnabled_) {}
274290
275291ProcessGroupXCCL::WorkXCCL::~WorkXCCL () = default ;
276292
@@ -486,7 +502,7 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
486502 profilingTitle ? profilingTitle : " " ,
487503 inputs,
488504 outputs,
489- nullptr ,
505+ r-> xcclStartEvent_ . get () ,
490506 r->xcclEndEvent_ .get (),
491507 options_->timeout ,
492508 pgStatus_,
@@ -495,6 +511,17 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
495511 return r;
496512}
497513
514+ float ProcessGroupXCCL::WorkXCCL::getDuration () const {
515+ TORCH_CHECK (timingEnabled_, " getDuration only works if timing was enabled" );
516+ TORCH_CHECK (
517+ xcclStartEvent_,
518+ " getDuration only works if xcclStartEvents_ is populated, true if timing enabled" );
519+ TORCH_CHECK (
520+ xcclEndEvent_,
521+ " getDuration only works if xcclEndEvents_ is populated, which should always be true" );
522+ return xcclStartEvent_->elapsed_time (*xcclEndEvent_);
523+ }
524+
498525std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm (
499526 const std::string& deviceKey,
500527 at::Device& device,
@@ -643,6 +670,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {
643670
644671 work->stashed_for_allocator_safety_ ->stash (coalescedTensors_);
645672
673+ if (work->timingEnabled_ ) {
674+ work->xcclStartEvent_ ->record (stream);
675+ }
676+
646677 groupEnd ();
647678
648679 work->xcclEndEvent_ ->record (stream);
@@ -773,6 +804,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
773804
774805 pre (stream, work);
775806
807+ if (work->timingEnabled_ && !coalescing_state_) {
808+ work->xcclStartEvent_ ->record (stream);
809+ }
810+
776811 for (const auto i : c10::irange (inputs.size ())) {
777812 fn (inputs[i], outputs[i], *comm, stream, *cclstream);
778813 }
@@ -810,12 +845,14 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
810845 return asyncOp ? work : nullptr ;
811846}
812847
813- template <typename Fn>
848+ template <typename Fn, typename PreProcess, typename PostProcess >
814849c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint (
815850 at::Tensor& tensor,
816851 Fn fn,
817852 int peer,
818853 OpType opType,
854+ PreProcess pre ,
855+ PostProcess post ,
819856 const char * profilingTitle) {
820857 auto device = tensor.device ();
821858 std::string key;
@@ -865,13 +902,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
865902 auto cclstream = xcclStreamsMap_.at (key).second ;
866903 syncStream (device, xcclEventsMap_[key], stream);
867904
868- if (enableNanCheck_ && opType == OpType::SEND) {
869- checkForNan (tensor, stream);
870- }
871-
905+ c10::intrusive_ptr<ProcessGroupNCCL::WorkXCCL> work;
872906 if (!coalescing_state_) {
873- auto work =
874- initWork (device, rank_, opType, true , profilingTitle, {tensor}, {});
907+ work = initWork (device, rank_, opType, true , profilingTitle, {tensor}, {});
875908 work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
876909 work->outputs_ ->push_back (tensor);
877910
@@ -884,37 +917,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
884917 profilingTitle,
885918 {tensor},
886919 {tensor},
887- nullptr ,
920+ work-> xcclStartEvent_ . get () ,
888921 work->xcclEndEvent_ .get (),
889922 options_->timeout ,
890923 pgStatus_,
891924 true );
892925
893- c10::OptionalDeviceGuard gpuGuard (device);
894-
895- c10::xpu::XPUCachingAllocator::recordStream (
896- tensor.storage ().data_ptr (), stream);
897-
898- fn (tensor, *comm, stream, cclstream, p2pTargetRank);
899-
900- work->xcclEndEvent_ ->record (stream);
901- work->blockingWait_ = blockingWait_;
902- std::vector<c10::Stream> streams = {stream.unwrap ()};
903- c10::MultiStreamGuard streamGuard (streams);
904- std::vector<at::Device> devices{device};
905- work->future_ = c10::make_intrusive<at::ivalue::Future>(
906- c10::ListType::create (c10::TensorType::get ()), devices);
907- work->future_ ->markCompleted (at::IValue (*work->outputs_ ));
908- auto id = work->trace_id_ ;
909- work->future_ ->addCallback (
910- [id](at::ivalue::Future&) {
911- FlightRecorderXCCL::get ()->retire_id (id, /* compute_duration*/ false );
912- },
913- /* use_future*/ false );
914-
915- work->numelIn_ = work->numelOut_ = tensor.numel ();
916- setEnqueuedPgStatus (work);
917- return work;
918926 } else {
919927 FlightRecorderXCCL::get ()->record (
920928 local_id_,
@@ -930,15 +938,53 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
930938 options_->timeout ,
931939 pgStatus_,
932940 true );
933- c10::OptionalDeviceGuard gpuGuard (device);
941+ }
934942
935- c10::xpu::XPUCachingAllocator::recordStream (
936- tensor.storage ().data_ptr (), stream);
943+ if (enableNanCheck_ && opType == OpType::SEND) {
944+ checkForNan (tensor, stream);
945+ }
946+ if (!coalescing_state_) {
947+ // Start event should only be recorded before the ncclGroupStart()
948+ if (work->timingEnabled_ ) {
949+ work->ncclStartEvent_ ->record (stream);
950+ }
937951
938- fn (tensor, *comm, stream, cclstream, p2pTargetRank);
952+ pre (stream, work);
953+ }
954+ c10::OptionalDeviceGuard gpuGuard (device);
939955
940- return nullptr ;
956+ c10::xpu::XPUCachingAllocator::recordStream (
957+ tensor.storage ().data_ptr (), stream);
958+
959+ xcclGroupStart ();
960+ fn (tensor, *comm, stream, cclstream, p2pTargetRank);
961+ xcclGroupEnd ();
962+
963+ if (!coalescing_state_) {
964+ post (stream);
965+
966+ work->xcclEndEvent_ ->record (stream);
967+ work->blockingWait_ = blockingWait_;
968+ work->numelIn_ = work->numelOut_ = tensor.numel ();
969+ {
970+ std::vector<c10::Stream> streams = {stream.unwrap ()};
971+ c10::MultiStreamGuard streamGuard (streams);
972+ std::vector<at::Device> devices{device};
973+ work->future_ = c10::make_intrusive<at::ivalue::Future>(
974+ c10::ListType::create (c10::TensorType::get ()), devices);
975+ work->future_ ->markCompleted (at::IValue (*work->outputs_ ));
976+ }
977+
978+ auto id = work->trace_id_ ;
979+ work->future_ ->addCallback (
980+ [id](at::ivalue::Future&) {
981+ FlightRecorderXCCL::get ()->retire_id (id, /* compute_duration*/ false );
982+ },
983+ /* use_future*/ false );
984+ setEnqueuedPgStatus (work);
941985 }
986+
987+ return work;
942988}
943989
944990c10::intrusive_ptr<Work> ProcessGroupXCCL::send (
@@ -2043,8 +2089,8 @@ c10::DeviceIndex ProcessGroupXCCL::guessDeviceId() const {
20432089 } else if (!usedDeviceIdxs_.empty ()) {
20442090 return *usedDeviceIdxs_.begin ();
20452091 }
2046- int devIdx =
2047- static_cast < int16_t >( globalRank () % at::detail::getXPUHooks ().getNumGPUs ());
2092+ int devIdx = static_cast < int16_t >(
2093+ globalRank () % at::detail::getXPUHooks ().getNumGPUs ());
20482094 LOG (WARNING)
20492095 << logPrefix ()
20502096 << c10::str (
0 commit comments