diff --git a/.github/workflows/cmake_test.yml b/.github/workflows/cmake_test.yml index d23c692c1..e4a5f4400 100644 --- a/.github/workflows/cmake_test.yml +++ b/.github/workflows/cmake_test.yml @@ -77,6 +77,7 @@ jobs: -D CMAKE_CXX_COMPILER_LAUNCHER=ccache \ -D CMAKE_BUILD_TYPE=RelWithDebug \ -D FUZZTEST_BUILD_TESTING=on \ + -D FUZZTEST_BUILD_FLATBUFFERS=on \ && cmake --build build -j $(nproc) \ && ctest --test-dir build -j $(nproc) --output-on-failure - name: Run all tests in default mode with gcc @@ -90,6 +91,7 @@ jobs: -D CMAKE_CXX_COMPILER_LAUNCHER=ccache \ -D CMAKE_BUILD_TYPE=RelWithDebug \ -D FUZZTEST_BUILD_TESTING=on \ + -D FUZZTEST_BUILD_FLATBUFFERS=on \ && cmake --build build_gcc -j $(nproc) \ && ctest --test-dir build_gcc -j $(nproc) --output-on-failure - name: Run end-to-end tests in fuzzing mode @@ -104,6 +106,7 @@ jobs: -D CMAKE_BUILD_TYPE=RelWithDebug \ -D FUZZTEST_FUZZING_MODE=on \ -D FUZZTEST_BUILD_TESTING=on \ + -D FUZZTEST_BUILD_FLATBUFFERS=on \ && cmake --build build -j $(nproc) \ && ctest --test-dir build -j $(nproc) --output-on-failure -R "functional_test" - name: Save new cache based on main diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e34de3cc..e3803d30f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,7 @@ cmake_minimum_required(VERSION 3.19) project(fuzztest) option(FUZZTEST_BUILD_TESTING "Building the tests." OFF) +option(FUZZTEST_BUILD_FLATBUFFERS "Building the flatbuffers support." OFF) option(FUZZTEST_FUZZING_MODE "Building the fuzztest in fuzzing mode." OFF) set(FUZZTEST_COMPATIBILITY_MODE "" CACHE STRING "Compatibility mode. Available options: , libfuzzer") set(CMAKE_CXX_STANDARD 17) diff --git a/MODULE.bazel b/MODULE.bazel index 9fbd47947..28902ac8e 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -42,6 +42,10 @@ bazel_dep( name = "platforms", version = "0.0.10", ) +bazel_dep( + name = "flatbuffers", + version = "25.2.10" +) # GoogleTest is not a dev dependency, because it's needed when FuzzTest is used # with GoogleTest integration (e.g., googletest_adaptor). Note that the FuzzTest # framework can be used without GoogleTest integration as well. @@ -55,8 +59,6 @@ bazel_dep( name = "protobuf", version = "30.2", ) -# TODO(lszekeres): Make this a dev dependency, as the protobuf library is only -# required for testing. bazel_dep( name = "rules_proto", version = "7.1.0", diff --git a/cmake/BuildDependencies.cmake b/cmake/BuildDependencies.cmake index 5214fcefd..2966c0b64 100644 --- a/cmake/BuildDependencies.cmake +++ b/cmake/BuildDependencies.cmake @@ -21,6 +21,9 @@ set(proto_TAG v30.2) set(nlohmann_json_URL https://github.com/nlohmann/json.git) set(nlohmann_json_TAG v3.11.3) +set(flatbuffers_URL https://github.com/google/flatbuffers.git) +set(flatbuffers_TAG v25.2.10) + if(POLICY CMP0135) cmake_policy(SET CMP0135 NEW) set(CMAKE_POLICY_DEFAULT_CMP0135 NEW) @@ -50,6 +53,14 @@ FetchContent_Declare( URL_HASH MD5=${antlr_cpp_MD5} ) +if (FUZZTEST_BUILD_FLATBUFFERS) + FetchContent_Declare( + flatbuffers + GIT_REPOSITORY ${flatbuffers_URL} + GIT_TAG ${flatbuffers_TAG} + ) +endif() + if (FUZZTEST_BUILD_TESTING) FetchContent_Declare( @@ -87,3 +98,9 @@ if (FUZZTEST_BUILD_TESTING) FetchContent_MakeAvailable(nlohmann_json) endif () + +if (FUZZTEST_BUILD_FLATBUFFERS) + set(FLATBUFFERS_BUILD_TESTS OFF) + set(FLATBUFFERS_BUILD_INSTALL OFF) + FetchContent_MakeAvailable(flatbuffers) +endif() diff --git a/cmake/generate_cmake_from_bazel.py b/cmake/generate_cmake_from_bazel.py index 83d31d5dc..0c9079cbf 100755 --- a/cmake/generate_cmake_from_bazel.py +++ b/cmake/generate_cmake_from_bazel.py @@ -52,6 +52,7 @@ "@abseil-cpp//absl/types:optional": "absl::optional", "@abseil-cpp//absl/types:span": "absl::span", "@abseil-cpp//absl/types:variant": "absl::variant", + "@flatbuffers//:runtime_cc": "flatbuffers", "@googletest//:gtest": "GTest::gtest", "@googletest//:gtest_main": "GTest::gmock_main", "@protobuf//:protobuf": "protobuf::libprotobuf", diff --git a/domain_tests/BUILD b/domain_tests/BUILD index d71a93b06..82cd903f4 100644 --- a/domain_tests/BUILD +++ b/domain_tests/BUILD @@ -33,6 +33,24 @@ cc_test( ], ) +cc_test( + name = "arbitrary_domains_flatbuffers_test", + srcs = ["arbitrary_domains_flatbuffers_test.cc"], + deps = [ + ":domain_testing", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/random", + "@com_google_fuzztest//fuzztest:domain", + "@com_google_fuzztest//fuzztest:flatbuffers", + "@com_google_fuzztest//fuzztest:meta", + "@com_google_fuzztest//fuzztest:serialization", + "@com_google_fuzztest//fuzztest:test_flatbuffers_cc_fbs", + "@flatbuffers//:runtime_cc", + "@googletest//:gtest_main", + ], +) + cc_test( name = "arbitrary_domains_protobuf_test", srcs = ["arbitrary_domains_protobuf_test.cc"], diff --git a/domain_tests/CMakeLists.txt b/domain_tests/CMakeLists.txt index 703ee493d..092d2839a 100644 --- a/domain_tests/CMakeLists.txt +++ b/domain_tests/CMakeLists.txt @@ -19,6 +19,30 @@ fuzztest_cc_test( GTest::gmock_main ) +if (FUZZTEST_BUILD_FLATBUFFERS) + fuzztest_cc_test( + NAME + arbitrary_domains_flatbuffers_test + SRCS + "arbitrary_domains_flatbuffers_test.cc" + DEPS + absl::flat_hash_set + absl::random_bit_gen_ref + absl::random_random + absl::strings + flatbuffers + fuzztest::flatbuffers + fuzztest::domain + fuzztest::domain_testing + fuzztest::flatbuffers + GTest::gmock_main + test_flatbuffers + ) + add_dependencies(fuzztest_arbitrary_domains_flatbuffers_test + GENERATE_test_flatbuffers + ) +endif() + fuzztest_cc_test( NAME arbitrary_domains_protobuf_test diff --git a/domain_tests/arbitrary_domains_flatbuffers_test.cc b/domain_tests/arbitrary_domains_flatbuffers_test.cc new file mode 100644 index 000000000..15cbd7864 --- /dev/null +++ b/domain_tests/arbitrary_domains_flatbuffers_test.cc @@ -0,0 +1,913 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/random/random.h" +#include "flatbuffers/base.h" +#include "flatbuffers/buffer.h" +#include "flatbuffers/flatbuffer_builder.h" +#include "flatbuffers/string.h" +#include "flatbuffers/vector.h" +#include "flatbuffers/verifier.h" +#include "./fuzztest/domain.h" +#include "./domain_tests/domain_testing.h" +#include "./fuzztest/flatbuffers.h" +#include "./fuzztest/internal/meta.h" +#include "./fuzztest/internal/serialization.h" +#include "./fuzztest/internal/test_flatbuffers_generated.h" + +namespace fuzztest { +namespace { + +using ::fuzztest::internal::BoolStruct; +using ::fuzztest::internal::BoolTable; +using ::fuzztest::internal::DefaultStruct; +using ::fuzztest::internal::DefaultTable; +using ::fuzztest::internal::Enum; +using ::fuzztest::internal::OptionalTable; +using ::fuzztest::internal::RequiredTable; +using ::fuzztest::internal::StringTable; +using ::fuzztest::internal::UnionTable; +using ::testing::_; +using ::testing::Contains; +using ::testing::Each; +using ::testing::IsTrue; +using ::testing::NotNull; +using ::testing::Pair; +using ::testing::ResultOf; + +template +inline bool Eq(T rhs, T lhs) { + static_assert(!std::is_pointer_v, "T cannot be a pointer type"); + return lhs == rhs; +} + +template <> +inline bool Eq(const flatbuffers::String* rhs, + const flatbuffers::String* lhs) { + return (rhs == nullptr && lhs == nullptr) || + (rhs != nullptr && lhs != nullptr && rhs->str() == lhs->str()); +}; + +template <> +inline bool Eq(BoolStruct rhs, BoolStruct lhs) { + return Eq(rhs.b(), lhs.b()); +}; + +template <> +inline bool Eq(const DefaultStruct* rhs, + const DefaultStruct* lhs) { + if (rhs == nullptr && lhs == nullptr) { + return true; + } else if (rhs == nullptr || lhs == nullptr) { + return false; + } else { + return Eq(rhs->b(), lhs->b()) && Eq(rhs->i8(), lhs->i8()) && + Eq(rhs->i16(), lhs->i16()) && Eq(rhs->i32(), lhs->i32()) && + Eq(rhs->i64(), lhs->i64()) && Eq(rhs->u8(), lhs->u8()) && + Eq(rhs->u16(), lhs->u16()) && Eq(rhs->u32(), lhs->u32()) && + Eq(rhs->u64(), lhs->u64()) && Eq(rhs->f(), lhs->f()) && + Eq(rhs->d(), lhs->d()) && Eq(rhs->e(), lhs->e()) && + Eq(rhs->s(), lhs->s()); + } +} + +template <> +inline bool Eq(const BoolTable* rhs, const BoolTable* lhs) { + return (rhs == nullptr && lhs == nullptr) || + (rhs != nullptr && lhs != nullptr && rhs->b() == lhs->b()); +}; + +template <> +inline bool Eq>( + std::pair rhs, std::pair lhs) { + if (rhs.first == internal::Union_NONE && lhs.first == internal::Union_NONE) { + return true; + } else if (rhs.first != lhs.first) { + return false; + } else { + switch (rhs.first) { + case internal::Union_BoolTable: + return static_cast(rhs.second)->b() == + static_cast(lhs.second)->b(); + case internal::Union_StringTable: + return static_cast(rhs.second)->str()->str() == + static_cast(lhs.second)->str()->str(); + case internal::Union_BoolStruct: + return static_cast(rhs.second)->b() == + static_cast(lhs.second)->b(); + default: + CHECK(false) << "Unsupported union type"; + } + } +} + +template +inline bool VectorEq(const flatbuffers::Vector* rhs, + const flatbuffers::Vector* lhs) { + if (rhs == nullptr && lhs == nullptr) { + return true; + } else if (rhs == nullptr || lhs == nullptr) { + return false; + } + if (rhs->size() != lhs->size()) { + return false; + } + for (int i = 0; i < rhs->size(); ++i) { + if (!Eq(rhs->Get(i), lhs->Get(i))) { + return false; + } + } + return true; +}; + +inline bool VectorUnionEq( + const flatbuffers::Vector* rhs_type, + const flatbuffers::Vector<::flatbuffers::Offset>* rhs, + const flatbuffers::Vector* lhs_type, + const flatbuffers::Vector<::flatbuffers::Offset>* lhs) { + if (!VectorEq(rhs_type, lhs_type)) { + return false; + } + if (rhs == nullptr && lhs == nullptr) { + return true; + } else if (rhs == nullptr || lhs == nullptr) { + return false; + } + if (rhs->size() != lhs->size()) { + return false; + } + for (int i = 0; i < rhs->size(); ++i) { + if (!Eq(std::make_pair(rhs_type->Get(i), rhs->Get(i)), + std::make_pair(lhs_type->Get(i), lhs->Get(i)))) { + return false; + } + } + return true; +}; + +const internal::DefaultTable* CreateDefaultTable( + flatbuffers::FlatBufferBuilder& fbb) { + auto bool_table_offset = internal::CreateBoolTable(fbb, true); + auto string_table_offset = + internal::CreateStringTableDirect(fbb, "foo bar baz"); + DefaultStruct s{ + true, // b + 1, // i8 + 2, // i16 + 3, // i32 + 4, // i64 + 5, // u8 + 6, // u16 + 7, // u32 + 8, // u64 + 9, // f + 10.0, // d + internal::Enum_Second, // e + BoolStruct{true} // s + }; + std::vector v_b{true, false}; + std::vector v_i8{1, 2, 3}; + std::vector v_i16{1, 2, 3}; + std::vector v_i32{1, 2, 3}; + std::vector v_i64{1, 2, 3}; + std::vector v_u8{1, 2, 3}; + std::vector v_u16{1, 2, 3}; + std::vector v_u32{1, 2, 3}; + std::vector v_u64{1, 2, 3}; + std::vector v_f{1, 2, 3}; + std::vector v_d{1, 2, 3}; + std::vector> v_str{ + fbb.CreateString("foo"), fbb.CreateString("bar"), + fbb.CreateString("baz")}; + std::vector> v_e{ + internal::Enum_First, internal::Enum_Second, internal::Enum_Third}; + std::vector> v_t{bool_table_offset}; + std::vector> v_u_type{ + internal::Union_BoolTable, + internal::Union_StringTable, + }; + std::vector> v_u{ + bool_table_offset.Union(), + string_table_offset.Union(), + }; + std::vector v_s{s}; + auto table_offset = + internal::CreateDefaultTableDirect(fbb, + true, // b + 1, // i8 + 2, // i16 + 3, // i32 + 4, // i64 + 5, // u8 + 6, // u16 + 7, // u32 + 8, // u64 + 9.0, // f + 10.0, // d + "foo bar baz", // str + internal::Enum_Second, // e + bool_table_offset, // t + internal::Union_BoolTable, // u_type + bool_table_offset.Union(), // u + &s, // s + &v_b, // v_b + &v_i8, // v_i8 + &v_i16, // v_i16 + &v_i32, // v_i32 + &v_i64, // v_i64 + &v_u8, // v_u8 + &v_u16, // v_u16 + &v_u32, // v_u32 + &v_u64, // v_u64 + &v_f, // v_f + &v_d, // v_d + &v_str, // v_str + &v_e, // v_e + &v_t, // v_t + &v_u_type, // v_u_type + &v_u, // v_u + &v_s // v_s + ); + fbb.Finish(table_offset); + return flatbuffers::GetRoot(fbb.GetBufferPointer()); +} + +TEST(FlatbuffersMetaTest, IsFlatbuffersTable) { + static_assert(internal::is_flatbuffers_table_v); + static_assert(!internal::is_flatbuffers_table_v); + static_assert(!internal::is_flatbuffers_table_v>); +} + +TEST(FlatbuffersTableDomainImplTest, DefaultTableValueRoundTrip) { + flatbuffers::FlatBufferBuilder fbb; + auto table = CreateDefaultTable(fbb); + + auto domain = Arbitrary(); + auto corpus = domain.FromValue(table); + ASSERT_TRUE(corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*corpus)); + + auto ir = domain.SerializeCorpus(corpus.value()); + + auto new_corpus = domain.ParseCorpus(ir); + ASSERT_TRUE(new_corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*new_corpus)); + + auto new_table = domain.GetValue(*new_corpus); + EXPECT_EQ(new_table->b(), true); + EXPECT_EQ(new_table->i8(), 1); + EXPECT_EQ(new_table->i16(), 2); + EXPECT_EQ(new_table->i32(), 3); + EXPECT_EQ(new_table->i64(), 4); + EXPECT_EQ(new_table->u8(), 5); + EXPECT_EQ(new_table->u16(), 6); + EXPECT_EQ(new_table->u32(), 7); + EXPECT_EQ(new_table->u64(), 8); + EXPECT_EQ(new_table->f(), 9.0); + EXPECT_EQ(new_table->d(), 10.0); + EXPECT_EQ(new_table->str()->str(), "foo bar baz"); + EXPECT_EQ(new_table->e(), internal::Enum_Second); + EXPECT_EQ(new_table->u_type(), internal::Union_BoolTable); + EXPECT_EQ(new_table->u_as_BoolTable()->b(), true); + ASSERT_THAT(new_table->t(), NotNull()); + EXPECT_EQ(new_table->t()->b(), true); + ASSERT_THAT(new_table->s(), NotNull()); + EXPECT_EQ(new_table->s()->b(), true); + EXPECT_EQ(new_table->s()->i8(), 1); + EXPECT_EQ(new_table->s()->i16(), 2); + EXPECT_EQ(new_table->s()->i32(), 3); + EXPECT_EQ(new_table->s()->i64(), 4); + EXPECT_EQ(new_table->s()->u8(), 5); + EXPECT_EQ(new_table->s()->u16(), 6); + EXPECT_EQ(new_table->s()->u32(), 7); + EXPECT_EQ(new_table->s()->u64(), 8); + EXPECT_EQ(new_table->s()->f(), 9.0); + EXPECT_EQ(new_table->s()->d(), 10.0); + EXPECT_EQ(new_table->s()->e(), internal::Enum_Second); + EXPECT_EQ(new_table->s()->s().b(), true); + ASSERT_THAT(new_table->v_b(), NotNull()); + EXPECT_EQ(new_table->v_b()->size(), 2); + EXPECT_EQ(new_table->v_b()->Get(0), true); + EXPECT_EQ(new_table->v_b()->Get(1), false); + ASSERT_THAT(new_table->v_i8(), NotNull()); + EXPECT_EQ(new_table->v_i8()->size(), 3); + EXPECT_EQ(new_table->v_i8()->Get(0), 1); + EXPECT_EQ(new_table->v_i8()->Get(1), 2); + EXPECT_EQ(new_table->v_i8()->Get(2), 3); + ASSERT_THAT(new_table->v_i16(), NotNull()); + EXPECT_EQ(new_table->v_i16()->size(), 3); + EXPECT_EQ(new_table->v_i16()->Get(0), 1); + EXPECT_EQ(new_table->v_i16()->Get(1), 2); + EXPECT_EQ(new_table->v_i16()->Get(2), 3); + ASSERT_THAT(new_table->v_i32(), NotNull()); + EXPECT_EQ(new_table->v_i32()->size(), 3); + EXPECT_EQ(new_table->v_i32()->Get(0), 1); + EXPECT_EQ(new_table->v_i32()->Get(1), 2); + EXPECT_EQ(new_table->v_i32()->Get(2), 3); + ASSERT_THAT(new_table->v_i64(), NotNull()); + EXPECT_EQ(new_table->v_i64()->size(), 3); + EXPECT_EQ(new_table->v_i64()->Get(0), 1); + EXPECT_EQ(new_table->v_i64()->Get(1), 2); + EXPECT_EQ(new_table->v_i64()->Get(2), 3); + ASSERT_THAT(new_table->v_u8(), NotNull()); + EXPECT_EQ(new_table->v_u8()->size(), 3); + EXPECT_EQ(new_table->v_u8()->Get(0), 1); + EXPECT_EQ(new_table->v_u8()->Get(1), 2); + EXPECT_EQ(new_table->v_u8()->Get(2), 3); + ASSERT_THAT(new_table->v_u16(), NotNull()); + EXPECT_EQ(new_table->v_u16()->size(), 3); + EXPECT_EQ(new_table->v_u16()->Get(0), 1); + EXPECT_EQ(new_table->v_u16()->Get(1), 2); + EXPECT_EQ(new_table->v_u16()->Get(2), 3); + ASSERT_THAT(new_table->v_u32(), NotNull()); + EXPECT_EQ(new_table->v_u32()->size(), 3); + EXPECT_EQ(new_table->v_u32()->Get(0), 1); + EXPECT_EQ(new_table->v_u32()->Get(1), 2); + EXPECT_EQ(new_table->v_u32()->Get(2), 3); + ASSERT_THAT(new_table->v_u64(), NotNull()); + EXPECT_EQ(new_table->v_u64()->size(), 3); + EXPECT_EQ(new_table->v_u64()->Get(0), 1); + EXPECT_EQ(new_table->v_u64()->Get(1), 2); + EXPECT_EQ(new_table->v_u64()->Get(2), 3); + ASSERT_THAT(new_table->v_f(), NotNull()); + EXPECT_EQ(new_table->v_f()->size(), 3); + EXPECT_EQ(new_table->v_f()->Get(0), 1); + EXPECT_EQ(new_table->v_f()->Get(1), 2); + EXPECT_EQ(new_table->v_f()->Get(2), 3); + ASSERT_THAT(new_table->v_d(), NotNull()); + EXPECT_EQ(new_table->v_d()->size(), 3); + EXPECT_EQ(new_table->v_d()->Get(0), 1); + EXPECT_EQ(new_table->v_d()->Get(1), 2); + EXPECT_EQ(new_table->v_d()->Get(2), 3); + EXPECT_EQ(new_table->v_str()->size(), 3); + EXPECT_EQ(new_table->v_str()->Get(0)->str(), "foo"); + EXPECT_EQ(new_table->v_str()->Get(1)->str(), "bar"); + EXPECT_EQ(new_table->v_str()->Get(2)->str(), "baz"); + ASSERT_THAT(new_table->v_e(), NotNull()); + EXPECT_EQ(new_table->v_e()->size(), 3); + EXPECT_EQ(new_table->v_e()->Get(0), internal::Enum_First); + EXPECT_EQ(new_table->v_e()->Get(1), internal::Enum_Second); + EXPECT_EQ(new_table->v_e()->Get(2), internal::Enum_Third); + ASSERT_THAT(new_table->v_t(), NotNull()); + EXPECT_EQ(new_table->v_t()->size(), 1); + ASSERT_THAT(new_table->v_t()->Get(0), NotNull()); + EXPECT_EQ(new_table->v_t()->Get(0)->b(), true); + ASSERT_THAT(new_table->v_u_type(), NotNull()); + EXPECT_EQ(new_table->v_u_type()->size(), 2); + EXPECT_EQ(new_table->v_u_type()->Get(0), internal::Union_BoolTable); + EXPECT_EQ(new_table->v_u_type()->Get(1), internal::Union_StringTable); + ASSERT_THAT(new_table->v_u(), NotNull()); + EXPECT_EQ(new_table->v_u()->size(), 2); + auto v_u_0 = + static_cast(new_table->v_u()->Get(0)); + ASSERT_THAT(v_u_0, NotNull()); + EXPECT_EQ(v_u_0->b(), true); + auto v_u_1 = + static_cast(new_table->v_u()->Get(1)); + ASSERT_THAT(v_u_1, NotNull()); + ASSERT_THAT(v_u_1->str(), NotNull()); + EXPECT_EQ(v_u_1->str()->str(), "foo bar baz"); + ASSERT_THAT(new_table->v_s(), NotNull()); + EXPECT_EQ(new_table->v_s()->size(), 1); + ASSERT_THAT(new_table->v_s()->Get(0), NotNull()); + EXPECT_EQ(new_table->v_s()->Get(0)->b(), true); + EXPECT_EQ(new_table->v_s()->Get(0)->i8(), 1); + EXPECT_EQ(new_table->v_s()->Get(0)->i16(), 2); + EXPECT_EQ(new_table->v_s()->Get(0)->i32(), 3); + EXPECT_EQ(new_table->v_s()->Get(0)->i64(), 4); + EXPECT_EQ(new_table->v_s()->Get(0)->u8(), 5); + EXPECT_EQ(new_table->v_s()->Get(0)->u16(), 6); + EXPECT_EQ(new_table->v_s()->Get(0)->u32(), 7); + EXPECT_EQ(new_table->v_s()->Get(0)->u64(), 8); + EXPECT_EQ(new_table->v_s()->Get(0)->f(), 9.0); + EXPECT_EQ(new_table->v_s()->Get(0)->d(), 10.0); + EXPECT_EQ(new_table->v_s()->Get(0)->e(), internal::Enum_Second); + EXPECT_EQ(new_table->v_s()->Get(0)->s().b(), true); +} + +TEST(FlatbuffersTableDomainImplTest, InitGeneratesSeeds) { + flatbuffers::FlatBufferBuilder fbb; + auto table = CreateDefaultTable(fbb); + + auto domain = Arbitrary(); + domain.WithSeeds({table}); + + std::vector> values; + absl::BitGen bitgen; + values.reserve(1000); + for (int i = 0; i < 1000; ++i) { + Value value(domain, bitgen); + values.push_back(std::move(value)); + } + + EXPECT_THAT( + values, + Contains(ResultOf( + [table](const auto& val) { + return (Eq(val.user_value->b(), table->b()) && + Eq(val.user_value->i8(), table->i8()) && + Eq(val.user_value->i16(), table->i16()) && + Eq(val.user_value->i32(), table->i32()) && + Eq(val.user_value->i64(), table->i64()) && + Eq(val.user_value->u8(), table->u8()) && + Eq(val.user_value->u16(), table->u16()) && + Eq(val.user_value->u32(), table->u32()) && + Eq(val.user_value->u64(), table->u64()) && + Eq(val.user_value->f(), table->f()) && + Eq(val.user_value->d(), table->d()) && + Eq(val.user_value->f(), table->f()) && + Eq(val.user_value->e(), table->e()) && + Eq(val.user_value->str(), table->str()) && + Eq(val.user_value->t(), table->t()) && + Eq(std::make_pair( + static_cast(val.user_value->u_type()), + val.user_value->u()), + std::make_pair(static_cast(table->u_type()), + table->u())) && + Eq(val.user_value->s(), table->s()) && + VectorEq(val.user_value->v_b(), table->v_b()) && + VectorEq(val.user_value->v_i8(), table->v_i8()) && + VectorEq(val.user_value->v_i16(), table->v_i16()) && + VectorEq(val.user_value->v_i32(), table->v_i32()) && + VectorEq(val.user_value->v_i64(), table->v_i64()) && + VectorEq(val.user_value->v_u8(), table->v_u8()) && + VectorEq(val.user_value->v_u16(), table->v_u16()) && + VectorEq(val.user_value->v_u32(), table->v_u32()) && + VectorEq(val.user_value->v_u64(), table->v_u64()) && + VectorEq(val.user_value->v_f(), table->v_f()) && + VectorEq(val.user_value->v_d(), table->v_d()) && + VectorEq(val.user_value->v_str(), table->v_str()) && + VectorEq(val.user_value->v_e(), table->v_e()) && + VectorEq(val.user_value->v_t(), table->v_t()) && + VectorUnionEq(val.user_value->v_u_type(), + val.user_value->v_u(), table->v_u_type(), + table->v_u()) && + VectorEq(val.user_value->v_s(), table->v_s())); + }, + IsTrue()))); +} + +TEST(FlatbuffersTableDomainImplTest, EventuallyMutatesAllTableFields) { + absl::flat_hash_map mutated_fields{ + {"b", false}, {"i8", false}, {"i16", false}, + {"i32", false}, {"i64", false}, {"u8", false}, + {"u16", false}, {"u32", false}, {"u64", false}, + {"f", false}, {"d", false}, {"str", false}, + {"e", false}, {"t", false}, {"u_type", false}, + {"u", false}, {"s", false}, {"t.v_b", false}, + {"t.v_i8", false}, {"t.v_i16", false}, {"t.v_i32", false}, + {"t.v_i64", false}, {"t.v_u8", false}, {"t.v_u16", false}, + {"t.v_u32", false}, {"t.v_u64", false}, {"t.v_f", false}, + {"t.v_d", false}, {"t.v_e", false}, {"t.v_str", false}, + {"t.v_t", false}, {"t.v_u_type", false}, {"t.v_u", false}, + {"t.v_s", false}, + }; + + auto domain = Arbitrary(); + + absl::BitGen bitgen; + Value initial_val(domain, bitgen); + Value val(initial_val); + + for (size_t i = 0; i < 10'000; ++i) { + val.Mutate(domain, bitgen, {}, false); + const auto& mut = val.user_value; + const auto& init = initial_val.user_value; + + mutated_fields["b"] |= !Eq(mut->b(), init->b()); + mutated_fields["i8"] |= !Eq(mut->i8(), init->i8()); + mutated_fields["i16"] |= !Eq(mut->i16(), init->i16()); + mutated_fields["i32"] |= !Eq(mut->i32(), init->i32()); + mutated_fields["i64"] |= !Eq(mut->i64(), init->i64()); + mutated_fields["u8"] |= !Eq(mut->u8(), init->u8()); + mutated_fields["u16"] |= !Eq(mut->u16(), init->u16()); + mutated_fields["u32"] |= !Eq(mut->u32(), init->u32()); + mutated_fields["u64"] |= !Eq(mut->u64(), init->u64()); + mutated_fields["f"] |= !Eq(mut->f(), init->f()); + mutated_fields["d"] |= !Eq(mut->d(), init->d()); + mutated_fields["str"] |= Eq(mut->str(), init->str()); + mutated_fields["e"] |= !Eq(mut->e(), init->e()); + mutated_fields["t"] |= Eq(mut->t(), init->t()); + mutated_fields["u_type"] |= Eq(mut->u_type(), init->u_type()); + mutated_fields["u"] |= + !Eq(std::make_pair(static_cast(mut->u_type()), mut->u()), + std::make_pair(static_cast(init->u_type()), init->u())); + mutated_fields["s"] |= !Eq(mut->s(), init->s()); + mutated_fields["t.v_b"] |= !VectorEq(mut->v_b(), init->v_b()); + mutated_fields["t.v_i8"] |= !VectorEq(mut->v_i8(), init->v_i8()); + mutated_fields["t.v_i16"] |= !VectorEq(mut->v_i16(), init->v_i16()); + mutated_fields["t.v_i32"] |= !VectorEq(mut->v_i32(), init->v_i32()); + mutated_fields["t.v_i64"] |= !VectorEq(mut->v_i64(), init->v_i64()); + mutated_fields["t.v_u8"] |= !VectorEq(mut->v_u8(), init->v_u8()); + mutated_fields["t.v_u16"] |= !VectorEq(mut->v_u16(), init->v_u16()); + mutated_fields["t.v_u32"] |= !VectorEq(mut->v_u32(), init->v_u32()); + mutated_fields["t.v_u64"] |= !VectorEq(mut->v_u64(), init->v_u64()); + mutated_fields["t.v_f"] |= !VectorEq(mut->v_f(), init->v_f()); + mutated_fields["t.v_d"] |= !VectorEq(mut->v_d(), init->v_d()); + mutated_fields["t.v_e"] |= !VectorEq(mut->v_e(), init->v_e()); + mutated_fields["t.v_str"] |= !VectorEq(mut->v_str(), init->v_str()); + mutated_fields["t.v_t"] |= !VectorEq(mut->v_str(), init->v_str()); + mutated_fields["t.v_u_type"] |= + !VectorEq(mut->v_u_type(), init->v_u_type()); + mutated_fields["t.v_u"] |= !VectorUnionEq(mut->v_u_type(), mut->v_u(), + init->v_u_type(), init->v_u()); + mutated_fields["t.v_s"] |= !VectorEq(mut->v_s(), init->v_s()); + + bool all_mutated = true; + for (const auto& [name, mutated] : mutated_fields) { + all_mutated &= mutated; + if (!mutated) { + break; + } + } + if (all_mutated) { + break; + } + } + + EXPECT_THAT(mutated_fields, Each(Pair(_, true))); +} + +TEST(FlatbuffersTableDomainImplTest, OptionalTableEventuallyBecomeEmpty) { + flatbuffers::FlatBufferBuilder fbb; + auto bool_table_offset = internal::CreateBoolTable(fbb, true); + DefaultStruct s; + std::vector v_b{true, false}; + std::vector v_i8{}; + std::vector v_i16{}; + std::vector v_i32{}; + std::vector v_i64{}; + std::vector v_u8{}; + std::vector v_u16{}; + std::vector v_u32{}; + std::vector v_u64{}; + std::vector v_f{}; + std::vector v_d{}; + std::vector> v_str{ + fbb.CreateString(""), fbb.CreateString(""), fbb.CreateString("")}; + std::vector> v_e{}; + std::vector> v_t{}; + std::vector> v_u_type{}; + std::vector> v_u{}; + std::vector v_s{}; + auto table_offset = + internal::CreateOptionalTableDirect(fbb, + true, // b + 1, // i8 + 2, // i16 + 3, // i32 + 4, // i64 + 5, // u8 + 6, // u16 + 7, // u32 + 8, // u64 + 9.0, // f + 10.0, // d + "foo bar baz", // str + internal::Enum_Second, // e + bool_table_offset, // t + internal::Union_BoolTable, // u_type + bool_table_offset.Union(), // u + &s, + &v_b, // v_b + &v_i8, // v_i8 + &v_i16, // v_i16 + &v_i32, // v_i32 + &v_i64, // v_i64 + &v_u8, // v_u8 + &v_u16, // v_u16 + &v_u32, // v_u32 + &v_u64, // v_u64 + &v_f, // v_f + &v_d, // v_d + &v_str, // v_str + &v_e, // v_e + &v_t, // v_t + &v_u_type, // v_u_type + &v_u, // v_u + &v_s // v_s + ); + fbb.Finish(table_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto domain = Arbitrary(); + Value val(domain, table); + absl::BitGen bitgen; + + absl::flat_hash_map null_fields{ + {"b", false}, {"i8", false}, {"i16", false}, + {"i32", false}, {"i64", false}, {"u8", false}, + {"u16", false}, {"u32", false}, {"u64", false}, + {"f", false}, {"d", false}, {"str", false}, + {"e", false}, {"t", false}, {"u_type", false}, + {"u", false}, {"s", false}, {"t.v_b", false}, + {"t.v_i8", false}, {"t.v_i16", false}, {"t.v_i32", false}, + {"t.v_i64", false}, {"t.v_u8", false}, {"t.v_u16", false}, + {"t.v_u32", false}, {"t.v_u64", false}, {"t.v_f", false}, + {"t.v_d", false}, {"t.v_e", false}, {"t.v_str", false}, + {"t.v_t", false}, {"t.v_u_type", false}, {"t.v_u", false}, + {"t.v_s", false}, + }; + + for (size_t i = 0; i < 1'000'000; ++i) { + val.Mutate(domain, bitgen, {}, true); + const auto& v = val.user_value; + + null_fields["b"] |= !v->b().has_value(); + null_fields["i8"] |= !v->i8().has_value(); + null_fields["i16"] |= !v->i16().has_value(); + null_fields["i32"] |= !v->i32().has_value(); + null_fields["i64"] |= !v->i64().has_value(); + null_fields["u8"] |= !v->u8().has_value(); + null_fields["u16"] |= !v->u16().has_value(); + null_fields["u32"] |= !v->u32().has_value(); + null_fields["u64"] |= !v->u64().has_value(); + null_fields["f"] |= !v->f().has_value(); + null_fields["d"] |= !v->d().has_value(); + null_fields["str"] |= v->str() == nullptr; + null_fields["e"] |= !v->e().has_value(); + null_fields["t"] |= v->t() == nullptr; + null_fields["u_type"] |= v->u_type() == internal::Union_NONE; + null_fields["u"] |= v->u() == nullptr; + null_fields["s"] |= v->s() == nullptr; + null_fields["t.v_b"] |= v->v_b() == nullptr; + null_fields["t.v_i8"] |= v->v_i8() == nullptr; + null_fields["t.v_i16"] |= v->v_i16() == nullptr; + null_fields["t.v_i32"] |= v->v_i32() == nullptr; + null_fields["t.v_i64"] |= v->v_i64() == nullptr; + null_fields["t.v_u8"] |= v->v_u8() == nullptr; + null_fields["t.v_u16"] |= v->v_u16() == nullptr; + null_fields["t.v_u32"] |= v->v_u32() == nullptr; + null_fields["t.v_u64"] |= v->v_u64() == nullptr; + null_fields["t.v_f"] |= v->v_f() == nullptr; + null_fields["t.v_d"] |= v->v_d() == nullptr; + null_fields["t.v_e"] |= v->v_e() == nullptr; + null_fields["t.v_str"] |= v->v_str() == nullptr; + null_fields["t.v_t"] |= v->v_t() == nullptr; + null_fields["t.v_u_type"] |= v->v_u_type() == nullptr; + null_fields["t.v_u"] |= v->v_u() == nullptr; + null_fields["t.v_s"] |= v->v_s() == nullptr; + + bool all_null = true; + for (const auto& [name, is_null] : null_fields) { + all_null &= is_null; + if (!is_null) { + break; + } + } + if (all_null) { + break; + } + } + + EXPECT_THAT(null_fields, Each(Pair(_, true))); +} + +TEST(FlatbuffersTableDomainImplTest, RequiredTableFieldsAlwaysSet) { + flatbuffers::FlatBufferBuilder fbb; + auto bool_table_offset = internal::CreateBoolTable(fbb, true); + auto string_table_offset = + internal::CreateStringTableDirect(fbb, "foo bar baz"); + DefaultStruct s{ + true, // b + 1, // i8 + 2, // i16 + 3, // i32 + 4, // i64 + 5, // u8 + 6, // u16 + 7, // u32 + 8, // u64 + 9, // f + 10.0, // d + internal::Enum_Second, // e + BoolStruct{true} // s + }; + std::vector v_b{true, false}; + std::vector v_i8{1, 2, 3}; + std::vector v_i16{1, 2, 3}; + std::vector v_i32{1, 2, 3}; + std::vector v_i64{1, 2, 3}; + std::vector v_u8{1, 2, 3}; + std::vector v_u16{1, 2, 3}; + std::vector v_u32{1, 2, 3}; + std::vector v_u64{1, 2, 3}; + std::vector v_f{1.0, 2.0, 3.0}; + std::vector v_d{1.0, 2.0, 3.0}; + std::vector> v_str{ + fbb.CreateString("foo"), fbb.CreateString("bar"), + fbb.CreateString("baz")}; + std::vector> v_e{ + internal::Enum_First, + internal::Enum_Second, + internal::Enum_Third, + }; + std::vector> v_t{bool_table_offset}; + std::vector> v_u_type{ + internal::Union_BoolTable, internal::Union_StringTable}; + std::vector> v_u{bool_table_offset.Union(), + string_table_offset.Union()}; + std::vector v_s{s}; + auto table_offset = + internal::CreateRequiredTableDirect(fbb, + "foo bar baz", // str + bool_table_offset, // t + internal::Union_BoolTable, // u_type + bool_table_offset.Union(), // u + &s, // s + &v_b, // v_b + &v_i8, // v_i8 + &v_i16, // v_i16 + &v_i32, // v_i32 + &v_i64, // v_i64 + &v_u8, // v_u8 + &v_u16, // v_u16 + &v_u32, // v_u32 + &v_u64, // v_u64 + &v_f, // v_f + &v_d, // v_d + &v_str, // v_str + &v_e, // v_e + &v_t, // v_t + &v_u_type, // v_u_type + &v_u, // v_u + &v_s // v_s + ); + fbb.Finish(table_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto domain = Arbitrary(); + Value val(domain, table); + absl::BitGen bitgen; + + absl::flat_hash_map set_fields{ + {"str", false}, {"t", false}, {"u_type", false}, + {"u", false}, {"s", false}, {"t.v_b", false}, + {"t.v_i8", false}, {"t.v_i16", false}, {"t.v_i32", false}, + {"t.v_i64", false}, {"t.v_u8", false}, {"t.v_u16", false}, + {"t.v_u32", false}, {"t.v_u64", false}, {"t.v_f", false}, + {"t.v_d", false}, {"t.v_e", false}, {"t.v_str", false}, + {"t.v_t", false}, {"t.v_u_type", false}, {"t.v_u", false}, + {"t.v_s", false}, + }; + + for (size_t i = 0; i < 10'000; ++i) { + val.Mutate(domain, bitgen, {}, true); + const auto& v = val.user_value; + + set_fields["str"] |= v->str() != nullptr; + set_fields["t"] |= v->t() != nullptr; + set_fields["u_type"] |= v->u_type() != internal::Union_NONE; + set_fields["u"] |= v->u() != nullptr; + set_fields["s"] |= v->s() != nullptr; + set_fields["t.v_b"] |= v->v_b() != nullptr; + set_fields["t.v_i8"] |= v->v_i8() != nullptr; + set_fields["t.v_i16"] |= v->v_i16() != nullptr; + set_fields["t.v_i32"] |= v->v_i32() != nullptr; + set_fields["t.v_i64"] |= v->v_i64() != nullptr; + set_fields["t.v_u8"] |= v->v_u8() != nullptr; + set_fields["t.v_u16"] |= v->v_u16() != nullptr; + set_fields["t.v_u32"] |= v->v_u32() != nullptr; + set_fields["t.v_u64"] |= v->v_u64() != nullptr; + set_fields["t.v_f"] |= v->v_f() != nullptr; + set_fields["t.v_d"] |= v->v_d() != nullptr; + set_fields["t.v_e"] |= v->v_e() != nullptr; + set_fields["t.v_str"] |= v->v_str() != nullptr; + set_fields["t.v_t"] |= v->v_t() != nullptr; + set_fields["t.v_u_type"] |= v->v_u_type() != nullptr; + set_fields["t.v_u"] |= v->v_u() != nullptr; + set_fields["t.v_s"] |= v->v_s() != nullptr; + + bool all_set = true; + for (const auto& [name, is_set] : set_fields) { + all_set &= is_set; + if (!is_set) { + break; + } + } + if (all_set) { + break; + } + } + + EXPECT_THAT(set_fields, Each(Pair(_, true))); +} + +TEST(FlatbuffersTableDomainImplTest, CountNumberOfFieldsWithNull) { + flatbuffers::FlatBufferBuilder fbb; + auto table_offset = + internal::CreateDefaultTableDirect(fbb, + true, // b + 1, // i8 + 2, // i16 + 3, // i32 + 4, // i64 + 5, // u8 + 6, // u16 + 7, // u32 + 8, // u64 + 9.0, // f + 10.0, // d + "foo bar baz", // str + internal::Enum_Second // e + ); + fbb.Finish(table_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto domain = Arbitrary(); + auto corpus = domain.FromValue(table); + ASSERT_TRUE(corpus.has_value()); + EXPECT_EQ(domain.CountNumberOfFields(corpus.value()), 32); +} + +TEST(FlatbuffersUnionDomainImpl, ParseCorpusRejectsInvalidValues) { + auto domain = Arbitrary(); + { + flatbuffers::FlatBufferBuilder fbb; + internal::CreateUnionTable(fbb, internal::Union_BoolTable, 0); + fbb.Finish(internal::CreateUnionTable(fbb, internal::Union_BoolTable, 0)); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize()); + ASSERT_TRUE(verifier.VerifyBuffer()); + + auto corpus = domain.FromValue(table); + ASSERT_TRUE(corpus.has_value()); + EXPECT_FALSE(domain.ValidateCorpusValue(corpus.value()).ok()); + } + { + internal::IRObject ir_object; + auto& subs = ir_object.MutableSubs(); + subs.reserve(2); + + auto& u_obj = subs.emplace_back(); + auto& u_subs = u_obj.MutableSubs(); + u_subs.reserve(2); + u_subs.emplace_back(1); // id + auto& u_opt_value = u_subs.emplace_back(); // value + auto& u_opt_value_subs = u_opt_value.MutableSubs(); + u_opt_value_subs.reserve(2); + u_opt_value_subs.emplace_back(1); // has value + auto& u_inner_value = u_opt_value_subs.emplace_back(); + + u_inner_value.MutableSubs().reserve(2); + u_inner_value.MutableSubs().emplace_back(-1); // type (invalid) + u_inner_value.MutableSubs().emplace_back(); // value + + auto corpus = domain.ParseCorpus(ir_object); + ASSERT_FALSE(corpus.has_value()); + } + { + internal::IRObject ir_object; + auto& subs = ir_object.MutableSubs(); + subs.reserve(2); + + auto& u_obj = subs.emplace_back(); + auto& u_subs = u_obj.MutableSubs(); + u_subs.reserve(2); + u_subs.emplace_back(1); // id + auto& u_opt_value = u_subs.emplace_back(); // value + auto& u_opt_value_subs = u_opt_value.MutableSubs(); + u_opt_value_subs.reserve(2); + u_opt_value_subs.emplace_back(1); // has value + auto& u_inner_value = u_opt_value_subs.emplace_back(); + + u_inner_value.MutableSubs().reserve(2); + u_inner_value.MutableSubs().emplace_back( + internal::Union_BoolTable); // type + auto& bool_table = u_inner_value.MutableSubs().emplace_back(); // value + auto& bool_table_subs = bool_table.MutableSubs(); + bool_table_subs.reserve(2); + bool_table_subs.emplace_back(200); // id (invalid) + u_subs.emplace_back(); // value + + auto corpus = domain.ParseCorpus(ir_object); + ASSERT_FALSE(corpus.has_value()); + } +} + +} // namespace +} // namespace fuzztest diff --git a/fuzztest/BUILD b/fuzztest/BUILD index 6a7f41492..c8693378b 100644 --- a/fuzztest/BUILD +++ b/fuzztest/BUILD @@ -425,6 +425,34 @@ cc_library( ], ) +cc_library( + name = "flatbuffers", + srcs = [ + "internal/domains/flatbuffers_domain_impl.cc", + "internal/domains/flatbuffers_domain_impl.h", + ], + hdrs = ["flatbuffers.h"], + deps = [ + ":any", + ":domain_core", + ":logging", + ":meta", + ":serialization", + ":status", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/base:core_headers", + "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/random:bit_gen_ref", + "@abseil-cpp//absl/random:distributions", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:str_format", + "@abseil-cpp//absl/synchronization", + "@flatbuffers//:runtime_cc", + ], +) + cc_library( name = "fixture_driver", srcs = ["internal/fixture_driver.cc"], @@ -804,6 +832,28 @@ cc_proto_library( deps = [":test_protobuf"], ) +# Derived from @flatbuffers//build_defs.bzl:flatbuffer_cc_library but allows output prefix for +# single source target and to have embedded schema file in the outputs. +genrule( + name = "test_flatbuffers_fbs", + srcs = ["internal/test_flatbuffers.fbs"], + outs = [ + "internal/test_flatbuffers_bfbs_generated.h", + "internal/test_flatbuffers_generated.h", + ], + cmd = "$(location @flatbuffers//:flatc) -c -o $(@D)/internal --bfbs-gen-embed --gen-name-strings $(SRCS)", + message = "Generating flatbuffer files for test_flatbuffers_fbs", + tools = ["@flatbuffers//:flatc"], +) + +cc_library( + name = "test_flatbuffers_cc_fbs", + srcs = [":test_flatbuffers_fbs"], + hdrs = [":test_flatbuffers_fbs"], + features = ["-parse_headers"], + deps = ["@flatbuffers//:runtime_cc"], +) + cc_library( name = "type_support", srcs = ["internal/type_support.cc"], diff --git a/fuzztest/CMakeLists.txt b/fuzztest/CMakeLists.txt index 8288b3168..64cd8692d 100644 --- a/fuzztest/CMakeLists.txt +++ b/fuzztest/CMakeLists.txt @@ -56,6 +56,40 @@ fuzztest_cc_library( fuzztest::fuzztest_macros ) +if (FUZZTEST_BUILD_FLATBUFFERS) + fuzztest_cc_library( + NAME + flatbuffers + HDRS + "flatbuffers.h" + "internal/domains/flatbuffers_domain_impl.h" + SRCS + "internal/domains/flatbuffers_domain_impl.cc" + DEPS + absl::algorithm_container + absl::core_headers + absl::flat_hash_map + absl::flat_hash_set + absl::nullability + absl::random_bit_gen_ref + absl::random_distributions + absl::random_random + absl::status + absl::statusor + absl::str_format + absl::strings + absl::synchronization + flatbuffers + fuzztest::any + fuzztest::domain_core + fuzztest::logging + fuzztest::meta + fuzztest::serialization + fuzztest::status + fuzztest::type_support + ) +endif() + fuzztest_cc_library( NAME fuzztest_macros @@ -805,6 +839,70 @@ if (FUZZTEST_BUILD_TESTING) "${CMAKE_CURRENT_BINARY_DIR}/.." ) + if (FUZZTEST_BUILD_FLATBUFFERS) + # Generate test flatbuffers + include_directories(${FLATBUFFERS_INCLUDE_DIR}) + set(FBS_SCHEMA_FILE "${CMAKE_CURRENT_LIST_DIR}/internal/test_flatbuffers.fbs") + set(FLATC_FLAGS "--bfbs-gen-embed" "--gen-name-strings") + + # Modified version of `flatbuffers_generate_headers` + # from https://github.com/google/flatbuffers/blob/master/CMake/BuildFlatBuffers.cmake + # Supports using an output prefix for single file header generation as well + # as the embedded schema header in the output set. + add_custom_command( + OUTPUT + "internal/test_flatbuffers_bfbs_generated.h" + "internal/test_flatbuffers_generated.h" + COMMAND + $ + -o "${CMAKE_CURRENT_BINARY_DIR}/internal" + -c + ${FBS_SCHEMA_FILE} + ${FLATC_FLAGS} + DEPENDS + flatc + ${FBS_SCHEMA_FILE} + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" + COMMENT "Building ${FBS_SCHEMA_FILE} flatbuffers..." + ) + + # Create an additional target as add_custom_command scope is only within + # same directory (CMakeFile.txt) + add_custom_target( + GENERATE_test_flatbuffers ALL + DEPENDS + "internal/test_flatbuffers_bfbs_generated.h" + "internal/test_flatbuffers_generated.h" + COMMENT "Generating flatbuffer target test_flatbuffers" + ) + + # Set up interface library + add_library(test_flatbuffers INTERFACE) + add_dependencies( + test_flatbuffers + flatc + ${FBS_SCHEMA_FILE} + ) + target_include_directories( + test_flatbuffers + INTERFACE "${CMAKE_CURRENT_BINARY_DIR}/internal" + ) + + # Organize file layout for IDEs. + source_group( + TREE "${CMAKE_CURRENT_BINARY_DIR}/internal" + PREFIX "Flatbuffers/Generated/Headers Files" + FILES + "${CMAKE_CURRENT_BINARY_DIR}/internal/test_flatbuffers_bfbs_generated.h" + "${CMAKE_CURRENT_BINARY_DIR}/internal/test_flatbuffers_generated.h" + ) + source_group( + TREE "${CMAKE_CURRENT_SOURCE_DIR}/internal" + PREFIX "Flatbuffers/Schemas" + FILES ${FBS_SCHEMA_FILE} + ) + endif() + endif () fuzztest_cc_library( diff --git a/fuzztest/flatbuffers.h b/fuzztest/flatbuffers.h new file mode 100644 index 000000000..b70ed361b --- /dev/null +++ b/fuzztest/flatbuffers.h @@ -0,0 +1,19 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FUZZTEST_FUZZTEST_FLATBUFFERS_H_ +#define FUZZTEST_FUZZTEST_FLATBUFFERS_H_ + +#include "./fuzztest/internal/domains/flatbuffers_domain_impl.h" // IWYU pragma: export +#endif // FUZZTEST_FUZZTEST_FLATBUFFERS_H_ diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.cc b/fuzztest/internal/domains/flatbuffers_domain_impl.cc new file mode 100644 index 000000000..ab6b76341 --- /dev/null +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.cc @@ -0,0 +1,555 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "./fuzztest/internal/domains/flatbuffers_domain_impl.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "flatbuffers/base.h" +#include "flatbuffers/flatbuffer_builder.h" +#include "flatbuffers/reflection_generated.h" +#include "flatbuffers/struct.h" +#include "flatbuffers/table.h" +#include "./fuzztest/domain_core.h" +#include "./fuzztest/internal/any.h" +#include "./fuzztest/internal/domains/domain_base.h" +#include "./fuzztest/internal/domains/domain_type_erasure.h" +#include "./fuzztest/internal/meta.h" +#include "./fuzztest/internal/serialization.h" + +namespace fuzztest { +namespace internal { + +// Gets a domain for a specific struct type. +template <> +auto FlatbuffersUnionDomainImpl::GetDomainForType( + const reflection::EnumVal& enum_value) const { + const reflection::Object* object = + schema_->objects()->Get(enum_value.union_type()->index()); + return Domain( + FlatbuffersStructUntypedDomainImpl{schema_, object}); +} + +// Gets a domain for a specific table type. +template <> +auto FlatbuffersUnionDomainImpl::GetDomainForType( + const reflection::EnumVal& enum_value) const { + const reflection::Object* object = + schema_->objects()->Get(enum_value.union_type()->index()); + return Domain( + FlatbuffersTableUntypedDomainImpl{schema_, object}); +} + +FlatbuffersUnionDomainImpl::corpus_type FlatbuffersUnionDomainImpl::Init( + absl::BitGenRef prng) { + if (auto seed = this->MaybeGetRandomSeed(prng)) { + return *seed; + } + + // Unions are encoded as the combination of two fields: an enum representing + // the union choice and the offset to the actual element. + // + // The following code follows that logic. + corpus_type val; + + // Prepare `union_choice`. + auto selected_type_enumval_index = + absl::Uniform(prng, 0ul, union_def_->values()->size()); + auto type_enumval = union_def_->values()->Get(selected_type_enumval_index); + if (type_enumval == nullptr) { + return val; + } + auto type_value = type_domain_.FromValue(type_enumval->value()); + if (!type_value.has_value()) { + return val; + } + val.first = *type_value; + + // FlatBuffers reserves the enumeration constant NONE (encoded as 0) to mean + // that the union field is not set. + if (type_enumval->value() == 0 /* NONE */) { + return val; + } + + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto inner_val = + GetSubDomain(*type_enumval).Init(prng); + val.second = std::move(inner_val); + } else { + auto inner_val = + GetSubDomain(*type_enumval).Init(prng); + val.second = std::move(inner_val); + } + return val; +} + +// Mutates the corpus value. +void FlatbuffersUnionDomainImpl::Mutate( + corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink) { + auto total_weight = CountNumberOfFields(val); + auto selected_weight = absl::Uniform(prng, 0ul, total_weight); + if (selected_weight == 0) { + // Mutate both type and value. + + // Deal with the type. + type_domain_.Mutate(val.first, prng, metadata, only_shrink); + val.second = GenericDomainCorpusType(std::in_place_type, nullptr); + auto type_value = type_domain_.GetValue(val.first); + if (type_value == 0 /* NONE */) { + // NONE is a special value, it means that the union is not set. + return; + } + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return; + } + + // Deal with the value. + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto inner_val = + GetSubDomain(*type_enumval).Init(prng); + val.second = std::move(inner_val); + } else { + auto inner_val = + GetSubDomain(*type_enumval).Init(prng); + val.second = std::move(inner_val); + } + } else { + // Keep the type, mutate the value. + auto type_value = type_domain_.GetValue(val.first); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = GetSubDomain(*type_enumval); + domain.MutateSelectedField(val.second, prng, metadata, only_shrink, + selected_weight - 1); + } else { + auto domain = GetSubDomain(*type_enumval); + domain.MutateSelectedField(val.second, prng, metadata, only_shrink, + selected_weight - 1); + } + } +} + +uint64_t FlatbuffersUnionDomainImpl::CountNumberOfFields(corpus_type& val) { + // Unions are encoded as the combination of two fields: an enum representing + // the union choice and the offset to the actual element. + // + // In turn, count starts with 1 to take care of the first field. + uint64_t count = 1; + auto type_value = type_domain_.GetValue(val.first); + if (type_value == 0 /* NONE */) { + // Union field is not set. + return count; + } + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return count; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = GetSubDomain(*type_enumval); + count += domain.CountNumberOfFields(val.second); + } else { + auto domain = GetSubDomain(*type_enumval); + count += domain.CountNumberOfFields(val.second); + } + return count; +} + +absl::Status FlatbuffersUnionDomainImpl::ValidateCorpusValue( + const corpus_type& corpus_value) const { + // Unions are encoded as the combination of two fields: an enum representing + // the union choice and the offset to the actual element. + // + // Both type and value should be validated. + // + // Start with the type validation. + auto type_value = type_domain_.GetValue(corpus_value.first); + if (type_value == 0 /* NONE */) { + // Union field is not set. + return absl::OkStatus(); + } + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid union type: ", type_value)); + } + + // Validate the value. + if (!corpus_value.second.has_value()) { + return absl::InvalidArgumentError("Union value is not set."); + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = GetSubDomain(*type_enumval); + return domain.ValidateCorpusValue(corpus_value.second); + } else { + auto domain = GetSubDomain(*type_enumval); + return domain.ValidateCorpusValue(corpus_value.second); + } +} + +// Converts the value to a corpus value. +std::optional +FlatbuffersUnionDomainImpl::FromValue(const value_type& value) const { + auto out = std::make_optional(); + auto type_value = type_domain_.FromValue(value.first); + if (type_value.has_value()) { + out->first = *type_value; + } + auto type_enumval = union_def_->values()->LookupByKey(value.first); + if (type_enumval == nullptr) { + return std::nullopt; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + std::optional inner_corpus; + if (object->is_struct()) { + auto domain = GetSubDomain(*type_enumval); + inner_corpus = + domain.FromValue(static_cast(value.second)); + } else { + auto domain = GetSubDomain(*type_enumval); + inner_corpus = + domain.FromValue(static_cast(value.second)); + } + if (inner_corpus.has_value()) { + out->second = std::move(inner_corpus.value()); + } + return out; +} + +// Converts the IRObject to a corpus value. +std::optional +FlatbuffersUnionDomainImpl::ParseCorpus(const IRObject& obj) const { + // Follows the structure created by `SerializeCorpus` to deserialize the + // IRObject. + corpus_type out; + auto subs = obj.Subs(); + if (!subs) { + return std::nullopt; + } + + // We expect 2 fields: the type and the value. + if (subs->size() != 2) { + return std::nullopt; + } + + // Parse the type which is stored in the first field of the IRObject subs. + auto type_corpus = type_domain_.ParseCorpus((*subs)[0]); + if (!type_corpus.has_value() || + !type_domain_.ValidateCorpusValue(*type_corpus).ok()) { + return std::nullopt; + } + out.first = *type_corpus; + auto type_value = type_domain_.GetValue(out.first); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return std::nullopt; + } + + // Parse the value. + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object == nullptr) { + return std::nullopt; + } + std::optional inner_corpus; + if (object->is_struct()) { + auto domain = GetSubDomain(*type_enumval); + // The value is stored in the second field of the IRObject subs. + inner_corpus = domain.ParseCorpus((*subs)[1]); + } else { + auto domain = GetSubDomain(*type_enumval); + // The value is stored in the second field of the IRObject subs. + inner_corpus = domain.ParseCorpus((*subs)[1]); + } + + if (inner_corpus.has_value()) { + out.second = std::move(inner_corpus.value()); + } + return out; +} + +// Converts the corpus value to an IRObject. +IRObject FlatbuffersUnionDomainImpl::SerializeCorpus( + const corpus_type& value) const { + IRObject out; + auto type_value = type_domain_.GetValue(value.first); + if (type_value == 0 /* NONE */) { + return out; + } + + auto& pair = out.MutableSubs(); + // We have 2 fields: the type and the value. + pair.reserve(2); + + // Serialize the type. + pair.push_back(type_domain_.SerializeCorpus(value.first)); + + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return out; + } + + // Serialize the value. + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = GetSubDomain(*type_enumval); + pair.push_back(domain.SerializeCorpus(value.second)); + } else { + auto domain = GetSubDomain(*type_enumval); + pair.push_back(domain.SerializeCorpus(value.second)); + } + return out; +} + +std::optional FlatbuffersUnionDomainImpl::BuildValue( + const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const { + // Get the object type. + auto type_value = type_domain_.GetValue(value.first); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr || type_value == 0 /* NONE */ || + !value.second.has_value()) { + return std::nullopt; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object == nullptr) { + return std::nullopt; + } + if (object->is_struct()) { + FlatbuffersStructUntypedDomainImpl domain{schema_, object}; + return domain.BuildValue( + value.second.GetAs>(), + builder); + } else { + FlatbuffersTableUntypedDomainImpl domain{schema_, object}; + return domain.BuildTable( + value.second.GetAs>(), + builder); + } +} + +void FlatbuffersUnionDomainImpl::Printer::PrintCorpusValue( + const corpus_type& value, domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + auto type_value = self.type_domain_.GetValue(value.first); + auto type_enumval = self.union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return; + } + absl::Format(out, "<%s>(", type_enumval->name()->str()); + if (type_value == 0 /* NONE */) { + absl::Format(out, "NONE"); + } else { + const reflection::Object* object = + self.schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = self.GetSubDomain(*type_enumval); + domain_implementor::PrintValue(domain, value.second, out, mode); + } else { + auto domain = self.GetSubDomain(*type_enumval); + domain_implementor::PrintValue(domain, value.second, out, mode); + } + } + absl::Format(out, ")"); +} + +FlatbuffersStructUntypedDomainImpl::corpus_type +FlatbuffersStructUntypedDomainImpl::Init(absl::BitGenRef prng) { + if (auto seed = this->MaybeGetRandomSeed(prng)) { + return *seed; + } + corpus_type val; + for (const auto* field : *struct_object_->fields()) { + VisitFlatbufferField(schema_, field, InitializeVisitor{*this, prng, val}); + } + return val; +} + +void FlatbuffersStructUntypedDomainImpl::Mutate( + corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink) { + auto total_weight = CountNumberOfFields(val); + auto selected_weight = + absl::Uniform(absl::IntervalClosedClosed, prng, 0ul, total_weight - 1); + + MutateSelectedField(val, prng, metadata, only_shrink, selected_weight); +} + +uint64_t FlatbuffersStructUntypedDomainImpl::CountNumberOfFields( + corpus_type& val) { + uint64_t total_weight = 0; + for (const auto* field : *struct_object_->fields()) { + VisitFlatbufferField(schema_, field, + CountNumberOfFieldsVisitor{*this, total_weight, val}); + } + return total_weight; +} + +// Mutates the selected field. +// The selected field index is based on the flattened tree. +uint64_t FlatbuffersStructUntypedDomainImpl::MutateSelectedField( + corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink, + uint64_t selected_field_index) { + uint64_t field_counter = 0; + for (const auto* field : *struct_object_->fields()) { + ++field_counter; + + if (field_counter == selected_field_index + 1) { + VisitFlatbufferField( + schema_, field, + MutateVisitor{*this, prng, metadata, only_shrink, val}); + return field_counter; + } + + auto base_type = field->type()->base_type(); + if (base_type == reflection::BaseType::Obj) { + auto sub_object = schema_->objects()->Get(field->type()->index()); + if (sub_object->is_struct()) { + field_counter += + GetSubDomain(field).MutateSelectedField( + val[field->id()], prng, metadata, only_shrink, + selected_field_index - field_counter); + } + } + + if (field_counter > selected_field_index) { + return field_counter; + } + } + return field_counter; +} + +absl::Status FlatbuffersStructUntypedDomainImpl::ValidateCorpusValue( + const corpus_type& corpus_value) const { + for (const auto& [id, field_corpus] : corpus_value) { + const reflection::Field* absl_nullable field = GetFieldById(id); + if (field == nullptr) continue; + absl::Status result; + VisitFlatbufferField(schema_, field, + ValidateVisitor{*this, field_corpus, result}); + if (!result.ok()) return result; + } + return absl::OkStatus(); +} + +std::optional +FlatbuffersStructUntypedDomainImpl::FromValue(const value_type& value) const { + if (value == nullptr) { + return std::nullopt; + } + corpus_type ret; + for (const auto* field : *struct_object_->fields()) { + VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, ret}); + } + return ret; +} + +std::optional +FlatbuffersStructUntypedDomainImpl::BuildValue( + const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const { + std::vector buf(struct_object_->bytesize()); + BuildValue(value, buf.data()); + builder.StartStruct(struct_object_->minalign()); + builder.PushBytes(buf.data(), buf.size()); + return builder.EndStruct(); +} + +void FlatbuffersStructUntypedDomainImpl::BuildValue(const corpus_type& value, + uint8_t* buf) const { + for (const auto* field : *struct_object_->fields()) { + VisitFlatbufferField(schema_, field, BuildValueVisitor{*this, value, buf}); + } +} + +std::optional +FlatbuffersStructUntypedDomainImpl::ParseCorpus(const IRObject& obj) const { + corpus_type out; + auto subs = obj.Subs(); + if (!subs) { + return std::nullopt; + } + out.reserve(subs->size()); + for (const auto& sub : *subs) { + auto pair_subs = sub.Subs(); + if (!pair_subs || pair_subs->size() != 2) { + return std::nullopt; + } + auto id = (*pair_subs)[0].GetScalar(); + if (!id.has_value()) { + return std::nullopt; + } + const reflection::Field* absl_nullable field = GetFieldById(id.value()); + if (field == nullptr) { + return std::nullopt; + } + std::optional inner_parsed; + VisitFlatbufferField(schema_, field, + ParseVisitor{*this, (*pair_subs)[1], inner_parsed}); + if (!inner_parsed) { + return std::nullopt; + } + out[id.value()] = *std::move(inner_parsed); + } + return out; +} + +// Converts the corpus value to an IRObject. +IRObject FlatbuffersStructUntypedDomainImpl::SerializeCorpus( + const corpus_type& value) const { + IRObject out; + auto& subs = out.MutableSubs(); + subs.reserve(value.size()); + for (const auto& [id, field_corpus] : value) { + const reflection::Field* absl_nullable field = GetFieldById(id); + if (field == nullptr) { + continue; + } + IRObject& pair = subs.emplace_back(); + auto& pair_subs = pair.MutableSubs(); + pair_subs.reserve(2); + pair_subs.emplace_back(field->id()); + VisitFlatbufferField( + schema_, field, + SerializeVisitor{*this, field_corpus, pair_subs.emplace_back()}); + } + return out; +} +} // namespace internal +} // namespace fuzztest diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.h b/fuzztest/internal/domains/flatbuffers_domain_impl.h new file mode 100644 index 000000000..ce711abb5 --- /dev/null +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.h @@ -0,0 +1,2106 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLATBUFFERS_DOMAIN_IMPL_H_ +#define FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLATBUFFERS_DOMAIN_IMPL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "flatbuffers/base.h" +#include "flatbuffers/flatbuffer_builder.h" +#include "flatbuffers/reflection.h" +#include "flatbuffers/reflection_generated.h" +#include "flatbuffers/string.h" +#include "flatbuffers/struct.h" +#include "flatbuffers/table.h" +#include "flatbuffers/vector.h" +#include "flatbuffers/verifier.h" +#include "./fuzztest/domain_core.h" +#include "./fuzztest/internal/any.h" +#include "./fuzztest/internal/domains/arbitrary_impl.h" +#include "./fuzztest/internal/domains/domain_base.h" +#include "./fuzztest/internal/domains/domain_type_erasure.h" +#include "./fuzztest/internal/domains/element_of_impl.h" +#include "./fuzztest/internal/logging.h" +#include "./fuzztest/internal/meta.h" +#include "./fuzztest/internal/serialization.h" +#include "./fuzztest/internal/status.h" + +namespace fuzztest::internal { + +template && + !std::is_same_v>> + +// +// Flatbuffers enum detection. +// +struct FlatbuffersEnumTag { + using type = Underlying; +}; + +template +struct is_flatbuffers_enum_tag : std::false_type {}; + +template +struct is_flatbuffers_enum_tag> + : std::true_type {}; + +template +inline constexpr bool is_flatbuffers_enum_tag_v = + is_flatbuffers_enum_tag::value; + +// +// Flatbuffers vector detection. +// +template +struct FlatbuffersVectorTag { + using value_type = T; +}; + +template +struct is_flatbuffers_vector_tag : std::false_type {}; + +template +struct is_flatbuffers_vector_tag> : std::true_type {}; + +template +inline constexpr bool is_flatbuffers_vector_tag_v = + is_flatbuffers_vector_tag::value; + +struct FlatbuffersArrayTag; +struct FlatbuffersTableTag; +struct FlatbuffersStructTag; +struct FlatbuffersUnionTag; + +// Dynamic to static dispatch visitor pattern for flatbuffers vector elements. +template +auto VisitFlatbufferVectorElementField(const reflection::Schema* schema, + const reflection::Field* field, + Visitor visitor) { + auto field_index = field->type()->index(); + auto element_type = field->type()->element(); + switch (element_type) { + case reflection::BaseType::Bool: + visitor.template Visit>(field); + break; + case reflection::BaseType::Byte: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::Short: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::Int: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::Long: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::UByte: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::UShort: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::UInt: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::ULong: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::Float: + visitor.template Visit>(field); + break; + case reflection::BaseType::Double: + visitor.template Visit>(field); + break; + case reflection::BaseType::String: + visitor.template Visit>(field); + break; + case reflection::BaseType::Obj: { + auto sub_object = schema->objects()->Get(field_index); + if (sub_object->is_struct()) { + visitor.template Visit>( + field); + } else { + visitor.template Visit>( + field); + } + break; + } + case reflection::BaseType::Union: + visitor.template Visit>(field); + break; + case reflection::BaseType::UType: + // Noop: Union types are visited at the same time as their corresponding + // union values. + break; + default: // Vector of vectors and vector of arrays are not supported. + FUZZTEST_INTERNAL_CHECK(false, "Unsupported vector base type"); + } +} + +// Dynamic to static dispatch visitor pattern. +template +auto VisitFlatbufferField(const reflection::Schema* absl_nonnull schema, + const reflection::Field* absl_nonnull field, + Visitor visitor) { + auto field_index = field->type()->index(); + switch (field->type()->base_type()) { + case reflection::BaseType::Bool: + visitor.template Visit(field); + break; + case reflection::BaseType::Byte: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::Short: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::Int: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::Long: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::UByte: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::UShort: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::UInt: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::ULong: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::Float: + visitor.template Visit(field); + break; + case reflection::BaseType::Double: + visitor.template Visit(field); + break; + case reflection::BaseType::String: + visitor.template Visit(field); + break; + case reflection::BaseType::Vector: + case reflection::BaseType::Vector64: + VisitFlatbufferVectorElementField(schema, field, visitor); + break; + case reflection::BaseType::Array: + visitor.template Visit(field); + break; + case reflection::BaseType::Obj: { + auto sub_object = schema->objects()->Get(field->type()->index()); + if (sub_object->is_struct()) { + visitor.template Visit(field); + } else { + visitor.template Visit(field); + } + break; + } + case reflection::BaseType::Union: + visitor.template Visit(field); + break; + case reflection::BaseType::UType: + // Noop: Union types are visited at the same time as their corresponding + // union values. + break; + default: + FUZZTEST_INTERNAL_CHECK(false, "Unsupported base type"); + } +} + +// Flatbuffers enum domain implementation. +template +class FlatbuffersEnumDomainImpl + // Value type: Underlaying + // Corpus type: ElementOfImplCorpusType + // See ElementOfImpl for more details. + : public ElementOfImpl { + public: + using typename ElementOfImpl::DomainBase::corpus_type; + using typename ElementOfImpl::DomainBase::value_type; + + explicit FlatbuffersEnumDomainImpl(const reflection::Enum* enum_def) + : ElementOfImpl(GetEnumValues(enum_def)), + enum_def_(enum_def) {} + + auto GetPrinter() const { return Printer{*this}; } + + private: + const reflection::Enum* enum_def_; + + static std::vector GetEnumValues( + const reflection::Enum* enum_def) { + std::vector values; + values.reserve(enum_def->values()->size()); + for (const auto* value : *enum_def->values()) { + FUZZTEST_INTERNAL_CHECK( + value->value() >= std::numeric_limits::min() && + value->value() <= std::numeric_limits::max(), + "Enum value from reflection is out of range for the target type."); + values.push_back(static_cast(value->value())); + } + return values; + } + + struct Printer { + const FlatbuffersEnumDomainImpl& self; + void PrintCorpusValue(const corpus_type& value, + domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + if (mode == domain_implementor::PrintMode::kHumanReadable) { + auto user_value = self.GetValue(value); + absl::Format( + out, "%s", + self.enum_def_->values()->LookupByKey(user_value)->name()->str()); + } else { + absl::Format(out, "%d", value); + } + } + }; +}; + +// From flatbuffers documentation: +// Unions are encoded as the combination of two fields: an enum representing the +// union choice and the offset to the actual element. +// The type of the enum is always uint8_t as generated by the flatbuffers +// compiler. +using FlatbuffersUnionTypeDomainImpl = FlatbuffersEnumDomainImpl; + +// Domain implementation for flatbuffers struct types. +// The corpus type is a map of field ids to field values. +class FlatbuffersStructUntypedDomainImpl + : public domain_implementor::DomainBase< + // Derived, for CRTP needs. See DomainBase for more details. + FlatbuffersStructUntypedDomainImpl, + // ValueType - user facing type + const flatbuffers::Struct* absl_nonnull, + // CorpusType - internal representation of ValueType, + // a map of field ids to field values. + absl::flat_hash_map< + // a.k.a. uint16_t + decltype(static_cast(nullptr)->id()), + // Fancy wrapper around `void*`: knows about the exact type of + // stored value and can copy it using exact type copy constructor + // via `CopyFrom` method. + GenericDomainCorpusType>> { + public: + using typename FlatbuffersStructUntypedDomainImpl::DomainBase::corpus_type; + using typename FlatbuffersStructUntypedDomainImpl::DomainBase::value_type; + + FlatbuffersStructUntypedDomainImpl(const reflection::Schema* schema, + const reflection::Object* struct_object) + : schema_(schema), struct_object_(struct_object) {} + + FlatbuffersStructUntypedDomainImpl( + const FlatbuffersStructUntypedDomainImpl& other) + : schema_(other.schema_), struct_object_(other.struct_object_) { + absl::MutexLock l(&other.mutex_); + domains_ = other.domains_; + } + + FlatbuffersStructUntypedDomainImpl& operator=( + const FlatbuffersStructUntypedDomainImpl& other) { + schema_ = other.schema_; + struct_object_ = other.struct_object_; + absl::MutexLock l(&other.mutex_); + domains_ = other.domains_; + return *this; + } + + FlatbuffersStructUntypedDomainImpl(FlatbuffersStructUntypedDomainImpl&& other) + : schema_(other.schema_), struct_object_(other.struct_object_) { + absl::MutexLock l(&other.mutex_); + domains_ = std::move(other.domains_); + } + + FlatbuffersStructUntypedDomainImpl& operator=( + FlatbuffersStructUntypedDomainImpl&& other) { + schema_ = other.schema_; + struct_object_ = other.struct_object_; + absl::MutexLock l(&other.mutex_); + domains_ = std::move(other.domains_); + return *this; + } + + // Returns the domain for the given field. + // The domain is cached, and the same instance is returned for the same + // field. + template + auto& GetSubDomain(const reflection::Field* absl_nonnull field) const { + using DomainT = decltype(GetDomainForField(field)); + // Do the operation under a lock to prevent race conditions in `const` + // methods. + absl::MutexLock l(&mutex_); + auto it = domains_.find(field->id()); + if (it == domains_.end()) { + it = domains_ + .try_emplace(field->id(), std::in_place_type, + GetDomainForField(field)) + .first; + } + return it->second.template GetAs(); + } + + // Initializes the corpus value. + corpus_type Init(absl::BitGenRef prng); + + // Mutates the corpus value. + void Mutate(corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, + bool only_shrink); + + // Counts the number of fields that can be mutated. + // Returns the number of fields in the flattened tree for supported field + // types. + uint64_t CountNumberOfFields(corpus_type& val); + + // Mutates the selected field. + // The selected field index is based on the flattened tree. + uint64_t MutateSelectedField( + corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink, + uint64_t selected_field_index); + + auto GetPrinter() const { return Printer{*this}; } + + absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const; + + value_type GetValue(const corpus_type& value) const { + // Untyped domain does not support GetValue since if it is a nested table + // it would need the top level table corpus value to be able to build it. + return nullptr; + } + + // Converts the struct pointer to a corpus value. + std::optional FromValue(const value_type& value) const; + + // Builds the struct in a builder. + std::optional BuildValue( + const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const; + + // Builds the struct in a buffer. + void BuildValue(const corpus_type& value, uint8_t* buf) const; + + // Converts the IRObject to a corpus value. + std::optional ParseCorpus(const IRObject& obj) const; + + // Converts the corpus value to an IRObject. + IRObject SerializeCorpus(const corpus_type& value) const; + + private: + const reflection::Schema* schema_; + const reflection::Object* struct_object_; + mutable absl::Mutex mutex_; + mutable absl::flat_hash_map< + // a.k.a. uint16_t + decltype(static_cast(nullptr)->id()), + // Fancy wrapper around `void*`: knows about the exact type of + // stored value and can copy it using exact type copy constructor + // via `CopyFrom` method. + GenericDomainCorpusType> + domains_ ABSL_GUARDED_BY(mutex_); + + const reflection::Field* absl_nullable GetFieldById( + typename corpus_type::key_type id) const { + const auto it = + absl::c_find_if(*struct_object_->fields(), + [id](const auto* field) { return field->id() == id; }); + return it != struct_object_->fields()->end() ? *it : nullptr; + } + + // Returns the domain for the given field. + template + auto GetDomainForField(const reflection::Field* absl_nonnull field) const { + if constexpr (std::is_same_v) { + // TODO(b/405938558): Implement this. + return Domain{Arbitrary()}; + } else if constexpr (is_flatbuffers_enum_tag_v) { + auto enum_object = schema_->enums()->Get(field->type()->index()); + return Domain{ + FlatbuffersEnumDomainImpl(enum_object)}; + } else if constexpr (std::is_same_v) { + const reflection::Object* sub_object = + schema_->objects()->Get(field->type()->index()); + FUZZTEST_INTERNAL_CHECK(sub_object->is_struct(), + "Field must be a struct type."); + return Domain( + FlatbuffersStructUntypedDomainImpl{schema_, sub_object}); + } else if constexpr (std::is_integral_v || std::is_floating_point_v) { + return Domain{Arbitrary()}; + } else { + FUZZTEST_INTERNAL_CHECK(false, "Unsupported type"); + // Return a no-op domain for the static checks to pass. + return Domain{Arbitrary()}; + } + } + + struct InitializeVisitor { + FlatbuffersStructUntypedDomainImpl& self; + absl::BitGenRef prng; + corpus_type& val; + + template + void Visit(const reflection::Field* absl_nonnull field) { + auto& domain = self.GetSubDomain(field); + val[field->id()] = domain.Init(prng); + } + }; + + struct CountNumberOfFieldsVisitor { + const FlatbuffersStructUntypedDomainImpl& self; + uint64_t& total_weight; + corpus_type& corpus; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + if constexpr (std::is_same_v) { + // TODO(b/405938558): Implement this. + return; + } else if constexpr (std::is_same_v) { + auto sub_object = self.schema_->objects()->Get(field->type()->index()); + FUZZTEST_INTERNAL_CHECK(sub_object->is_struct(), + "Field must be a struct type."); + auto sub_domain = self.GetSubDomain(field); + total_weight += sub_domain.CountNumberOfFields(corpus.at(field->id())); + } else { + total_weight++; + } + } + }; + + struct MutateVisitor { + FlatbuffersStructUntypedDomainImpl& self; + absl::BitGenRef prng; + const domain_implementor::MutationMetadata& metadata; + bool only_shrink; + corpus_type& val; + + template + void Visit(const reflection::Field* absl_nonnull field) { + auto& domain = self.GetSubDomain(field); + if (auto it = val.find(field->id()); it != val.end()) { + domain.Mutate(it->second, prng, metadata, only_shrink); + } else if (!only_shrink) { + val[field->id()] = domain.Init(prng); + } + } + }; + + struct ParseVisitor { + const FlatbuffersStructUntypedDomainImpl& self; + const IRObject& obj; + std::optional& out; + + template + void Visit(const reflection::Field* absl_nonnull field) { + out = self.GetSubDomain(field).ParseCorpus(obj); + } + }; + + struct SerializeVisitor { + const FlatbuffersStructUntypedDomainImpl& self; + const GenericDomainCorpusType& corpus_value; + IRObject& out; + + template + void Visit(const reflection::Field* absl_nonnull field) { + out = self.GetSubDomain(field).SerializeCorpus(corpus_value); + } + }; + + struct Printer { + const FlatbuffersStructUntypedDomainImpl& self; + + void PrintCorpusValue(const corpus_type& value, + domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + absl::Format(out, "{"); + bool first = true; + for (const auto& [id, field_corpus] : value) { + if (!first) { + absl::Format(out, ", "); + } + const reflection::Field* absl_nullable field = self.GetFieldById(id); + if (field == nullptr) { + absl::Format(out, "", id); + } else { + VisitFlatbufferField(self.schema_, field, + PrinterVisitor{self, field_corpus, out, mode}); + } + first = false; + } + absl::Format(out, "}"); + } + }; + + struct PrinterVisitor { + const FlatbuffersStructUntypedDomainImpl& self; + const GenericDomainCorpusType& val; + domain_implementor::RawSink out; + domain_implementor::PrintMode mode; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + auto& domain = self.GetSubDomain(field); + absl::Format(out, "%s: ", field->name()->str()); + domain_implementor::PrintValue(domain, val, out, mode); + } + }; + + struct ValidateVisitor { + const FlatbuffersStructUntypedDomainImpl& self; + const GenericDomainCorpusType& corpus_value; + absl::Status& out; + + template + void Visit(const reflection::Field* absl_nonnull field) { + auto& domain = self.GetSubDomain(field); + out = domain.ValidateCorpusValue(corpus_value); + if (!out.ok()) { + out = Prefix(out, absl::StrCat("Invalid value for field ", + field->name()->str())); + } + } + }; + + struct FromValueVisitor { + const FlatbuffersStructUntypedDomainImpl& self; + value_type value; + corpus_type& out; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + [[maybe_unused]] + reflection::BaseType base_type = field->type()->base_type(); + auto& domain = self.GetSubDomain(field); + std::optional>> inner_corpus; + + if constexpr (is_flatbuffers_enum_tag_v) { + FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Byte && + base_type <= reflection::BaseType::ULong && + field->type()->index() >= 0, + "Field must be an enum type."); + auto inner_value = value->GetField(field->offset()); + inner_corpus = domain.FromValue(inner_value); + } else if constexpr (std::is_integral_v || + std::is_floating_point_v) { + FUZZTEST_INTERNAL_CHECK(flatbuffers::IsScalar(base_type), + "Field must be an scalar type."); + auto inner_value = value->GetField(field->offset()); + inner_corpus = domain.FromValue(inner_value); + } else if constexpr (std::is_same_v) { + auto sub_object = self.schema_->objects()->Get(field->type()->index()); + FUZZTEST_INTERNAL_CHECK( + base_type == reflection::BaseType::Obj && sub_object->is_struct(), + "Field must be a struct type."); + auto inner_value = + value->GetStruct(field->offset()); + inner_corpus = domain.FromValue(inner_value); + } + + if (inner_corpus.has_value()) { + out[field->id()] = std::move(*inner_corpus); + } + }; + }; + + struct BuildValueVisitor { + const FlatbuffersStructUntypedDomainImpl& self; + const corpus_type& corpus_value; + uint8_t* struct_ptr; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + [[maybe_unused]] + reflection::BaseType base_type = field->type()->base_type(); + auto& domain = self.GetSubDomain(field); + if constexpr (is_flatbuffers_enum_tag_v || std::is_integral_v || + std::is_floating_point_v) { + FUZZTEST_INTERNAL_CHECK(flatbuffers::IsScalar(base_type), + "Field must be an scalar type."); + auto inner_value = domain.GetValue(corpus_value.at(field->id())); + flatbuffers::WriteScalar(struct_ptr + field->offset(), inner_value); + } else if constexpr (std::is_same_v) { + auto sub_object = self.schema_->objects()->Get(field->type()->index()); + FUZZTEST_INTERNAL_CHECK( + base_type == reflection::BaseType::Obj && sub_object->is_struct(), + "Field must be a struct type."); + auto inner_corpus_value = + corpus_value.at(field->id()) + .GetAs(); + FlatbuffersStructUntypedDomainImpl sub_domain(self.schema_, sub_object); + for (const auto* nested_field : *sub_object->fields()) { + VisitFlatbufferField(sub_domain.schema_, nested_field, + BuildValueVisitor{sub_domain, inner_corpus_value, + struct_ptr + field->offset()}); + } + } else if constexpr (std::is_same_v) { + // TODO (b/405938558): Implement array support. + } + } + }; +}; + +class FlatbuffersTableUntypedDomainImpl; + +// Flatbuffers union domain implementation. +class FlatbuffersUnionDomainImpl + : public domain_implementor::DomainBase< + // Derived, for CRTP needs. See DomainBase for more details. + FlatbuffersUnionDomainImpl, + // ValueType - user facing type + std::pair, + // CorpusType - internal representation of ValueType + std::pair< + // `Union choice` type representation + typename FlatbuffersUnionTypeDomainImpl::corpus_type, + // Fancy wrapper around `void*`: knows about the exact type + // of stored value and can copy it using exact type copy + // constructor via `AnyBase::CopyFrom` method. + GenericDomainCorpusType>> { + public: + using typename FlatbuffersUnionDomainImpl::DomainBase::corpus_type; + using typename FlatbuffersUnionDomainImpl::DomainBase::value_type; + + FlatbuffersUnionDomainImpl(const reflection::Schema* schema, + const reflection::Enum* union_def) + : schema_(schema), union_def_(union_def), type_domain_(union_def) {} + + FlatbuffersUnionDomainImpl(const FlatbuffersUnionDomainImpl& other) + : schema_(other.schema_), + union_def_(other.union_def_), + type_domain_(other.type_domain_) { + absl::MutexLock l(&other.mutex_); + domains_ = other.domains_; + } + + FlatbuffersUnionDomainImpl(FlatbuffersUnionDomainImpl&& other) + : schema_(other.schema_), + union_def_(other.union_def_), + type_domain_(std::move(other.type_domain_)) { + absl::MutexLock l(&other.mutex_); + domains_ = std::move(other.domains_); + } + + FlatbuffersUnionDomainImpl& operator=( + const FlatbuffersUnionDomainImpl& other) { + schema_ = other.schema_; + union_def_ = other.union_def_; + type_domain_ = other.type_domain_; + absl::MutexLock l(&other.mutex_); + domains_ = other.domains_; + return *this; + } + + FlatbuffersUnionDomainImpl& operator=(FlatbuffersUnionDomainImpl&& other) { + schema_ = other.schema_; + union_def_ = other.union_def_; + type_domain_ = std::move(other.type_domain_); + absl::MutexLock l(&other.mutex_); + domains_ = std::move(other.domains_); + return *this; + } + + // Initializes the corpus value. + corpus_type Init(absl::BitGenRef prng); + + // Mutates the corpus value. + void Mutate(corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, + bool only_shrink); + + uint64_t CountNumberOfFields(corpus_type& val); + + auto GetPrinter() const { return Printer{*this}; } + + absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const; + + // UNSUPPORTED: Flatbuffers unions user values are not supported. + value_type GetValue(const corpus_type& value) const { + FUZZTEST_INTERNAL_CHECK(false, "GetValue is not supported for unions."); + } + + // Gets the type of the union field. + auto GetType(const corpus_type& value) const { + return type_domain_.GetValue(value.first); + } + + // Creates flatbuffer from the corpus value. + std::optional BuildValue( + const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const; + + // Converts the value to a corpus value. + std::optional FromValue(const value_type& value) const; + + // Converts the IRObject to a corpus value. + std::optional ParseCorpus(const IRObject& obj) const; + + // Converts the corpus value to an IRObject. + IRObject SerializeCorpus(const corpus_type& value) const; + + // Returns the domain for the given enum value. + template + auto& GetSubDomain(const reflection::EnumVal& enum_value) const { + using DomainT = decltype(GetDomainForType(enum_value)); + absl::MutexLock l(&mutex_); + auto it = domains_.find(enum_value.value()); + if (it == domains_.end()) { + it = domains_ + .try_emplace(enum_value.value(), std::in_place_type, + GetDomainForType(enum_value)) + .first; + } + return it->second.GetAs(); + } + + private: + const reflection::Schema* schema_; + const reflection::Enum* union_def_; + FlatbuffersEnumDomainImpl type_domain_; + mutable absl::Mutex mutex_; + mutable absl::flat_hash_map + domains_ ABSL_GUARDED_BY(mutex_); + + // Creates new or returns existing domain for the given enum value. + template + auto GetDomainForType(const reflection::EnumVal& enum_value) const; + + struct Printer { + const FlatbuffersUnionDomainImpl& self; + + void PrintCorpusValue(const corpus_type& value, + domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const; + }; +}; + +// Domain implementation for flatbuffers untyped tables. +// The corpus type is a map of field ids to field values. +class FlatbuffersTableUntypedDomainImpl + : public domain_implementor::DomainBase< + // Derived, for CRTP needs. See DomainBase for more details. + FlatbuffersTableUntypedDomainImpl, + // ValueType - user facing type + const flatbuffers::Table* absl_nonnull, + // CorpusType - internal representation of ValueType, + // a map of field ids to field values. + absl::flat_hash_map< + // a.k.a. uint16_t + decltype(static_cast(nullptr)->id()), + // Fancy wrapper around `void*`: knows about the exact type of + // stored value and can copy it using exact type copy constructor + // via `CopyFrom` method. + GenericDomainCorpusType>> { + public: + using typename FlatbuffersTableUntypedDomainImpl::DomainBase::corpus_type; + using typename FlatbuffersTableUntypedDomainImpl::DomainBase::value_type; + + explicit FlatbuffersTableUntypedDomainImpl( + const reflection::Schema* schema, const reflection::Object* table_object) + : schema_(schema), table_object_(table_object) {} + + FlatbuffersTableUntypedDomainImpl( + const FlatbuffersTableUntypedDomainImpl& other) + : schema_(other.schema_), table_object_(other.table_object_) { + absl::MutexLock l_other(&other.mutex_); + absl::MutexLock l_this(&mutex_); + domains_ = other.domains_; + } + + FlatbuffersTableUntypedDomainImpl& operator=( + const FlatbuffersTableUntypedDomainImpl& other) { + schema_ = other.schema_; + table_object_ = other.table_object_; + absl::MutexLock l_other(&other.mutex_); + absl::MutexLock l_this(&mutex_); + domains_ = other.domains_; + return *this; + } + + FlatbuffersTableUntypedDomainImpl(FlatbuffersTableUntypedDomainImpl&& other) + : schema_(other.schema_), table_object_(other.table_object_) { + absl::MutexLock l_other(&other.mutex_); + absl::MutexLock l_this(&mutex_); + domains_ = std::move(other.domains_); + } + + FlatbuffersTableUntypedDomainImpl& operator=( + FlatbuffersTableUntypedDomainImpl&& other) { + schema_ = other.schema_; + table_object_ = other.table_object_; + absl::MutexLock l_other(&other.mutex_); + absl::MutexLock l_this(&mutex_); + domains_ = std::move(other.domains_); + return *this; + } + + // Initializes the corpus value. + corpus_type Init(absl::BitGenRef prng) { + if (auto seed = this->MaybeGetRandomSeed(prng)) { + return *seed; + } + corpus_type val; + for (const auto* field : *table_object_->fields()) { + VisitFlatbufferField(schema_, field, InitializeVisitor{*this, prng, val}); + } + return val; + } + + // Mutates the corpus value. + void Mutate(corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, + bool only_shrink) { + auto total_weight = CountNumberOfFields(val); + auto selected_weight = + absl::Uniform(absl::IntervalClosedClosed, prng, 0ul, total_weight - 1); + + MutateSelectedField(val, prng, metadata, only_shrink, selected_weight); + } + + // Returns the domain for the given vector field. + template + auto GetDomainForVectorField(const reflection::Field* field) const { + if constexpr (is_flatbuffers_enum_tag_v) { + auto enum_object = schema_->enums()->Get(field->type()->index()); + auto inner = OptionalOf( + VectorOf( + FlatbuffersEnumDomainImpl(enum_object)) + .WithMaxSize(std::numeric_limits::max())); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>{inner}; + } else if constexpr (std::is_same_v) { + auto table_object = schema_->objects()->Get(field->type()->index()); + auto inner = OptionalOf( + VectorOf(FlatbuffersTableUntypedDomainImpl{schema_, table_object}) + .WithMaxSize(std::numeric_limits::max())); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>>{ + inner}; + } else if constexpr (std::is_same_v) { + auto struct_object = schema_->objects()->Get(field->type()->index()); + auto inner = OptionalOf( + VectorOf(FlatbuffersStructUntypedDomainImpl{schema_, struct_object}) + .WithMaxSize(std::numeric_limits::max())); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>>{ + inner}; + } else if constexpr (std::is_same_v) { + auto union_type = schema_->enums()->Get(field->type()->index()); + auto inner = OptionalOf( + VectorOf(FlatbuffersUnionDomainImpl{schema_, union_type}) + .WithMaxSize(std::numeric_limits::max())); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>{inner}; + } else { + auto inner = OptionalOf( + VectorOf(ArbitraryImpl()) + .WithMaxSize(std::numeric_limits::max())); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>>{inner}; + } + } + + // Returns the domain for the given field. + template + auto GetDomainForField(const reflection::Field* field) const { + if constexpr (std::is_same_v) { + FUZZTEST_INTERNAL_CHECK( + false, "Arrays in tables are not supported in flatbuffers."); + // Return a placeholder domain to make the compiler happy. + return Domain>{Arbitrary>()}; + } else if constexpr (is_flatbuffers_enum_tag_v) { + auto enum_object = schema_->enums()->Get(field->type()->index()); + auto domain = + OptionalOf(FlatbuffersEnumDomainImpl(enum_object)); + if (!field->optional()) { + domain.SetWithoutNull(); + } + return Domain>{domain}; + } else if constexpr (std::is_same_v) { + auto table_object = schema_->objects()->Get(field->type()->index()); + auto inner = + OptionalOf(FlatbuffersTableUntypedDomainImpl{schema_, table_object}); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>{inner}; + } else if constexpr (std::is_same_v) { + auto struct_object = schema_->objects()->Get(field->type()->index()); + auto inner = OptionalOf( + FlatbuffersStructUntypedDomainImpl{schema_, struct_object}); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>(inner); + } else if constexpr (std::is_same_v) { + auto union_type = schema_->enums()->Get(field->type()->index()); + auto inner = OptionalOf(FlatbuffersUnionDomainImpl{schema_, union_type}); + return Domain>{inner}; + } else if constexpr (is_flatbuffers_vector_tag_v) { + return GetDomainForVectorField(field); + } else { + auto inner = OptionalOf(ArbitraryImpl()); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>{inner}; + } + } + + // Returns the domain for the given field. + // The domain is cached, and the same instance is returned for the same + // field. + template + auto& GetSubDomain(const reflection::Field* absl_nonnull field) const { + using DomainT = decltype(GetDomainForField(field)); + // Do the operation under a lock to prevent race conditions in `const` + // methods. + absl::MutexLock l(&mutex_); + auto it = domains_.find(field->id()); + if (it == domains_.end()) { + it = domains_ + .try_emplace(field->id(), std::in_place_type, + GetDomainForField(field)) + .first; + } + return it->second.template GetAs(); + } + + // Counts the number of fields that can be mutated. + // Returns the number of fields in the flattened tree for supported field + // types. + uint64_t CountNumberOfFields(corpus_type& val) { + uint64_t total_weight = 0; + for (const auto* field : *table_object_->fields()) { + VisitFlatbufferField( + schema_, field, CountNumberOfFieldsVisitor{*this, total_weight, val}); + } + return total_weight; + } + + // Mutates the selected field. + // The selected field index is based on the flattened tree. + uint64_t MutateSelectedField( + corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink, + uint64_t selected_field_index) { + uint64_t field_counter = 0; + for (const auto* field : *table_object_->fields()) { + ++field_counter; + + if (field_counter == selected_field_index + 1) { + VisitFlatbufferField( + schema_, field, + MutateVisitor{*this, prng, metadata, only_shrink, val}); + return field_counter; + } + + auto base_type = field->type()->base_type(); + if (base_type == reflection::BaseType::Obj) { + auto sub_object = schema_->objects()->Get(field->type()->index()); + if (sub_object->is_struct()) { + field_counter += + GetSubDomain(field).MutateSelectedField( + val[field->id()], prng, metadata, only_shrink, + selected_field_index - field_counter); + } else { + field_counter += + GetSubDomain(field).MutateSelectedField( + val[field->id()], prng, metadata, only_shrink, + selected_field_index - field_counter); + } + } + + if (base_type == reflection::BaseType::Vector || + base_type == reflection::BaseType::Vector64) { + auto elem_type = field->type()->element(); + if (elem_type == reflection::BaseType::Obj) { + auto sub_object = schema_->objects()->Get(field->type()->index()); + if (!sub_object->is_struct()) { + field_counter += + GetSubDomain>(field) + .MutateSelectedField(val[field->id()], prng, metadata, + only_shrink, + selected_field_index - field_counter); + } else { + field_counter += + GetSubDomain>(field) + .MutateSelectedField(val[field->id()], prng, metadata, + only_shrink, + selected_field_index - field_counter); + } + } else if (elem_type == reflection::BaseType::Union) { + field_counter += + GetSubDomain>(field) + .MutateSelectedField(val[field->id()], prng, metadata, + only_shrink, + selected_field_index - field_counter); + } + } + + if (base_type == reflection::BaseType::Union) { + field_counter += + GetSubDomain(field).MutateSelectedField( + val[field->id()], prng, metadata, only_shrink, + selected_field_index - field_counter); + } + + if (field_counter > selected_field_index) { + return field_counter; + } + } + return field_counter; + } + + auto GetPrinter() const { return Printer{*this}; } + + absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const { + for (const auto& [id, field_corpus] : corpus_value) { + const reflection::Field* absl_nullable field = GetFieldById(id); + if (field == nullptr) continue; + absl::Status result; + VisitFlatbufferField(schema_, field, + ValidateVisitor{*this, field_corpus, result}); + if (!result.ok()) return result; + } + return absl::OkStatus(); + } + + value_type GetValue(const corpus_type& value) const { + // Untyped domain does not support GetValue since if it is a nested table + // it would need the top level table corpus value to be able to build it. + return nullptr; + } + + // Converts the table pointer to a corpus value. + std::optional FromValue(const value_type& value) const { + if (value == nullptr) { + return std::nullopt; + } + corpus_type ret; + for (const auto* field : *table_object_->fields()) { + VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, ret}); + } + return ret; + } + + // Converts the IRObject to a corpus value. + std::optional ParseCorpus(const IRObject& obj) const { + corpus_type out; + auto subs = obj.Subs(); + if (!subs) { + return std::nullopt; + } + // Follows the structure created by `SerializeCorpus` to deserialize the + // IRObject. + + // subs->size() represents the number of fields in the table. + out.reserve(subs->size()); + for (const auto& sub : *subs) { + auto pair_subs = sub.Subs(); + // Each field is represented by a pair of field id and the serialized + // corpus value. + if (!pair_subs || pair_subs->size() != 2) { + return std::nullopt; + } + + // Deserialize the field id. + auto id = (*pair_subs)[0].GetScalar(); + if (!id.has_value()) { + return std::nullopt; + } + + // Get information about the field from reflection. + const reflection::Field* absl_nullable field = GetFieldById(id.value()); + if (field == nullptr) { + return std::nullopt; + } + + if (field->type()->base_type() == reflection::BaseType::UType) { + // Union types are handled as part of the union field. + continue; + } + + // Deserialize the field corpus value. + std::optional inner_parsed; + VisitFlatbufferField(schema_, field, + ParseVisitor{*this, (*pair_subs)[1], inner_parsed}); + if (!inner_parsed) { + return std::nullopt; + } + out[id.value()] = *std::move(inner_parsed); + } + return out; + } + + // Converts the corpus value to an IRObject. + IRObject SerializeCorpus(const corpus_type& value) const { + IRObject out; + auto& subs = out.MutableSubs(); + subs.reserve(value.size()); + + // Each field is represented by a pair of field id and the serialized + // corpus value. + for (const auto& [id, field_corpus] : value) { + // Get information about the field from reflection. + const reflection::Field* absl_nullable field = GetFieldById(id); + if (field == nullptr) { + continue; + } + IRObject& pair = subs.emplace_back(); + auto& pair_subs = pair.MutableSubs(); + pair_subs.reserve(2); + + // Serialize the field id. + pair_subs.emplace_back(field->id()); + + // Serialize the field corpus value. + VisitFlatbufferField( + schema_, field, + SerializeVisitor{*this, field_corpus, pair_subs.emplace_back()}); + } + return out; + } + + uint32_t BuildTable(const corpus_type& value, + flatbuffers::FlatBufferBuilder& builder) const { + // Add all the fields to the builder. + + // Offsets is the map of field id to its offset in the table. + absl::flat_hash_map + offsets; + + // Some fields are stored inline in the flatbuffer table itself (a.k.a + // "inline fields") and some are referenced by their offsets (a.k.a. "out of + // line fields"). + // + // "Out of line fields" shall be added to the builder first, so that we can + // refer to them in the final table. + for (const auto& [id, field_corpus] : value) { + const reflection::Field* absl_nullable field = GetFieldById(id); + if (field == nullptr) { + continue; + } + // Take care of strings, and tables. + VisitFlatbufferField( + schema_, field, + TableFieldBuilderVisitor{*this, builder, offsets, field_corpus}); + } + + // Now it is time to build the final table. + uint32_t table_start = builder.StartTable(); + for (const auto& [id, field_corpus] : value) { + const reflection::Field* absl_nullable field = GetFieldById(id); + if (field == nullptr) { + continue; + } + + // Visit all fields. + // + // Inline fields will be stored in the table itself, out of line fields + // will be referenced by their offsets. + VisitFlatbufferField( + schema_, field, + TableBuilderVisitor{*this, builder, offsets, field_corpus}); + } + return builder.EndTable(table_start); + } + + private: + const reflection::Schema* absl_nonnull schema_; + const reflection::Object* absl_nonnull table_object_; + mutable absl::Mutex mutex_; + mutable absl::flat_hash_map + domains_ ABSL_GUARDED_BY(mutex_); + + const reflection::Field* absl_nullable GetFieldById( + typename corpus_type::key_type id) const { + const auto it = + absl::c_find_if(*table_object_->fields(), + [id](const auto* field) { return field->id() == id; }); + return it != table_object_->fields()->end() ? *it : nullptr; + } + + struct SerializeVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + const GenericDomainCorpusType& corpus_value; + IRObject& out; + + template + void Visit(const reflection::Field* absl_nonnull field) { + out = self.GetSubDomain(field).SerializeCorpus(corpus_value); + } + }; + + struct FromValueVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + value_type value; + corpus_type& out; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + [[maybe_unused]] + reflection::BaseType base_type = field->type()->base_type(); + auto& domain = self.GetSubDomain(field); + value_type_t> inner_value; + + if constexpr (is_flatbuffers_enum_tag_v) { + FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Byte && + base_type <= reflection::BaseType::ULong, + "Field must be an enum type."); + if (field->optional() && !value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + inner_value = std::make_optional(value->GetField( + field->offset(), field->default_integer())); + } + } else if constexpr (std::is_integral_v) { + FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Bool && + base_type <= reflection::BaseType::ULong, + "Field must be an integer type."); + if (field->optional() && !value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + inner_value = std::make_optional( + value->GetField(field->offset(), field->default_integer())); + } + } else if constexpr (std::is_floating_point_v) { + FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Float && + base_type <= reflection::BaseType::Double, + "Field must be a floating point type."); + if (field->optional() && !value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + inner_value = std::make_optional( + value->GetField(field->offset(), field->default_real())); + } + } else if constexpr (std::is_same_v) { + FUZZTEST_INTERNAL_CHECK(base_type == reflection::BaseType::String, + "Field must be a string type."); + if (!value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + inner_value = std::make_optional( + value->GetPointer(field->offset())->str()); + } + } else if constexpr (std::is_same_v) { + auto sub_object = self.schema_->objects()->Get(field->type()->index()); + FUZZTEST_INTERNAL_CHECK( + base_type == reflection::BaseType::Obj && !sub_object->is_struct(), + "Field must be a table type."); + inner_value = + value->GetPointer(field->offset()); + } else if constexpr (std::is_same_v) { + auto sub_object = self.schema_->objects()->Get(field->type()->index()); + FUZZTEST_INTERNAL_CHECK( + base_type == reflection::BaseType::Obj && sub_object->is_struct(), + "Field must be a struct type."); + inner_value = + value->GetStruct(field->offset()); + } else if constexpr (is_flatbuffers_vector_tag_v) { + FUZZTEST_INTERNAL_CHECK(base_type == reflection::BaseType::Vector || + base_type == reflection::BaseType::Vector64, + "Field must be a vector type."); + if (!value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + VisitVector>( + field, inner_value); + } + } else if constexpr (std::is_same_v) { + constexpr char kUnionTypeFieldSuffix[] = "_type"; + auto enumdef = self.schema_->enums()->Get(field->type()->index()); + auto type_field = self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + if (type_field == nullptr) { + return; + } + auto union_type = value->GetField(type_field->offset(), 0); + if (union_type > 0 /* NONE */) { + auto enumval = enumdef->values()->LookupByKey(union_type); + auto union_object = + self.schema_->objects()->Get(enumval->union_type()->index()); + if (union_object->is_struct()) { + auto union_value = value->template GetPointer( + field->offset()); + inner_value = std::make_pair(union_type, union_value); + } else { + auto union_value = + value->GetPointer(field->offset()); + inner_value = std::make_pair(union_type, union_value); + } + } + } + + auto inner = domain.FromValue(inner_value); + if (inner) { + out[field->id()] = std::move(inner.value()); + } + }; + + template + void VisitVector(const reflection::Field* field, + value_type_t& inner_value) const { + if constexpr (std::is_integral_v || + std::is_floating_point_v) { + auto vec = value->GetPointer*>( + field->offset()); + inner_value = std::make_optional(std::vector()); + inner_value->reserve(vec->size()); + for (auto i = 0; i < vec->size(); ++i) { + inner_value->push_back(vec->Get(i)); + } + } else if constexpr (is_flatbuffers_enum_tag_v) { + using Underlaying = typename ElementType::type; + auto vec = value->GetPointer*>( + field->offset()); + inner_value = std::make_optional(std::vector()); + inner_value->reserve(vec->size()); + for (auto i = 0; i < vec->size(); ++i) { + inner_value->push_back(vec->Get(i)); + } + } else if constexpr (std::is_same_v) { + auto vec = value->GetPointer< + flatbuffers::Vector>*>( + field->offset()); + inner_value = std::make_optional(std::vector()); + inner_value->reserve(vec->size()); + for (auto i = 0; i < vec->size(); ++i) { + inner_value->push_back(vec->Get(i)->str()); + } + } else if constexpr (std::is_same_v) { + auto vec = value->GetPointer< + flatbuffers::Vector>*>( + field->offset()); + inner_value = + std::make_optional(std::vector()); + inner_value->reserve(vec->size()); + for (auto i = 0; i < vec->size(); ++i) { + inner_value->push_back(vec->Get(i)); + } + } else if constexpr (std::is_same_v) { + const reflection::Object* struct_object = + self.schema_->objects()->Get(field->type()->index()); + auto vec = + value->template GetPointer*>( + field->offset()); + if (vec == nullptr) { + return; + } + inner_value = + std::make_optional(std::vector()); + inner_value->reserve(vec->size()); + for (std::remove_pointer_t::size_type i = 0; + i < vec->size(); ++i) { + const uint8_t* struct_data_ptr = + vec->Data() + i * struct_object->bytesize(); + inner_value->push_back( + reinterpret_cast(struct_data_ptr)); + } + } else if constexpr (std::is_same_v) { + constexpr char kUnionTypeFieldSuffix[] = "_type"; + auto type_field = self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + if (type_field == nullptr) { + return; + } + auto type_vec = value->GetPointer*>( + type_field->offset()); + auto value_vec = + value->GetPointer>*>( + field->offset()); + inner_value = std::make_optional( + typename std::decay_t::value_type{}); + inner_value->reserve(value_vec->size()); + for (auto i = 0; i < value_vec->size(); ++i) { + inner_value->push_back( + std::make_pair(type_vec->Get(i), value_vec->Get(i))); + } + } + } + }; + + // Create out-of-line table fields, see `BuildTable` for details. + struct TableFieldBuilderVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + flatbuffers::FlatBufferBuilder& builder; + absl::flat_hash_map& + offsets; + const typename corpus_type::value_type::second_type& corpus_value; + + template + void Visit(const reflection::Field* absl_nonnull field) { + if constexpr (std::is_same_v) { + auto& domain = self.GetSubDomain(field); + auto user_value = domain.GetValue(corpus_value); + if (user_value.has_value()) { + auto offset = + builder.CreateString(user_value->data(), user_value->size()).o; + offsets.insert({field->id(), offset}); + } + } else if constexpr (std::is_same_v) { + FlatbuffersTableUntypedDomainImpl inner_domain( + self.schema_, self.schema_->objects()->Get(field->type()->index())); + auto opt_corpus = + corpus_value + .GetAs>(); + if (std::holds_alternative(opt_corpus)) { + auto inner_corpus = std::get(opt_corpus) + .GetAs(); + auto offset = inner_domain.BuildTable(inner_corpus, builder); + offsets.insert({field->id(), offset}); + } + } else if constexpr (is_flatbuffers_vector_tag_v) { + VisitVector(field, self.GetSubDomain(field)); + } else if constexpr (std::is_same_v) { + const reflection::Enum* union_type = + self.schema_->enums()->Get(field->type()->index()); + FlatbuffersUnionDomainImpl inner_domain{self.schema_, union_type}; + auto opt_corpus = + corpus_value + .GetAs>(); + if (std::holds_alternative(opt_corpus)) { + auto inner_corpus = + std::get(opt_corpus) + .GetAs>(); + auto offset = inner_domain.BuildValue(inner_corpus, builder); + if (offset.has_value()) { + offsets.insert({field->id(), *offset}); + } + } + } + } + + private: + template || + std::is_floating_point_v || + is_flatbuffers_enum_tag_v, + int> = 0> + void VisitVector(const reflection::Field* field, const Domain& domain) { + auto value = domain.GetValue(corpus_value); + if (value && (!value->empty() || !field->optional())) { + offsets.insert({field->id(), builder.CreateVector(*value).o}); + } else if (!value && !field->optional()) { + // Handle case where value is std::nullopt but field is not optional + // Create an empty vector of the appropriate type. + if constexpr (is_flatbuffers_enum_tag_v) { + offsets.insert( + {field->id(), + builder.CreateVector(std::vector()).o}); + } else { + offsets.insert( + {field->id(), builder.CreateVector(std::vector()).o}); + } + } + } + + template < + typename Element, typename Domain, + std::enable_if_t, int> = 0> + void VisitVector(const reflection::Field* field, const Domain& domain) { + auto opt_corpus = + corpus_value + .GetAs>(); + if (std::holds_alternative(opt_corpus)) { + return; + } + auto container_corpus = std::get(opt_corpus) + .GetAs>(); + if (field->optional() && container_corpus.empty()) { + return; + } + + FlatbuffersTableUntypedDomainImpl inner_domain( + self.schema_, self.schema_->objects()->Get(field->type()->index())); + std::vector> vec_offsets; + vec_offsets.reserve(container_corpus.size()); + for (auto& inner_corpus : container_corpus) { + auto offset = inner_domain.BuildTable(inner_corpus, builder); + vec_offsets.push_back(offset); + } + offsets.insert({field->id(), builder.CreateVector(vec_offsets).o}); + } + + template , + int> = 0> + void VisitVector(const reflection::Field* field, const Domain& domain) { + auto struct_object = self.schema_->objects()->Get(field->type()->index()); + auto opt_corpus = corpus_value.template GetAs< + std::variant>(); + if (std::holds_alternative(opt_corpus)) { + return; + } + auto container_corpus = std::get(opt_corpus) + .template GetAs>(); + uint8_t* vec_ptr = nullptr; + FlatbuffersStructUntypedDomainImpl inner_domain(self.schema_, + struct_object); + auto vec_offset = builder.CreateUninitializedVector( + container_corpus.size(), struct_object->bytesize(), + struct_object->minalign(), &vec_ptr); + size_t i = 0; + for (const auto& inner_corpus : container_corpus) { + uint8_t* current_struct_ptr = vec_ptr + i * struct_object->bytesize(); + inner_domain.BuildValue(inner_corpus, current_struct_ptr); + ++i; + } + offsets.insert({field->id(), vec_offset}); + } + + template , int> = 0> + void VisitVector(const reflection::Field* field, const Domain& domain) { + auto value = domain.GetValue(corpus_value); + if (!value) { + return; + } + std::vector> vec_offsets; + vec_offsets.reserve(value->size()); + for (const auto& str : *value) { + auto offset = builder.CreateString(str); + vec_offsets.push_back(offset); + } + offsets.insert({field->id(), builder.CreateVector(vec_offsets).o}); + } + + template < + typename Element, typename Domain, + std::enable_if_t, int> = 0> + void VisitVector(const reflection::Field* field, const Domain& domain) { + const reflection::Enum* union_type = + self.schema_->enums()->Get(field->type()->index()); + constexpr char kUnionTypeFieldSuffix[] = "_type"; + const reflection::Field* type_field = + self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + + auto opt_corpus = + corpus_value + .GetAs>(); + if (std::holds_alternative(opt_corpus)) { + return; + } + FlatbuffersUnionDomainImpl inner_domain{self.schema_, union_type}; + auto container_corpus = + std::get(opt_corpus) + .GetAs>>(); + + std::vector>::first_type> + vec_types; + vec_types.reserve(container_corpus.size()); + vec_types.reserve(container_corpus.size()); + std::vector> vec_offsets; + vec_offsets.reserve(container_corpus.size()); + vec_offsets.reserve(container_corpus.size()); + for (auto& inner_corpus : container_corpus) { + auto offset = inner_domain.BuildValue(inner_corpus, builder); + if (offset.has_value()) { + vec_offsets.push_back(*offset); + vec_types.push_back(inner_domain.GetType(inner_corpus)); + } + } + offsets.insert({field->id(), builder.CreateVector(vec_offsets).o}); + offsets.insert({type_field->id(), builder.CreateVector(vec_types).o}); + } + }; + + // Create complete table: store "inline fields" values inline, and store + // just offsets for "out-of-line fields". See `BuildTable` for details. + struct TableBuilderVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + flatbuffers::FlatBufferBuilder& builder; + absl::flat_hash_map& + offsets; + const typename corpus_type::value_type::second_type& corpus_value; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + auto size = flatbuffers::GetTypeSize(field->type()->base_type()); + if constexpr (std::is_integral_v || std::is_floating_point_v || + is_flatbuffers_enum_tag_v) { + auto& domain = self.GetSubDomain(field); + auto v = domain.GetValue(corpus_value); + if (!v) { + return; + } + // Store "inline field" value inline. + builder.Align(size); + builder.PushBytes(reinterpret_cast(&v), size); + builder.TrackField(field->offset(), builder.GetSize()); + } else if constexpr (std::is_same_v || + is_flatbuffers_vector_tag_v) { + // "Out-of-line field". Store just offset. + if constexpr (is_flatbuffers_vector_tag_v) { + if constexpr (std::is_same_v) { + constexpr char kUnionTypeFieldSuffix[] = "_type"; + const reflection::Field* type_field = + self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + if (auto it = offsets.find(type_field->id()); it != offsets.end()) { + builder.AddOffset(type_field->offset(), + flatbuffers::Offset<>(it->second)); + } + } + } + if (auto it = offsets.find(field->id()); it != offsets.end()) { + builder.AddOffset( + field->offset(), + flatbuffers::Offset(it->second)); + } + } else if constexpr (std::is_same_v) { + // "Out-of-line field". Store just offset. + if (auto it = offsets.find(field->id()); it != offsets.end()) { + builder.AddOffset( + field->offset(), + flatbuffers::Offset(it->second)); + } + } else if constexpr (std::is_same_v) { + FlatbuffersStructUntypedDomainImpl domain( + self.schema_, self.schema_->objects()->Get(field->type()->index())); + auto opt_corpus = corpus_value.template GetAs< + std::variant>(); + if (std::holds_alternative(opt_corpus)) { + return; + } + auto inner_corpus = + std::get(opt_corpus) + .template GetAs>(); + auto offset = domain.BuildValue(inner_corpus, builder); + if (offset.has_value()) { + builder.AddStructOffset(field->offset(), offset.value()); + } + } else if constexpr (std::is_same_v) { + // From flatbuffers documentation: + // Unions are encoded as the combination of two fields: an enum + // representing the union choice and the offset to the actual element + const reflection::Enum* union_type = + self.schema_->enums()->Get(field->type()->index()); + FlatbuffersUnionDomainImpl domain(self.schema_, union_type); + if (auto it = offsets.find(field->id()); it != offsets.end()) { + // Store just an offset to the actual union element. + builder.AddOffset(field->offset(), + flatbuffers::Offset(it->second)); + + constexpr char kUnionTypeFieldSuffix[] = "_type"; + const reflection::Field* type_field = + self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + auto opt_corpus = corpus_value.GetAs< + std::variant>(); + if (std::holds_alternative(opt_corpus)) { + return; + } + auto inner_corpus = std::get(opt_corpus) + .GetAs>(); + auto type_value = domain.GetType(inner_corpus); + auto size = flatbuffers::GetTypeSize(type_field->type()->base_type()); + // Store the type value inline. + builder.Align(size); + builder.PushBytes(reinterpret_cast(&type_value), + size); + builder.TrackField(type_field->offset(), builder.GetSize()); + } + } + } + }; + + struct ParseVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + const IRObject& obj; + std::optional& out; + + template + void Visit(const reflection::Field* absl_nonnull field) { + out = self.GetSubDomain(field).ParseCorpus(obj); + } + }; + + struct ValidateVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + const GenericDomainCorpusType& corpus_value; + absl::Status& out; + + template + void Visit(const reflection::Field* absl_nonnull field) { + auto& domain = self.GetSubDomain(field); + out = domain.ValidateCorpusValue(corpus_value); + if (!out.ok()) { + out = Prefix(out, absl::StrCat("Invalid value for field ", + field->name()->str())); + } + } + }; + + struct InitializeVisitor { + FlatbuffersTableUntypedDomainImpl& self; + absl::BitGenRef prng; + corpus_type& val; + + template + void Visit(const reflection::Field* absl_nonnull field) { + auto& domain = self.GetSubDomain(field); + val[field->id()] = domain.Init(prng); + } + }; + + struct CountNumberOfFieldsVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + uint64_t& total_weight; + corpus_type& corpus; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + // Add the weight of the field itself. + total_weight += 1; + + auto domain = self.GetSubDomain(field); + if (auto it = corpus.find(field->id()); it != corpus.end()) { + // Add the weight of the field corpus. + total_weight += domain.CountNumberOfFields(it->second); + } + } + }; + + struct MutateVisitor { + FlatbuffersTableUntypedDomainImpl& self; + absl::BitGenRef prng; + const domain_implementor::MutationMetadata& metadata; + bool only_shrink; + corpus_type& val; + + template + void Visit(const reflection::Field* absl_nonnull field) { + auto& domain = self.GetSubDomain(field); + if (auto it = val.find(field->id()); it != val.end()) { + domain.Mutate(it->second, prng, metadata, only_shrink); + } else if (!only_shrink) { + val[field->id()] = domain.Init(prng); + } + } + }; + + struct Printer { + const FlatbuffersTableUntypedDomainImpl& self; + + void PrintCorpusValue(const corpus_type& value, + domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + absl::Format(out, "{"); + bool first = true; + for (const auto& [id, field_corpus] : value) { + if (!first) { + absl::Format(out, ", "); + } + const reflection::Field* absl_nullable field = self.GetFieldById(id); + if (field == nullptr) { + absl::Format(out, "", id); + } else { + VisitFlatbufferField(self.schema_, field, + PrinterVisitor{self, field_corpus, out, mode}); + } + first = false; + } + absl::Format(out, "}"); + } + }; + + struct PrinterVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + const GenericDomainCorpusType& val; + domain_implementor::RawSink out; + domain_implementor::PrintMode mode; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + auto& domain = self.GetSubDomain(field); + absl::Format(out, "%s: ", field->name()->str()); + domain_implementor::PrintValue(domain, val, out, mode); + } + }; +}; + +// Domain implementation for flatbuffers generated table classes. +// The corpus type is a pair of: +// - A map of field ids to field values. +// - The serialized buffer of the table. +template +class FlatbuffersTableDomainImpl + : public domain_implementor::DomainBase< + // Derived, for CRTP needs. See DomainBase for more details. + FlatbuffersTableDomainImpl, + // ValueType - user facing type, exact flatbuffer + const T* absl_nonnull, + // CorpusType - internal representation of ValueType + std::pair< + // Map of field ids to field values (note *Untyped*). + typename FlatbuffersTableUntypedDomainImpl::corpus_type, + // ^^^^^^^ + // Serialized flatbuffer. + std::vector>> { + public: + using typename FlatbuffersTableDomainImpl::DomainBase::corpus_type; + using typename FlatbuffersTableDomainImpl::DomainBase::value_type; + + FlatbuffersTableDomainImpl() { + flatbuffers::Verifier verifier(T::BinarySchema::data(), + T::BinarySchema::size()); + FUZZTEST_INTERNAL_CHECK(reflection::VerifySchemaBuffer(verifier), + "Invalid schema for flatbuffers table."); + auto schema = reflection::GetSchema(T::BinarySchema::data()); + auto table_object = + schema->objects()->LookupByKey(T::GetFullyQualifiedName()); + inner_ = FlatbuffersTableUntypedDomainImpl{schema, table_object}; + } + + FlatbuffersTableDomainImpl(const FlatbuffersTableDomainImpl& other) + : inner_(other.inner_) { + builder_.Clear(); + } + + FlatbuffersTableDomainImpl& operator=( + const FlatbuffersTableDomainImpl& other) { + if (this == &other) return *this; + inner_ = other.inner_; + builder_.Clear(); + return *this; + } + + FlatbuffersTableDomainImpl(FlatbuffersTableDomainImpl&& other) + : inner_(std::move(other.inner_)) { + builder_.Clear(); + } + + FlatbuffersTableDomainImpl& operator=(FlatbuffersTableDomainImpl&& other) { + if (this == &other) return *this; + inner_ = std::move(other.inner_); + builder_.Clear(); + return *this; + } + + // Initializes the table with random values. + corpus_type Init(absl::BitGenRef prng) { + if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed; + + // Create new map of field ids to field values + auto val = inner_->Init(prng); + // Serialize the map into a flatbuffer + auto offset = inner_->BuildTable(val, builder_); + builder_.Finish(flatbuffers::Offset(offset)); + // Store the serialized buffer in a vector. + auto buffer = + std::vector(builder_.GetBufferPointer(), + builder_.GetBufferPointer() + builder_.GetSize()); + builder_.Clear(); + + // Return corpus value: pair of the map and the serialized buffer. + return std::make_pair(val, std::move(buffer)); + } + + // Returns the number of fields in the table. + uint64_t CountNumberOfFields(corpus_type& val) { + return inner_->CountNumberOfFields(val.first); + } + + // Mutates the given corpus value. + void Mutate(corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, + bool only_shrink) { + // Modify values in the map. + inner_->Mutate(val.first, prng, metadata, only_shrink); + // Serialize the map into a flatbuffer and store it in vector + val.second = BuildBuffer(val.first); + } + + // Converts corpus value into the exact flatbuffer. + value_type GetValue(const corpus_type& value) const { + return flatbuffers::GetRoot(value.second.data()); + } + + // Creates corpus value from the exact flatbuffer. + std::optional FromValue(const value_type& value) const { + auto val = inner_->FromValue((const flatbuffers::Table*)value); + if (!val.has_value()) return std::nullopt; + return std::make_optional(std::make_pair(*val, BuildBuffer(*val))); + } + + // Returns the printer for the table. + auto GetPrinter() const { return Printer{*inner_}; } + + // Returns the parsed corpus value. + std::optional ParseCorpus(const IRObject& obj) const { + auto val = inner_->ParseCorpus(obj); + if (!val.has_value()) return std::nullopt; + return std::make_optional(std::make_pair(*val, BuildBuffer(*val))); + } + + // Returns the serialized corpus value. + IRObject SerializeCorpus(const corpus_type& corpus_value) const { + return inner_->SerializeCorpus(corpus_value.first); + } + + // Returns the status of the given corpus value. + absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const { + return inner_->ValidateCorpusValue(corpus_value.first); + } + + private: + std::optional inner_; + mutable flatbuffers::FlatBufferBuilder builder_; + + struct Printer { + const FlatbuffersTableUntypedDomainImpl& inner; + + void PrintCorpusValue(const corpus_type& value, + domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + inner.GetPrinter().PrintCorpusValue(value.first, out, mode); + } + }; + + std::vector BuildBuffer( + const typename corpus_type::first_type& val) const { + auto offset = inner_->BuildTable(val, builder_); + builder_.Finish(flatbuffers::Offset(offset)); + auto buffer = + std::vector(builder_.GetBufferPointer(), + builder_.GetBufferPointer() + builder_.GetSize()); + builder_.Clear(); + return buffer; + } +}; + +template +class ArbitraryImpl>> + : public FlatbuffersTableDomainImpl {}; + +} // namespace fuzztest::internal +#endif // FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLATBUFFERS_DOMAIN_IMPL_H_ diff --git a/fuzztest/internal/meta.h b/fuzztest/internal/meta.h index 4ddada107..c36d80b9a 100644 --- a/fuzztest/internal/meta.h +++ b/fuzztest/internal/meta.h @@ -200,6 +200,22 @@ template inline constexpr bool is_protocol_buffer_enum_v = IsProtocolBufferEnumImpl(true); +template +inline constexpr bool is_flatbuffers_table_v = false; + +// Flatbuffers tables generated structs do not have a public base class, so we +// check for a few specific methods: +// - T is a struct. +// - T has a `Builder` type. +// - T has a `BinarySchema` type with a static method `data()` (only available +// when passing `--bfbs-gen-embed` to the flatbuffer compiler). +// - T has a static method called `GetFullyQualifiedName` (only available when +// passing `--gen-name-strings` to the flatbuffer compiler). +template +inline constexpr bool is_flatbuffers_table_v< + T, std::void_t> = true; + template inline constexpr bool has_size_v = Requires([](auto v) -> decltype(v.size()) {}); diff --git a/fuzztest/internal/test_flatbuffers.fbs b/fuzztest/internal/test_flatbuffers.fbs new file mode 100644 index 000000000..a34fa4abd --- /dev/null +++ b/fuzztest/internal/test_flatbuffers.fbs @@ -0,0 +1,140 @@ +namespace fuzztest.internal; + +enum Enum: byte { + First, + Second, + Third +} + +struct BoolStruct { + b: bool; +} + +struct DefaultStruct { + b: bool; + i8: byte; + i16: short; + i32: int; + i64: long; + u8: ubyte; + u16: ushort; + u32: uint; + u64: ulong; + f: float; + d: double; + e: Enum; + s: BoolStruct; +} + +table BoolTable { + b: bool; +} + +table StringTable { + str: string; +} + +union Union { + BoolTable, + StringTable, + BoolStruct, +} + +table UnionTable { + u: Union; +} + +table DefaultTable { + b: bool; + i8: byte; + i16: short; + i32: int; + i64: long; + u8: ubyte; + u16: ushort; + u32: uint; + u64: ulong; + f: float; + d: double; + str: string; + e: Enum; + t: BoolTable; + u: Union; + s: DefaultStruct; + v_b: [bool]; + v_i8: [byte]; + v_i16: [short]; + v_i32: [int]; + v_i64: [long]; + v_u8: [ubyte]; + v_u16: [ushort]; + v_u32: [uint]; + v_u64: [ulong]; + v_f: [float]; + v_d: [double]; + v_str: [string]; + v_e: [Enum]; + v_t: [BoolTable]; + v_u: [Union]; + v_s: [DefaultStruct]; +} + +table OptionalTable { + b: bool = null; + i8: byte = null; + i16: short = null; + i32: int = null; + i64: long = null; + u8: ubyte = null; + u16: ushort = null; + u32: uint = null; + u64: ulong = null; + f: float = null; + d: double = null; + str: string; + e: Enum = null; + t: BoolTable; + u: Union; + s: DefaultStruct; + v_b: [bool]; + v_i8: [byte]; + v_i16: [short]; + v_i32: [int]; + v_i64: [long]; + v_u8: [ubyte]; + v_u16: [ushort]; + v_u32: [uint]; + v_u64: [ulong]; + v_f: [float]; + v_d: [double]; + v_str: [string]; + v_e: [Enum]; + v_t: [BoolTable]; + v_u: [Union]; + v_s: [DefaultStruct]; +} + +table RequiredTable { + str: string (required); + t: BoolTable (required); + u: Union (required); + s: DefaultStruct (required); + v_b: [bool] (required); + v_i8: [byte] (required); + v_i16: [short] (required); + v_i32: [int] (required); + v_i64: [long] (required); + v_u8: [ubyte] (required); + v_u16: [ushort] (required); + v_u32: [uint] (required); + v_u64: [ulong] (required); + v_f: [float] (required); + v_d: [double] (required); + v_str: [string] (required); + v_e: [Enum] (required); + v_t: [BoolTable] (required); + v_u: [Union] (required); + v_s: [DefaultStruct] (required); +} + +root_type DefaultTable;