Skip to content

Commit

Permalink
Selective load ONNX schema by specific opset_version (onnx#3266)
Browse files Browse the repository at this point in the history
* fix debug build

Signed-off-by: Chun-Wei Chen <[email protected]>

* prototype: add operator_versions to handle dynamically schema register

Signed-off-by: Chun-Wei Chen <[email protected]>

* typo

Signed-off-by: Chun-Wei Chen <[email protected]>

* find the latest opset and handle missing

Signed-off-by: Chun-Wei Chen <[email protected]>

* fix mac error about unordered_map

Signed-off-by: Chun-Wei Chen <[email protected]>

* complete all opset versions

Signed-off-by: Chun-Wei Chen <[email protected]>

* while adding the newer opset, remove the old one

Signed-off-by: Chun-Wei Chen <[email protected]>

* use ONNX_TRY

Signed-off-by: Chun-Wei Chen <[email protected]>

* reverse the order for loading

Signed-off-by: Chun-Wei Chen <[email protected]>

* fix bug

Signed-off-by: Chun-Wei Chen <[email protected]>

* skip debug check if partially load schema

Signed-off-by: Chun-Wei Chen <[email protected]>

* check repeated register before register

Signed-off-by: Chun-Wei Chen <[email protected]>

* correct; not use map()

Signed-off-by: Chun-Wei Chen <[email protected]>

* remove unnecessary const

Signed-off-by: Chun-Wei Chen <[email protected]>

* simplified with the original function

Signed-off-by: Chun-Wei Chen <[email protected]>

* use integer instead of bool

Signed-off-by: Chun-Wei Chen <[email protected]>

* don't insert version > max_version

Signed-off-by: Chun-Wei Chen <[email protected]>

* fix debug mode condition; change ver name; update comments

Signed-off-by: Chun-Wei Chen <[email protected]>

* try enable ONNX_DISABLE_STATIC_REGISTRATION

Signed-off-by: Chun-Wei Chen <[email protected]>

* fix useless name_space

Signed-off-by: Chun-Wei Chen <[email protected]>

* set -DONNX_DISABLE_STATIC_REGISTRATION=ON in the right place

Signed-off-by: Chun-Wei Chen <[email protected]>

* move def

Signed-off-by: Chun-Wei Chen <[email protected]>

* include the missing lib

Signed-off-by: Chun-Wei Chen <[email protected]>

* fix

Signed-off-by: Chun-Wei Chen <[email protected]>

* update cmake

Signed-off-by: Chun-Wei Chen <[email protected]>

* recover NOEBUG

Signed-off-by: Chun-Wei Chen <[email protected]>

* still use __ONNX_DISABLE_STATIC_REGISTRATION

Signed-off-by: Chun-Wei Chen <[email protected]>

* skip for ONNX_DISABLE_STATIC_REGISTRATION

Signed-off-by: Chun-Wei Chen <[email protected]>

* correct the right def

Signed-off-by: Chun-Wei Chen <[email protected]>

* register     RegisterOnnxOperatorSetSchema if needed

Signed-off-by: Chun-Wei Chen <[email protected]>

* fix def

Signed-off-by: Chun-Wei Chen <[email protected]>

* add missing library

Signed-off-by: Chun-Wei Chen <[email protected]>

* decouple tests

Signed-off-by: Chun-Wei Chen <[email protected]>

* cd back

Signed-off-by: Chun-Wei Chen <[email protected]>

* refactor cd path

Signed-off-by: Chun-Wei Chen <[email protected]>

* correct path

Signed-off-by: Chun-Wei Chen <[email protected]>

* use $ instead

Signed-off-by: Chun-Wei Chen <[email protected]>

* try simple test

Signed-off-by: Chun-Wei Chen <[email protected]>

* correct CI

Signed-off-by: Chun-Wei Chen <[email protected]>

* fix typo

Signed-off-by: Chun-Wei Chen <[email protected]>

* add more tests

Signed-off-by: Chun-Wei Chen <[email protected]>

* update comments

Signed-off-by: Chun-Wei Chen <[email protected]>

* switch testing places

Signed-off-by: Chun-Wei Chen <[email protected]>

* correct gtest_filter

Signed-off-by: Chun-Wei Chen <[email protected]>

* update by comments

Signed-off-by: Chun-Wei Chen <[email protected]>

* typo

Signed-off-by: Chun-Wei Chen <[email protected]>

* remove

Signed-off-by: Chun-Wei Chen <[email protected]>

* move GetRegisteredSchemaCount to private

Signed-off-by: Chun-Wei Chen <[email protected]>

* remove in .h as well

Signed-off-by: Chun-Wei Chen <[email protected]>

* remove space

Signed-off-by: Chun-Wei Chen <[email protected]>

Co-authored-by: Michał Karzyński <[email protected]>
  • Loading branch information
jcwchen and postrational authored Mar 31, 2021
1 parent 3c8160e commit 2fe362c
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 16 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/win_no_exception_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ jobs:
$Env:PATH="$ENV:PATH;$protoc_path;$protoc_lib_path;$protobuf_include_path"
cd ../../../onnx
echo "Build ONNX"
cmake -G "Visual Studio 16 2019" -A $arch -DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DONNX_WERROR=ON -DONNX_DISABLE_EXCEPTIONS=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_BUILD_TYPE=Release -DONNX_USE_MSVC_STATIC_RUNTIME=ON -DONNX_ML=1 -DONNX_BUILD_TESTS=ON -S . -B .setuptools-cmake-build\
cd .setuptools-cmake-build\
Expand All @@ -84,3 +85,16 @@ jobs:
if($lastexitcode -ne 0) {
EXIT 1
}
cd ..
git clean -xdf
echo "Build ONNX with non-static registration for testing selective ONNX schema loading"
cmake -G "Visual Studio 16 2019" -A $arch -DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DONNX_WERROR=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_BUILD_TYPE=Release -DONNX_USE_MSVC_STATIC_RUNTIME=ON -DONNX_ML=1 -DONNX_BUILD_TESTS=ON -DONNX_DISABLE_STATIC_REGISTRATION=ON -S . -B .setuptools-cmake-build\
cd .setuptools-cmake-build\
msbuild onnx.sln /m /p:Configuration=Release
echo "Only test selective ONNX schema loading"
Release\onnx_gtests.exe --gtest_filter="SchemaRegistrationTest*"
if($lastexitcode -ne 0) {
EXIT 1
}
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ option(ONNX_BUILD_TESTS "Build ONNX C++ APIs Tests" OFF)
option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF)
option(ONNXIFI_ENABLE_EXT "Enable onnxifi extensions." OFF)
option(ONNX_DISABLE_EXCEPTIONS "Disable exception handling." OFF)
option(ONNX_DISABLE_STATIC_REGISTRATION "Disable static registration for onnx operator schemas." OFF)

