Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 198 additions & 1 deletion pyrefly/lib/alt/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1925,10 +1925,22 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
range,
Some(&|| ErrorContext::Index(self.for_display(base.clone()))),
),
Type::LiteralString | Type::Literal(Lit::Str(_)) if xs.len() <= 3 => {
Type::LiteralString if xs.len() <= 3 => {
// We could have a more precise type here, but this matches Pyright.
self.stdlib.str().clone().to_type()
}
Type::Literal(Lit::Str(ref value)) if xs.len() <= 3 => {
let base_ty = Type::Literal(Lit::Str(value.clone()));
let context = || ErrorContext::Index(self.for_display(base_ty.clone()));
self.subscript_str_literal(
value.as_str(),
&base_ty,
slice,
errors,
range,
Some(&context),
)
}
Type::ClassType(ref cls) | Type::SelfType(ref cls)
if let Some(Tuple::Concrete(elts)) = self.as_tuple(cls) =>
{
Expand Down Expand Up @@ -2100,6 +2112,191 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

fn subscript_str_literal(
&self,
value: &str,
base_type: &Type,
index_expr: &Expr,
errors: &ErrorCollector,
range: TextRange,
context: Option<&dyn Fn() -> ErrorContext>,
) -> Type {
let fallback = || {
self.call_method_or_error(
base_type,
&dunder::GETITEM,
range,
&[CallArg::expr(index_expr)],
&[],
errors,
context,
)
};

if matches!(index_expr, Expr::Tuple(_)) {
return fallback();
}

let literal_index = |expr: &Expr| -> Option<i64> {
match self.expr_infer(expr, errors) {
Type::Literal(ref lit) => lit.as_index_i64(),
_ => None,
}
};

let chars: Vec<char> = value.chars().collect();
let len_usize = chars.len();
if len_usize > i64::MAX as usize {
return fallback();
}
let len = len_usize as i64;

if let Expr::Slice(slice) = index_expr {
let step = match slice.step.as_deref() {
Some(expr) => match literal_index(expr) {
Some(value) if value != 0 => value,
_ => return fallback(),
},
None => 1,
};

if step == i64::MIN {
return fallback();
}

let mut start = match slice.lower.as_deref() {
Some(expr) => match literal_index(expr) {
Some(value) => value,
None => return fallback(),
},
None => {
if step < 0 {
len.saturating_sub(1)
} else {
0
}
}
};

let mut stop = match slice.upper.as_deref() {
Some(expr) => match literal_index(expr) {
Some(value) => value,
None => return fallback(),
},
None => {
if step < 0 {
match len.checked_add(1) {
Some(v) => -v,
None => return fallback(),
}
} else {
len
}
}
};

if step > 0 {
if start < 0 {
start += len;
if start < 0 {
start = 0;
}
} else if start > len {
start = len;
}

if stop < 0 {
stop += len;
if stop < 0 {
stop = 0;
}
} else if stop > len {
stop = len;
}
} else {
if start < 0 {
start += len;
if start < 0 {
start = -1;
}
} else if start >= len {
start = len.saturating_sub(1);
}

if stop < 0 {
stop += len;
if stop < 0 {
stop = -1;
}
} else if stop >= len {
stop = len.saturating_sub(1);
}
}

let slice_length = if step < 0 {
if stop < start {
(start - stop - 1) / (-step) + 1
} else {
0
}
} else if start < stop {
(stop - start - 1) / step + 1
} else {
0
};

if slice_length <= 0 {
return Type::Literal(Lit::Str("".into()));
}

if slice_length as usize as i64 != slice_length {
return fallback();
}

let mut result = String::new();
let mut idx = start;
for _ in 0..slice_length as usize {
if idx < 0 || idx >= len {
return fallback();
}
let Some(&ch) = chars.get(idx as usize) else {
return fallback();
};
result.push(ch);
idx = match idx.checked_add(step) {
Some(next) => next,
None => return fallback(),
};
}

Type::Literal(Lit::Str(result.into()))
} else {
let idx_ty = self.expr_infer(index_expr, errors);
if let Type::Literal(lit) = idx_ty {
if let Some(idx) = lit.as_index_i64() {
let normalized = if idx < 0 { len + idx } else { idx };
if normalized >= 0 && normalized < len {
let ch = chars[normalized as usize];
let mut buf = String::new();
buf.push(ch);
return Type::Literal(Lit::Str(buf.into()));
} else {
return self.error(
errors,
range,
ErrorInfo::Kind(ErrorKind::BadIndex),
format!(
"Index `{idx}` out of range for string with {} elements",
chars.len()
),
);
}
}
}
fallback()
}
}

fn subscript_bytes_literal(
&self,
bytes: &[u8],
Expand Down
6 changes: 3 additions & 3 deletions pyrefly/lib/test/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ def f(x: LiteralString):
testcase!(
test_index_literal,
r#"
from typing import assert_type
from typing import Literal, assert_type

def foo(x):
assert_type("Magic"[0], str)
assert_type("Magic"[3:4], str)
assert_type("Magic"[0], Literal['M'])
assert_type("Magic"[3:4], Literal['i'])
"#,
);

Expand Down
25 changes: 25 additions & 0 deletions pyrefly/lib/test/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,31 @@ def f(a: A):
"#,
);

testcase!(
test_literal_string_subscript_precision,
r#"
from typing import Literal, assert_type
s: Literal["abcde"] = "abcde"
ss: Literal["こんにちは"] = "こんにちは"
em: Literal["\U0001F44D\U0001F3FC"] = "\U0001F44D\U0001F3FC"
assert_type(s[0], Literal["a"])
assert_type(s[0:2], Literal["ab"])
assert_type(s[-1], Literal["e"])
assert_type(s[2:], Literal["cde"])
assert_type(s[::-1], Literal["edcba"])
assert_type(s[3:0:-2], Literal["db"])
assert_type(ss[0], Literal["こ"])
assert_type(ss[-1], Literal["は"])
assert_type(ss[2:], Literal["にちは"])
assert_type(ss[:1], Literal["こ"])
assert_type(ss[::-1], Literal["はちにんこ"])
assert_type(em[0], Literal["\U0001F44D"])
assert_type(em[-1], Literal["\U0001F3FC"])
assert_type(em[:1], Literal["\U0001F44D"])
assert_type(em[::-1], Literal["\U0001F3FC\U0001F44D"])
"#,
);

testcase!(
test_invalid_annotation,
r#"
Expand Down