diff --git a/pyrefly/lib/alt/expr.rs b/pyrefly/lib/alt/expr.rs index 34808f2c5..6023a72b6 100644 --- a/pyrefly/lib/alt/expr.rs +++ b/pyrefly/lib/alt/expr.rs @@ -1925,10 +1925,22 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { range, Some(&|| ErrorContext::Index(self.for_display(base.clone()))), ), - Type::LiteralString | Type::Literal(Lit::Str(_)) if xs.len() <= 3 => { + Type::LiteralString if xs.len() <= 3 => { // We could have a more precise type here, but this matches Pyright. self.stdlib.str().clone().to_type() } + Type::Literal(Lit::Str(ref value)) if xs.len() <= 3 => { + let base_ty = Type::Literal(Lit::Str(value.clone())); + let context = || ErrorContext::Index(self.for_display(base_ty.clone())); + self.subscript_str_literal( + value.as_str(), + &base_ty, + slice, + errors, + range, + Some(&context), + ) + } Type::ClassType(ref cls) | Type::SelfType(ref cls) if let Some(Tuple::Concrete(elts)) = self.as_tuple(cls) => { @@ -2100,6 +2112,191 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } + fn subscript_str_literal( + &self, + value: &str, + base_type: &Type, + index_expr: &Expr, + errors: &ErrorCollector, + range: TextRange, + context: Option<&dyn Fn() -> ErrorContext>, + ) -> Type { + let fallback = || { + self.call_method_or_error( + base_type, + &dunder::GETITEM, + range, + &[CallArg::expr(index_expr)], + &[], + errors, + context, + ) + }; + + if matches!(index_expr, Expr::Tuple(_)) { + return fallback(); + } + + let literal_index = |expr: &Expr| -> Option { + match self.expr_infer(expr, errors) { + Type::Literal(ref lit) => lit.as_index_i64(), + _ => None, + } + }; + + let chars: Vec = value.chars().collect(); + let len_usize = chars.len(); + if len_usize > i64::MAX as usize { + return fallback(); + } + let len = len_usize as i64; + + if let Expr::Slice(slice) = index_expr { + let step = match slice.step.as_deref() { + Some(expr) => match literal_index(expr) { + Some(value) if value != 0 => value, + _ => return fallback(), + }, + None => 1, + }; + + if step == i64::MIN { + return fallback(); + } + + let mut start = match slice.lower.as_deref() { + Some(expr) => match literal_index(expr) { + Some(value) => value, + None => return fallback(), + }, + None => { + if step < 0 { + len.saturating_sub(1) + } else { + 0 + } + } + }; + + let mut stop = match slice.upper.as_deref() { + Some(expr) => match literal_index(expr) { + Some(value) => value, + None => return fallback(), + }, + None => { + if step < 0 { + match len.checked_add(1) { + Some(v) => -v, + None => return fallback(), + } + } else { + len + } + } + }; + + if step > 0 { + if start < 0 { + start += len; + if start < 0 { + start = 0; + } + } else if start > len { + start = len; + } + + if stop < 0 { + stop += len; + if stop < 0 { + stop = 0; + } + } else if stop > len { + stop = len; + } + } else { + if start < 0 { + start += len; + if start < 0 { + start = -1; + } + } else if start >= len { + start = len.saturating_sub(1); + } + + if stop < 0 { + stop += len; + if stop < 0 { + stop = -1; + } + } else if stop >= len { + stop = len.saturating_sub(1); + } + } + + let slice_length = if step < 0 { + if stop < start { + (start - stop - 1) / (-step) + 1 + } else { + 0 + } + } else if start < stop { + (stop - start - 1) / step + 1 + } else { + 0 + }; + + if slice_length <= 0 { + return Type::Literal(Lit::Str("".into())); + } + + if slice_length as usize as i64 != slice_length { + return fallback(); + } + + let mut result = String::new(); + let mut idx = start; + for _ in 0..slice_length as usize { + if idx < 0 || idx >= len { + return fallback(); + } + let Some(&ch) = chars.get(idx as usize) else { + return fallback(); + }; + result.push(ch); + idx = match idx.checked_add(step) { + Some(next) => next, + None => return fallback(), + }; + } + + Type::Literal(Lit::Str(result.into())) + } else { + let idx_ty = self.expr_infer(index_expr, errors); + if let Type::Literal(lit) = idx_ty { + if let Some(idx) = lit.as_index_i64() { + let normalized = if idx < 0 { len + idx } else { idx }; + if normalized >= 0 && normalized < len { + let ch = chars[normalized as usize]; + let mut buf = String::new(); + buf.push(ch); + return Type::Literal(Lit::Str(buf.into())); + } else { + return self.error( + errors, + range, + ErrorInfo::Kind(ErrorKind::BadIndex), + format!( + "Index `{idx}` out of range for string with {} elements", + chars.len() + ), + ); + } + } + } + fallback() + } + } + fn subscript_bytes_literal( &self, bytes: &[u8], diff --git a/pyrefly/lib/test/literal.rs b/pyrefly/lib/test/literal.rs index 3807c1cc0..94b167833 100644 --- a/pyrefly/lib/test/literal.rs +++ b/pyrefly/lib/test/literal.rs @@ -127,11 +127,11 @@ def f(x: LiteralString): testcase!( test_index_literal, r#" -from typing import assert_type +from typing import Literal, assert_type def foo(x): - assert_type("Magic"[0], str) - assert_type("Magic"[3:4], str) + assert_type("Magic"[0], Literal['M']) + assert_type("Magic"[3:4], Literal['i']) "#, ); diff --git a/pyrefly/lib/test/simple.rs b/pyrefly/lib/test/simple.rs index 8bfb101ec..688dd2214 100644 --- a/pyrefly/lib/test/simple.rs +++ b/pyrefly/lib/test/simple.rs @@ -721,6 +721,31 @@ def f(a: A): "#, ); +testcase!( + test_literal_string_subscript_precision, + r#" +from typing import Literal, assert_type +s: Literal["abcde"] = "abcde" +ss: Literal["こんにちは"] = "こんにちは" +em: Literal["\U0001F44D\U0001F3FC"] = "\U0001F44D\U0001F3FC" +assert_type(s[0], Literal["a"]) +assert_type(s[0:2], Literal["ab"]) +assert_type(s[-1], Literal["e"]) +assert_type(s[2:], Literal["cde"]) +assert_type(s[::-1], Literal["edcba"]) +assert_type(s[3:0:-2], Literal["db"]) +assert_type(ss[0], Literal["こ"]) +assert_type(ss[-1], Literal["は"]) +assert_type(ss[2:], Literal["にちは"]) +assert_type(ss[:1], Literal["こ"]) +assert_type(ss[::-1], Literal["はちにんこ"]) +assert_type(em[0], Literal["\U0001F44D"]) +assert_type(em[-1], Literal["\U0001F3FC"]) +assert_type(em[:1], Literal["\U0001F44D"]) +assert_type(em[::-1], Literal["\U0001F3FC\U0001F44D"]) + "#, +); + testcase!( test_invalid_annotation, r#"