Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions src/emeraldtree/ElementPath.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,22 @@ def select(context, result):

def prepare_dot_dot(next, token):
def select(context, result):
parent_map = context.parent_map
if parent_map is None:
context.parent_map = parent_map = {}
for p in context.root.iter():
for e in p:
parent_map[e] = p
for elem in result:
if elem in parent_map:
yield parent_map[elem]
parent = getattr(elem, '_parent', None)
if parent is not None:
yield parent
else:
if context.parent_map is None:
context.parent_map = {}
for p in context.root.iter():
try:
iter(p)
except TypeError:
continue
for e in p:
context.parent_map[e] = p
if elem in context.parent_map:
yield context.parent_map[elem]
return select

def prepare_predicate(next, token):
Expand Down Expand Up @@ -146,10 +153,21 @@ def select(context, result):
token = next()
if token[0] != "]":
raise SyntaxError("invalid node predicate")
def select(context, result):
for elem in result:
if elem.find(tag) is not None:
yield elem
try:
index = int(tag)
except ValueError:
def select(context, result):
for elem in result:
if elem.find(tag) is not None:
yield elem
else:
if index < 1:
raise SyntaxError("XPath position >= 1 expected")
def select(context, result):
for i, elem in enumerate(result):
if i + 1 == index:
yield elem
break
else:
raise SyntaxError("invalid predicate")
return select
Expand Down
11 changes: 7 additions & 4 deletions src/emeraldtree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def test_Element_findall_bracketed_tag():
assert result[0] is b1 # b1 has 'c' childs

def test_Element_findall_dotdot():
pytest.skip('broken')
c1 = Element('c')
c2 = Element('c')
text = "text"
Expand All @@ -184,7 +183,6 @@ def test_Element_findall_dotdot():
assert result[1] is c2

def test_Element_findall_slashslash():
pytest.skip('broken')
c1 = Element('c')
c2 = Element('c')
text = "text"
Expand All @@ -199,7 +197,6 @@ def test_Element_findall_slashslash():
assert result[1] is c2

def test_Element_findall_dotslashslash():
pytest.skip('broken')
c1 = Element('c')
c2 = Element('c')
text = "text"
Expand Down Expand Up @@ -235,7 +232,6 @@ def test_Element_findall_attribute():
assert len(result) == 0

def test_Element_findall_position():
pytest.skip('not supported')
c1 = Element('c')
c2 = Element('c')
text = "text"
Expand All @@ -251,6 +247,13 @@ def test_Element_findall_position():
assert len(result) == 1
assert result[0] is c2

def test_Element_findall_position_invalid():
b1 = Element('b')
with pytest.raises(SyntaxError):
list(b1.findall('c[0]'))
with pytest.raises(SyntaxError):
list(b1.findall('c[-1]'))

def test_Element_findtext_default():
elem = Element('a')
default_text = 'defaulttext'
Expand Down
64 changes: 62 additions & 2 deletions src/emeraldtree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class Element(Node):

attrib = None

# Parent pointer (internal). None for root or detached elements.
_parent = None

##
# (Attribute) Text before first subelement. This is either a
# string or the value None, if there was no text.
Expand Down Expand Up @@ -151,6 +154,10 @@ def __init__(self, tag, attrib=None, children=(), **extra):
self.tag = tag
self.attrib = attrib
self._children = list(children)
# set parent pointers for element children
for ch in self._children:
if isinstance(ch, Element):
ch._parent = self

def __repr__(self):
return "<Element {} at {:x}>".format(repr(self.tag), id(self))
Expand Down Expand Up @@ -186,7 +193,34 @@ def __getitem__(self, index):
# @exception AssertionError If element is not a valid object.

def __setitem__(self, index, element):
self._children.__setitem__(index, element)
# clear parent of replaced children and set parent of new ones
if isinstance(index, slice):
# clear parents for removed elements
old_items = self._children[index]
for old in old_items:
if isinstance(old, Element):
old._parent = None
# assign
self._children[index] = element
# set parents for new elements
try:
iterator = iter(element)
except TypeError:
iterator = None
if iterator is not None:
for new in element:
if isinstance(new, Element):
new._parent = self
else:
try:
old = self._children[index]
except Exception:
old = None
if isinstance(old, Element):
old._parent = None
self._children[index] = element
if isinstance(element, Element):
element._parent = self

##
# Deletes the given subelement.
Expand All @@ -195,6 +229,19 @@ def __setitem__(self, index, element):
# @exception IndexError If the given element does not exist.

def __delitem__(self, index):
# clear parent pointer for removed element(s)
if isinstance(index, slice):
old_items = self._children[index]
for old in old_items:
if isinstance(old, Element):
old._parent = None
else:
try:
old = self._children[index]
except Exception:
old = None
if isinstance(old, Element):
old._parent = None
self._children.__delitem__(index)

##
Expand All @@ -205,6 +252,8 @@ def __delitem__(self, index):

def append(self, element):
self._children.append(element)
if isinstance(element, Element):
element._parent = self

##
# Appends subelements from a sequence.
Expand All @@ -215,6 +264,9 @@ def append(self, element):

def extend(self, elements):
self._children.extend(elements)
for e in elements:
if isinstance(e, Element):
e._parent = self

##
# Inserts a subelement at the given position in this element.
Expand All @@ -224,6 +276,8 @@ def extend(self, elements):

def insert(self, index, element):
self._children.insert(index, element)
if isinstance(element, Element):
element._parent = self

##
# Removes a matching subelement. Unlike the <b>find</b> methods,
Expand All @@ -236,11 +290,16 @@ def insert(self, index, element):

def remove(self, element):
self._children.remove(element)
if isinstance(element, Element):
element._parent = None

##
# Removes all subelements.

def remove_all(self):
for ch in self._children:
if isinstance(ch, Element):
ch._parent = None
self._children = []

##
Expand Down Expand Up @@ -355,7 +414,8 @@ def iter(self, tag=None):
for e in e.iter(tag):
yield e
else:
yield e
if tag is None:
yield e

##
# Creates a text iterator. The iterator loops over this element
Expand Down