if(NOT DEFINED ONNX_ML)
if(DEFINED ENV{ONNX_ML})
set(DEFAULT_ONNX_ML $ENV{ONNX_ML})
Expand Down
4 changes: 4 additions & 0 deletions cmake/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ function(add_onnx_global_defines target)
if(ONNX_USE_LITE_PROTO)
target_compile_definitions(${target} PUBLIC "ONNX_USE_LITE_PROTO=1")
endif()

if(ONNX_DISABLE_STATIC_REGISTRATION)
target_compile_definitions(${target} PUBLIC "__ONNX_DISABLE_STATIC_REGISTRATION")
endif()
endfunction()

function(add_whole_archive_flag lib output_var)
Expand Down
23 changes: 23 additions & 0 deletions onnx/defs/operator_sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,8 @@ class OpSet_Onnx_ver14 {
};

inline void RegisterOnnxOperatorSetSchema() {
// 0 means all versions of ONNX schema have been loaded
OpSchemaRegistry::Instance()->SetLoadedSchemaVersion(0);
RegisterOpSetSchema<OpSet_Onnx_ver1>();
RegisterOpSetSchema<OpSet_Onnx_ver2>();
RegisterOpSetSchema<OpSet_Onnx_ver3>();
Expand All @@ -981,4 +983,25 @@ inline void RegisterOnnxOperatorSetSchema() {
RegisterOpSetSchema<OpSet_Onnx_ver14>();
}

inline void RegisterOnnxOperatorSetSchema(int target_version) {
// Sets to record the loaded version and prevent the full operator check in Debug mode
OpSchemaRegistry::Instance()->SetLoadedSchemaVersion(target_version);
// Update here if opset_version bumps
// These calls for schema registration here are required to be in descending order for this to work correctly
RegisterOpSetSchema<OpSet_Onnx_ver14>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver13>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver12>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver11>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver10>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver9>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver8>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver7>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver6>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver5>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver4>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver3>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver2>(target_version);
RegisterOpSetSchema<OpSet_Onnx_ver1>(target_version);
}

} // namespace ONNX_NAMESPACE
26 changes: 17 additions & 9 deletions onnx/defs/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@
#include "onnx/common/stl_backports.h"

