diff --git a/MODULE.bazel b/MODULE.bazel index 9de1d249a..043422f44 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -42,6 +42,10 @@ bazel_dep( name = "platforms", version = "0.0.10", ) +bazel_dep( + name = "flatbuffers", + version = "25.2.10" +) # GoogleTest is not a dev dependency, because it's needed when FuzzTest is used # with GoogleTest integration (e.g., googletest_adaptor). Note that the FuzzTest # framework can be used without GoogleTest integration as well. @@ -55,8 +59,6 @@ bazel_dep( name = "protobuf", version = "30.2", ) -# TODO(lszekeres): Make this a dev dependency, as the protobuf library is only -# required for testing. bazel_dep( name = "rules_proto", version = "7.1.0", diff --git a/domain_tests/BUILD b/domain_tests/BUILD index 436a2c9fc..4e207e380 100644 --- a/domain_tests/BUILD +++ b/domain_tests/BUILD @@ -33,6 +33,21 @@ cc_test( ], ) +cc_test( + name = "arbitrary_domains_flatbuffers_test", + srcs = ["arbitrary_domains_flatbuffers_test.cc"], + deps = [ + ":domain_testing", + "@abseil-cpp//absl/random", + "@com_google_fuzztest//fuzztest:domain", + "@com_google_fuzztest//fuzztest:flatbuffers", + "@com_google_fuzztest//fuzztest:meta", + "@com_google_fuzztest//fuzztest:test_flatbuffers_cc_fbs", + "@flatbuffers//:runtime_cc", + "@googletest//:gtest_main", + ], +) + cc_test( name = "arbitrary_domains_protobuf_test", srcs = ["arbitrary_domains_protobuf_test.cc"], diff --git a/domain_tests/arbitrary_domains_flatbuffers_test.cc b/domain_tests/arbitrary_domains_flatbuffers_test.cc new file mode 100644 index 000000000..aca9ea523 --- /dev/null +++ b/domain_tests/arbitrary_domains_flatbuffers_test.cc @@ -0,0 +1,526 @@ +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/random/random.h" +#include "flatbuffers/base.h" +#include "flatbuffers/buffer.h" +#include "flatbuffers/flatbuffer_builder.h" +#include "flatbuffers/string.h" +#include "flatbuffers/vector.h" +#include "./fuzztest/domain.h" +#include "./domain_tests/domain_testing.h" +#include "./fuzztest/flatbuffers.h" +#include "./fuzztest/internal/meta.h" +#include "./fuzztest/test_flatbuffers_generated.h" + +namespace fuzztest { +namespace { + +using ::fuzztest::internal::NestedTestFbsTable; +using ::fuzztest::internal::OptionalRequiredTestFbsTable; +using ::fuzztest::internal::SimpleTestFbsTable; +using ::fuzztest::internal::TestFbsEnum; +using ::fuzztest::internal::UnionTestFbsTable; +using ::fuzztest::internal::VectorsTestFbsTable; +using ::testing::Contains; +using ::testing::IsTrue; +using ::testing::ResultOf; + +TEST(FlatbuffersMetaTest, IsFlatbuffersTable) { + static_assert(internal::is_flatbuffers_table_v); + static_assert(!internal::is_flatbuffers_table_v); +} + +TEST(FlatbuffersTableDomainImplTest, SimpleTestFbsTableValueRoundTrip) { + auto domain = Arbitrary(); + + flatbuffers::FlatBufferBuilder fbb; + auto table_offset = internal::CreateSimpleTestFbsTableDirect( + fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second); + fbb.Finish(table_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto corpus = domain.FromValue(table); + ASSERT_TRUE(corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*corpus)); + + auto ir = domain.SerializeCorpus(corpus.value()); + + auto new_corpus = domain.ParseCorpus(ir); + ASSERT_TRUE(new_corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*new_corpus)); + + auto new_table = domain.GetValue(*new_corpus); + EXPECT_EQ(new_table->b(), true); + EXPECT_EQ(new_table->f(), 1.0); + EXPECT_EQ(new_table->str()->str(), "foo bar baz"); + EXPECT_TRUE(new_table->e() == internal::TestFbsEnum_Second); +} + +TEST(FlatbuffersTableDomainImplTest, InitGeneratesSeeds) { + auto domain = Arbitrary(); + + flatbuffers::FlatBufferBuilder fbb; + auto table_offset = internal::CreateSimpleTestFbsTableDirect( + fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second); + fbb.Finish(table_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + domain.WithSeeds({table}); + + std::vector> values; + absl::BitGen bitgen; + values.reserve(1000); + for (int i = 0; i < 1000; ++i) { + Value value(domain, bitgen); + values.push_back(std::move(value)); + } + + EXPECT_THAT( + values, + Contains(ResultOf( + [table](const auto& val) { + bool has_same_str = + val.user_value->str() == nullptr && table->str() == nullptr; + if (val.user_value->str() != nullptr && table->str() != nullptr) { + has_same_str = + val.user_value->str()->str() == table->str()->str(); + } + return (val.user_value->b() == table->b() && + val.user_value->f() == table->f() && + val.user_value->e() == table->e() && has_same_str); + }, + IsTrue()))); +} + +TEST(FlatbuffersTableDomainImplTest, EventuallyMutatesAllTableFields) { + auto domain = Arbitrary(); + + absl::BitGen bitgen; + Value val(domain, bitgen); + + const auto verify_field_changes = [&](std::string_view name, auto get) { + Set values; + + int iterations = 10'000; + while (--iterations > 0 && values.size() < 2) { + values.insert(get(val.user_value)); + val.Mutate(domain, bitgen, {}, false); + } + EXPECT_GT(iterations, 0) + << "Field: " << name << " -- " << testing::PrintToString(values); + }; + + verify_field_changes("b", [](auto v) { return v->b(); }); + verify_field_changes("f", [](auto v) { return v->f(); }); + verify_field_changes("str", + [](auto v) { return v->str() ? v->str()->str() : ""; }); + verify_field_changes("e", [](auto v) { return v->e(); }); +} + +TEST(FlatbuffersTableDomainImplTest, OptionalFieldsEventuallyBecomeEmpty) { + auto domain = Arbitrary(); + + absl::BitGen bitgen; + Value val(domain, bitgen); + + const auto verify_field_becomes_null = [&](std::string_view name, auto has) { + for (int i = 0; i < 10'000; ++i) { + val.Mutate(domain, bitgen, {}, false); + if (!has(val.user_value)) { + break; + } + } + EXPECT_FALSE(has(val.user_value)) << "Field never became unset: " << name; + }; + + verify_field_becomes_null("opt_scalar", + [](auto v) { return v->opt_scalar().has_value(); }); + verify_field_becomes_null("opt_str", + [](auto v) { return v->opt_str() != nullptr; }); +} + +TEST(FlatbuffersTableDomainImplTest, DefaultAndRequiredFieldsAlwaysSet) { + auto domain = Arbitrary(); + + absl::BitGen bitgen; + Value val(domain, bitgen); + + const auto verify_field_always_set = [&](std::string_view name, auto has) { + for (int i = 0; i < 10'000; ++i) { + val.Mutate(domain, bitgen, {}, false); + if (!has(val.user_value)) { + break; + } + } + EXPECT_TRUE(has(val.user_value)) << "Field is not set: " << name; + }; + + verify_field_always_set("def_scalar", [](auto v) { return true; }); + verify_field_always_set("req_str", + [](auto v) { return v->req_str() != nullptr; }); +} + +TEST(FlatbuffersTableDomainImplTest, NestedTableValueRoundTrip) { + auto domain = Arbitrary(); + absl::BitGen bitgen; + Value val(domain, bitgen); + + flatbuffers::FlatBufferBuilder fbb; + auto child_offset = internal::CreateSimpleTestFbsTableDirect( + fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second); + auto parent_offset = internal::CreateNestedTestFbsTable(fbb, child_offset); + fbb.Finish(parent_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto parent_corpus = domain.FromValue(table); + ASSERT_TRUE(parent_corpus.has_value()); + + auto ir = domain.SerializeCorpus(parent_corpus.value()); + + auto new_corpus = domain.ParseCorpus(ir); + ASSERT_TRUE(new_corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*new_corpus)); + + auto new_table = domain.GetValue(parent_corpus.value()); + EXPECT_NE(new_table->t(), nullptr); + EXPECT_EQ(new_table->t()->b(), true); + EXPECT_EQ(new_table->t()->f(), 1.0); + EXPECT_NE(new_table->t()->str(), nullptr); + EXPECT_EQ(new_table->t()->str()->str(), "foo bar baz"); + EXPECT_TRUE(new_table->t()->e() == internal::TestFbsEnum_Second); +} + +TEST(FlatbuffersTableDomainImplTest, EventuallyMutatesAllNestedTableFields) { + auto domain = Arbitrary(); + absl::BitGen bitgen; + Value val(domain, bitgen); + + const auto verify_field_changes = [&](std::string_view name, auto get) { + Set values; + + int iterations = 10'000; + while (--iterations > 0 && values.size() < 2) { + auto value = get(val.user_value); + if (value.has_value()) { + values.insert(*value); + } + val.Mutate(domain, bitgen, {}, false); + } + EXPECT_GT(iterations, 0) + << "Field: " << name << " -- " << testing::PrintToString(values); + }; + + verify_field_changes("t.b", [](auto v) { + return v->t() ? std::make_optional(v->t()->b()) : std::nullopt; + }); + verify_field_changes("t.f", [](auto v) { + return v->t() ? std::make_optional(v->t()->f()) : std::nullopt; + }); + verify_field_changes("t.str", [](auto v) { + return v->t() ? v->t()->str() ? std::make_optional(v->t()->str()->str()) + : std::nullopt + : std::nullopt; + }); + verify_field_changes("t.e", [](auto v) { + return v->t() ? std::make_optional(v->t()->e()) : std::nullopt; + }); +} + +TEST(FlatbuffersTableDomainImplTest, VectorsSerializeAndDeserialize) { + auto domain = Arbitrary(); + + absl::BitGen bitgen; + Value val(domain, bitgen); + + flatbuffers::FlatBufferBuilder fbb; + std::vector> str_offsets; + for (const auto& str : {"foo", "bar", "baz"}) { + str_offsets.push_back(fbb.CreateString(str)); + } + std::vector> table_offsets; + for (const auto& str : {"foo", "bar", "baz"}) { + table_offsets.push_back(internal::CreateSimpleTestFbsTableDirect( + fbb, true, 1.0, str, internal::TestFbsEnum_Second)); + } + std::vector b{true, false}; + std::vector i8{1, 2, 3}; + std::vector i16{1, 2, 3}; + std::vector i32{1, 2, 3}; + std::vector i64{1, 2, 3}; + std::vector u8{1, 2, 3}; + std::vector u16{1, 2, 3}; + std::vector u32{1, 2, 3}; + std::vector u64{1, 2, 3}; + std::vector f{1, 2, 3}; + std::vector d{1, 2, 3}; + std::vector> e{ + TestFbsEnum::TestFbsEnum_First, TestFbsEnum::TestFbsEnum_Second, + TestFbsEnum::TestFbsEnum_Third}; + auto table_offset = internal::CreateVectorsTestFbsTableDirect( + fbb, &b, &i8, &i16, &i32, &i64, &u8, &u16, &u32, &u64, &f, &d, + &str_offsets, &e, &table_offsets); + fbb.Finish(table_offset); + auto table = + flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto corpus = domain.FromValue(table); + auto ir = domain.SerializeCorpus(corpus.value()); + { + auto new_corpus = domain.ParseCorpus(ir); + ASSERT_TRUE(new_corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*new_corpus)); + + auto new_table = domain.GetValue(*new_corpus); + ASSERT_NE(new_table, nullptr); + ASSERT_NE(new_table->b(), nullptr); + EXPECT_EQ(new_table->b()->size(), 2); + EXPECT_EQ(new_table->b()->Get(0), true); + EXPECT_EQ(new_table->b()->Get(1), false); + ASSERT_NE(new_table->i8(), nullptr); + EXPECT_EQ(new_table->i8()->size(), 3); + EXPECT_EQ(new_table->i8()->Get(0), 1); + EXPECT_EQ(new_table->i8()->Get(1), 2); + EXPECT_EQ(new_table->i8()->Get(2), 3); + ASSERT_NE(new_table->i16(), nullptr); + EXPECT_EQ(new_table->i16()->size(), 3); + EXPECT_EQ(new_table->i16()->Get(0), 1); + EXPECT_EQ(new_table->i16()->Get(1), 2); + EXPECT_EQ(new_table->i16()->Get(2), 3); + ASSERT_NE(new_table->i32(), nullptr); + EXPECT_EQ(new_table->i32()->size(), 3); + EXPECT_EQ(new_table->i32()->Get(0), 1); + EXPECT_EQ(new_table->i32()->Get(1), 2); + EXPECT_EQ(new_table->i32()->Get(2), 3); + ASSERT_NE(new_table->i64(), nullptr); + EXPECT_EQ(new_table->i64()->size(), 3); + EXPECT_EQ(new_table->i64()->Get(0), 1); + EXPECT_EQ(new_table->i64()->Get(1), 2); + EXPECT_EQ(new_table->i64()->Get(2), 3); + ASSERT_NE(new_table->u8(), nullptr); + EXPECT_EQ(new_table->u8()->size(), 3); + EXPECT_EQ(new_table->u8()->Get(0), 1); + EXPECT_EQ(new_table->u8()->Get(1), 2); + EXPECT_EQ(new_table->u8()->Get(2), 3); + ASSERT_NE(new_table->u16(), nullptr); + EXPECT_EQ(new_table->u16()->size(), 3); + EXPECT_EQ(new_table->u16()->Get(0), 1); + EXPECT_EQ(new_table->u16()->Get(1), 2); + EXPECT_EQ(new_table->u16()->Get(2), 3); + ASSERT_NE(new_table->u32(), nullptr); + EXPECT_EQ(new_table->u32()->size(), 3); + EXPECT_EQ(new_table->u32()->Get(0), 1); + EXPECT_EQ(new_table->u32()->Get(1), 2); + EXPECT_EQ(new_table->u32()->Get(2), 3); + ASSERT_NE(new_table->u64(), nullptr); + EXPECT_EQ(new_table->u64()->size(), 3); + EXPECT_EQ(new_table->u64()->Get(0), 1); + EXPECT_EQ(new_table->u64()->Get(1), 2); + EXPECT_EQ(new_table->u64()->Get(2), 3); + ASSERT_NE(new_table->f(), nullptr); + EXPECT_EQ(new_table->f()->size(), 3); + EXPECT_EQ(new_table->f()->Get(0), 1); + EXPECT_EQ(new_table->f()->Get(1), 2); + EXPECT_EQ(new_table->f()->Get(2), 3); + ASSERT_NE(new_table->d(), nullptr); + EXPECT_EQ(new_table->d()->size(), 3); + EXPECT_EQ(new_table->d()->Get(0), 1); + EXPECT_EQ(new_table->d()->Get(1), 2); + EXPECT_EQ(new_table->d()->Get(2), 3); + ASSERT_NE(new_table->e(), nullptr); + EXPECT_EQ(new_table->e()->size(), 3); + EXPECT_EQ(new_table->e()->Get(0), internal::TestFbsEnum_First); + EXPECT_EQ(new_table->e()->Get(1), internal::TestFbsEnum_Second); + EXPECT_EQ(new_table->e()->Get(2), internal::TestFbsEnum_Third); + EXPECT_EQ(new_table->str()->size(), 3); + EXPECT_EQ(new_table->str()->Get(0)->str(), "foo"); + EXPECT_EQ(new_table->str()->Get(1)->str(), "bar"); + EXPECT_EQ(new_table->str()->Get(2)->str(), "baz"); + ASSERT_NE(new_table->t(), nullptr); + EXPECT_EQ(new_table->t()->size(), 3); + EXPECT_EQ(new_table->t()->Get(0)->b(), true); + EXPECT_EQ(new_table->t()->Get(1)->b(), true); + EXPECT_EQ(new_table->t()->Get(2)->b(), true); + EXPECT_EQ(new_table->t()->Get(0)->f(), 1.0); + EXPECT_EQ(new_table->t()->Get(1)->f(), 1.0); + EXPECT_EQ(new_table->t()->Get(2)->f(), 1.0); + EXPECT_EQ(new_table->t()->Get(0)->str()->str(), "foo"); + EXPECT_EQ(new_table->t()->Get(1)->str()->str(), "bar"); + EXPECT_EQ(new_table->t()->Get(2)->str()->str(), "baz"); + EXPECT_EQ(new_table->t()->Get(0)->e(), internal::TestFbsEnum_Second); + EXPECT_EQ(new_table->t()->Get(1)->e(), internal::TestFbsEnum_Second); + EXPECT_EQ(new_table->t()->Get(2)->e(), internal::TestFbsEnum_Second); + } +} + +TEST(FlatbuffersTableDomainImplTest, EventuallyMutatesAllVectorFields) { + auto domain = Arbitrary(); + + absl::BitGen bitgen; + Value val(domain, bitgen); + + const auto verify_field_changes = [&](std::string_view name, auto get) { + Set values; + + int iterations = 10'000; + while (--iterations > 0 && values.size() < 2) { + auto value = get(val.user_value); + if (value.has_value()) { + values.insert(*value); + } + val.Mutate(domain, bitgen, {}, false); + } + EXPECT_GT(iterations, 0) + << "Field: " << name << " -- " << testing::PrintToString(values); + }; + + verify_field_changes("t.b", [](auto v) { + return v && v->t() ? std::make_optional(v->b()) : std::nullopt; + }); + verify_field_changes("t.i8", [](auto v) { + return v && v->i8() ? std::make_optional(v->i8()) : std::nullopt; + }); + verify_field_changes("t.i16", [](auto v) { + return v && v->i16() ? std::make_optional(v->i16()) : std::nullopt; + }); + verify_field_changes("t.i32", [](auto v) { + return v && v->i32() ? std::make_optional(v->i32()) : std::nullopt; + }); + verify_field_changes("t.i64", [](auto v) { + return v && v->i64() ? std::make_optional(v->i64()) : std::nullopt; + }); + verify_field_changes("t.u8", [](auto v) { + return v && v->u8() ? std::make_optional(v->u8()) : std::nullopt; + }); + verify_field_changes("t.u16", [](auto v) { + return v && v->u16() ? std::make_optional(v->u16()) : std::nullopt; + }); + verify_field_changes("t.u32", [](auto v) { + return v && v->u32() ? std::make_optional(v->u32()) : std::nullopt; + }); + verify_field_changes("t.u64", [](auto v) { + return v && v->u64() ? std::make_optional(v->u64()) : std::nullopt; + }); + verify_field_changes("t.f", [](auto v) { + return v && v->f() ? std::make_optional(v->f()) : std::nullopt; + }); + verify_field_changes("t.d", [](auto v) { + return v && v->d() ? std::make_optional(v->d()) : std::nullopt; + }); + verify_field_changes("t.e", [](auto v) { + return v && v->e() ? std::make_optional(v->e()) : std::nullopt; + }); + verify_field_changes("t.str", [](auto v) { + return v && v->str() ? std::make_optional(v->str()) : std::nullopt; + }); + verify_field_changes("t.t", [](auto v) { + return v && v->t() ? std::make_optional(v->t()) : std::nullopt; + }); +} + +TEST(FlatbuffersTableDomainImplTest, UnionFieldsSerializeAndDeserialize) { + flatbuffers::FlatBufferBuilder fbb; + auto child_table_offset = internal::CreateSimpleTestFbsTableDirect( + fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second); + auto vec_child_one = internal::CreateSimpleTestFbsTableDirect( + fbb, true, 1.0, "foo bar baz", internal::TestFbsEnum_Second); + auto vec_child_two = internal::CreateOptionalRequiredTestFbsTableDirect( + fbb, true, std::nullopt, "foo bar baz"); + auto vec_types = + fbb.CreateVector({internal::Union_SimpleTestFbsTable, + internal::Union_OptionalRequiredTestFbsTable}); + auto vec_values = + fbb.CreateVector({vec_child_one.Union(), vec_child_two.Union()}); + auto parent_offset = internal::CreateUnionTestFbsTable( + fbb, internal::Union_SimpleTestFbsTable, child_table_offset.Union(), + vec_types, vec_values); + fbb.Finish(parent_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto domain = Arbitrary(); + auto corpus = domain.FromValue(table); + auto ir = domain.SerializeCorpus(*corpus); + auto new_corpus = domain.ParseCorpus(ir); + ASSERT_TRUE(new_corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*new_corpus)); + auto new_table = domain.GetValue(*new_corpus); + ASSERT_NE(new_table, nullptr); + ASSERT_NE(new_table->u(), nullptr); + ASSERT_NE(new_table->u_as_SimpleTestFbsTable(), nullptr); + EXPECT_EQ(new_table->u_as_SimpleTestFbsTable()->b(), true); + EXPECT_EQ(new_table->u_as_SimpleTestFbsTable()->f(), 1.0); + EXPECT_EQ(new_table->u_as_SimpleTestFbsTable()->str()->str(), "foo bar baz"); + EXPECT_EQ(new_table->u_as_SimpleTestFbsTable()->e(), + internal::TestFbsEnum_Second); + + ASSERT_NE(new_table->u_vec(), nullptr); + ASSERT_EQ(new_table->u_vec()->size(), 2); + auto u_vec_one = + static_cast(new_table->u_vec()->Get(0)); + ASSERT_NE(u_vec_one, nullptr); + EXPECT_EQ(u_vec_one->b(), true); + EXPECT_EQ(u_vec_one->f(), 1.0); + EXPECT_EQ(u_vec_one->str()->str(), "foo bar baz"); + EXPECT_EQ(u_vec_one->e(), internal::TestFbsEnum_Second); + + auto u_vec_two = static_cast( + new_table->u_vec()->Get(1)); + ASSERT_NE(u_vec_two, nullptr); + EXPECT_EQ(u_vec_two->def_scalar(), true); + EXPECT_EQ(u_vec_two->opt_scalar(), std::nullopt); + ASSERT_NE(u_vec_two->req_str(), nullptr); + EXPECT_EQ(u_vec_two->req_str()->str(), "foo bar baz"); + EXPECT_EQ(u_vec_two->opt_str(), nullptr); +} + +TEST(FlatbuffersTableDomainImplTest, UnionFieldsEventuallyMutate) { + auto domain = Arbitrary(); + + absl::BitGen bitgen; + Value val(domain, bitgen); + + const auto verify_field_changes = [&](std::string_view name, auto get) { + Set values; + + int iterations = 10'000; + while (--iterations > 0 && values.size() < 2) { + auto value = get(val.user_value); + values.insert(value); + val.Mutate(domain, bitgen, {}, false); + } + EXPECT_GT(iterations, 0) + << "Field: " << name << " -- " << testing::PrintToString(values); + }; + + verify_field_changes("u_type", [](auto v) { return v->u_type(); }); + verify_field_changes("u_as_SimpleTestFbsTable", + [](auto v) { return v->u_as_SimpleTestFbsTable(); }); + verify_field_changes("u_as_OptionalRequiredTestFbsTable", [](auto v) { + return v->u_as_OptionalRequiredTestFbsTable(); + }); + verify_field_changes("u_vec_type", [](auto v) { return v->u_vec_type(); }); + verify_field_changes("u_vec", [](auto v) { return v->u_vec(); }); + verify_field_changes("u_vec[0].as_SimpleTestFbsTable", [](auto v) { + return v->u_vec() && v->u_vec()->size() > 0 && + v->u_vec_type()->Get(0) == internal::Union_SimpleTestFbsTable + ? static_cast(v->u_vec()->Get(0)) + : nullptr; + }); + verify_field_changes("u_vec[0].as_OptionalRequiredTestFbsTable", [](auto v) { + return v->u_vec() && v->u_vec()->size() > 0 && + v->u_vec_type()->Get(0) == + internal::Union_OptionalRequiredTestFbsTable + ? static_cast( + v->u_vec()->Get(0)) + : nullptr; + }); +} + +} // namespace +} // namespace fuzztest diff --git a/fuzztest/BUILD b/fuzztest/BUILD index b6041809b..f89a96c87 100644 --- a/fuzztest/BUILD +++ b/fuzztest/BUILD @@ -15,6 +15,7 @@ # FuzzTest: a coverage-guided fuzzing / property-based testing framework. load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@flatbuffers//:build_defs.bzl", "flatbuffer_library_public") load("@rules_proto//proto:defs.bzl", "proto_library") package(default_visibility = ["//visibility:public"]) @@ -345,6 +346,7 @@ cc_library( ":serialization", ":status", ":type_support", + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/base:no_destructor", "@abseil-cpp//absl/container:flat_hash_map", @@ -422,6 +424,38 @@ cc_library( ], ) +cc_library( + name = "flatbuffers", + srcs = [ + "internal/domains/flatbuffers_domain_impl.cc", + "internal/domains/flatbuffers_domain_impl.h", + ], + hdrs = ["flatbuffers.h"], + deps = [ + ":any", + ":domain_core", + ":logging", + ":meta", + ":serialization", + ":status", + ":type_support", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/base:core_headers", + "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/random", + "@abseil-cpp//absl/random:bit_gen_ref", + "@abseil-cpp//absl/random:distributions", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:str_format", + "@abseil-cpp//absl/synchronization", + "@flatbuffers//:runtime_cc", + ], +) + cc_library( name = "fixture_driver", srcs = ["internal/fixture_driver.cc"], @@ -799,6 +833,28 @@ cc_proto_library( deps = [":test_protobuf"], ) +flatbuffer_library_public( + name = "test_flatbuffers_fbs", + srcs = ["internal/test_flatbuffers.fbs"], + outs = [ + "test_flatbuffers_bfbs_generated.h", + "test_flatbuffers_generated.h", + ], + flatc_args = [ + "--bfbs-gen-embed", + "--gen-name-strings", + ], + language_flag = "-c", +) + +cc_library( + name = "test_flatbuffers_cc_fbs", + srcs = [":test_flatbuffers_fbs"], + hdrs = [":test_flatbuffers_fbs"], + features = ["-parse_headers"], + deps = ["@flatbuffers//:runtime_cc"], +) + cc_library( name = "type_support", srcs = ["internal/type_support.cc"], diff --git a/fuzztest/flatbuffers.h b/fuzztest/flatbuffers.h new file mode 100644 index 000000000..b70ed361b --- /dev/null +++ b/fuzztest/flatbuffers.h @@ -0,0 +1,19 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FUZZTEST_FUZZTEST_FLATBUFFERS_H_ +#define FUZZTEST_FUZZTEST_FLATBUFFERS_H_ + +#include "./fuzztest/internal/domains/flatbuffers_domain_impl.h" // IWYU pragma: export +#endif // FUZZTEST_FUZZTEST_FLATBUFFERS_H_ diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.cc b/fuzztest/internal/domains/flatbuffers_domain_impl.cc new file mode 100644 index 000000000..8a7950f98 --- /dev/null +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.cc @@ -0,0 +1,299 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "./fuzztest/internal/domains/flatbuffers_domain_impl.h" + +#include +#include +#include +#include + +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "flatbuffers/base.h" +#include "flatbuffers/flatbuffer_builder.h" +#include "flatbuffers/reflection_generated.h" +#include "flatbuffers/table.h" +#include "./fuzztest/domain_core.h" +#include "./fuzztest/internal/any.h" +#include "./fuzztest/internal/domains/domain_base.h" +#include "./fuzztest/internal/domains/domain_type_erasure.h" +#include "./fuzztest/internal/logging.h" +#include "./fuzztest/internal/meta.h" +#include "./fuzztest/internal/serialization.h" + +namespace fuzztest { +namespace internal { + +FlatbuffersUnionDomainImpl::corpus_type FlatbuffersUnionDomainImpl::Init( + absl::BitGenRef prng) { + if (auto seed = this->MaybeGetRandomSeed(prng)) { + return *seed; + } + corpus_type val; + auto selected_type_enumval_index = + absl::Uniform(prng, 0ul, union_def_->values()->size()); + auto type_enumval = union_def_->values()->Get(selected_type_enumval_index); + if (type_enumval == nullptr) { + return val; + } + auto type_value = type_domain_.FromValue(type_enumval->value()); + if (!type_value.has_value()) { + return val; + } + val.first = *type_value; + if (type_enumval->value() == 0 /* NONE */) { + return val; + } + + auto domain = GetTableDomain(*type_enumval); + if (domain == nullptr) { + return val; + } + + auto inner_val = domain->Init(prng); + val.second = GenericDomainCorpusType(std::in_place_type, + std::move(inner_val)); + return val; +} + +// Mutates the corpus value. +void FlatbuffersUnionDomainImpl::Mutate( + corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink) { + auto total_weight = CountNumberOfFields(val); + auto selected_weight = absl::Uniform(prng, 0ul, total_weight); + if (selected_weight == 0) { + type_domain_.Mutate(val.first, prng, metadata, only_shrink); + val.second = GenericDomainCorpusType(std::in_place_type, nullptr); + auto type_value = type_domain_.GetValue(val.first); + if (type_value == 0) { + return; + } + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return; + } + auto domain = GetTableDomain(*type_enumval); + if (domain == nullptr) { + return; + } + auto inner_val = domain->Init(prng); + val.second = GenericDomainCorpusType( + std::in_place_type, std::move(inner_val)); + } else { + auto type_value = type_domain_.GetValue(val.first); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return; + } + auto domain = GetTableDomain(*type_enumval); + if (domain == nullptr) { + return; + } + auto inner_val = val.second.template GetAs< + corpus_type_t>>(); + domain->MutateSelectedField(inner_val, prng, metadata, only_shrink, + selected_weight - 1); + } +} + +uint64_t FlatbuffersUnionDomainImpl::CountNumberOfFields(corpus_type& val) { + uint64_t count = 1; + auto type_value = type_domain_.GetValue(val.first); + if (type_value == 0 /* NONE */) { + return count; + } + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return count; + } + auto domain = GetTableDomain(*type_enumval); + if (domain != nullptr) { + auto inner_val = val.second.template GetAs< + corpus_type_t>>(); + count += domain->CountNumberOfFields(inner_val); + } + return count; +} + +absl::Status FlatbuffersUnionDomainImpl::ValidateCorpusValue( + const corpus_type& corpus_value) const { + auto type_value = type_domain_.GetValue(corpus_value.first); + if (type_value == 0) { + return absl::OkStatus(); + } + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return absl::OkStatus(); + } + auto domain = GetTableDomain(*type_enumval); + if (domain == nullptr) { + return absl::OkStatus(); + } + auto inner_corpus_value = corpus_value.second.template GetAs< + corpus_type_t>>(); + return domain->ValidateCorpusValue(inner_corpus_value); +} + +std::optional +FlatbuffersUnionDomainImpl::FromValue(const value_type& value) const { + std::optional out{{}}; + auto type_value = type_domain_.FromValue(value.first); + if (type_value.has_value()) { + out->first = *type_value; + } + auto type_enumval = union_def_->values()->LookupByKey(value.first); + if (type_enumval == nullptr) { + return std::nullopt; + } + auto domain = GetTableDomain(*type_enumval); + if (domain != nullptr) { + auto inner_value = + domain->FromValue(static_cast(value.second)); + if (inner_value.has_value()) { + out->second = GenericDomainCorpusType( + std::in_place_type, + std::move(*inner_value)); + } + } + return out; +} + +// Converts the IRObject to a corpus value. +std::optional +FlatbuffersUnionDomainImpl::ParseCorpus(const IRObject& obj) const { + corpus_type out; + auto subs = obj.Subs(); + if (!subs) { + return std::nullopt; + } + if (subs->size() != 2) { + return std::nullopt; + } + + auto type_corpus = type_domain_.ParseCorpus((*subs)[0]); + if (!type_corpus.has_value()) { + return std::nullopt; + } + out.first = *type_corpus; + auto type_value = type_domain_.GetValue(out.first); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return std::nullopt; + } + auto domain = GetTableDomain(*type_enumval); + if (domain == nullptr) { + return std::nullopt; + } + + auto inner_corpus = domain->ParseCorpus((*subs)[1]); + if (inner_corpus.has_value()) { + out.second = GenericDomainCorpusType( + std::in_place_type< + typename std::remove_pointer_t::value_type>, + *inner_corpus); + } + return out; +} + +// Converts the corpus value to an IRObject. +IRObject FlatbuffersUnionDomainImpl::SerializeCorpus( + const corpus_type& value) const { + IRObject out; + auto& pair = out.MutableSubs(); + pair.reserve(2); + + auto type_value = type_domain_.GetValue(value.first); + pair.push_back(type_domain_.SerializeCorpus(value.first)); + + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return out; + } + auto domain = GetTableDomain(*type_enumval); + if (domain == nullptr) { + return out; + } + pair.push_back(domain->SerializeCorpus( + value.second.template GetAs< + corpus_type_t>>())); + return out; +} + +std::optional FlatbuffersUnionDomainImpl::BuildValue( + const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const { + auto type_value = type_domain_.GetValue(value.first); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return std::nullopt; + } + auto domain = GetTableDomain(*type_enumval); + if (domain == nullptr) { + return std::nullopt; + } + return domain->BuildTable( + value.second.template GetAs< + corpus_type_t>>(), + builder); +} + +FlatbuffersTableUntypedDomainImpl* FlatbuffersUnionDomainImpl::GetTableDomain( + const reflection::EnumVal& enum_value) const { + absl::MutexLock l(&mutex_); + auto it = domains_.find(enum_value.value()); + if (it == domains_.end()) { + auto base_type = enum_value.union_type()->base_type(); + if (base_type == reflection::BaseType::None) { + return nullptr; + } + FUZZTEST_INTERNAL_CHECK(base_type == reflection::BaseType::Obj, + "EnumVal union type is not a BaseType::Obj"); + auto object = schema_->objects()->Get(enum_value.union_type()->index()); + if (object->is_struct()) { + // TODO(b/405939014): Support structs. + return nullptr; + } + it = domains_ + .emplace(enum_value.value(), + FlatbuffersTableUntypedDomainImpl{schema_, object}) + .first; + } + return &it->second; +} + +void FlatbuffersUnionDomainImpl::Printer::PrintCorpusValue( + const corpus_type& value, domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + auto type_value = self.type_domain_.GetValue(value.first); + auto type_enumval = self.union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return; + } + absl::Format(out, "<%s>(", type_enumval->name()->str()); + auto domain = self.GetTableDomain(*type_enumval); + if (domain == nullptr) { + absl::Format(out, "UNSUPPORTED_UNION_TYPE"); + return; + } + auto inner_corpus_value = value.second.template GetAs< + corpus_type_t>>(); + domain_implementor::PrintValue(*domain, inner_corpus_value, out, mode); + absl::Format(out, ")"); +} +} // namespace internal +} // namespace fuzztest diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.h b/fuzztest/internal/domains/flatbuffers_domain_impl.h new file mode 100644 index 000000000..5dffa5d46 --- /dev/null +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.h @@ -0,0 +1,1535 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLATBUFFERS_DOMAIN_IMPL_H_ +#define FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLATBUFFERS_DOMAIN_IMPL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "flatbuffers/base.h" +#include "flatbuffers/flatbuffer_builder.h" +#include "flatbuffers/reflection.h" +#include "flatbuffers/reflection_generated.h" +#include "flatbuffers/string.h" +#include "flatbuffers/table.h" +#include "flatbuffers/vector.h" +#include "flatbuffers/verifier.h" +#include "./fuzztest/domain_core.h" +#include "./fuzztest/internal/any.h" +#include "./fuzztest/internal/domains/arbitrary_impl.h" +#include "./fuzztest/internal/domains/domain_base.h" +#include "./fuzztest/internal/domains/domain_type_erasure.h" +#include "./fuzztest/internal/domains/element_of_impl.h" +#include "./fuzztest/internal/logging.h" +#include "./fuzztest/internal/meta.h" +#include "./fuzztest/internal/serialization.h" +#include "./fuzztest/internal/status.h" + +namespace fuzztest::internal { + +template && + !std::is_same_v>> +struct FlatbuffersEnumTag { + using type = Underlying; +}; + +template +struct is_flatbuffers_enum_tag : std::false_type {}; + +template +struct is_flatbuffers_enum_tag> + : std::true_type {}; + +template +inline constexpr bool is_flatbuffers_enum_tag_v = + is_flatbuffers_enum_tag::value; + +template +struct FlatbuffersVectorTag { + using value_type = T; +}; + +template +struct is_flatbuffers_vector_tag : std::false_type {}; + +template +struct is_flatbuffers_vector_tag> : std::true_type {}; + +template +inline constexpr bool is_flatbuffers_vector_tag_v = + is_flatbuffers_vector_tag::value; + +struct FlatbuffersArrayTag; +struct FlatbuffersTableTag; +struct FlatbuffersStructTag; +struct FlatbuffersUnionTag; + +// Dynamic to static dispatch visitor pattern for flatbuffers vector elements. +template +auto VisitFlatbufferVectorElementField(const reflection::Schema* schema, + const reflection::Field* field, + Visitor visitor) { + auto field_index = field->type()->index(); + auto element_type = field->type()->element(); + switch (element_type) { + case reflection::BaseType::Bool: + visitor.template Visit>(field); + break; + case reflection::BaseType::Byte: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::Short: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::Int: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::Long: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::UByte: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::UShort: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::UInt: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::ULong: + if (field_index >= 0) { + visitor + .template Visit>>( + field); + } else { + visitor.template Visit>(field); + } + break; + case reflection::BaseType::Float: + visitor.template Visit>(field); + break; + case reflection::BaseType::Double: + visitor.template Visit>(field); + break; + case reflection::BaseType::String: + visitor.template Visit>(field); + break; + case reflection::BaseType::Obj: { + auto sub_object = schema->objects()->Get(field_index); + if (sub_object->is_struct()) { + visitor.template Visit>( + field); + } else { + visitor.template Visit>( + field); + } + break; + } + case reflection::BaseType::Union: + visitor.template Visit>(field); + break; + case reflection::BaseType::UType: + // Noop: Union types are visited at the same time as their corresponding + // union values. + break; + default: // Vector of vectors and vector of arrays are not supported. + FUZZTEST_INTERNAL_CHECK(false, "Unsupported vector base type"); + } +} + +// Dynamic to static dispatch visitor pattern. +template +auto VisitFlatbufferField(const reflection::Schema* absl_nonnull schema, + const reflection::Field* absl_nonnull field, + Visitor visitor) { + auto field_index = field->type()->index(); + switch (field->type()->base_type()) { + case reflection::BaseType::Bool: + visitor.template Visit(field); + break; + case reflection::BaseType::Byte: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::Short: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::Int: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::Long: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::UByte: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::UShort: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::UInt: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::ULong: + if (field_index >= 0) { + visitor.template Visit>(field); + } else { + visitor.template Visit(field); + } + break; + case reflection::BaseType::Float: + visitor.template Visit(field); + break; + case reflection::BaseType::Double: + visitor.template Visit(field); + break; + case reflection::BaseType::String: + visitor.template Visit(field); + break; + case reflection::BaseType::Vector: + case reflection::BaseType::Vector64: { + VisitFlatbufferVectorElementField(schema, field, visitor); + break; + case reflection::BaseType::Array: + visitor.template Visit(field); + break; + case reflection::BaseType::Obj: { + auto sub_object = schema->objects()->Get(field->type()->index()); + if (sub_object->is_struct()) { + visitor.template Visit(field); + } else { + visitor.template Visit(field); + } + break; + } + case reflection::BaseType::Union: + visitor.template Visit(field); + break; + case reflection::BaseType::UType: + // Noop: Union types are visited at the same time as their corresponding + // union values. + break; + default: + FUZZTEST_INTERNAL_CHECK(false, "Unsupported base type"); + } + } +} + +// Flatbuffers enum domain implementation. +template +class FlatbuffersEnumDomainImpl : public ElementOfImpl { + public: + using typename ElementOfImpl::DomainBase::corpus_type; + using typename ElementOfImpl::DomainBase::value_type; + + explicit FlatbuffersEnumDomainImpl(const reflection::Enum* enum_def) + : ElementOfImpl(GetEnumValues(enum_def)), + enum_def_(enum_def) {} + + auto GetPrinter() const { return Printer{*this}; } + + private: + const reflection::Enum* enum_def_; + + static std::vector GetEnumValues( + const reflection::Enum* enum_def) { + std::vector values; + values.reserve(enum_def->values()->size()); + for (const auto* value : *enum_def->values()) { + values.push_back(value->value()); + } + return values; + } + + struct Printer { + const FlatbuffersEnumDomainImpl& self; + void PrintCorpusValue(const corpus_type& value, + domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + if (mode == domain_implementor::PrintMode::kHumanReadable) { + auto user_value = self.GetValue(value); + absl::Format( + out, "%s", + self.enum_def_->values()->LookupByKey(user_value)->name()->str()); + } else { + absl::Format(out, "%d", value); + } + } + }; +}; + +using FlatbuffersUnionTypeDomainImpl = FlatbuffersEnumDomainImpl< + decltype(static_cast(nullptr)->value())>; + +class FlatbuffersTableUntypedDomainImpl; + +// Flatbuffers union domain implementation. +class FlatbuffersUnionDomainImpl + : public domain_implementor::DomainBase< + FlatbuffersUnionDomainImpl, + std::pair, + std::pair> { + public: + using typename FlatbuffersUnionDomainImpl::DomainBase::corpus_type; + using typename FlatbuffersUnionDomainImpl::DomainBase::value_type; + + FlatbuffersUnionDomainImpl(const reflection::Schema* schema, + const reflection::Enum* union_def) + : schema_(schema), union_def_(union_def), type_domain_(union_def) {} + + FlatbuffersUnionDomainImpl(const FlatbuffersUnionDomainImpl& other) + : schema_(other.schema_), + union_def_(other.union_def_), + type_domain_(other.type_domain_) { + absl::MutexLock l(&other.mutex_); + domains_ = other.domains_; + } + + FlatbuffersUnionDomainImpl(FlatbuffersUnionDomainImpl&& other) + : schema_(other.schema_), + union_def_(other.union_def_), + type_domain_(std::move(other.type_domain_)) { + absl::MutexLock l(&other.mutex_); + domains_ = std::move(other.domains_); + } + + FlatbuffersUnionDomainImpl& operator=( + const FlatbuffersUnionDomainImpl& other) { + schema_ = other.schema_; + union_def_ = other.union_def_; + type_domain_ = other.type_domain_; + absl::MutexLock l(&other.mutex_); + domains_ = other.domains_; + return *this; + } + + FlatbuffersUnionDomainImpl& operator=(FlatbuffersUnionDomainImpl&& other) { + schema_ = other.schema_; + union_def_ = other.union_def_; + type_domain_ = std::move(other.type_domain_); + absl::MutexLock l(&other.mutex_); + domains_ = std::move(other.domains_); + return *this; + } + + // Initializes the corpus value. + corpus_type Init(absl::BitGenRef prng); + + // Mutates the corpus value. + void Mutate(corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, + bool only_shrink); + + uint64_t CountNumberOfFields(corpus_type& val); + + auto GetPrinter() const { return Printer{*this}; } + + absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const; + + // UNSUPPORTED: Flatbuffers unions user values are not supported. + value_type GetValue(const corpus_type& value) const { + FUZZTEST_INTERNAL_CHECK(false, "GetValue is not supported for unions."); + } + + auto GetType(const corpus_type& value) const { + return type_domain_.GetValue(value.first); + } + + std::optional BuildValue( + const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const; + + std::optional FromValue(const value_type& value) const; + + // Converts the IRObject to a corpus value. + std::optional ParseCorpus(const IRObject& obj) const; + + // Converts the corpus value to an IRObject. + IRObject SerializeCorpus(const corpus_type& value) const; + + private: + const reflection::Schema* schema_; + const reflection::Enum* union_def_; + FlatbuffersEnumDomainImpl type_domain_; + mutable absl::Mutex mutex_; + mutable absl::flat_hash_map + domains_ ABSL_GUARDED_BY(mutex_); + + FlatbuffersTableUntypedDomainImpl* GetTableDomain( + const reflection::EnumVal& enum_value) const; + + struct Printer { + const FlatbuffersUnionDomainImpl& self; + + void PrintCorpusValue(const corpus_type& value, + domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const; + }; +}; + +// Domain implementation for flatbuffers untyped tables. +// The corpus type is a pair of: +// - A map of field ids to field values. +// - The serialized buffer of the table. +class FlatbuffersTableUntypedDomainImpl + : public domain_implementor::DomainBase< + FlatbuffersTableUntypedDomainImpl, + const flatbuffers::Table* absl_nonnull, + absl::flat_hash_map< + decltype(static_cast(nullptr)->id()), + GenericDomainCorpusType>> { + public: + using typename FlatbuffersTableUntypedDomainImpl::DomainBase::corpus_type; + using typename FlatbuffersTableUntypedDomainImpl::DomainBase::value_type; + + explicit FlatbuffersTableUntypedDomainImpl( + const reflection::Schema* schema, const reflection::Object* table_object) + : schema_(schema), table_object_(table_object) {} + + FlatbuffersTableUntypedDomainImpl( + const FlatbuffersTableUntypedDomainImpl& other) + : schema_(other.schema_), table_object_(other.table_object_) { + absl::MutexLock l(&other.mutex_); + domains_ = other.domains_; + } + + FlatbuffersTableUntypedDomainImpl& operator=( + const FlatbuffersTableUntypedDomainImpl& other) { + schema_ = other.schema_; + table_object_ = other.table_object_; + absl::MutexLock l(&other.mutex_); + domains_ = other.domains_; + return *this; + } + + FlatbuffersTableUntypedDomainImpl(FlatbuffersTableUntypedDomainImpl&& other) + : schema_(other.schema_), table_object_(other.table_object_) { + absl::MutexLock l(&other.mutex_); + domains_ = std::move(other.domains_); + } + + FlatbuffersTableUntypedDomainImpl& operator=( + FlatbuffersTableUntypedDomainImpl&& other) { + schema_ = other.schema_; + table_object_ = other.table_object_; + absl::MutexLock l(&other.mutex_); + domains_ = std::move(other.domains_); + return *this; + } + + // Initializes the corpus value. + corpus_type Init(absl::BitGenRef prng) { + if (auto seed = this->MaybeGetRandomSeed(prng)) { + return *seed; + } + corpus_type val; + for (const auto* field : *table_object_->fields()) { + VisitFlatbufferField(schema_, field, InitializeVisitor{*this, prng, val}); + } + return val; + } + + // Mutates the corpus value. + void Mutate(corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, + bool only_shrink) { + auto total_weight = CountNumberOfFields(val); + auto selected_weight = + absl::Uniform(absl::IntervalClosedClosed, prng, 0ul, total_weight - 1); + + MutateSelectedField(val, prng, metadata, only_shrink, selected_weight); + } + + // Returns the domain for the given vector field. + template + auto GetDomainForVectorField(const reflection::Field* field) const { + if constexpr (is_flatbuffers_enum_tag_v) { + auto enum_object = schema_->enums()->Get(field->type()->index()); + auto inner = OptionalOf( + VectorOf( + FlatbuffersEnumDomainImpl(enum_object)) + .WithMaxSize(std::numeric_limits::max())); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>{inner}; + } else if constexpr (std::is_same_v) { + auto table_object = schema_->objects()->Get(field->type()->index()); + auto inner = OptionalOf( + VectorOf(FlatbuffersTableUntypedDomainImpl{schema_, table_object}) + .WithMaxSize(std::numeric_limits::max())); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>>{ + inner}; + } else if constexpr (std::is_same_v) { + // TODO(b/399123660): implement this. + return Domain>(OptionalOf(ArbitraryImpl())); + } else if constexpr (std::is_same_v) { + auto union_type = schema_->enums()->Get(field->type()->index()); + auto inner = OptionalOf( + VectorOf(FlatbuffersUnionDomainImpl{schema_, union_type}) + .WithMaxSize(std::numeric_limits::max())); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>{inner}; + } else { + auto inner = OptionalOf( + VectorOf(ArbitraryImpl()) + .WithMaxSize(std::numeric_limits::max())); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>>{inner}; + } + } + + // Returns the domain for the given field. + template + auto GetDomainForField(const reflection::Field* field) const { + if constexpr (std::is_same_v) { + // TODO(b/399123660): Implement this. + return Domain>(OptionalOf(ArbitraryImpl())); + } else if constexpr (is_flatbuffers_enum_tag_v) { + auto enum_object = schema_->enums()->Get(field->type()->index()); + auto domain = + OptionalOf(FlatbuffersEnumDomainImpl(enum_object)); + if (!field->optional()) { + domain.SetWithoutNull(); + } + return Domain>{domain}; + } else if constexpr (std::is_same_v) { + auto table_object = schema_->objects()->Get(field->type()->index()); + auto inner = + OptionalOf(FlatbuffersTableUntypedDomainImpl{schema_, table_object}); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>{inner}; + } else if constexpr (std::is_same_v) { + // TODO(b/399123660): Implement this. + return Domain>(OptionalOf(ArbitraryImpl())); + } else if constexpr (std::is_same_v) { + auto union_type = schema_->enums()->Get(field->type()->index()); + auto inner = OptionalOf(FlatbuffersUnionDomainImpl{schema_, union_type}); + return Domain>{inner}; + } else if constexpr (is_flatbuffers_vector_tag_v) { + return GetDomainForVectorField(field); + } else { + auto inner = OptionalOf(ArbitraryImpl()); + if (!field->optional()) { + inner.SetWithoutNull(); + } + return Domain>{inner}; + } + } + + // Returns the domain for the given field. + // The domain is cached, and the same instance is returned for the same + // field. + template + auto& GetSubDomain(const reflection::Field* field) const { + using DomainT = decltype(GetDomainForField(field)); + // Do the operation under a lock to prevent race conditions in `const` + // methods. + absl::MutexLock l(&mutex_); + auto it = domains_.find(field->id()); + if (it == domains_.end()) { + it = domains_ + .try_emplace(field->id(), std::in_place_type, + GetDomainForField(field)) + .first; + } + return it->second.template GetAs(); + } + + // Counts the number of fields that can be mutated. + // Returns the number of fields in the flattened tree for supported field + // types. + uint64_t CountNumberOfFields(corpus_type& val) { + uint64_t total_weight = 0; + for (const auto* field : *table_object_->fields()) { + reflection::BaseType base_type = field->type()->base_type(); + if (IsScalarType(base_type) || + base_type == reflection::BaseType::String) { + ++total_weight; + } else if (base_type == reflection::BaseType::Obj) { + auto sub_object = schema_->objects()->Get(field->type()->index()); + // TODO(b/405939014): Support structs. + if (!sub_object->is_struct()) { + ++total_weight; + auto& sub_domain = GetSubDomain(field); + total_weight += sub_domain.CountNumberOfFields(val[field->id()]); + } + } else if (base_type == reflection::BaseType::Vector || + base_type == reflection::BaseType::Vector64) { + ++total_weight; + auto elem_type = field->type()->element(); + if (IsScalarType(elem_type) || + elem_type == reflection::BaseType::String) { + ++total_weight; + } else if (elem_type == reflection::BaseType::Obj) { + auto sub_object = schema_->objects()->Get(field->type()->index()); + if (!sub_object->is_struct()) { + ++total_weight; + auto sub_domain = + GetSubDomain>(field); + total_weight += sub_domain.CountNumberOfFields(val[field->id()]); + } + } else if (elem_type == reflection::BaseType::Union) { + ++total_weight; + auto& sub_domain = + GetSubDomain>(field); + total_weight += sub_domain.CountNumberOfFields(val[field->id()]); + } + } else if (base_type == reflection::BaseType::Union) { + ++total_weight; + auto& sub_domain = GetSubDomain(field); + total_weight += sub_domain.CountNumberOfFields(val[field->id()]); + } + } + return total_weight; + } + + // Mutates the selected field. + // The selected field index is based on the flattened tree. + uint64_t MutateSelectedField( + corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink, + uint64_t selected_field_index) { + uint64_t field_counter = 0; + for (const auto* field : *table_object_->fields()) { + ++field_counter; + + if (field_counter == selected_field_index + 1) { + VisitFlatbufferField( + schema_, field, + MutateVisitor{*this, prng, metadata, only_shrink, val}); + return field_counter; + } + + auto base_type = field->type()->base_type(); + if (base_type == reflection::BaseType::Obj) { + auto sub_object = schema_->objects()->Get(field->type()->index()); + if (!sub_object->is_struct()) { + field_counter += + GetSubDomain(field).MutateSelectedField( + val[field->id()], prng, metadata, only_shrink, + selected_field_index - field_counter); + } + } + + if (base_type == reflection::BaseType::Vector || + base_type == reflection::BaseType::Vector64) { + auto elem_type = field->type()->element(); + if (elem_type == reflection::BaseType::Obj) { + auto sub_object = schema_->objects()->Get(field->type()->index()); + if (!sub_object->is_struct()) { + field_counter += + GetSubDomain>(field) + .MutateSelectedField(val[field->id()], prng, metadata, + only_shrink, + selected_field_index - field_counter); + } + } else if (elem_type == reflection::BaseType::Union) { + field_counter += + GetSubDomain>(field) + .MutateSelectedField(val[field->id()], prng, metadata, + only_shrink, + selected_field_index - field_counter); + } + } + + if (base_type == reflection::BaseType::Union) { + field_counter += + GetSubDomain(field).MutateSelectedField( + val[field->id()], prng, metadata, only_shrink, + selected_field_index - field_counter); + } + + if (field_counter > selected_field_index) { + return field_counter; + } + } + return field_counter; + } + + auto GetPrinter() const { return Printer{*this}; } + + absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const { + for (const auto& [id, field_corpus] : corpus_value) { + absl::Nullable field = GetFieldById(id); + if (field == nullptr) continue; + absl::Status result; + VisitFlatbufferField(schema_, field, + ValidateVisitor{*this, field_corpus, result}); + if (!result.ok()) return result; + } + return absl::OkStatus(); + } + + value_type GetValue(const corpus_type& value) const { + // Untyped domain does not support GetValue since if it is a nested table + // it would need the top level table corpus value to be able to build it. + return nullptr; + } + + // Converts the table pointer to a corpus value. + std::optional FromValue(const value_type& value) const { + if (value == nullptr) { + return std::nullopt; + } + corpus_type ret; + for (const auto* field : *table_object_->fields()) { + VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, ret}); + } + return ret; + } + + // Converts the IRObject to a corpus value. + std::optional ParseCorpus(const IRObject& obj) const { + corpus_type out; + auto subs = obj.Subs(); + if (!subs) { + return std::nullopt; + } + out.reserve(subs->size()); + for (const auto& sub : *subs) { + auto pair_subs = sub.Subs(); + if (!pair_subs || pair_subs->size() != 2) { + return std::nullopt; + } + auto id = (*pair_subs)[0].GetScalar(); + if (!id.has_value()) { + return std::nullopt; + } + absl::Nullable field = GetFieldById(id.value()); + if (field == nullptr) { + return std::nullopt; + } + std::optional inner_parsed; + VisitFlatbufferField(schema_, field, + ParseVisitor{*this, (*pair_subs)[1], inner_parsed}); + if (!inner_parsed) { + return std::nullopt; + } + out[id.value()] = *std::move(inner_parsed); + } + return out; + } + + // Converts the corpus value to an IRObject. + IRObject SerializeCorpus(const corpus_type& value) const { + IRObject out; + auto& subs = out.MutableSubs(); + subs.reserve(value.size()); + for (const auto& [id, field_corpus] : value) { + absl::Nullable field = GetFieldById(id); + if (field == nullptr) { + continue; + } + IRObject& pair = subs.emplace_back(); + auto& pair_subs = pair.MutableSubs(); + pair_subs.reserve(2); + pair_subs.emplace_back(field->id()); + VisitFlatbufferField( + schema_, field, + SerializeVisitor{*this, field_corpus, pair_subs.emplace_back()}); + } + return out; + } + + uint32_t BuildTable(const corpus_type& value, + flatbuffers::FlatBufferBuilder& builder) const { + // Add all the fields to the builder. + absl::flat_hash_map + offsets; + for (const auto& [id, field_corpus] : value) { + absl::Nullable field = GetFieldById(id); + if (field == nullptr) { + continue; + } + VisitFlatbufferField( + schema_, field, + TableFieldBuilderVisitor{*this, builder, offsets, field_corpus}); + } + // Build the table with the out of line fields offsets and inline fields. + uint32_t table_start = builder.StartTable(); + for (const auto& [id, field_corpus] : value) { + absl::Nullable field = GetFieldById(id); + if (field == nullptr) { + continue; + } + VisitFlatbufferField( + schema_, field, + TableBuilderVisitor{*this, builder, offsets, field_corpus}); + } + return builder.EndTable(table_start); + } + + private: + const reflection::Schema* absl_nonnull schema_; + const reflection::Object* absl_nonnull table_object_; + mutable absl::Mutex mutex_; + mutable absl::flat_hash_map + domains_ ABSL_GUARDED_BY(mutex_); + + absl::Nullable GetFieldById( + typename corpus_type::key_type id) const { + const auto it = + absl::c_find_if(*table_object_->fields(), + [id](const auto* field) { return field->id() == id; }); + return it != table_object_->fields()->end() ? *it : nullptr; + } + + bool IsScalarType(reflection::BaseType base_type) const { + switch (base_type) { + case reflection::BaseType::Bool: + case reflection::BaseType::Byte: + case reflection::BaseType::Short: + case reflection::BaseType::Int: + case reflection::BaseType::Long: + case reflection::BaseType::UByte: + case reflection::BaseType::UShort: + case reflection::BaseType::UInt: + case reflection::BaseType::ULong: + case reflection::BaseType::Float: + case reflection::BaseType::Double: + return true; + default: + return false; + } + } + + bool IsTypeSupported(reflection::BaseType base_type) const { + return IsScalarType(base_type) || base_type == reflection::BaseType::String; + } + + struct SerializeVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + const GenericDomainCorpusType& corpus_value; + IRObject& out; + + template + void Visit(const reflection::Field* absl_nonnull field) { + out = self.GetSubDomain(field).SerializeCorpus(corpus_value); + } + }; + + struct FromValueVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + value_type value; + corpus_type& out; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + [[maybe_unused]] + reflection::BaseType base_type = field->type()->base_type(); + auto& domain = self.GetSubDomain(field); + value_type_t> inner_value; + + if constexpr (is_flatbuffers_enum_tag_v) { + FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Byte && + base_type <= reflection::BaseType::ULong, + "Field must be an enum type."); + if (field->optional() && !value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + inner_value = + std::make_optional(value->template GetField( + field->offset(), field->default_integer())); + } + } else if constexpr (std::is_integral_v) { + FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Bool && + base_type <= reflection::BaseType::ULong, + "Field must be an integer type."); + if (field->optional() && !value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + inner_value = std::make_optional(value->template GetField( + field->offset(), field->default_integer())); + } + } else if constexpr (std::is_floating_point_v) { + FUZZTEST_INTERNAL_CHECK(base_type >= reflection::BaseType::Float && + base_type <= reflection::BaseType::Double, + "Field must be a floating point type."); + if (field->optional() && !value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + inner_value = std::make_optional(value->template GetField( + field->offset(), field->default_real())); + } + } else if constexpr (std::is_same_v) { + FUZZTEST_INTERNAL_CHECK(base_type == reflection::BaseType::String, + "Field must be a string type."); + if (!value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + inner_value = std::make_optional( + value->template GetPointer(field->offset()) + ->str()); + } + } else if constexpr (std::is_same_v) { + auto sub_object = self.schema_->objects()->Get(field->type()->index()); + FUZZTEST_INTERNAL_CHECK( + base_type == reflection::BaseType::Obj && !sub_object->is_struct(), + "Field must be a table type."); + inner_value = value->template GetPointer( + field->offset()); + } else if constexpr (is_flatbuffers_vector_tag_v) { + FUZZTEST_INTERNAL_CHECK(base_type == reflection::BaseType::Vector || + base_type == reflection::BaseType::Vector64, + "Field must be a vector type."); + if (!value->CheckField(field->offset())) { + inner_value = std::nullopt; + } else { + VisitVector>( + field, inner_value); + } + } else if constexpr (std::is_same_v) { + constexpr char kUnionTypeFieldSuffix[] = "_type"; + auto enumdef = self.schema_->enums()->Get(field->type()->index()); + auto type_field = self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + if (type_field == nullptr) { + return; + } + auto union_type = + value->template GetField(type_field->offset(), 0); + auto enumval = enumdef->values()->LookupByKey(union_type); + auto union_object = + self.schema_->objects()->Get(enumval->union_type()->index()); + if (union_object->is_struct()) { + // TODO: (b/405939014) support structs in unions. + } else { + auto union_value = + value->template GetPointer(field->offset()); + inner_value = std::make_pair(union_type, union_value); + } + } + + auto inner = domain.FromValue(inner_value); + if (inner) { + out[field->id()] = *std::move(inner); + } + }; + + template + void VisitVector(const reflection::Field* field, + value_type_t& inner_value) const { + if constexpr (std::is_integral_v || + std::is_floating_point_v) { + auto vec = + value->template GetPointer*>( + field->offset()); + inner_value = std::make_optional(std::vector()); + inner_value->reserve(vec->size()); + for (auto i = 0; i < vec->size(); ++i) { + inner_value->push_back(vec->Get(i)); + } + } else if constexpr (is_flatbuffers_enum_tag_v) { + using Underlaying = typename ElementType::type; + auto vec = + value->template GetPointer*>( + field->offset()); + inner_value = std::make_optional(std::vector()); + inner_value->reserve(vec->size()); + for (auto i = 0; i < vec->size(); ++i) { + inner_value->push_back(vec->Get(i)); + } + } else if constexpr (std::is_same_v) { + auto vec = value->template GetPointer< + flatbuffers::Vector>*>( + field->offset()); + inner_value = std::make_optional(std::vector()); + inner_value->reserve(vec->size()); + for (auto i = 0; i < vec->size(); ++i) { + inner_value->push_back(vec->Get(i)->str()); + } + } else if constexpr (std::is_same_v) { + auto vec = value->template GetPointer< + flatbuffers::Vector>*>( + field->offset()); + inner_value = + std::make_optional(std::vector()); + inner_value->reserve(vec->size()); + for (auto i = 0; i < vec->size(); ++i) { + inner_value->push_back(vec->Get(i)); + } + } else if constexpr (std::is_same_v) { + constexpr char kUnionTypeFieldSuffix[] = "_type"; + auto type_field = self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + if (type_field == nullptr) { + return; + } + auto type_vec = + value->template GetPointer*>( + type_field->offset()); + auto value_vec = value->template GetPointer< + flatbuffers::Vector>*>(field->offset()); + inner_value = std::make_optional( + typename std::decay_t::value_type{}); + inner_value->reserve(value_vec->size()); + for (auto i = 0; i < value_vec->size(); ++i) { + inner_value->push_back( + std::make_pair(type_vec->Get(i), value_vec->Get(i))); + } + } + } + }; + + struct TableFieldBuilderVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + flatbuffers::FlatBufferBuilder& builder; + absl::flat_hash_map& + offsets; + const typename corpus_type::value_type::second_type& corpus_value; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + if constexpr (std::is_same_v) { + auto& domain = self.GetSubDomain(field); + auto user_value = domain.GetValue(corpus_value); + if (user_value.has_value()) { + auto offset = + builder.CreateString(user_value->data(), user_value->size()).o; + offsets.insert({field->id(), offset}); + } + } else if constexpr (std::is_same_v) { + FlatbuffersTableUntypedDomainImpl inner_domain( + self.schema_, self.schema_->objects()->Get(field->type()->index())); + auto opt_corpus = corpus_value.template GetAs< + std::variant>(); + if (std::holds_alternative( + opt_corpus)) { + auto inner_corpus = + std::get(opt_corpus) + .template GetAs(); + auto offset = inner_domain.BuildTable(inner_corpus, builder); + offsets.insert({field->id(), offset}); + } + } else if constexpr (is_flatbuffers_vector_tag_v) { + VisitVector(field, self.GetSubDomain(field)); + } else if constexpr (std::is_same_v) { + const reflection::Enum* union_type = + self.schema_->enums()->Get(field->type()->index()); + FlatbuffersUnionDomainImpl inner_domain{self.schema_, union_type}; + auto opt_corpus = corpus_value.template GetAs< + std::variant>(); + if (std::holds_alternative( + opt_corpus)) { + auto inner_corpus = + std::get(opt_corpus) + .template GetAs>(); + auto offset = inner_domain.BuildValue(inner_corpus, builder); + if (offset.has_value()) { + offsets.insert({field->id(), *offset}); + } + } + } + } + + private: + template + void VisitVector(const reflection::Field* field, + const Domain& domain) const { + if constexpr (std::is_integral_v || + std::is_floating_point_v) { + auto value = domain.GetValue(corpus_value); + if (!value) { + return; + } + offsets.insert({field->id(), builder.CreateVector(*value).o}); + } else if constexpr (is_flatbuffers_enum_tag_v) { + auto value = domain.GetValue(corpus_value); + if (!value) { + return; + } + offsets.insert({field->id(), builder.CreateVector(*value).o}); + } + if constexpr (std::is_same_v) { + FlatbuffersTableUntypedDomainImpl domain( + self.schema_, self.schema_->objects()->Get(field->type()->index())); + auto opt_corpus = corpus_value.template GetAs< + std::variant>(); + if (std::holds_alternative(opt_corpus)) { + return; + } + auto container_corpus = + std::get(opt_corpus) + .template GetAs>(); + std::vector> vec_offsets; + for (auto& inner_corpus : container_corpus) { + auto offset = domain.BuildTable(inner_corpus, builder); + vec_offsets.push_back(offset); + } + offsets.insert({field->id(), builder.CreateVector(vec_offsets).o}); + } else if constexpr (std::is_same_v) { + auto value = domain.GetValue(corpus_value); + if (!value) { + return; + } + std::vector> vec_offsets; + for (const auto& str : *value) { + auto offset = builder.CreateString(str); + vec_offsets.push_back(offset); + } + offsets.insert({field->id(), builder.CreateVector(vec_offsets).o}); + } else if constexpr (std::is_same_v) { + const reflection::Enum* union_type = + self.schema_->enums()->Get(field->type()->index()); + FlatbuffersUnionDomainImpl domain{self.schema_, union_type}; + constexpr char kUnionTypeFieldSuffix[] = "_type"; + const reflection::Field* type_field = + self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + + auto opt_corpus = corpus_value.template GetAs< + std::variant>(); + if (std::holds_alternative(opt_corpus)) { + return; + } + auto container_corpus = + std::get(opt_corpus) + .template GetAs>>(); + + std::vector< + typename value_type_t>::first_type> + vec_types; + std::vector> vec_offsets; + for (auto& inner_corpus : container_corpus) { + auto offset = domain.BuildValue(inner_corpus, builder); + if (offset.has_value()) { + vec_offsets.push_back(*offset); + vec_types.push_back(domain.GetType(inner_corpus)); + } + } + offsets.insert({field->id(), builder.CreateVector(vec_offsets).o}); + offsets.insert({type_field->id(), builder.CreateVector(vec_types).o}); + } + } + }; + + struct TableBuilderVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + flatbuffers::FlatBufferBuilder& builder; + absl::flat_hash_map& + offsets; + const typename corpus_type::value_type::second_type& corpus_value; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + auto size = flatbuffers::GetTypeSize(field->type()->base_type()); + if constexpr (std::is_integral_v || std::is_floating_point_v || + is_flatbuffers_enum_tag_v) { + auto& domain = self.GetSubDomain(field); + auto v = domain.GetValue(corpus_value); + if (!v) { + return; + } + builder.Align(size); + builder.PushBytes(reinterpret_cast(&v), size); + builder.TrackField(field->offset(), builder.GetSize()); + } else if constexpr (std::is_same_v || + is_flatbuffers_vector_tag_v) { + if constexpr (is_flatbuffers_vector_tag_v) { + if constexpr (std::is_same_v) { + constexpr char kUnionTypeFieldSuffix[] = "_type"; + const reflection::Field* type_field = + self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + if (auto it = offsets.find(type_field->id()); it != offsets.end()) { + builder.AddOffset(type_field->offset(), + flatbuffers::Offset<>(it->second)); + } + } + } + if (auto it = offsets.find(field->id()); it != offsets.end()) { + builder.AddOffset( + field->offset(), + flatbuffers::Offset(it->second)); + } + } else if constexpr (std::is_same_v) { + if (auto it = offsets.find(field->id()); it != offsets.end()) { + builder.AddOffset( + field->offset(), + flatbuffers::Offset(it->second)); + } + } else if constexpr (std::is_same_v) { + const reflection::Enum* union_type = + self.schema_->enums()->Get(field->type()->index()); + FlatbuffersUnionDomainImpl domain(self.schema_, union_type); + if (auto it = offsets.find(field->id()); it != offsets.end()) { + builder.AddOffset(field->offset(), + flatbuffers::Offset(it->second)); + + constexpr char kUnionTypeFieldSuffix[] = "_type"; + const reflection::Field* type_field = + self.table_object_->fields()->LookupByKey( + (field->name()->str() + kUnionTypeFieldSuffix).c_str()); + auto opt_corpus = corpus_value.template GetAs>(); + if (std::holds_alternative(opt_corpus)) { + return; + } + auto inner_corpus = + std::get(opt_corpus) + .template GetAs>(); + auto type_value = domain.GetType(inner_corpus); + auto size = flatbuffers::GetTypeSize(type_field->type()->base_type()); + builder.Align(size); + builder.PushBytes(reinterpret_cast(&type_value), + size); + builder.TrackField(type_field->offset(), builder.GetSize()); + } + } + } + }; + + struct ParseVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + const IRObject& obj; + std::optional& out; + + template + void Visit(const reflection::Field* absl_nonnull field) { + out = self.GetSubDomain(field).ParseCorpus(obj); + } + }; + + struct ValidateVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + const GenericDomainCorpusType& corpus_value; + absl::Status& out; + + template + void Visit(const reflection::Field* absl_nonnull field) { + auto& domain = self.GetSubDomain(field); + out = domain.ValidateCorpusValue(corpus_value); + if (!out.ok()) { + out = Prefix(out, absl::StrCat("Invalid value for field ", + field->name()->str())); + } + } + }; + + struct InitializeVisitor { + FlatbuffersTableUntypedDomainImpl& self; + absl::BitGenRef prng; + corpus_type& val; + + template + void Visit(const reflection::Field* absl_nonnull field) { + auto& domain = self.GetSubDomain(field); + val[field->id()] = domain.Init(prng); + } + }; + + struct MutateVisitor { + FlatbuffersTableUntypedDomainImpl& self; + absl::BitGenRef prng; + const domain_implementor::MutationMetadata& metadata; + bool only_shrink; + corpus_type& val; + + template + void Visit(const reflection::Field* absl_nonnull field) { + auto& domain = self.GetSubDomain(field); + if (auto it = val.find(field->id()); it != val.end()) { + domain.Mutate(it->second, prng, metadata, only_shrink); + } else if (!only_shrink) { + val[field->id()] = domain.Init(prng); + } + } + }; + + struct Printer { + const FlatbuffersTableUntypedDomainImpl& self; + + void PrintCorpusValue(const corpus_type& value, + domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + absl::Format(out, "{"); + bool first = true; + for (const auto& [id, field_corpus] : value) { + if (!first) { + absl::Format(out, ", "); + } + absl::Nullable field = self.GetFieldById(id); + if (field == nullptr) { + absl::Format(out, "", id); + } else { + VisitFlatbufferField(self.schema_, field, + PrinterVisitor{self, field_corpus, out, mode}); + } + first = false; + } + absl::Format(out, "}"); + } + }; + + struct PrinterVisitor { + const FlatbuffersTableUntypedDomainImpl& self; + const GenericDomainCorpusType& val; + domain_implementor::RawSink out; + domain_implementor::PrintMode mode; + + template + void Visit(const reflection::Field* absl_nonnull field) const { + auto& domain = self.GetSubDomain(field); + absl::Format(out, "%s: ", field->name()->str()); + domain_implementor::PrintValue(domain, val, out, mode); + } + }; +}; + +// Domain implementation for flatbuffers generated table classes. +template +class FlatbuffersTableDomainImpl + : public domain_implementor::DomainBase< + FlatbuffersTableDomainImpl, const T* absl_nonnull, + std::pair, + std::vector>> { + public: + static_assert( + Requires([](auto) -> decltype(T::BinarySchema::data()) {}), + "The flatbuffers generated class must be generated with the " + "`--bfbs-gen-embed` flag."); + static_assert( + Requires([](auto) -> decltype(T::GetFullyQualifiedName()) {}), + "The flatbuffers generated class must be generated with the " + "`--gen-name-strings` flag."); + + using typename FlatbuffersTableDomainImpl::DomainBase::corpus_type; + using typename FlatbuffersTableDomainImpl::DomainBase::value_type; + + FlatbuffersTableDomainImpl() { + flatbuffers::Verifier verifier(T::BinarySchema::data(), + T::BinarySchema::size()); + FUZZTEST_INTERNAL_CHECK(reflection::VerifySchemaBuffer(verifier), + "Invalid schema for flatbuffers table."); + auto schema = reflection::GetSchema(T::BinarySchema::data()); + auto table_object = + schema->objects()->LookupByKey(T::GetFullyQualifiedName()); + inner_ = FlatbuffersTableUntypedDomainImpl{schema, table_object}; + } + + FlatbuffersTableDomainImpl(const FlatbuffersTableDomainImpl& other) + : inner_(other.inner_) { + builder_.Clear(); + } + + FlatbuffersTableDomainImpl& operator=( + const FlatbuffersTableDomainImpl& other) { + if (this == &other) return *this; + inner_ = other.inner_; + builder_.Clear(); + return *this; + } + + FlatbuffersTableDomainImpl(FlatbuffersTableDomainImpl&& other) + : inner_(std::move(other.inner_)) { + builder_.Clear(); + } + + FlatbuffersTableDomainImpl& operator=(FlatbuffersTableDomainImpl&& other) { + if (this == &other) return *this; + inner_ = std::move(other.inner_); + builder_.Clear(); + return *this; + } + + // Initializes the table with random values. + corpus_type Init(absl::BitGenRef prng) { + if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed; + auto val = inner_->Init(prng); + auto offset = inner_->BuildTable(val, builder_); + builder_.Finish(flatbuffers::Offset(offset)); + auto buffer = + std::vector(builder_.GetBufferPointer(), + builder_.GetBufferPointer() + builder_.GetSize()); + builder_.Clear(); + return std::make_pair(val, std::move(buffer)); + } + + // Returns the number of fields in the table. + uint64_t CountNumberOfFields(corpus_type& val) { + return inner_->CountNumberOfFields(val.first); + } + + // Mutates the given corpus value. + void Mutate(corpus_type& val, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, + bool only_shrink) { + inner_->Mutate(val.first, prng, metadata, only_shrink); + val.second = BuildBuffer(val.first); + } + + // Returns the parsed corpus value. + value_type GetValue(const corpus_type& value) const { + return flatbuffers::GetRoot(value.second.data()); + } + + // Returns the parsed corpus value. + std::optional FromValue(const value_type& value) const { + auto val = inner_->FromValue((const flatbuffers::Table*)value); + if (!val.has_value()) return std::nullopt; + return std::make_optional(std::make_pair(*val, BuildBuffer(*val))); + } + + // Returns the printer for the table. + auto GetPrinter() const { return Printer{*inner_}; } + + // Returns the parsed corpus value. + std::optional ParseCorpus(const IRObject& obj) const { + auto val = inner_->ParseCorpus(obj); + if (!val.has_value()) return std::nullopt; + return std::make_optional(std::make_pair(*val, BuildBuffer(*val))); + } + + // Returns the serialized corpus value. + IRObject SerializeCorpus(const corpus_type& corpus_value) const { + return inner_->SerializeCorpus(corpus_value.first); + } + + // Returns the status of the given corpus value. + absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const { + return inner_->ValidateCorpusValue(corpus_value.first); + } + + private: + std::optional inner_; + mutable flatbuffers::FlatBufferBuilder builder_; + + struct Printer { + const FlatbuffersTableUntypedDomainImpl& inner; + + void PrintCorpusValue(const corpus_type& value, + domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + inner.GetPrinter().PrintCorpusValue(value.first, out, mode); + } + }; + + std::vector BuildBuffer( + const typename corpus_type::first_type& val) const { + auto offset = inner_->BuildTable(val, builder_); + builder_.Finish(flatbuffers::Offset(offset)); + auto buffer = + std::vector(builder_.GetBufferPointer(), + builder_.GetBufferPointer() + builder_.GetSize()); + builder_.Clear(); + return buffer; + } +}; + +template +class ArbitraryImpl>> + : public FlatbuffersTableDomainImpl {}; +} // namespace fuzztest::internal +#endif // FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLATBUFFERS_DOMAIN_IMPL_H_ diff --git a/fuzztest/internal/meta.h b/fuzztest/internal/meta.h index 4ddada107..83e84f64a 100644 --- a/fuzztest/internal/meta.h +++ b/fuzztest/internal/meta.h @@ -200,6 +200,24 @@ template inline constexpr bool is_protocol_buffer_enum_v = IsProtocolBufferEnumImpl(true); +template +inline constexpr bool is_flatbuffers_table_v = false; + +// Flatbuffers tables generated structs do not have a public base class, so we +// check for a few specific methods: +// - T is a struct. +// - T has a `Builder` type. +// - T has a `BinarySchema` type with a `data()` method (only available when +// passing `--bfbs-gen-embed` to the flatbuffer compiler) +// - T has a static method called `GetFullyQualifiedName` (only available when +// passing `--gen-name-strings` to the flatbuffer compiler). +template +inline constexpr bool + is_flatbuffers_table_v>> = + Requires([](auto) -> typename T::Builder {}) && + Requires([](auto) -> decltype(T::BinarySchema::data()) {}) && + Requires([](auto) -> decltype(T::GetFullyQualifiedName()) {}); + template inline constexpr bool has_size_v = Requires([](auto v) -> decltype(v.size()) {}); diff --git a/fuzztest/internal/test_flatbuffers.fbs b/fuzztest/internal/test_flatbuffers.fbs new file mode 100644 index 000000000..0efb9eb6c --- /dev/null +++ b/fuzztest/internal/test_flatbuffers.fbs @@ -0,0 +1,54 @@ +namespace fuzztest.internal; + +enum TestFbsEnum: byte { + First, + Second, + Third +} + +table SimpleTestFbsTable { + b: bool; + f: float; + str: string; + e: TestFbsEnum; +} + +table NestedTestFbsTable { + t: SimpleTestFbsTable; +} + +table OptionalRequiredTestFbsTable { + def_scalar: bool = true; + opt_scalar: bool = null; + req_str: string (required); + opt_str: string; +} + +table VectorsTestFbsTable { + b: [bool]; + i8: [byte]; + i16: [short]; + i32: [int]; + i64: [long]; + u8: [ubyte]; + u16: [ushort]; + u32: [uint]; + u64: [ulong]; + f: [float]; + d: [double]; + str: [string]; + e: [TestFbsEnum]; + t: [SimpleTestFbsTable]; +} + +union Union { + OptionalRequiredTestFbsTable, + SimpleTestFbsTable, +} + +table UnionTestFbsTable { + u: Union; + u_vec: [Union]; +} + +root_type SimpleTestFbsTable;