Skip to content

Commit

Permalink
[red-knot] Decompose bool to Literal[True, False] in unions and i…
Browse files Browse the repository at this point in the history
…ntersections
  • Loading branch information
AlexWaygood committed Jan 25, 2025
1 parent fcd0f34 commit a347b0d
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,22 @@ reveal_type(c >= d) # revealed: Literal[True]
#### Results with Ambiguity

```py
def _(x: bool, y: int):
class P:
def __lt__(self, other: "P") -> bool:
return True

def __le__(self, other: "P") -> bool:
return True

def __gt__(self, other: "P") -> bool:
return True

def __ge__(self, other: "P") -> bool:
return True

class Q(P): ...

def _(x: P, y: Q):
a = (x,)
b = (y,)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,9 @@ else:
reveal_type(x) # revealed: slice
finally:
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
reveal_type(x) # revealed: bool | float | slice
reveal_type(x) # revealed: bool | slice | float

reveal_type(x) # revealed: bool | float | slice
reveal_type(x) # revealed: bool | slice | float
```

## Nested `try`/`except` blocks
Expand Down Expand Up @@ -534,7 +534,7 @@ try:
reveal_type(x) # revealed: slice
finally:
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice`
reveal_type(x) # revealed: bool | float | slice
reveal_type(x) # revealed: bool | slice | float
x = 2
reveal_type(x) # revealed: Literal[2]
reveal_type(x) # revealed: Literal[2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ else:
if x and not x:
reveal_type(x) # revealed: Never
else:
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None

if not (x and not x):
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
else:
reveal_type(x) # revealed: Never

if x or not x:
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None
else:
reveal_type(x) # revealed: Never

if not (x or not x):
reveal_type(x) # revealed: Never
else:
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()]
reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None

if (isinstance(x, int) or isinstance(x, str)) and x:
reveal_type(x) # revealed: Literal[-1, True, "foo"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,46 @@ static_assert(
)
```

## Unions containing tuples containing `bool`

```py
from knot_extensions import is_equivalent_to, static_assert
from typing_extensions import Never, Literal

class P: ...

static_assert(is_equivalent_to(tuple[Literal[True, False]] | P, tuple[bool] | P))
static_assert(is_equivalent_to(P | tuple[bool], P | tuple[Literal[True, False]]))
```

## Unions and intersections involving `AlwaysTruthy`, `bool` and `AlwaysFalsy`

```py
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not
from typing_extensions import Literal

static_assert(is_equivalent_to(AlwaysTruthy | bool, Literal[False] | AlwaysTruthy))
static_assert(is_equivalent_to(AlwaysFalsy | bool, Literal[True] | AlwaysFalsy))
static_assert(is_equivalent_to(Not[AlwaysTruthy] | bool, Not[AlwaysTruthy] | Literal[True]))
static_assert(is_equivalent_to(Not[AlwaysFalsy] | bool, Literal[False] | Not[AlwaysFalsy]))
```

## Unions and intersections involving `AlwaysTruthy`, `LiteralString` and `AlwaysFalsy`

```py
from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not, Intersection
from typing_extensions import Literal, LiteralString

# TODO: these should all pass!

# error: [static-assert-error]
static_assert(is_equivalent_to(AlwaysTruthy | LiteralString, Literal[""] | AlwaysTruthy))
# error: [static-assert-error]
static_assert(is_equivalent_to(AlwaysFalsy | LiteralString, Intersection[LiteralString, Not[Literal[""]]] | AlwaysFalsy))
# error: [static-assert-error]
static_assert(is_equivalent_to(Not[AlwaysTruthy] | LiteralString, Not[AlwaysTruthy] | Intersection[LiteralString, Not[Literal[""]]]))
# error: [static-assert-error]
static_assert(is_equivalent_to(Not[AlwaysFalsy] | LiteralString, Literal[""] | Not[AlwaysFalsy]))
```

[the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent
102 changes: 87 additions & 15 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,19 @@ impl<'db> Type<'db> {
}
}

#[must_use]
pub fn normalized(self, db: &'db dyn Db) -> Self {
const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)];

match self {
Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => {
Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS)))
}
// TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`? --Alex
_ => self,
}
}

/// Return true if this type is a [subtype of] type `target`.
///
/// This method returns `false` if either `self` or `other` is not fully static.
Expand Down Expand Up @@ -840,7 +853,7 @@ impl<'db> Type<'db> {
return false;
}

