diff --git a/docs/renderers.md b/docs/renderers.md
index 7916053..6f4aec3 100644
--- a/docs/renderers.md
+++ b/docs/renderers.md
@@ -59,6 +59,8 @@ If you are considering using `XML` for your API, you may want to consider implem
**.charset**: `utf-8`
-**item_tag_name**: `list-item`
+**.item_tag_name**: `list-item`
**.root_tag_name**: `root`
+
+**.override_item_tag_name**: `False`
diff --git a/rest_framework_xml/renderers.py b/rest_framework_xml/renderers.py
index 66199a9..2152746 100644
--- a/rest_framework_xml/renderers.py
+++ b/rest_framework_xml/renderers.py
@@ -5,9 +5,10 @@
from django.utils import six
from django.utils.xmlutils import SimplerXMLGenerator
-from django.utils.six.moves import StringIO
+from django.utils.six import StringIO
from django.utils.encoding import force_text
from rest_framework.renderers import BaseRenderer
+from xml.etree import ElementTree as ET
class XMLRenderer(BaseRenderer):
@@ -20,6 +21,7 @@ class XMLRenderer(BaseRenderer):
charset = 'utf-8'
item_tag_name = 'list-item'
root_tag_name = 'root'
+ override_item_tag_name = False
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -38,8 +40,24 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
xml.endElement(self.root_tag_name)
xml.endDocument()
+
+ if self.override_item_tag_name:
+ self._do_override_item_tag_name(stream)
+
return stream.getvalue()
+ def _do_override_item_tag_name(self, stream):
+ root = ET.fromstring(stream.getvalue())
+ for parent in root.findall('.//*list-item/..'):
+ child_name = parent.tag[0:-1]
+ for child in list(parent):
+ child.tag = child_name
+
+ stream.truncate(0)
+ stream.seek(0)
+ stream.write('\n')
+ stream.write(ET.tostring(root).decode('utf-8'))
+
def _to_xml(self, xml, data):
if isinstance(data, (list, tuple)):
for item in data:
diff --git a/tests/test_renderers.py b/tests/test_renderers.py
index 4168868..a9637f2 100644
--- a/tests/test_renderers.py
+++ b/tests/test_renderers.py
@@ -6,7 +6,7 @@
from django.test import TestCase
from django.test.utils import skipUnless
-from django.utils.six.moves import StringIO
+from django.utils.six import StringIO
from django.utils.translation import gettext_lazy
from rest_framework_xml.renderers import XMLRenderer
from rest_framework_xml.parsers import XMLParser
@@ -33,6 +33,41 @@ class XMLRendererTestCase(TestCase):
]
}
+ _complex_order_data = {
+ "creation_date": datetime.datetime(2017, 7, 1, 14, 30, 00),
+ "orderId": 1,
+ "positions": [
+ {
+ "posNo": 1,
+ "amount": 3,
+ "messages": [
+ {
+ "type": "O",
+ "code": "xyz"
+ },
+ {
+ "type": "L",
+ "code": "zyx"
+ }
+ ]
+ },
+ {
+ "posNo": 2,
+ "amount": 1,
+ "messages": [
+ {
+ "type": "O",
+ "code": "xyz"
+ },
+ {
+ "type": "L",
+ "code": "zyx"
+ }
+ ]
+ }
+ ]
+ }
+
def test_render_string(self):
"""
Test XML rendering.
@@ -104,6 +139,14 @@ def test_render_lazy(self):
content = renderer.render({'field': lazy}, 'application/xml')
self.assertXMLContains(content, 'hello')
+ def test_render_override_list_item(self):
+ renderer = XMLRenderer()
+ renderer.root_tag_name = 'order'
+ renderer.override_item_tag_name = True
+ content = renderer.render(self._complex_order_data, 'application/xml')
+ self.assertXMLContains(content, '', renderer.root_tag_name)
+ self.assertXMLContains(content, '', renderer.root_tag_name)
+
@skipUnless(etree, 'defusedxml not installed')
def test_render_and_parse_complex_data(self):
"""
@@ -117,7 +160,7 @@ def test_render_and_parse_complex_data(self):
error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
self.assertEqual(self._complex_data, complex_data_out, error_msg)
- def assertXMLContains(self, xml, string):
- self.assertTrue(xml.startswith('\n'))
- self.assertTrue(xml.endswith(''))
+ def assertXMLContains(self, xml, string, root_tag='root'):
+ self.assertTrue(xml.startswith('\n<{0}>'.format(root_tag)))
+ self.assertTrue(xml.endswith('{0}>'.format(root_tag)))
self.assertTrue(string in xml, '%r not in %r' % (string, xml))