diff --git a/src/syntax/keyword.rs b/src/syntax/keyword.rs index 350146e..eb004d6 100644 --- a/src/syntax/keyword.rs +++ b/src/syntax/keyword.rs @@ -45,6 +45,16 @@ pub fn keyword_node(input: Input) -> IResult { tracing::instrument(level = "debug", skip(input), fields(input = input.s)) )] pub fn affiliated_keyword_nodes(input: Input) -> IResult, ()> { + if !starts_with_keyword_prefix(input) { + return Ok((input, Vec::new())); + } + + if let Some(key) = peek_keyword_key(input) { + if input.c.affiliated_keywords.iter().all(|w| w != key) && !key.starts_with("ATTR_") { + return Ok((input, Vec::new())); + } + } + let mut children = vec![]; let mut i = input; @@ -73,6 +83,14 @@ pub fn affiliated_keyword_nodes(input: Input) -> IResult IResult, ()> { + if !starts_with_keyword_prefix(input) { + return Ok((input, Vec::new())); + } + + if peek_keyword_key(input).is_some_and(|key| !key.eq_ignore_ascii_case("TBLFM")) { + return Ok((input, Vec::new())); + } + let mut children = vec![]; let mut i = input; @@ -94,6 +112,10 @@ pub fn tblfm_keyword_nodes(input: Input) -> IResult, () } fn keyword_node_base(input: Input<'_>) -> IResult, (&str, Vec), ()> { + if !starts_with_keyword_prefix(input) { + return Err(nom::Err::Error(())); + } + let (input, (ws, hash_plus)) = (space0, hash_plus_token).parse(input)?; let (input, (key, optional, colon)) = alt((key_with_optional, key)).parse(input)?; @@ -123,6 +145,48 @@ fn keyword_node_base(input: Input<'_>) -> IResult, (&str, Vec) -> bool { + let bytes = input.as_bytes(); + let mut i = 0; + + while matches!(bytes.get(i), Some(b' ' | b'\t')) { + i += 1; + } + + bytes.get(i..i + 2) == Some(b"#+") +} + +#[inline] +fn peek_keyword_key(input: Input<'_>) -> Option<&str> { + let bytes = input.as_bytes(); + let mut start = 0; + + while matches!(bytes.get(start), Some(b' ' | b'\t')) { + start += 1; + } + + if bytes.get(start..start + 2) != Some(b"#+") { + return None; + } + + start += 2; + let key_start = start; + + while let Some(byte) = bytes.get(start) { + if byte.is_ascii_whitespace() || matches!(byte, b':' | b'[') { + break; + } + start += 1; + } + + if start == key_start { + None + } else { + Some(&input.as_str()[key_start..start]) + } +} + fn key(input: Input) -> IResult, Input), ()> { let (input, output) = verify( recognize(( @@ -289,3 +353,32 @@ fn parse() { assert!(keyword_node(("#+KE Y: VALUE", config).into()).is_err()); assert!(keyword_node(("#+ KEY: VALUE", config).into()).is_err()); } + +#[test] +fn keyword_prefix_detection() { + let config = &crate::ParseConfig::default(); + + assert!(starts_with_keyword_prefix(("#+KEY: value", config).into())); + assert!(starts_with_keyword_prefix( + (" \t#+KEY: value", config).into() + )); + assert_eq!( + peek_keyword_key(("#+KEY: value", config).into()), + Some("KEY") + ); + assert_eq!( + peek_keyword_key(("#+CAPTION[short]: value", config).into()), + Some("CAPTION") + ); + assert_eq!( + peek_keyword_key(("#+ATTR_HTML: value", config).into()), + Some("ATTR_HTML") + ); + assert!(!starts_with_keyword_prefix( + ("regular paragraph", config).into() + )); + assert!(!starts_with_keyword_prefix( + (" # not an org keyword", config).into() + )); + assert!(!starts_with_keyword_prefix(("", config).into())); +} diff --git a/src/syntax/object.rs b/src/syntax/object.rs index 265c9bd..5e3a254 100644 --- a/src/syntax/object.rs +++ b/src/syntax/object.rs @@ -88,26 +88,38 @@ impl<'a> Iterator for ObjectPositions<'a> { return None; } - let previous = self.pos; - let i = self.finder.find(&self.input.as_bytes()[self.pos..])?; - let p = self.pos + i; - - self.pos = p + 1; - - debug_assert!( - previous < self.pos && self.pos <= self.input.s.len(), - "{} < {} < {}", - previous, - self.pos, - self.input.s.len() - ); - - // a valid object requires at least two characters - if self.input.s.len() - p < 2 { - return None; + while self.pos < self.input.len() { + let previous = self.pos; + let i = self.finder.find(&self.input.as_bytes()[self.pos..])?; + let p = self.pos + i; + + self.pos = p + 1; + + debug_assert!( + previous < self.pos && self.pos <= self.input.s.len(), + "{} < {} < {}", + previous, + self.pos, + self.input.s.len() + ); + + // a valid object requires at least two characters + if self.input.s.len() - p < 2 { + return None; + } + + let bytes = &self.input.as_bytes()[p..]; + if bytes[0] == b'c' && !bytes.starts_with(b"call_") { + continue; + } + if bytes[0] == b's' && !bytes.starts_with(b"src_") { + continue; + } + + return Some(self.input.take_split(p)); } - Some(self.input.take_split(p)) + None } } @@ -280,14 +292,22 @@ fn positions() { assert_eq!(vec[0].0.s, "{3}"); let vec = ObjectPositions::standard(("*{()}//s\nc<<", &config).into()).collect::>(); - assert_eq!(vec.len(), 7); + assert_eq!(vec.len(), 5); assert_eq!(vec[0].0.s, "*{()}//s\nc<<"); assert_eq!(vec[1].0.s, "{()}//s\nc<<"); assert_eq!(vec[2].0.s, "//s\nc<<"); assert_eq!(vec[3].0.s, "/s\nc<<"); - assert_eq!(vec[4].0.s, "s\nc<<"); - assert_eq!(vec[5].0.s, "c<<"); - assert_eq!(vec[6].0.s, "<<"); + assert_eq!(vec[4].0.s, "<<"); + + let vec = + ObjectPositions::standard(("call_square(4) and src_rust{let x = 1;}", &config).into()) + .collect::>(); + assert!(vec + .iter() + .any(|(input, _)| input.s == "call_square(4) and src_rust{let x = 1;}")); + assert!(vec + .iter() + .any(|(input, _)| input.s == "src_rust{let x = 1;}")); } #[test]