diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 304a35bd2ee80..52f0d60458ef9 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -83,6 +83,7 @@ option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in l option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump node input shapes and output data to standard output when executing the model." OFF) option(onnxruntime_USE_DML "Build with DirectML support" OFF) option(onnxruntime_USE_ACL "Build with ACL support" OFF) +option(onnxruntime_ENABLE_INSTRUMENT "Enable Instrument with Event Tracing for Windows (ETW)" OFF) set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) #nsync tests failed on Mac Build @@ -91,6 +92,15 @@ set(ONNX_ML 1) if(NOT onnxruntime_ENABLE_PYTHON) set(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS OFF) endif() + +if(NOT WIN32) + #TODO: On Linux we may try https://github.com/microsoft/TraceLogging + if(onnxruntime_ENABLE_INSTRUMENT) + message(WARNING "Instrument is only supported on Windows now") + set(onnxruntime_ENABLE_INSTRUMENT OFF) + endif() +endif() + if(onnxruntime_USE_OPENMP) find_package(OpenMP) if (OPENMP_FOUND) diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index aea834e8ff8e9..1993858ac5331 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -10,7 +10,9 @@ file(GLOB_RECURSE onnxruntime_framework_srcs CONFIGURE_DEPENDS source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_framework_srcs}) add_library(onnxruntime_framework ${onnxruntime_framework_srcs}) - +if(onnxruntime_ENABLE_INSTRUMENT) + target_compile_definitions(onnxruntime_framework PRIVATE ONNXRUNTIME_ENABLE_INSTRUMENT) +endif() target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) onnxruntime_add_include_to_target(onnxruntime_framework onnxruntime_common onnx onnx_proto protobuf::libprotobuf) set_target_properties(onnxruntime_framework PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index a9bb53cc0e481..180dcf61893a4 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -12,6 +12,9 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_session_srcs}) add_library(onnxruntime_session ${onnxruntime_session_srcs}) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/session DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf) +if(onnxruntime_ENABLE_INSTRUMENT) + target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT) +endif() target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS}) add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 4766c43d872d0..05278e94dbd73 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -776,6 +776,17 @@ if (onnxruntime_BUILD_SERVER) endif() +#some ETW tools +if(WIN32 AND onnxruntime_ENABLE_INSTRUMENT) + add_executable(generate_perf_report_from_etl ${ONNXRUNTIME_ROOT}/tool/etw/main.cc ${ONNXRUNTIME_ROOT}/tool/etw/eparser.h ${ONNXRUNTIME_ROOT}/tool/etw/eparser.cc ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.h ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.cc) + target_compile_definitions(generate_perf_report_from_etl PRIVATE "_CONSOLE" "_UNICODE" "UNICODE") + target_link_libraries(generate_perf_report_from_etl PRIVATE tdh Advapi32) + + add_executable(compare_two_sessions ${ONNXRUNTIME_ROOT}/tool/etw/compare_two_sessions.cc ${ONNXRUNTIME_ROOT}/tool/etw/eparser.h ${ONNXRUNTIME_ROOT}/tool/etw/eparser.cc ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.h ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.cc) + target_compile_definitions(compare_two_sessions PRIVATE "_CONSOLE" "_UNICODE" "UNICODE") + target_link_libraries(compare_two_sessions PRIVATE ${GETOPT_LIB_WIDE} tdh Advapi32) +endif() + add_executable(onnxruntime_mlas_test ${TEST_SRC_DIR}/mlas/unittest.cpp) target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}) set(onnxruntime_mlas_test_libs onnxruntime_mlas onnxruntime_common) diff --git a/include/onnxruntime/core/platform/tracing.h b/include/onnxruntime/core/platform/tracing.h new file mode 100644 index 0000000000000..fb61632aae238 --- /dev/null +++ b/include/onnxruntime/core/platform/tracing.h @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +TRACELOGGING_DECLARE_PROVIDER(telemetry_provider_handle); diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 9272b18e8db3b..ba69f45faff6f 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -25,6 +25,22 @@ using namespace Concurrency; #endif +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT +#include +#include "core/platform/tracing.h" +namespace { +LARGE_INTEGER OrtGetPerformanceFrequency() { + LARGE_INTEGER v; + // On systems that run Windows XP or later, the QueryPerformanceFrequency function will always succeed + // and will thus never return zero. + (void)QueryPerformanceFrequency(&v); + return v; +} + +LARGE_INTEGER perf_freq = OrtGetPerformanceFrequency(); +} // namespace +#endif + namespace onnxruntime { static Status ReleaseNodeMLValues(ExecutionFrame& frame, @@ -87,7 +103,10 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: if (p_op_kernel == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Got nullptr from GetKernel for node: ", node.Name()); - +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT + LARGE_INTEGER kernel_start; + QueryPerformanceCounter(&kernel_start); +#endif // construct OpKernelContext // TODO: log kernel inputs? OpKernelContextInternal op_kernel_context(session_state, frame, *p_op_kernel, logger, terminate_flag_); @@ -128,7 +147,6 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: } } } - #if defined DEBUG_NODE_INPUTS_OUTPUTS utils::DumpNodeInputs(op_kernel_context, p_op_kernel->Node()); #endif @@ -202,7 +220,19 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: } } } - +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT + LARGE_INTEGER kernel_stop; + QueryPerformanceCounter(&kernel_stop); + LARGE_INTEGER elapsed; + elapsed.QuadPart = kernel_stop.QuadPart - kernel_start.QuadPart; + elapsed.QuadPart *= 1000000; + elapsed.QuadPart /= perf_freq.QuadPart; + // Log an event + TraceLoggingWrite(telemetry_provider_handle, // handle to my provider + "OpEnd", // Event Name that should uniquely identify your event. + TraceLoggingValue(p_op_kernel->KernelDef().OpName().c_str(), "op_name"), + TraceLoggingValue(elapsed.QuadPart, "time")); +#endif if (is_profiler_enabled) { session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT, p_op_kernel->Node().Name() + "_fence_after", diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index bafac8fba9f56..15fb8c32bf680 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -271,7 +271,10 @@ void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const s ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index, " for attribute ", attribute_name); } - +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT + session_state->parent_ = this; + GenerateGraphId(); +#endif subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state))); } diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index a8c0a9833c7fc..bff72c1d09945 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -268,6 +268,19 @@ class SessionState { std::unique_ptr node_index_info_; std::multimap> cached_feeds_fetches_managers_; +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT + SessionState* parent_ = nullptr; + //Assign each graph in each session an unique id. + int graph_id_ = 0; + int next_graph_id_ = 1; + + void GenerateGraphId() { + SessionState* p = this; + while (p->parent_ != nullptr) p = p->parent_; + graph_id_ = p->next_graph_id_ ++; + } + +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 6186236c0291f..da0281e643b29 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -19,6 +19,10 @@ #include "core/platform/env.h" +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT +#include "core/platform/tracing.h" +#endif + namespace onnxruntime { using namespace ::onnxruntime::common; using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 233b176ce9d6c..b9307d4dfca48 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -73,7 +73,7 @@ inline const wchar_t* GetDateFormatString() { return L"%Y-%m-%d_%H-%M-%S"; } #endif -//TODO: use LoggingManager::GetTimestamp and date::operator<< +// TODO: use LoggingManager::GetTimestamp and date::operator<< // (see ostream_sink.cc for an example) // to simplify this and match the log file timestamp format. template @@ -115,7 +115,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, insert_cast_transformer_("CastFloat16Transformer") { ORT_ENFORCE(Environment::IsInitialized(), "Environment must be initialized before creating an InferenceSession."); - InitLogger(logging_manager); session_state_.SetDataTransferMgr(&data_transfer_mgr_); @@ -144,6 +143,9 @@ InferenceSession::~InferenceSession() { LOGS(*session_logger_, ERROR) << "Unknown error during EndProfiling()"; } } +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT + if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity"); +#endif } common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr p_exec_provider) { @@ -176,8 +178,8 @@ common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr p_graph_transformer, - TransformerLevel level) { +common::Status InferenceSession::RegisterGraphTransformer( + std::unique_ptr p_graph_transformer, TransformerLevel level) { if (p_graph_transformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for graph transformer"); } @@ -185,8 +187,7 @@ common::Status InferenceSession::RegisterGraphTransformer(std::unique_ptr& transformers_to_enable) { - std::copy(transformers_to_enable.begin(), transformers_to_enable.end(), - std::back_inserter(transformers_to_enable_)); + std::copy(transformers_to_enable.begin(), transformers_to_enable.end(), std::back_inserter(transformers_to_enable_)); return Status::OK(); } @@ -213,7 +214,8 @@ common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptr&)> loader, const std::string& event_name) { +common::Status InferenceSession::Load(std::function&)> loader, + const std::string& event_name) { Status status = Status::OK(); TimePoint tp; if (session_profiler_.IsEnabled()) { @@ -223,8 +225,7 @@ common::Status InferenceSession::Load(std::function l(session_mutex_); if (is_model_loaded_) { // already loaded LOGS(*session_logger_, ERROR) << "This session already contains a loaded model."; - return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, - "This session already contains a loaded model."); + return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model."); } std::shared_ptr p_tmp_model; @@ -281,14 +282,10 @@ common::Status InferenceSession::Load(const std::basic_string& model_uri) { return Status::OK(); } -common::Status InferenceSession::Load(const std::string& model_uri) { - return Load(model_uri); -} +common::Status InferenceSession::Load(const std::string& model_uri) { return Load(model_uri); } #ifdef _WIN32 -common::Status InferenceSession::Load(const std::wstring& model_uri) { - return Load(model_uri); -} +common::Status InferenceSession::Load(const std::wstring& model_uri) { return Load(model_uri); } #endif common::Status InferenceSession::Load(const ModelProto& model_proto) { @@ -522,7 +519,6 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio const auto implicit_inputs = node.ImplicitInputDefs(); ORT_RETURN_IF_ERROR_SESSIONID_(initializer.CreatePlan(&node, &implicit_inputs, session_options_.execution_mode)); - // LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(), // &*subgraph_info.session_state); @@ -554,12 +550,14 @@ common::Status InferenceSession::Initialize() { LOGS(*session_logger_, ERROR) << "Model was not loaded"; return common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."); } - if (is_inited_) { // already initialized LOGS(*session_logger_, INFO) << "Session has already been initialized."; return common::Status::OK(); } - +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT + TraceLoggingWriteStart(session_activity, "OrtInferenceSessionActivity"); + session_activity_started_ = true; +#endif // Register default CPUExecutionProvider if user didn't provide it through the Register() calls if (!execution_providers_.Get(onnxruntime::kCpuExecutionProvider)) { LOGS(*session_logger_, INFO) << "Adding default CPU execution provider."; @@ -578,7 +576,8 @@ common::Status InferenceSession::Initialize() { } // add predefined transformers - AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, transformers_to_enable_); + AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, + transformers_to_enable_); onnxruntime::Graph& graph = model_->MainGraph(); @@ -624,7 +623,6 @@ common::Status InferenceSession::Initialize() { // handle any subgraphs ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(graph, session_state_)); is_inited_ = true; - LOGS(*session_logger_, INFO) << "Session successfully initialized."; } catch (const NotImplementedException& ex) { status = ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Exception during initialization: ", ex.what()); @@ -643,9 +641,7 @@ common::Status InferenceSession::Initialize() { return status; } -int InferenceSession::GetCurrentNumRuns() const { - return current_num_runs_.load(); -} +int InferenceSession::GetCurrentNumRuns() const { return current_num_runs_.load(); } const std::vector& InferenceSession::GetRegisteredProviderTypes() const { return execution_providers_.GetIds(); @@ -662,8 +658,7 @@ common::Status InferenceSession::CheckShapes(const std::string& input_name, auto expected_shape_sz = expected_shape.NumDimensions(); if (input_shape_sz != expected_shape_sz) { std::ostringstream ostr; - ostr << "Invalid rank for input: " << input_name - << " Got: " << input_shape_sz << " Expected: " << expected_shape_sz + ostr << "Invalid rank for input: " << input_name << " Got: " << input_shape_sz << " Expected: " << expected_shape_sz << " Please fix either the inputs or the model."; return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str()); } @@ -705,10 +700,8 @@ static common::Status CheckTypes(MLDataType actual, MLDataType expected) { common::Status InferenceSession::ValidateInputs(const std::vector& feed_names, const std::vector& feeds) const { if (feed_names.size() != feeds.size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Size mismatch: feed_names has ", - feed_names.size(), "elements, but feeds has ", - feeds.size(), " elements."); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(), + "elements, but feeds has ", feeds.size(), " elements."); } for (size_t i = 0; i < feeds.size(); ++i) { @@ -716,8 +709,7 @@ common::Status InferenceSession::ValidateInputs(const std::vector& auto iter = input_def_map_.find(feed_name); if (input_def_map_.end() == iter) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Invalid Feed Input Name:", feed_name); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Feed Input Name:", feed_name); } auto expected_type = iter->second.ml_data_type; @@ -725,8 +717,8 @@ common::Status InferenceSession::ValidateInputs(const std::vector& if (input_ml_value.IsTensor()) { // check for type if (!expected_type->IsTensorType()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", - feed_name, " is not expected to be of type tensor."); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name, + " is not expected to be of type tensor."); } auto expected_element_type = expected_type->AsTensorType()->GetElementType(); @@ -751,17 +743,14 @@ common::Status InferenceSession::ValidateInputs(const std::vector& common::Status InferenceSession::ValidateOutputs(const std::vector& output_names, const std::vector* p_fetches) const { if (p_fetches == nullptr) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Output vector pointer is NULL"); + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL"); } if (output_names.empty()) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "At least one output should be requested."); + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "At least one output should be requested."); } - if (!p_fetches->empty() && - (output_names.size() != p_fetches->size())) { + if (!p_fetches->empty() && (output_names.size() != p_fetches->size())) { std::ostringstream ostr; ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size() << "p_fetches->size(): " << p_fetches->size(); @@ -770,8 +759,7 @@ common::Status InferenceSession::ValidateOutputs(const std::vector& for (const auto& name : output_names) { if (model_output_names_.find(name) == model_output_names_.end()) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Invalid Output Name:" + name); + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid Output Name:" + name); } } @@ -787,6 +775,12 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector ortrun_activity; + ortrun_activity.SetRelatedActivity(session_activity); + TraceLoggingWriteStart(ortrun_activity, "OrtRun"); +#endif Status retval = Status::OK(); try { @@ -858,7 +852,9 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector InferenceSession::GetModelMetada std::lock_guard l(session_mutex_); if (!is_model_loaded_) { LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), - nullptr); + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); } } @@ -902,8 +897,7 @@ std::pair InferenceSession::GetModelInputs( std::lock_guard l(session_mutex_); if (!is_model_loaded_) { LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), - nullptr); + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); } } @@ -916,8 +910,7 @@ std::pair InferenceSession::GetOverridableI std::lock_guard l(session_mutex_); if (!is_model_loaded_) { LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), - nullptr); + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr); } } @@ -971,14 +964,10 @@ void InferenceSession::StartProfiling(const std::basic_string& file_prefix) { session_profiler_.StartProfiling(ss.str()); } -void InferenceSession::StartProfiling(const std::string& file_prefix) { - StartProfiling(file_prefix); -} +void InferenceSession::StartProfiling(const std::string& file_prefix) { StartProfiling(file_prefix); } #ifdef _WIN32 -void InferenceSession::StartProfiling(const std::wstring& file_prefix) { - StartProfiling(file_prefix); -} +void InferenceSession::StartProfiling(const std::wstring& file_prefix) { StartProfiling(file_prefix); } #endif void InferenceSession::StartProfiling(const logging::Logger* logger_ptr) { @@ -1026,11 +1015,11 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod for (auto elem : inputs) { auto elem_type = utils::GetMLDataType(*elem); auto elem_shape_proto = elem->Shape(); - input_def_map_.insert({elem->Name(), InputDefMetaData(elem, - elem_type, - elem_shape_proto - ? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto) - : TensorShape())}); + input_def_map_.insert( + {elem->Name(), + InputDefMetaData( + elem, elem_type, + elem_shape_proto ? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto) : TensorShape())}); } }; @@ -1086,10 +1075,7 @@ const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& ru severity = static_cast(run_options.run_log_severity_level); } - new_run_logger = logging_manager_->CreateLogger(run_log_id, - severity, - false, - run_options.run_log_verbosity_level); + new_run_logger = logging_manager_->CreateLogger(run_log_id, severity, false, run_options.run_log_verbosity_level); run_logger = new_run_logger.get(); VLOGS(*run_logger, 1) << "Created logger for run with id of " << run_log_id; @@ -1116,9 +1102,7 @@ void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { severity = static_cast(session_options_.session_log_severity_level); } - owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, - severity, - false, + owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, severity, false, session_options_.session_log_verbosity_level); session_logger_ = owned_session_logger_.get(); } else { diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index a11d4e8c6b89b..791d61e62e6f3 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -24,6 +24,10 @@ #ifdef ENABLE_LANGUAGE_INTEROP_OPS #include "core/language_interop_ops/language_interop_ops.h" #endif +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT +#include "core/platform/tracing.h" +#include +#endif namespace onnxruntime { // forward declarations class GraphTransformer; @@ -434,7 +438,6 @@ class InferenceSession { #ifdef ENABLE_LANGUAGE_INTEROP_OPS InterOpDomains interop_domains_; #endif - // used to support platform telemetry static std::atomic global_session_id_; // a monotonically increasing session id uint32_t session_id_; // the current session's id @@ -442,5 +445,10 @@ class InferenceSession { long long total_run_duration_since_last_; // the total duration (us) of Run() calls since the last report TimePoint time_sent_last_; // the TimePoint of the last report const long long kDurationBetweenSending = 1000* 1000 * 60 * 10; // duration in (us). send a report every 10 mins + +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT + bool session_activity_started_ = false; + TraceLoggingActivity session_activity; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/tool/etw/TraceSession.cc b/onnxruntime/tool/etw/TraceSession.cc new file mode 100644 index 0000000000000..1275906dbab58 --- /dev/null +++ b/onnxruntime/tool/etw/TraceSession.cc @@ -0,0 +1,306 @@ +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License (MIT). +// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF +// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY +// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR +// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. +// +//********************************************************* + +#include "TraceSession.h" + +namespace { + +VOID WINAPI EventRecordCallback(EVENT_RECORD* pEventRecord) +{ + auto session = (TraceSession*) pEventRecord->UserContext; + auto const& hdr = pEventRecord->EventHeader; + + if (session->startTime_ == 0) { + session->startTime_ = hdr.TimeStamp.QuadPart; + } + + auto iter = session->eventHandler_.find(hdr.ProviderId); + if (iter != session->eventHandler_.end()) { + auto const& h = iter->second; + (*h.fn_)(pEventRecord, h.ctxt_); + } +} + +ULONG WINAPI BufferCallback(EVENT_TRACE_LOGFILE* pLogFile) +{ + auto session = (TraceSession*) pLogFile->Context; + auto shouldStopFn = session->shouldStopProcessingEventsFn_; + if (shouldStopFn && (*shouldStopFn)()) { + return FALSE; // break out of ProcessTrace() + } + + return TRUE; // continue processing events +} + +bool OpenLogger( + TraceSession* session, + TCHAR const* name, + bool realtime) +{ + // Open trace + EVENT_TRACE_LOGFILE loggerInfo = {}; + /* Filled out below based on realtime: + loggerInfo.LogFileName = nullptr; + loggerInfo.LoggerName = nullptr; + */ + loggerInfo.ProcessTraceMode = PROCESS_TRACE_MODE_EVENT_RECORD | PROCESS_TRACE_MODE_RAW_TIMESTAMP; + loggerInfo.BufferCallback = BufferCallback; + loggerInfo.EventRecordCallback = EventRecordCallback; + loggerInfo.Context = session; + /* Output members (passed also to BufferCallback()): + loggerInfo.CurrentTime + loggerInfo.BuffersRead + loggerInfo.CurrentEvent + loggerInfo.LogfileHeader + loggerInfo.BufferSize + loggerInfo.Filled + loggerInfo.IsKernelTrace + */ + /* Not used: + loggerInfo.EventsLost + */ + + if (realtime) { + loggerInfo.LoggerName = const_cast(name); + loggerInfo.ProcessTraceMode |= PROCESS_TRACE_MODE_REAL_TIME; + } else { + loggerInfo.LogFileName = const_cast(name); + } + + session->traceHandle_ = OpenTrace(&loggerInfo); + if (session->traceHandle_ == INVALID_PROCESSTRACE_HANDLE) { + fprintf(stderr, "error: failed to open trace"); + auto lastError = GetLastError(); + switch (lastError) { + case ERROR_INVALID_PARAMETER: fprintf(stderr, " (Logfile is NULL)"); break; + case ERROR_BAD_PATHNAME: fprintf(stderr, " (invalid LoggerName)"); break; + case ERROR_ACCESS_DENIED: fprintf(stderr, " (access denied)"); break; + default: fprintf(stderr, " (error=%u)", lastError); break; + } + fprintf(stderr, ".\n"); + return false; + } + + // Copy desired state from loggerInfo + session->frequency_ = loggerInfo.LogfileHeader.PerfFreq.QuadPart; + return true; +} + +} + +size_t TraceSession::GUIDHash::operator()(GUID const& g) const +{ + static_assert((sizeof(g) % sizeof(size_t)) == 0, "sizeof(GUID) must be multiple of sizeof(size_t)"); + auto p = (size_t const*) &g; + auto h = (size_t) 0; + for (size_t i = 0; i < sizeof(g) / sizeof(size_t); ++i) { + h ^= p[i]; + } + return h; +} + +bool TraceSession::GUIDEqual::operator()(GUID const& lhs, GUID const& rhs) const +{ + return IsEqualGUID(lhs, rhs) != FALSE; +} + +bool TraceSession::AddProvider(GUID providerId, UCHAR level, + ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword) +{ + auto p = eventProvider_.emplace(std::make_pair(providerId, Provider())); + if (!p.second) { + return false; + } + + auto h = &p.first->second; + h->matchAny_ = matchAnyKeyword; + h->matchAll_ = matchAllKeyword; + h->level_ = level; + return true; +} + +bool TraceSession::AddHandler(GUID providerId, EventHandlerFn handlerFn, void* handlerContext) +{ + auto p = eventHandler_.emplace(std::make_pair(providerId, Handler())); + if (!p.second) { + return false; + } + + auto h = &p.first->second; + h->fn_ = handlerFn; + h->ctxt_ = handlerContext; + return true; +} + +bool TraceSession::AddProviderAndHandler(GUID providerId, UCHAR level, + ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword, + EventHandlerFn handlerFn, void* handlerContext) +{ + if (!AddProvider(providerId, level, matchAnyKeyword, matchAllKeyword)) + return false; + if (!AddHandler(providerId, handlerFn, handlerContext)) { + RemoveProvider(providerId); + return false; + } + return true; +} + +bool TraceSession::RemoveProvider(GUID providerId) +{ + if (sessionHandle_ != 0) { + auto status = EnableTraceEx2(sessionHandle_, &providerId, EVENT_CONTROL_CODE_DISABLE_PROVIDER, 0, 0, 0, 0, nullptr); + (void) status; + } + + return eventProvider_.erase(providerId) != 0; +} + +bool TraceSession::RemoveHandler(GUID providerId) +{ + return eventHandler_.erase(providerId) != 0; +} + +bool TraceSession::RemoveProviderAndHandler(GUID providerId) +{ + return RemoveProvider(providerId) || RemoveHandler(providerId); +} + +bool TraceSession::InitializeEtlFile(TCHAR const* inputEtlPath, ShouldStopProcessingEventsFn shouldStopFn) +{ + // Open the trace + if (!OpenLogger(this, inputEtlPath, false)) { + Finalize(); + return false; + } + + // Initialize state + shouldStopProcessingEventsFn_ = shouldStopFn; + eventsLostCount_ = 0; + buffersLostCount_ = 0; + return true; +} + +bool TraceSession::InitializeRealtime(TCHAR const* traceSessionName, ShouldStopProcessingEventsFn shouldStopFn) +{ + // Set up and start a real-time collection session + memset(&properties_, 0, sizeof(properties_)); + + properties_.Wnode.BufferSize = (ULONG) offsetof(TraceSession, sessionHandle_); + //properties_.Wnode.Guid // ETW will create Guid + properties_.Wnode.ClientContext = 1; // Clock resolution to use when logging the timestamp for each event + // 1 == query performance counter + properties_.Wnode.Flags = 0; + //properties_.BufferSize = 0; + properties_.MinimumBuffers = 200; + //properties_.MaximumBuffers = 0; + //properties_.MaximumFileSize = 0; + properties_.LogFileMode = EVENT_TRACE_REAL_TIME_MODE; + //properties_.FlushTimer = 0; + //properties_.EnableFlags = 0; + properties_.LogFileNameOffset = 0; + properties_.LoggerNameOffset = offsetof(TraceSession, loggerName_); + + auto status = StartTrace(&sessionHandle_, traceSessionName, &properties_); + if (status == ERROR_ALREADY_EXISTS) { +#ifdef _DEBUG + fprintf(stderr, "warning: trying to start trace session that already exists.\n"); +#endif + status = ControlTrace((TRACEHANDLE) 0, traceSessionName, &properties_, EVENT_TRACE_CONTROL_STOP); + if (status == ERROR_SUCCESS) { + status = StartTrace(&sessionHandle_, traceSessionName, &properties_); + } + } + if (status != ERROR_SUCCESS) { + fprintf(stderr, "error: failed to start trace session (error=%lu).\n", status); + return false; + } + + // Enable desired providers + for (auto const& p : eventProvider_) { + auto pGuid = &p.first; + auto const& h = p.second; + + status = EnableTraceEx2(sessionHandle_, pGuid, EVENT_CONTROL_CODE_ENABLE_PROVIDER, h.level_, h.matchAny_, h.matchAll_, 0, nullptr); + if (status != ERROR_SUCCESS) { + fprintf(stderr, "error: failed to enable provider {%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x}.\n", + pGuid->Data1, pGuid->Data2, pGuid->Data3, pGuid->Data4[0], pGuid->Data4[1], pGuid->Data4[2], + pGuid->Data4[3], pGuid->Data4[4], pGuid->Data4[5], pGuid->Data4[6], pGuid->Data4[7]); + Finalize(); + return false; + } + } + + // Open the trace + if (!OpenLogger(this, traceSessionName, true)) { + Finalize(); + return false; + } + + // Initialize state + shouldStopProcessingEventsFn_ = shouldStopFn; + eventsLostCount_ = 0; + buffersLostCount_ = 0; + + return true; +} + +void TraceSession::Finalize() +{ + ULONG status = ERROR_SUCCESS; + + if (traceHandle_ != INVALID_PROCESSTRACE_HANDLE) { + status = CloseTrace(traceHandle_); + traceHandle_ = INVALID_PROCESSTRACE_HANDLE; + } + + if (sessionHandle_ != 0) { + status = ControlTraceW(sessionHandle_, nullptr, &properties_, EVENT_TRACE_CONTROL_STOP); + + while (!eventProvider_.empty()) { + RemoveProvider(eventProvider_.begin()->first); + } + while (!eventHandler_.empty()) { + RemoveHandler(eventHandler_.begin()->first); + } + + sessionHandle_ = 0; + } +} + +bool TraceSession::CheckLostReports(uint32_t* eventsLost, uint32_t* buffersLost) +{ + if (sessionHandle_ == 0) { + *eventsLost = 0; + *buffersLost = 0; + return false; + } + + auto status = ControlTraceW(sessionHandle_, nullptr, &properties_, EVENT_TRACE_CONTROL_QUERY); + if (status == ERROR_MORE_DATA) { // The buffer &properties_ is too small to hold all the information + *eventsLost = 0; // for the session. If you don't need the session's property information + *buffersLost = 0; // you can ignore this error. + return false; + } + + if (status != ERROR_SUCCESS) { + fprintf(stderr, "error: failed to query trace status (%lu).\n", status); + *eventsLost = 0; + *buffersLost = 0; + return false; + } + + *eventsLost = properties_.EventsLost - eventsLostCount_; + *buffersLost = properties_.RealTimeBuffersLost - buffersLostCount_; + eventsLostCount_ = properties_.EventsLost; + buffersLostCount_ = properties_.RealTimeBuffersLost; + return *eventsLost + *buffersLost > 0; +} + diff --git a/onnxruntime/tool/etw/TraceSession.h b/onnxruntime/tool/etw/TraceSession.h new file mode 100644 index 0000000000000..4592609500465 --- /dev/null +++ b/onnxruntime/tool/etw/TraceSession.h @@ -0,0 +1,99 @@ +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License (MIT). +// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF +// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY +// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR +// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. +// +//********************************************************* + +#pragma once + +#include +#include +#include // must be after windows.h +#include +#include + +typedef void (*EventHandlerFn)(EVENT_RECORD* pEventRecord, void* pContext); +typedef bool (*ShouldStopProcessingEventsFn)(); + +struct TraceSession { + // BEGIN trace property block, must be beginning of TraceSession + EVENT_TRACE_PROPERTIES properties_; + wchar_t loggerName_[MAX_PATH]; + // END Trace property block + + TRACEHANDLE sessionHandle_; // Must be first member after trace property block + TRACEHANDLE traceHandle_; + ShouldStopProcessingEventsFn shouldStopProcessingEventsFn_; + uint64_t startTime_; + uint64_t frequency_; + uint32_t eventsLostCount_; + uint32_t buffersLostCount_; + + // Structure to hold the mapping from provider ID to event handler function + struct GUIDHash { size_t operator()(GUID const& g) const; }; + struct GUIDEqual { bool operator()(GUID const& lhs, GUID const& rhs) const; }; + struct Provider { + ULONGLONG matchAny_; + ULONGLONG matchAll_; + UCHAR level_; + }; + struct Handler { + EventHandlerFn fn_; + void* ctxt_; + }; + std::unordered_map eventProvider_; + std::unordered_map eventHandler_; + + TraceSession() + : sessionHandle_(0) + , traceHandle_(INVALID_PROCESSTRACE_HANDLE) + , startTime_(0) + , frequency_(0) + , shouldStopProcessingEventsFn_(nullptr) + { + } + + // Usage: + // + // 1) use TraceSession::AddProvider() to add the IDs for all the providers + // you want to trace. Use TraceSession::AddHandler() to add the handler + // functions for the providers/events you want to trace. + // + // 2) call TraceSession::InitializeRealtime() or + // TraceSession::InitializeEtlFile(), to start tracing events from + // real-time collection or from a previously-captured .etl file. At this + // point, events start to be traced. + // + // 3) call ::ProcessTrace() to start collecting the events; provider + // handler functions will be called as those provider events are collected. + // ProcessTrace() will exit when shouldStopProcessingEventsFn_ returns + // true, or when the .etl file is fully consumed. + // + // 4) Finalize() to clean up. + + // AddProvider/Handler() returns false if the providerId already has a handler. + // RemoveProvider/Handler() returns false if the providerId don't have a handler. + bool AddProvider(GUID providerId, UCHAR level, ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword); + bool AddHandler(GUID handlerId, EventHandlerFn handlerFn, void* handlerContext); + bool AddProviderAndHandler(GUID providerId, UCHAR level, ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword, + EventHandlerFn handlerFn, void* handlerContext); + bool RemoveProvider(GUID providerId); + bool RemoveHandler(GUID handlerId); + bool RemoveProviderAndHandler(GUID providerId); + + // InitializeRealtime() and InitializeEtlFile() return false if the session + // could not be created. + bool InitializeEtlFile(TCHAR const* etlPath, ShouldStopProcessingEventsFn shouldStopProcessingEventsFn); + bool InitializeRealtime(TCHAR const* traceSessionName, ShouldStopProcessingEventsFn shouldStopProcessingEventsFn); + void Finalize(); + + // Call CheckLostReports() at any time the session is initialized to query + // how many events and buffers have been lost while tracing. + bool CheckLostReports(uint32_t* eventsLost, uint32_t* buffersLost); +}; + diff --git a/onnxruntime/tool/etw/compare_two_sessions.cc b/onnxruntime/tool/etw/compare_two_sessions.cc new file mode 100644 index 0000000000000..29cbe568b596f --- /dev/null +++ b/onnxruntime/tool/etw/compare_two_sessions.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "eparser.h" +#include "TraceSession.h" + +#ifdef _WIN32 +#include +#else +#define TCHAR char +#define _tmain main +#endif + +#ifdef _WIN32 +#include "getopt.h" +#else +#include +#include +#endif + +static const GUID OrtProviderGuid = {0x54d81939, 0x62a0, 0x4dc0, {0xbf, 0x32, 0x3, 0x5e, 0xbd, 0xc7, 0xbc, 0xe9}}; + +int fetch_data(TCHAR* filename, ProfilingInfo& context) { + TraceSession session; + session.AddHandler(OrtProviderGuid, OrtEventHandler, &context); + session.InitializeEtlFile(filename, nullptr); + ULONG status = ProcessTrace(&session.traceHandle_, 1, 0, 0); + if (status != ERROR_SUCCESS && status != ERROR_CANCELLED) { + std::cout << "OpenTrace failed with " << status << std::endl; + session.Finalize(); + return -1; + } + session.Finalize(); + return 0; +} + +template +std::pair CalcMeanAndStdSquare(const T* input, size_t input_len) { + T sum = 0; + T sum_square = 0; + const size_t N = input_len; + for (size_t i = 0; i != N; ++i) { + T t = input[i]; + sum += t; + sum_square += t * t; + } + double mean = ((double)sum) / N; + double std = (sum_square - N * mean * mean) / (N - 1); + return std::make_pair(mean, std); +} +// see: "Statistical Distributions", 4th Edition, by Catherine Forbes, Merran Evans, Nicholas Hastings and Brian +// Peacock. Chapter 42: "Student’s t Distribution". I only implemented when v is even. +double TDistributionCDF(int v, double x) { + assert(v >= 2 && (v & 1) == 0); + double t = x / (2 * std::sqrt(v + x * x)); + double sum1 = 0; + double b_j = 1; + for (int j = 0; j <= (v - 2) / 2; ++j) { + sum1 += b_j / std::pow(1 + x * x / v, j); + b_j *= static_cast(2 * j + 1) / (2 * j + 2); + } + return 0.5 + t * sum1; +} + +struct TTestResult { + double mean1, mean2; + double std1, std2; + double tvalue; +}; + +template +TTestResult CalcTValue(const T* input1, size_t input1_len, const T* input2, size_t input2_len) { + TTestResult result; + auto p1 = CalcMeanAndStdSquare(input1, input1_len); + result.mean1 = p1.first; + result.std1 = std::sqrt(p1.second); + auto p2 = CalcMeanAndStdSquare(input2, input2_len); + result.mean2 = p2.first; + result.std2 = std::sqrt(p2.second); + auto diff_mean = p1.first - p2.first; + size_t n1 = input1_len; + size_t n2 = input2_len; + auto sdiff = ((n1 - 1) * p1.second + (n2 - 1) * p2.second) / (n1 + n2 - 2); + sdiff *= ((double)1) / n1 + ((double)1) / n2; + result.tvalue = diff_mean / std::sqrt(sdiff); + return result; +} + +int real_main(int argc, TCHAR* argv[]) { + if (argc < 3) { + printf("error\n"); + return -1; + } + + ProfilingInfo context1; + int ret = fetch_data(argv[1], context1); + if (ret != 0) return ret; + ProfilingInfo context2; + ret = fetch_data(argv[2], context2); + if (ret != 0) return ret; + size_t n1 = context1.time_per_run.size(); + size_t n2 = context2.time_per_run.size(); + if (n1 <= 10 || n2 <= 10) { + printf("samples are too few, please try to gather more\n"); + return -1; + } + // ignore the first run + --n1; + --n2; + if (((n1 + n2) & 1) != 0) { + if (n1 > n2) + n1--; + else + n2--; + } + TTestResult tresult = CalcTValue(context1.time_per_run.data() + 1, n1, context2.time_per_run.data() + 1, n2); + size_t freedom = n1 + n2 - 2; + double p = TDistributionCDF(static_cast(freedom), std::abs(tresult.tvalue)); + std::cout << "Mean1: " << tresult.mean1 << " std1: " << tresult.std1 << "\n" + << "Mean2: " << tresult.mean2 << " std2: " << tresult.std2 << "\n" + << "H0: Mean1 = Mean2\n" + << "H1: Mean1 != Mean2\n" + << "Test statistic: T = " << tresult.tvalue << "\n" + << "Degrees of Freedom: v = " << freedom << "\n" + << "Significance level:" << (1 - p) * 2 << ". The lower the more likely to reject H0\n"; + if (p > 0.99995) { + std::cout << "The two population means are different at the 0.0001 significance level." << std::endl; + return -1; + } else { + std::cout << "They don't have significant statistical difference." << std::endl; + return 0; + } +} + +int _tmain(int argc, TCHAR* argv[]) { + int retval = -1; + try { + retval = real_main(argc, argv); + } catch (std::exception& ex) { + fprintf(stderr, "%s\n", ex.what()); + retval = -1; + } + return retval; +} \ No newline at end of file diff --git a/onnxruntime/tool/etw/eparser.cc b/onnxruntime/tool/etw/eparser.cc new file mode 100644 index 0000000000000..883a1aef3e401 --- /dev/null +++ b/onnxruntime/tool/etw/eparser.cc @@ -0,0 +1,355 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "eparser.h" + +// Get the metadata for the event. + +// Get the length of the property data. For MOF-based events, the size is inferred from the data type +// of the property. For manifest-based events, the property can specify the size of the property value +// using the length attribute. The length attribue can specify the size directly or specify the name +// of another property in the event data that contains the size. If the property does not include the +// length attribute, the size is inferred from the data type. The length will be zero for variable +// length, null-terminated strings and structures. + +DWORD GetPropertyLength(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT PropertyLength); + +// Get the size of the array. For MOF-based events, the size is specified in the declaration or using +// the MAX qualifier. For manifest-based events, the property can specify the size of the array +// using the count attribute. The count attribue can specify the size directly or specify the name +// of another property in the event data that contains the size. + +DWORD GetArraySize(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT ArraySize); + +// Both MOF-based events and manifest-based events can specify name/value maps. The +// map values can be integer values or bit values. If the property specifies a value +// map, get the map. + +DWORD GetMapInfo(PEVENT_RECORD pEvent, LPWSTR pMapName, DWORD DecodingSource, PEVENT_MAP_INFO& pMapInfo); + +// Print the property. +template +PBYTE PrintProperties(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, DWORD PointerSize, USHORT i, PBYTE pUserData, + PBYTE pEndOfUserData, const T& t) { + TDHSTATUS status = ERROR_SUCCESS; + USHORT PropertyLength = 0; + DWORD FormattedDataSize = 0; + USHORT UserDataConsumed = 0; + LPWSTR pFormattedData = NULL; + DWORD LastMember = 0; // Last member of a structure + USHORT ArraySize = 0; + PEVENT_MAP_INFO pMapInfo = NULL; + + // Get the length of the property. + + status = GetPropertyLength(pEvent, pInfo, i, &PropertyLength); + if (ERROR_SUCCESS != status) { + wprintf(L"GetPropertyLength failed.\n"); + pUserData = NULL; + goto cleanup; + } + + // Get the size of the array if the property is an array. + + status = GetArraySize(pEvent, pInfo, i, &ArraySize); + + for (USHORT k = 0; k < ArraySize; k++) { + // If the property is a structure, print the members of the structure. + + if ((pInfo->EventPropertyInfoArray[i].Flags & PropertyStruct) == PropertyStruct) { + LastMember = pInfo->EventPropertyInfoArray[i].structType.StructStartIndex + + pInfo->EventPropertyInfoArray[i].structType.NumOfStructMembers; + + for (USHORT j = pInfo->EventPropertyInfoArray[i].structType.StructStartIndex; j < LastMember; j++) { + pUserData = PrintProperties(pEvent, pInfo, PointerSize, j, pUserData, pEndOfUserData, t); + if (NULL == pUserData) { + wprintf(L"Printing the members of the structure failed.\n"); + pUserData = NULL; + goto cleanup; + } + } + } else { + // Get the name/value mapping if the property specifies a value map. + + status = + GetMapInfo(pEvent, (PWCHAR)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[i].nonStructType.MapNameOffset), + pInfo->DecodingSource, pMapInfo); + + if (ERROR_SUCCESS != status) { + wprintf(L"GetMapInfo failed\n"); + pUserData = NULL; + goto cleanup; + } + + // Get the size of the buffer required for the formatted data. + + status = TdhFormatProperty(pInfo, pMapInfo, PointerSize, pInfo->EventPropertyInfoArray[i].nonStructType.InType, + pInfo->EventPropertyInfoArray[i].nonStructType.OutType, PropertyLength, + (USHORT)(pEndOfUserData - pUserData), pUserData, &FormattedDataSize, pFormattedData, + &UserDataConsumed); + + if (ERROR_INSUFFICIENT_BUFFER == status) { + if (pFormattedData) { + free(pFormattedData); + pFormattedData = NULL; + } + + pFormattedData = (LPWSTR)malloc(FormattedDataSize); + if (pFormattedData == NULL) { + wprintf(L"Failed to allocate memory for formatted data (size=%lu).\n", FormattedDataSize); + status = ERROR_OUTOFMEMORY; + pUserData = NULL; + goto cleanup; + } + + // Retrieve the formatted data. + + status = TdhFormatProperty(pInfo, pMapInfo, PointerSize, pInfo->EventPropertyInfoArray[i].nonStructType.InType, + pInfo->EventPropertyInfoArray[i].nonStructType.OutType, PropertyLength, + (USHORT)(pEndOfUserData - pUserData), pUserData, &FormattedDataSize, pFormattedData, + &UserDataConsumed); + } + + if (ERROR_SUCCESS == status) { + t((PWCHAR)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[i].NameOffset), pFormattedData); + pUserData += UserDataConsumed; + } else { + wprintf(L"TdhFormatProperty failed with %lu.\n", status); + pUserData = NULL; + goto cleanup; + } + } + } + +cleanup: + + if (pFormattedData) { + free(pFormattedData); + pFormattedData = NULL; + } + + if (pMapInfo) { + free(pMapInfo); + pMapInfo = NULL; + } + + return pUserData; +} + +DWORD GetPropertyLength(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT PropertyLength) { + DWORD status = ERROR_SUCCESS; + PROPERTY_DATA_DESCRIPTOR DataDescriptor; + DWORD PropertySize = 0; + + // If the property is a binary blob and is defined in a manifest, the property can + // specify the blob's size or it can point to another property that defines the + // blob's size. The PropertyParamLength flag tells you where the blob's size is defined. + + if ((pInfo->EventPropertyInfoArray[i].Flags & PropertyParamLength) == PropertyParamLength) { + DWORD Length = 0; // Expects the length to be defined by a UINT16 or UINT32 + DWORD j = pInfo->EventPropertyInfoArray[i].lengthPropertyIndex; + ZeroMemory(&DataDescriptor, sizeof(PROPERTY_DATA_DESCRIPTOR)); + DataDescriptor.PropertyName = (ULONGLONG)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[j].NameOffset); + DataDescriptor.ArrayIndex = ULONG_MAX; + status = TdhGetPropertySize(pEvent, 0, NULL, 1, &DataDescriptor, &PropertySize); + status = TdhGetProperty(pEvent, 0, NULL, 1, &DataDescriptor, PropertySize, (PBYTE)&Length); + *PropertyLength = (USHORT)Length; + } else { + if (pInfo->EventPropertyInfoArray[i].length > 0) { + *PropertyLength = pInfo->EventPropertyInfoArray[i].length; + } else { + // If the property is a binary blob and is defined in a MOF class, the extension + // qualifier is used to determine the size of the blob. However, if the extension + // is IPAddrV6, you must set the PropertyLength variable yourself because the + // EVENT_PROPERTY_INFO.length field will be zero. + + if (TDH_INTYPE_BINARY == pInfo->EventPropertyInfoArray[i].nonStructType.InType && + TDH_OUTTYPE_IPV6 == pInfo->EventPropertyInfoArray[i].nonStructType.OutType) { + *PropertyLength = (USHORT)sizeof(IN6_ADDR); + } else if (TDH_INTYPE_UNICODESTRING == pInfo->EventPropertyInfoArray[i].nonStructType.InType || + TDH_INTYPE_ANSISTRING == pInfo->EventPropertyInfoArray[i].nonStructType.InType || + (pInfo->EventPropertyInfoArray[i].Flags & PropertyStruct) == PropertyStruct) { + *PropertyLength = pInfo->EventPropertyInfoArray[i].length; + } else { + wprintf(L"Unexpected length of 0 for intype %d and outtype %d\n", + pInfo->EventPropertyInfoArray[i].nonStructType.InType, + pInfo->EventPropertyInfoArray[i].nonStructType.OutType); + + status = ERROR_EVT_INVALID_EVENT_DATA; + goto cleanup; + } + } + } + +cleanup: + + return status; +} + +DWORD GetArraySize(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT ArraySize) { + DWORD status = ERROR_SUCCESS; + PROPERTY_DATA_DESCRIPTOR DataDescriptor; + DWORD PropertySize = 0; + + if ((pInfo->EventPropertyInfoArray[i].Flags & PropertyParamCount) == PropertyParamCount) { + DWORD Count = 0; // Expects the count to be defined by a UINT16 or UINT32 + DWORD j = pInfo->EventPropertyInfoArray[i].countPropertyIndex; + ZeroMemory(&DataDescriptor, sizeof(PROPERTY_DATA_DESCRIPTOR)); + DataDescriptor.PropertyName = (ULONGLONG)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[j].NameOffset); + DataDescriptor.ArrayIndex = ULONG_MAX; + status = TdhGetPropertySize(pEvent, 0, NULL, 1, &DataDescriptor, &PropertySize); + status = TdhGetProperty(pEvent, 0, NULL, 1, &DataDescriptor, PropertySize, (PBYTE)&Count); + *ArraySize = (USHORT)Count; + } else { + *ArraySize = pInfo->EventPropertyInfoArray[i].count; + } + + return status; +} + +DWORD GetMapInfo(PEVENT_RECORD pEvent, LPWSTR pMapName, DWORD DecodingSource, PEVENT_MAP_INFO& pMapInfo) { + DWORD status = ERROR_SUCCESS; + DWORD MapSize = 0; + + // Retrieve the required buffer size for the map info. + + status = TdhGetEventMapInformation(pEvent, pMapName, pMapInfo, &MapSize); + + if (ERROR_INSUFFICIENT_BUFFER == status) { + pMapInfo = (PEVENT_MAP_INFO)malloc(MapSize); + if (pMapInfo == NULL) { + wprintf(L"Failed to allocate memory for map info (size=%lu).\n", MapSize); + status = ERROR_OUTOFMEMORY; + goto cleanup; + } + + // Retrieve the map info. + + status = TdhGetEventMapInformation(pEvent, pMapName, pMapInfo, &MapSize); + } + + if (ERROR_SUCCESS == status) { + if (DecodingSourceXMLFile == DecodingSource) { + abort(); + } + } else { + if (ERROR_NOT_FOUND == status) { + status = ERROR_SUCCESS; // This case is okay. + } else { + wprintf(L"TdhGetEventMapInformation failed with 0x%x.\n", status); + } + } + +cleanup: + + return status; +} + +LoggingEventRecord LoggingEventRecord::CreateLoggingEventRecord(EVENT_RECORD* pEvent, DWORD& status) { + LoggingEventRecord ret; + ret.event_record_ = pEvent; + status = ERROR_SUCCESS; + DWORD BufferSize = 0; + + // Retrieve the required buffer size for the event metadata. + + status = TdhGetEventInformation(pEvent, 0, NULL, nullptr, &BufferSize); + + if (ERROR_INSUFFICIENT_BUFFER != status) return ret; + ret.buffer_.resize(BufferSize); + // Retrieve the event metadata. + status = TdhGetEventInformation(pEvent, 0, NULL, ret.GetEventInfo(), &BufferSize); + return ret; +} + +void OrtEventHandler(EVENT_RECORD* pEvent, void* pContext) { + ProfilingInfo& info = *(ProfilingInfo*)pContext; + DWORD status = ERROR_SUCCESS; + + LoggingEventRecord record = LoggingEventRecord::CreateLoggingEventRecord(pEvent, status); + if (ERROR_SUCCESS != status) { + if (status == ERROR_NOT_FOUND) return; + wprintf(L"GetEventInformation failed with %lu\n", status); + abort(); + } + DWORD PointerSize = 0; + if (EVENT_HEADER_FLAG_32_BIT_HEADER == (pEvent->EventHeader.Flags & EVENT_HEADER_FLAG_32_BIT_HEADER)) { + PointerSize = 4; + } else { + PointerSize = 8; + } + + PTRACE_EVENT_INFO pInfo = record.GetEventInfo(); + const wchar_t* name = record.GetTaskName(); + if (wcscmp(name, L"OpEnd") == 0) { + if (!info.session_started || info.session_ended) return; + PBYTE pUserData = (PBYTE)pEvent->UserData; + PBYTE pEndOfUserData = (PBYTE)pEvent->UserData + pEvent->UserDataLength; + + // Print the event data for all the top-level properties. Metadata for all the + // top-level properties come before structure member properties in the + // property information array. + std::wstring opname; + long time_spent_in_this_op = 0; + for (USHORT i = 0; i < pInfo->TopLevelPropertyCount; i++) { + pUserData = PrintProperties(pEvent, pInfo, PointerSize, i, pUserData, pEndOfUserData, + [&opname, &time_spent_in_this_op](const wchar_t* key, wchar_t* value) { + if (wcscmp(key, L"op_name") == 0) { + opname = value; + } else if (wcscmp(key, L"time") == 0) { + time_spent_in_this_op = wcstol(value, nullptr, 10); + } else { + wprintf(key); + abort(); + } + }); + if (NULL == pUserData) { + wprintf(L"Printing top level properties failed.\n"); + abort(); + } + } + auto iter = info.op_stat.find(opname); + if (iter == info.op_stat.end()) { + OpStat s; + s.name = opname; + s.count = 1; + s.total_time = time_spent_in_this_op; + info.op_stat[opname] = s; + } else { + OpStat& s = iter->second; + ++s.count; + s.total_time += time_spent_in_this_op; + } + } else if (wcscmp(name, L"OrtRun") == 0) { + if (!info.session_started || info.session_ended) return; + if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_START) { + info.op_start_time = pEvent->EventHeader.TimeStamp; + ++info.ortrun_count; + } else if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_END) { + if (pEvent->EventHeader.TimeStamp.QuadPart < info.op_start_time.QuadPart) { + throw std::runtime_error("time error"); + } + info.time_per_run.push_back(pEvent->EventHeader.TimeStamp.QuadPart - info.op_start_time.QuadPart); + ++info.ortrun_end_count; + } else { + abort(); + } + } + + else if (wcscmp(name, L"OrtInferenceSessionActivity") == 0) { + if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_START) { + info.session_started = true; + } else if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_END) { + info.session_ended = true; + } else { + abort(); + } + + printf("OrtInferenceSessionActivity\n"); + } else if (wcscmp(name, L"NodeNameMapping") == 0) { + // ignore + } else { + wprintf(L"unknown event:%s\n", name); + abort(); + } +} diff --git a/onnxruntime/tool/etw/eparser.h b/onnxruntime/tool/etw/eparser.h new file mode 100644 index 0000000000000..4e7199b57ec95 --- /dev/null +++ b/onnxruntime/tool/etw/eparser.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void OrtEventHandler(EVENT_RECORD* pEventRecord, void* pContext); + +class LoggingEventRecord { + private: + std::vector buffer_; + EVENT_RECORD* event_record_; + + public: + const TRACE_EVENT_INFO* GetEventInfo() const { return (const TRACE_EVENT_INFO*)buffer_.data(); } + TRACE_EVENT_INFO* GetEventInfo() { return (TRACE_EVENT_INFO*)buffer_.data(); } + + const wchar_t* GetTaskName() const { + const TRACE_EVENT_INFO* p = GetEventInfo(); + return (const wchar_t*)(buffer_.data() + p->TaskNameOffset); + } + + static LoggingEventRecord CreateLoggingEventRecord(EVENT_RECORD* pEvent, DWORD& status); +}; + +struct OpStat { + std::wstring name; + size_t count = 0; + uint64_t total_time = 0; +}; + +struct ProfilingInfo { + int ortrun_count = 0; + int ortrun_end_count = 0; + int session_count = 0; + bool session_started = false; + bool session_ended = false; + LARGE_INTEGER op_start_time; + + std::unordered_map op_stat; + std::vector time_per_run; +}; + + diff --git a/onnxruntime/tool/etw/main.cc b/onnxruntime/tool/etw/main.cc new file mode 100644 index 0000000000000..d11e4bf024af4 --- /dev/null +++ b/onnxruntime/tool/etw/main.cc @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "eparser.h" +#include "TraceSession.h" + +// Turns the DEFINE_GUID for EventTraceGuid into a const. +#define INITGUID + +static const GUID OrtProviderGuid = {0x3a26b1ff, 0x7484, 0x7484, {0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d}}; + +int real_main(int argc, TCHAR* argv[]) { + ProfilingInfo context; + TraceSession session; + session.AddHandler(OrtProviderGuid, OrtEventHandler, &context); + session.InitializeEtlFile(argv[1], nullptr); + ULONG status = ProcessTrace(&session.traceHandle_, 1, 0, 0); + if (status != ERROR_SUCCESS && status != ERROR_CANCELLED) { + std::cout << "OpenTrace failed with " << status << std::endl; + session.Finalize(); + return -1; + } + session.Finalize(); + + assert(context.ortrun_count == context.ortrun_end_count); + std::vector stat_array(context.op_stat.size()); + size_t i = 0; + for (auto& p : context.op_stat) { + stat_array[i++] = &p.second; + } + std::sort(stat_array.begin(), stat_array.end(), + [](const OpStat* left, const OpStat* right) { return left->total_time > right->total_time; }); + size_t iterations = context.time_per_run.size(); + ULONG64 total_time = std::accumulate(context.time_per_run.begin() + 1, context.time_per_run.end(), (ULONG64)0); + // in microseconds + ULONG64 avg_time = total_time / (context.time_per_run.size() - 1) / 10; + double sum = 0; + for (OpStat* p : stat_array) { + if (p->name == L"Scan") { + continue; + } + uint64_t avg_time_per_op = p->total_time / iterations; + if (avg_time_per_op >= 0) { + double t = avg_time_per_op * 100.0 / avg_time; + std::wcout << p->name << L" " << p->total_time / p->count << L" " << std::fixed << std::setprecision(1) << t + << L"%\n"; + } + sum += p->total_time / (double)iterations; + } + std::wcout << L"total " << std::fixed << std::setprecision(1) << (sum * 100.0) / avg_time << L"%\n"; + return 0; +} + +int _tmain(int argc, TCHAR* argv[]) { + int retval = -1; + try { + retval = real_main(argc, argv); + } catch (std::exception& ex) { + fprintf(stderr, "%s\n", ex.what()); + retval = -1; + } + return retval; +} \ No newline at end of file diff --git a/ort.wprp b/ort.wprp new file mode 100644 index 0000000000000..8738efeb599ad --- /dev/null +++ b/ort.wprp @@ -0,0 +1,51 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file