namespace ONNX_NAMESPACE {
// -1 means ONNX schema hasn't been loaded yet
// 0 means all versions of ONNX schema have been loaded
// Other positive integer means the ONNX schemas for the specified version have been loaded
int OpSchemaRegistry::loaded_schema_version = -1;

void RegisterSchema(OpSchema&& schema) {
OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration = schema;
// By default if opset_version_to_load=0, it registers all opset schema for all opset versions
// Otherwise, it only registers the latest schema according to opset_version_to_load
void RegisterSchema(OpSchema schema, int opset_version_to_load) {
OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration(schema, opset_version_to_load);
}

#ifndef NDEBUG
Expand Down Expand Up @@ -861,13 +867,15 @@ OpName_Domain_Version_Schema_Map& OpSchemaRegistry::map() {

#ifndef NDEBUG
size_t dbg_registered_schema_count = GetRegisteredSchemaCount() - dbg_initial_schema_count;

ONNX_ASSERTM(
dbg_registered_schema_count == ONNX_DBG_GET_COUNT_IN_OPSETS(),
"%u schema were exposed from operator sets and automatically placed into the static registry. "
"%u were expected based on calls to registration macros. Operator set functions may need to be updated.",
dbg_registered_schema_count,
ONNX_DBG_GET_COUNT_IN_OPSETS());
// Check enabled only if schemas for all opset versions are loaded
if (OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0) {
ONNX_ASSERTM(
dbg_registered_schema_count == ONNX_DBG_GET_COUNT_IN_OPSETS(),
"%u schema were exposed from operator sets and automatically placed into the static registry. "
"%u were expected based on calls to registration macros. Operator set functions may need to be updated.",
dbg_registered_schema_count,
ONNX_DBG_GET_COUNT_IN_OPSETS());
}
#endif
}

Expand Down
30 changes: 23 additions & 7 deletions onnx/defs/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ class OpSchemaRegistry final : public ISchemaRegistry {

class OpSchemaRegisterOnce final {
public:
OpSchemaRegisterOnce(OpSchema& op_schema) {
OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load=0) {
ONNX_TRY {
op_schema.Finalize();

Expand All @@ -963,6 +963,10 @@ class OpSchemaRegistry final : public ISchemaRegistry {
auto& op_name = op_schema.Name();
auto& op_domain = op_schema.domain();
auto ver = op_schema.SinceVersion();
// Stops because the opset_version is higher than opset_version_to_load
if (opset_version_to_load != 0 && ver > opset_version_to_load) {
return;
}

if (m[op_name][op_domain].count(ver)) {
const auto& schema = m[op_name][op_domain][ver];
Expand All @@ -974,6 +978,10 @@ class OpSchemaRegistry final : public ISchemaRegistry {
<< schema.file() << " line " << schema.line() << std::endl;
fail_schema(err.str());
}
// Return early if schema for the targeted opset version has already been loaded
if (opset_version_to_load != 0 && !m[op_name][op_domain].empty()) {
return;
}

auto ver_range_map = DomainToVersionRange::Instance().Map();
auto ver_range_it = ver_range_map.find(op_domain);
Expand Down Expand Up @@ -1060,7 +1068,12 @@ class OpSchemaRegistry final : public ISchemaRegistry {
const std::string& domain = ONNX_DOMAIN) const override {
return Schema(key, maxInclusiveVersion, domain);
}

static void SetLoadedSchemaVersion(int target_version) {
loaded_schema_version = target_version;
}
static int GetLoadedSchemaVersion() {
return loaded_schema_version;
}
private:
// OpSchemaRegistry should not need to be instantiated except statically
// within this class
Expand All @@ -1078,6 +1091,7 @@ class OpSchemaRegistry final : public ISchemaRegistry {
*/
static OpName_Domain_Version_Schema_Map& GetMapWithoutEnsuringRegistration();
static OpName_Domain_Version_Schema_Map& map();
static int loaded_schema_version;

public:
static const std::vector<OpSchema> get_all_schemas_with_history() {
Expand All @@ -1104,12 +1118,15 @@ class OpSchemaRegistry final : public ISchemaRegistry {
}
};

void RegisterSchema(OpSchema&& schema);
void RegisterSchema(OpSchema schema, int opset_version_to_load=0);

// Registers all schema of a given operator set
// Registers the latest opset schema before opset_version_to_load
// By default opset_version_to_load=0 means it will register all versions
template <class T>
void RegisterOpSetSchema() {
T::ForEachSchema(RegisterSchema);
void RegisterOpSetSchema(int opset_version_to_load=0) {
T::ForEachSchema([opset_version_to_load](OpSchema&& schema) {
RegisterSchema(schema, opset_version_to_load);
});
};

// Forward declaration for the non-specialized GetOpSchema method. This
Expand Down Expand Up @@ -1152,7 +1169,6 @@ OpSchema GetOpSchema();
size_t dbg_count_check_##name##_##domain##_ver##ver = \
(dbg_included_in_static_opset) ? ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() \
: 0;

#ifdef NDEBUG
#define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() 0
#else
Expand Down
43 changes: 43 additions & 0 deletions onnx/test/cpp/schema_registration_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

#include <iostream>
#include "gtest/gtest.h"
#include "onnx/defs/operator_sets.h"
#include "onnx/defs/schema.h"

using namespace ONNX_NAMESPACE;

namespace ONNX_NAMESPACE {
namespace Test {

// By default ONNX registers all opset versions and selective schema loading cannot be tested
// So this test is run only when static registration is disabled
TEST(SchemaRegistrationTest, RegisterSpecifiedOpsetSchemaVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);
RegisterOnnxOperatorSetSchema(13);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 13);

auto opSchema = OpSchemaRegistry::Schema("Add");
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 13);

// Should not find opset 12
opSchema = OpSchemaRegistry::Schema("Add", 12);
EXPECT_EQ(nullptr, opSchema);

// Should not find opset 14
opSchema = OpSchemaRegistry::Schema("Trilu");
EXPECT_EQ(nullptr, opSchema);

// Acos-7 is the latest Acos before specified 13
opSchema = OpSchemaRegistry::Schema("Acos");
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 7);
#endif
}

} // namespace Test
} // namespace ONNX_NAMESPACE

0 comments on commit 2fe362c

Please sign in to comment.