diff --git a/include/linux/tnum.h b/include/linux/tnum.h index 3c13240077b87a..6933db04c9ee77 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -55,6 +55,9 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b); /* Return @a with all but the lowest @size bytes cleared */ struct tnum tnum_cast(struct tnum a, u8 size); +/* Return @a sign-extended from @size bytes */ +struct tnum tnum_scast(struct tnum a, u8 size); + /* Returns true if @a is a known constant */ static inline bool tnum_is_const(struct tnum a) { diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index 9dbc31b25e3d08..621d236ad4986c 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -157,6 +157,37 @@ struct tnum tnum_cast(struct tnum a, u8 size) return a; } +struct tnum tnum_scast(struct tnum a, u8 size) +{ + u64 s = size * 8 - 1; + u64 sign_mask; + u64 value_mask; + u64 new_value, new_mask; + u64 sign_bit_unknown, sign_bit_value; + u64 mask; + + if (size >= 8) { + return a; + } + + sign_mask = 1ULL << s; + value_mask = (1ULL << (s + 1)) - 1; + + new_value = a.value & value_mask; + new_mask = a.mask & value_mask; + + sign_bit_unknown = (a.mask >> s) & 1; + sign_bit_value = (a.value >> s) & 1; + + mask = ~value_mask; + + new_mask |= mask & (0 - sign_bit_unknown); + + new_value |= mask & (0 - ((sign_bit_unknown ^ 1) & sign_bit_value)); + + return TNUM(new_value, new_mask); +} + bool tnum_is_aligned(struct tnum a, u64 size) { if (!size) @@ -211,3 +242,4 @@ struct tnum tnum_const_subreg(struct tnum a, u32 value) { return tnum_with_subreg(a, tnum_const(value)); } + diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 9a7ed527e47e34..3ef154b3cfa8b0 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -6288,61 +6288,27 @@ static void set_sext64_default_val(struct bpf_reg_state *reg, int size) static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) { - s64 init_s64_max, init_s64_min, s64_max, s64_min, u64_cval; - u64 top_smax_value, top_smin_value; - u64 num_bits = size * 8; + reg->var_off = tnum_scast(reg->var_off, size); - if (tnum_is_const(reg->var_off)) { - u64_cval = reg->var_off.value; - if (size == 1) - reg->var_off = tnum_const((s8)u64_cval); - else if (size == 2) - reg->var_off = tnum_const((s16)u64_cval); - else - /* size == 4 */ - reg->var_off = tnum_const((s32)u64_cval); - - u64_cval = reg->var_off.value; - reg->smax_value = reg->smin_value = u64_cval; - reg->umax_value = reg->umin_value = u64_cval; - reg->s32_max_value = reg->s32_min_value = u64_cval; - reg->u32_max_value = reg->u32_min_value = u64_cval; - return; - } - - top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits; - top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits; + reg->smin_value = (s64)(reg->var_off.value & ~reg->var_off.mask); + reg->smax_value = (s64)(reg->var_off.value | reg->var_off.mask); - if (top_smax_value != top_smin_value) - goto out; + reg->umin_value = (u64)reg->smin_value; + reg->umax_value = (u64)reg->smax_value; - /* find the s64_min and s64_min after sign extension */ - if (size == 1) { - init_s64_max = (s8)reg->smax_value; - init_s64_min = (s8)reg->smin_value; - } else if (size == 2) { - init_s64_max = (s16)reg->smax_value; - init_s64_min = (s16)reg->smin_value; + if (size <= 4) { + reg->s32_min_value = (s32)reg->smin_value; + reg->s32_max_value = (s32)reg->smax_value; + reg->u32_min_value = (u32)reg->umin_value; + reg->u32_max_value = (u32)reg->umax_value; } else { - init_s64_max = (s32)reg->smax_value; - init_s64_min = (s32)reg->smin_value; - } - - s64_max = max(init_s64_max, init_s64_min); - s64_min = min(init_s64_max, init_s64_min); - - /* both of s64_max/s64_min positive or negative */ - if ((s64_max >= 0) == (s64_min >= 0)) { - reg->smin_value = reg->s32_min_value = s64_min; - reg->smax_value = reg->s32_max_value = s64_max; - reg->umin_value = reg->u32_min_value = s64_min; - reg->umax_value = reg->u32_max_value = s64_max; - reg->var_off = tnum_range(s64_min, s64_max); - return; + reg->s32_min_value = S32_MIN; + reg->s32_max_value = S32_MAX; + reg->u32_min_value = 0; + reg->u32_max_value = U32_MAX; } -out: - set_sext64_default_val(reg, size); + reg_bounds_sync(reg); } static void set_sext32_default_val(struct bpf_reg_state *reg, int size) diff --git a/tools/testing/selftests/bpf/test_tnum.c b/tools/testing/selftests/bpf/test_tnum.c new file mode 100644 index 00000000000000..c072cc87eda9f1 --- /dev/null +++ b/tools/testing/selftests/bpf/test_tnum.c @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* test_tnum.c: Selftests for tnum_scast function + * + * This program tests the tnum_scast function + */ + +#include +#include +#include +#include +#include +#include + +#include "tnum.h" + +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) + +struct tnum tnum_scast(struct tnum a, u8 size) +{ + u64 s = size * 8 - 1; + u64 sign_mask; + u64 value_mask; + u64 new_value, new_mask; + + if (size >= 8) { + return a; + } + + sign_mask = 1ULL << s; + value_mask = (1ULL << (s + 1)) - 1; + + new_value = a.value & value_mask; + new_mask = a.mask & value_mask; + + if (a.mask & sign_mask) { + new_mask |= ~value_mask; + } else if (a.value & sign_mask) { + new_value |= ~value_mask; + } + + return TNUM(new_value, new_mask); +} + +struct tnum_test_case { + const char *description; + struct tnum input; + u8 size; + struct tnum expected; +}; + +static int test_tnum_scast(void) +{ + int i, err = 0; + struct tnum result; + + /* Define test cases */ + struct tnum_test_case tests[] = { + /* 8-bit tests */ + { + .description = "Known positive value (8-bit)", + .input = TNUM(0x7F, 0x00), // 127 in decimal + .size = 1, + .expected = TNUM(0x000000000000007F, 0x0000000000000000), + }, + { + .description = "Known negative value (8-bit)", + .input = TNUM(0xFF, 0x00), // -1 in 8-bit signed + .size = 1, + .expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Unknown sign bit (8-bit)", + .input = TNUM(0x7F, 0x80), // Value 127, sign bit unknown + .size = 1, + .expected = TNUM(0x000000000000007F, 0xFFFFFFFFFFFFFF80), + }, + { + .description = "Completely unknown value (8-bit)", + .input = TNUM(0x00, 0xFF), // All bits unknown + .size = 1, + .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + }, + /* 16-bit tests */ + { + .description = "Known positive value (16-bit)", + .input = TNUM(0x7FFF, 0x0000), + .size = 2, + .expected = TNUM(0x0000000000007FFF, 0x0000000000000000), + }, + { + .description = "Known negative value (16-bit)", + .input = TNUM(0xFFFF, 0x0000), // -1 in 16-bit signed + .size = 2, + .expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Unknown sign bit (16-bit)", + .input = TNUM(0x7FFF, 0x8000), + .size = 2, + .expected = TNUM(0x0000000000007FFF, 0xFFFFFFFFFFFF8000), + }, + { + .description = "Completely unknown value (16-bit)", + .input = TNUM(0x0000, 0xFFFF), + .size = 2, + .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + }, + /* 32-bit tests */ + { + .description = "Known positive value (32-bit)", + .input = TNUM(0x7FFFFFFF, 0x00000000), + .size = 4, + .expected = TNUM(0x000000007FFFFFFF, 0x0000000000000000), + }, + { + .description = "Known negative value (32-bit)", + .input = TNUM(0xFFFFFFFF, 0x00000000), // -1 in 32-bit signed + .size = 4, + .expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Unknown sign bit (32-bit)", + .input = TNUM(0x7FFFFFFF, 0x80000000), + .size = 4, + .expected = TNUM(0x000000007FFFFFFF, 0xFFFFFFFF80000000), + }, + { + .description = "Completely unknown value (32-bit)", + .input = TNUM(0x00000000, 0xFFFFFFFF), + .size = 4, + .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + }, + /* 64-bit tests */ + { + .description = "Known positive value (64-bit)", + .input = TNUM(0x7FFFFFFFFFFFFFFF, 0x0000000000000000), + .size = 8, + .expected = TNUM(0x7FFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Known negative value (64-bit)", + .input = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + .size = 8, + .expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000), + }, + { + .description = "Unknown sign bit (64-bit)", + .input = TNUM(0x7FFFFFFFFFFFFFFF, 0x8000000000000000ULL), + .size = 8, + .expected = TNUM(0x7FFFFFFFFFFFFFFF, 0x8000000000000000ULL), + }, + { + .description = "Completely unknown value (64-bit)", + .input = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + .size = 8, + .expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF), + }, + }; + + printf("Running tnum_scast tests...\n\n"); + + for (i = 0; i < ARRAY_SIZE(tests); i++) { + struct tnum_test_case *t = &tests[i]; + + result = tnum_scast(t->input, t->size); + + printf("Test %d (%s, size=%d bytes):\n", i + 1, t->description, t->size); + printf(" Input: value=0x%016llx, mask=0x%016llx\n", + t->input.value, t->input.mask); + printf(" Expected: value=0x%016llx, mask=0x%016llx\n", + t->expected.value, t->expected.mask); + printf(" Result: value=0x%016llx, mask=0x%016llx\n", + result.value, result.mask); + + if (memcmp(&result, &t->expected, sizeof(struct tnum)) != 0) { + printf(" Fail.\n\n"); + err = 1; + } else { + printf(" Pass.\n\n"); + } + } + + if (err) + printf("Some tnum_scast tests failed.\n"); + else + printf("All tnum_scast tests passed successfully.\n"); + + return err; +} + +int main(int argc, char **argv) +{ + int err = 0; + + err |= test_tnum_scast(); + + return err; +} + diff --git a/tools/testing/selftests/bpf/tnum.h b/tools/testing/selftests/bpf/tnum.h new file mode 100644 index 00000000000000..f46be04ef570e6 --- /dev/null +++ b/tools/testing/selftests/bpf/tnum.h @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* tnum.h: Header file for tnum utility functions */ + +#ifndef __TNUM_H__ +#define __TNUM_H__ + +#include + +typedef uint64_t u64; +typedef int64_t s64; +typedef uint32_t u32; +typedef int32_t s32; +typedef uint8_t u8; + +struct tnum { + u64 value; + u64 mask; +}; + +#define TNUM(_v, _m) (struct tnum){.value = (_v), .mask = (_m)} + +/* Function prototypes */ +struct tnum tnum_scast(struct tnum a, u8 size); + +#endif /* __TNUM_H__ */ +