match (self, target) {
match (self.normalized(db), target.normalized(db)) {
// We should have handled these immediately above.
(Type::Dynamic(_), _) | (_, Type::Dynamic(_)) => {
unreachable!("Non-fully-static types do not participate in subtyping!")
Expand Down Expand Up @@ -932,7 +945,7 @@ impl<'db> Type<'db> {
KnownClass::Str.to_instance(db).is_subtype_of(db, target)
}
(Type::BooleanLiteral(_), _) => {
KnownClass::Bool.to_instance(db).is_subtype_of(db, target)
KnownClass::Int.to_instance(db).is_subtype_of(db, target)
}
(Type::IntLiteral(_), _) => KnownClass::Int.to_instance(db).is_subtype_of(db, target),
(Type::BytesLiteral(_), _) => {
Expand Down Expand Up @@ -1048,6 +1061,14 @@ impl<'db> Type<'db> {
if self.is_gradual_equivalent_to(db, target) {
return true;
}
let normalized_self = self.normalized(db);
if normalized_self != self {
return normalized_self.is_assignable_to(db, target);
}
let normalized_target = target.normalized(db);
if normalized_target != target {
return self.is_assignable_to(db, normalized_target);
}
match (self, target) {
// Never can be assigned to any type.
(Type::Never, _) => true,
Expand Down Expand Up @@ -1148,13 +1169,13 @@ impl<'db> Type<'db> {
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
// TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc.

match (self, other) {
match (self.normalized(db), other.normalized(db)) {
(Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right),
(Type::Intersection(left), Type::Intersection(right)) => {
left.is_equivalent_to(db, right)
}
(Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right),
_ => self.is_fully_static(db) && other.is_fully_static(db) && self == other,
(left, right) => left == right && left.is_fully_static(db) && right.is_fully_static(db),
}
}

Expand Down Expand Up @@ -1189,11 +1210,14 @@ impl<'db> Type<'db> {
///
/// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
if self == other {
let left = self.normalized(db);
let right = other.normalized(db);

if left == right {
return true;
}

match (self, other) {
match (left, right) {
(Type::Dynamic(_), Type::Dynamic(_)) => true,

(Type::SubclassOf(first), Type::SubclassOf(second)) => {
Expand Down Expand Up @@ -1221,6 +1245,15 @@ impl<'db> Type<'db> {
/// Note: This function aims to have no false positives, but might return
/// wrong `false` answers in some cases.
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
let normalized_self = self.normalized(db);
if normalized_self != self {
return normalized_self.is_disjoint_from(db, other);
}
let normalized_other = other.normalized(db);
if normalized_other != other {
return self.is_disjoint_from(db, normalized_other);
}

match (self, other) {
(Type::Never, _) | (_, Type::Never) => true,

Expand Down Expand Up @@ -4354,8 +4387,10 @@ impl<'db> UnionType<'db> {
pub fn to_sorted_union(self, db: &'db dyn Db) -> Self {
let mut new_elements = self.elements(db).to_vec();
for element in &mut new_elements {
if let Type::Intersection(intersection) = element {
intersection.sort(db);
match element {
Type::Intersection(intersection) => intersection.sort(db),
Type::Tuple(tuple) => tuple.sort_inner_unions(db),
_ => {}
}
}
new_elements.sort_unstable_by(union_elements_ordering);
Expand Down Expand Up @@ -4453,10 +4488,26 @@ impl<'db> IntersectionType<'db> {
/// according to a canonical ordering.
#[must_use]
pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self {
let mut positive = self.positive(db).clone();
let mut positive: FxOrderSet<Type<'db>> = self
.positive(db)
.iter()
.map(|ty| match ty {
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_inner_unions(db)),
_ => *ty,
})
.collect();

positive.sort_unstable_by(union_elements_ordering);

let mut negative = self.negative(db).clone();
let mut negative: FxOrderSet<Type<'db>> = self
.negative(db)
.iter()
.map(|ty| match ty {
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_inner_unions(db)),
_ => *ty,
})
.collect();

negative.sort_unstable_by(union_elements_ordering);

IntersectionType::new(db, positive, negative)
Expand Down Expand Up @@ -4591,23 +4642,44 @@ pub struct TupleType<'db> {
}

impl<'db> TupleType<'db> {
pub fn from_elements<T: Into<Type<'db>>>(
db: &'db dyn Db,
types: impl IntoIterator<Item = T>,
) -> Type<'db> {
pub fn from_elements<I, T>(db: &'db dyn Db, types: I) -> Type<'db>
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
{
let mut elements = vec![];

for ty in types {
let ty = ty.into();
if ty.is_never() {
return Type::Never;
}
elements.push(ty);
elements.push(ty.normalized(db));
}

Type::Tuple(Self::new(db, elements.into_boxed_slice()))
}

#[must_use]
pub fn with_sorted_inner_unions(self, db: &'db dyn Db) -> Self {
let elements: Box<[Type<'db>]> = self
.elements(db)
.iter()
.map(|ty| match ty {
Type::Union(union) => Type::Union(union.to_sorted_union(db)),
Type::Intersection(intersection) => {
Type::Intersection(intersection.to_sorted_intersection(db))
}
_ => *ty,
})
.collect();
TupleType::new(db, elements)
}

pub fn sort_inner_unions(&mut self, db: &'db dyn Db) {
*self = self.with_sorted_inner_unions(db);
}

pub fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
let self_elements = self.elements(db);
let other_elements = other.elements(db);
Expand Down
Loading

0 comments on commit a347b0d

Please sign in to comment.