From 9fd75a001f809f07aa99106063cd591115ee6f08 Mon Sep 17 00:00:00 2001 From: Dimitar Kanaliev Date: Mon, 30 Sep 2024 14:12:36 +0300 Subject: [PATCH 1/5] WIP: Add a signed cast for tnum --- include/linux/tnum.h | 3 + kernel/bpf/tnum.c | 33 +++++ kernel/bpf/verifier.c | 64 +++------ tools/testing/selftests/bpf/test_tnum.c | 181 ++++++++++++++++++++++++ tools/testing/selftests/bpf/tnum.h | 26 ++++ 5 files changed, 259 insertions(+), 48 deletions(-) create mode 100644 tools/testing/selftests/bpf/test_tnum.c create mode 100644 tools/testing/selftests/bpf/tnum.h diff --git a/include/linux/tnum.h b/include/linux/tnum.h index 3c13240077b87a..6933db04c9ee77 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -55,6 +55,9 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b); /* Return @a with all but the lowest @size bytes cleared */ struct tnum tnum_cast(struct tnum a, u8 size); +/* Return @a sign-extended from @size bytes */ +struct tnum tnum_scast(struct tnum a, u8 size); + /* Returns true if @a is a known constant */ static inline bool tnum_is_const(struct tnum a) { diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index 9dbc31b25e3d08..3de80851c53e46 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -157,6 +157,38 @@ struct tnum tnum_cast(struct tnum a, u8 size) return a; } +struct tnum tnum_scast(struct tnum a, u8 size) +{ + int s = size * 8 - 1; + u64 sign_mask, higher_bits, new_value, new_mask; + + sign_mask = 1ULL << s; + + if (s >= 63) { + higher_bits = 0; + } else { + higher_bits = ~((1ULL << (s + 1)) - 1); + } + + new_value = a.value; + new_mask = a.mask; + + if (a.mask & sign_mask) { + new_value &= (1ULL << (s + 1)) - 1; + new_mask |= higher_bits; + } else { + if (a.value & sign_mask) { + new_value |= higher_bits; + new_mask &= ~higher_bits; + } else { + new_value &= (1ULL << (s + 1)) - 1; + new_mask &= ~higher_bits; + } + } + + return TNUM(new_value, new_mask); +} + bool tnum_is_aligned(struct tnum a, u64 size) { if (!size) @@ -211,3 +243,4 @@ struct tnum tnum_const_subreg(struct tnum a, u32 value) { return tnum_with_subreg(a, tnum_const(value)); } + diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 9a7ed527e47e34..0b8b7d5eca7625 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -6286,63 +6286,31 @@ static void set_sext64_default_val(struct bpf_reg_state *reg, int size) reg->var_off = tnum_unknown; } + static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) { - s64 init_s64_max, init_s64_min, s64_max, s64_min, u64_cval; - u64 top_smax_value, top_smin_value; - u64 num_bits = size * 8; + s64 smin, smax; + u64 umax, umin; - if (tnum_is_const(reg->var_off)) { - u64_cval = reg->var_off.value; - if (size == 1) - reg->var_off = tnum_const((s8)u64_cval); - else if (size == 2) - reg->var_off = tnum_const((s16)u64_cval); - else - /* size == 4 */ - reg->var_off = tnum_const((s32)u64_cval); - - u64_cval = reg->var_off.value; - reg->smax_value = reg->smin_value = u64_cval; - reg->umax_value = reg->umin_value = u64_cval; - reg->s32_max_value = reg->s32_min_value = u64_cval; - reg->u32_max_value = reg->u32_min_value = u64_cval; - return; - } + reg->var_off = tnum_scast(reg->var_off, size); - top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits; - top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits; + reg->smin_value = reg->var_off.value; + reg->smax_value = reg->var_off.value | reg->var_off.mask; - if (top_smax_value != top_smin_value) - goto out; + reg->umin_value = (u64)reg->smin_value; + reg->umax_value = (u64)reg->smax_value; - /* find the s64_min and s64_min after sign extension */ - if (size == 1) { - init_s64_max = (s8)reg->smax_value; - init_s64_min = (s8)reg->smin_value; - } else if (size == 2) { - init_s64_max = (s16)reg->smax_value; - init_s64_min = (s16)reg->smin_value; - } else { - init_s64_max = (s32)reg->smax_value; - init_s64_min = (s32)reg->smin_value; + if (size <= 4) { + reg->s32_min_value = (s32)reg->smin_value; + reg->s32_max_value = (s32)reg->smax_value; + reg->u32_min_value = (u32)reg->umin_value; + reg->u32_max_value = (u32)reg->umax_value; } - s64_max = max(init_s64_max, init_s64_min); - s64_min = min(init_s64_max, init_s64_min); - - /* both of s64_max/s64_min positive or negative */ - if ((s64_max >= 0) == (s64_min >= 0)) { - reg->smin_value = reg->s32_min_value = s64_min; - reg->smax_value = reg->s32_max_value = s64_max; - reg->umin_value = reg->u32_min_value = s64_min; - reg->umax_value = reg->u32_max_value = s64_max; - reg->var_off = tnum_range(s64_min, s64_max); - return; - } + if (size < 4) + __mark_reg32_unbounded(reg); -out: - set_sext64_default_val(reg, size); + reg_bounds_sync(reg); } static void set_sext32_default_val(struct bpf_reg_state *reg, int size) diff --git a/tools/testing/selftests/bpf/test_tnum.c b/tools/testing/selftests/bpf/test_tnum.c new file mode 100644 index 00000000000000..442a81228f5a91 --- /dev/null +++ b/tools/testing/selftests/bpf/test_tnum.c @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* test_tnum.c: Selftests for tnum_scast function + * + * This program tests the tnum_scast function + */ + +#include +#include +#include +#include +#include +#include + +#include "tnum.h" + +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) + +struct tnum tnum_scast(struct tnum a, u8 size) +{ + int s = size * 8 - 1; + u64 sign_mask, higher_bits, new_value, new_mask; + + sign_mask = 1ULL << s; + + if (s >= 63) { + higher_bits = 0; + } else { + higher_bits = ~((1ULL << (s + 1)) - 1); + } + + new_value = a.value; + new_mask = a.mask; + + if (a.mask & sign_mask) { + new_value &= (1ULL << (s + 1)) - 1; + new_mask |= higher_bits; + } else { + if (a.value & sign_mask) { + new_value |= higher_bits; + new_mask &= ~higher_bits; + } else { + new_value &= (1ULL << (s + 1)) - 1; + new_mask &= ~higher_bits; + } + } + + return TNUM(new_value, new_mask); +} + + +struct tnum_test_case { + const char *description; + struct tnum input; + u8 size; + struct tnum expected; +}; + +static int test_tnum_scast(void) +{ + int i, err = 0; + struct tnum result; + + /* Define test cases */ + struct tnum_test_case tests[] = { + { + .description = "Known positive value (8-bit)", + .input = TNUM(0x7F, 0x00), // 127 in decimal + .size = 1, + .expected = TNUM(0x000000000000007F, 0x0000000000000000), + }, + { + .description = "Known negative value (8-bit)", + .input = TNUM(0xFF, 0x00), // -1 in 8-bit signed + .size = 1, + .expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Unknown sign bit (8-bit)", + .input = TNUM(0x7F, 0x80), // Value 127, sign bit unknown + .size = 1, + .expected = TNUM(0x000000000000007F, 0xFFFFFFFFFFFFFF80), + }, + { + .description = "Completely unknown value (8-bit)", + .input = TNUM(0x00, 0xFF), // All bits unknown + .size = 1, + .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + }, + { + .description = "Known positive value (16-bit)", + .input = TNUM(0x7FFF, 0x0000), + .size = 2, + .expected = TNUM(0x0000000000007FFF, 0x0000000000000000), + }, + { + .description = "Known negative value (16-bit)", + .input = TNUM(0xFFFF, 0x0000), // -1 in 16-bit signed + .size = 2, + .expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Unknown sign bit (16-bit)", + .input = TNUM(0x7FFF, 0x8000), + .size = 2, + .expected = TNUM(0x0000000000007FFF, 0xFFFFFFFFFFFF8000), + }, + { + .description = "Known negative value (32-bit)", + .input = TNUM(0xFFFFFFFF, 0x00000000), // -1 in 32-bit signed + .size = 4, + .expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Unknown sign bit (32-bit)", + .input = TNUM(0x7FFFFFFF, 0x80000000), + .size = 4, + .expected = TNUM(0x000000007FFFFFFF, 0xFFFFFFFF80000000), + }, + { + .description = "Completely unknown value (32-bit)", + .input = TNUM(0x00000000, 0xFFFFFFFF), + .size = 4, + .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + }, + { + .description = "Known positive value (64-bit)", + .input = TNUM(0x7FFFFFFFFFFFFFFF, 0x0000000000000000), + .size = 8, + .expected = TNUM(0x7FFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Known negative value (64-bit)", + .input = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + .size = 8, + .expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Unknown sign bit (64-bit)", + .input = TNUM(0x7FFFFFFFFFFFFFFF, 0x8000000000000000ULL), + .size = 8, + .expected = TNUM(0x7FFFFFFFFFFFFFFF, 0x8000000000000000ULL), + }, + }; + + printf("Running tnum_scast tests...\n"); + + for (i = 0; i < ARRAY_SIZE(tests); i++) { + struct tnum_test_case *t = &tests[i]; + + result = tnum_scast(t->input, t->size); + + if (memcmp(&result, &t->expected, sizeof(struct tnum)) != 0) { + printf("Test %d failed: %s\n", i + 1, t->description); + printf(" Input: value=0x%016llx, mask=0x%016llx\n", + t->input.value, t->input.mask); + printf(" Expected: value=0x%016llx, mask=0x%016llx\n", + t->expected.value, t->expected.mask); + printf(" Got: value=0x%016llx, mask=0x%016llx\n", + result.value, result.mask); + err = 1; + } else { + printf("Test %d passed: %s\n", i + 1, t->description); + } + } + + if (err) + printf("tnum_scast tests failed.\n"); + else + printf("All tnum_scast tests passed successfully.\n"); + + return err; +} + +int main(int argc, char **argv) +{ + int err = 0; + + err |= test_tnum_scast(); + + return err; +} diff --git a/tools/testing/selftests/bpf/tnum.h b/tools/testing/selftests/bpf/tnum.h new file mode 100644 index 00000000000000..f46be04ef570e6 --- /dev/null +++ b/tools/testing/selftests/bpf/tnum.h @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* tnum.h: Header file for tnum utility functions */ + +#ifndef __TNUM_H__ +#define __TNUM_H__ + +#include + +typedef uint64_t u64; +typedef int64_t s64; +typedef uint32_t u32; +typedef int32_t s32; +typedef uint8_t u8; + +struct tnum { + u64 value; + u64 mask; +}; + +#define TNUM(_v, _m) (struct tnum){.value = (_v), .mask = (_m)} + +/* Function prototypes */ +struct tnum tnum_scast(struct tnum a, u8 size); + +#endif /* __TNUM_H__ */ + From 18a7ca845c01ee42b72448df3afccfb9619a4b89 Mon Sep 17 00:00:00 2001 From: Dimitar Kanaliev Date: Mon, 30 Sep 2024 14:44:10 +0300 Subject: [PATCH 2/5] Address UB on 64 bit shifts --- kernel/bpf/tnum.c | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index 3de80851c53e46..a5458ac6051804 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -159,36 +159,31 @@ struct tnum tnum_cast(struct tnum a, u8 size) struct tnum tnum_scast(struct tnum a, u8 size) { - int s = size * 8 - 1; - u64 sign_mask, higher_bits, new_value, new_mask; + u64 s = size * 8 - 1; + u64 sign_mask; + u64 value_mask; + u64 new_value, new_mask; - sign_mask = 1ULL << s; - - if (s >= 63) { - higher_bits = 0; - } else { - higher_bits = ~((1ULL << (s + 1)) - 1); + if (size >= 8) { + return a; } - new_value = a.value; - new_mask = a.mask; + sign_mask = 1ULL << s; + value_mask = (1ULL << (s + 1)) - 1; + + new_value = a.value & value_mask; + new_mask = a.mask & value_mask; if (a.mask & sign_mask) { - new_value &= (1ULL << (s + 1)) - 1; - new_mask |= higher_bits; - } else { - if (a.value & sign_mask) { - new_value |= higher_bits; - new_mask &= ~higher_bits; - } else { - new_value &= (1ULL << (s + 1)) - 1; - new_mask &= ~higher_bits; - } + new_mask |= ~value_mask; + } else if (a.value & sign_mask) { + new_value |= ~value_mask; } return TNUM(new_value, new_mask); } + bool tnum_is_aligned(struct tnum a, u64 size) { if (!size) From d7071f3e3977461e6126550de0563bd46b9d3ccf Mon Sep 17 00:00:00 2001 From: Dimitar Kanaliev Date: Mon, 30 Sep 2024 15:25:51 +0300 Subject: [PATCH 3/5] Update tests --- tools/testing/selftests/bpf/test_tnum.c | 80 +++++++++++++++---------- 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/tools/testing/selftests/bpf/test_tnum.c b/tools/testing/selftests/bpf/test_tnum.c index 442a81228f5a91..c072cc87eda9f1 100644 --- a/tools/testing/selftests/bpf/test_tnum.c +++ b/tools/testing/selftests/bpf/test_tnum.c @@ -17,37 +17,30 @@ struct tnum tnum_scast(struct tnum a, u8 size) { - int s = size * 8 - 1; - u64 sign_mask, higher_bits, new_value, new_mask; + u64 s = size * 8 - 1; + u64 sign_mask; + u64 value_mask; + u64 new_value, new_mask; - sign_mask = 1ULL << s; - - if (s >= 63) { - higher_bits = 0; - } else { - higher_bits = ~((1ULL << (s + 1)) - 1); + if (size >= 8) { + return a; } - new_value = a.value; - new_mask = a.mask; + sign_mask = 1ULL << s; + value_mask = (1ULL << (s + 1)) - 1; + + new_value = a.value & value_mask; + new_mask = a.mask & value_mask; if (a.mask & sign_mask) { - new_value &= (1ULL << (s + 1)) - 1; - new_mask |= higher_bits; - } else { - if (a.value & sign_mask) { - new_value |= higher_bits; - new_mask &= ~higher_bits; - } else { - new_value &= (1ULL << (s + 1)) - 1; - new_mask &= ~higher_bits; - } + new_mask |= ~value_mask; + } else if (a.value & sign_mask) { + new_value |= ~value_mask; } return TNUM(new_value, new_mask); } - struct tnum_test_case { const char *description; struct tnum input; @@ -62,6 +55,7 @@ static int test_tnum_scast(void) /* Define test cases */ struct tnum_test_case tests[] = { + /* 8-bit tests */ { .description = "Known positive value (8-bit)", .input = TNUM(0x7F, 0x00), // 127 in decimal @@ -86,6 +80,7 @@ static int test_tnum_scast(void) .size = 1, .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), }, + /* 16-bit tests */ { .description = "Known positive value (16-bit)", .input = TNUM(0x7FFF, 0x0000), @@ -104,6 +99,19 @@ static int test_tnum_scast(void) .size = 2, .expected = TNUM(0x0000000000007FFF, 0xFFFFFFFFFFFF8000), }, + { + .description = "Completely unknown value (16-bit)", + .input = TNUM(0x0000, 0xFFFF), + .size = 2, + .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + }, + /* 32-bit tests */ + { + .description = "Known positive value (32-bit)", + .input = TNUM(0x7FFFFFFF, 0x00000000), + .size = 4, + .expected = TNUM(0x000000007FFFFFFF, 0x0000000000000000), + }, { .description = "Known negative value (32-bit)", .input = TNUM(0xFFFFFFFF, 0x00000000), // -1 in 32-bit signed @@ -122,6 +130,7 @@ static int test_tnum_scast(void) .size = 4, .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), }, + /* 64-bit tests */ { .description = "Known positive value (64-bit)", .input = TNUM(0x7FFFFFFFFFFFFFFF, 0x0000000000000000), @@ -140,31 +149,39 @@ static int test_tnum_scast(void) .size = 8, .expected = TNUM(0x7FFFFFFFFFFFFFFF, 0x8000000000000000ULL), }, + { + .description = "Completely unknown value (64-bit)", + .input = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + .size = 8, + .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + }, }; - printf("Running tnum_scast tests...\n"); + printf("Running tnum_scast tests...\n\n"); for (i = 0; i < ARRAY_SIZE(tests); i++) { struct tnum_test_case *t = &tests[i]; result = tnum_scast(t->input, t->size); + printf("Test %d (%s, size=%d bytes):\n", i + 1, t->description, t->size); + printf(" Input: value=0x%016llx, mask=0x%016llx\n", + t->input.value, t->input.mask); + printf(" Expected: value=0x%016llx, mask=0x%016llx\n", + t->expected.value, t->expected.mask); + printf(" Result: value=0x%016llx, mask=0x%016llx\n", + result.value, result.mask); + if (memcmp(&result, &t->expected, sizeof(struct tnum)) != 0) { - printf("Test %d failed: %s\n", i + 1, t->description); - printf(" Input: value=0x%016llx, mask=0x%016llx\n", - t->input.value, t->input.mask); - printf(" Expected: value=0x%016llx, mask=0x%016llx\n", - t->expected.value, t->expected.mask); - printf(" Got: value=0x%016llx, mask=0x%016llx\n", - result.value, result.mask); + printf(" Fail.\n\n"); err = 1; } else { - printf("Test %d passed: %s\n", i + 1, t->description); + printf(" Pass.\n\n"); } } if (err) - printf("tnum_scast tests failed.\n"); + printf("Some tnum_scast tests failed.\n"); else printf("All tnum_scast tests passed successfully.\n"); @@ -179,3 +196,4 @@ int main(int argc, char **argv) return err; } + From 81520dc39582c776cbd8507308d27d4d2343f2ee Mon Sep 17 00:00:00 2001 From: Dimitar Kanaliev Date: Tue, 1 Oct 2024 13:40:46 +0300 Subject: [PATCH 4/5] minimize branching --- kernel/bpf/tnum.c | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index a5458ac6051804..621d236ad4986c 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -163,6 +163,8 @@ struct tnum tnum_scast(struct tnum a, u8 size) u64 sign_mask; u64 value_mask; u64 new_value, new_mask; + u64 sign_bit_unknown, sign_bit_value; + u64 mask; if (size >= 8) { return a; @@ -174,16 +176,18 @@ struct tnum tnum_scast(struct tnum a, u8 size) new_value = a.value & value_mask; new_mask = a.mask & value_mask; - if (a.mask & sign_mask) { - new_mask |= ~value_mask; - } else if (a.value & sign_mask) { - new_value |= ~value_mask; - } + sign_bit_unknown = (a.mask >> s) & 1; + sign_bit_value = (a.value >> s) & 1; + + mask = ~value_mask; + + new_mask |= mask & (0 - sign_bit_unknown); + + new_value |= mask & (0 - ((sign_bit_unknown ^ 1) & sign_bit_value)); return TNUM(new_value, new_mask); } - bool tnum_is_aligned(struct tnum a, u64 size) { if (!size) From 23a77edabd9de5e6318552e8dd2aa8d176a4faa3 Mon Sep 17 00:00:00 2001 From: Dimitar Kanaliev Date: Mon, 14 Oct 2024 17:19:16 +0300 Subject: [PATCH 5/5] attempt to better deduce smin/max values --- kernel/bpf/verifier.c | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 0b8b7d5eca7625..3ef154b3cfa8b0 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -6286,16 +6286,12 @@ static void set_sext64_default_val(struct bpf_reg_state *reg, int size) reg->var_off = tnum_unknown; } - static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) { - s64 smin, smax; - u64 umax, umin; - reg->var_off = tnum_scast(reg->var_off, size); - reg->smin_value = reg->var_off.value; - reg->smax_value = reg->var_off.value | reg->var_off.mask; + reg->smin_value = (s64)(reg->var_off.value & ~reg->var_off.mask); + reg->smax_value = (s64)(reg->var_off.value | reg->var_off.mask); reg->umin_value = (u64)reg->smin_value; reg->umax_value = (u64)reg->smax_value; @@ -6305,11 +6301,13 @@ static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) reg->s32_max_value = (s32)reg->smax_value; reg->u32_min_value = (u32)reg->umin_value; reg->u32_max_value = (u32)reg->umax_value; + } else { + reg->s32_min_value = S32_MIN; + reg->s32_max_value = S32_MAX; + reg->u32_min_value = 0; + reg->u32_max_value = U32_MAX; } - if (size < 4) - __mark_reg32_unbounded(reg); - reg_bounds_sync(reg); }