diff --git a/docs/renderers.md b/docs/renderers.md index 7916053..6f4aec3 100644 --- a/docs/renderers.md +++ b/docs/renderers.md @@ -59,6 +59,8 @@ If you are considering using `XML` for your API, you may want to consider implem **.charset**: `utf-8` -**item_tag_name**: `list-item` +**.item_tag_name**: `list-item` **.root_tag_name**: `root` + +**.override_item_tag_name**: `False` diff --git a/rest_framework_xml/renderers.py b/rest_framework_xml/renderers.py index 66199a9..2152746 100644 --- a/rest_framework_xml/renderers.py +++ b/rest_framework_xml/renderers.py @@ -5,9 +5,10 @@ from django.utils import six from django.utils.xmlutils import SimplerXMLGenerator -from django.utils.six.moves import StringIO +from django.utils.six import StringIO from django.utils.encoding import force_text from rest_framework.renderers import BaseRenderer +from xml.etree import ElementTree as ET class XMLRenderer(BaseRenderer): @@ -20,6 +21,7 @@ class XMLRenderer(BaseRenderer): charset = 'utf-8' item_tag_name = 'list-item' root_tag_name = 'root' + override_item_tag_name = False def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -38,8 +40,24 @@ def render(self, data, accepted_media_type=None, renderer_context=None): xml.endElement(self.root_tag_name) xml.endDocument() + + if self.override_item_tag_name: + self._do_override_item_tag_name(stream) + return stream.getvalue() + def _do_override_item_tag_name(self, stream): + root = ET.fromstring(stream.getvalue()) + for parent in root.findall('.//*list-item/..'): + child_name = parent.tag[0:-1] + for child in list(parent): + child.tag = child_name + + stream.truncate(0) + stream.seek(0) + stream.write('\n') + stream.write(ET.tostring(root).decode('utf-8')) + def _to_xml(self, xml, data): if isinstance(data, (list, tuple)): for item in data: diff --git a/tests/test_renderers.py b/tests/test_renderers.py index 4168868..a9637f2 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -6,7 +6,7 @@ from django.test import TestCase from django.test.utils import skipUnless -from django.utils.six.moves import StringIO +from django.utils.six import StringIO from django.utils.translation import gettext_lazy from rest_framework_xml.renderers import XMLRenderer from rest_framework_xml.parsers import XMLParser @@ -33,6 +33,41 @@ class XMLRendererTestCase(TestCase): ] } + _complex_order_data = { + "creation_date": datetime.datetime(2017, 7, 1, 14, 30, 00), + "orderId": 1, + "positions": [ + { + "posNo": 1, + "amount": 3, + "messages": [ + { + "type": "O", + "code": "xyz" + }, + { + "type": "L", + "code": "zyx" + } + ] + }, + { + "posNo": 2, + "amount": 1, + "messages": [ + { + "type": "O", + "code": "xyz" + }, + { + "type": "L", + "code": "zyx" + } + ] + } + ] + } + def test_render_string(self): """ Test XML rendering. @@ -104,6 +139,14 @@ def test_render_lazy(self): content = renderer.render({'field': lazy}, 'application/xml') self.assertXMLContains(content, 'hello') + def test_render_override_list_item(self): + renderer = XMLRenderer() + renderer.root_tag_name = 'order' + renderer.override_item_tag_name = True + content = renderer.render(self._complex_order_data, 'application/xml') + self.assertXMLContains(content, '', renderer.root_tag_name) + self.assertXMLContains(content, '', renderer.root_tag_name) + @skipUnless(etree, 'defusedxml not installed') def test_render_and_parse_complex_data(self): """ @@ -117,7 +160,7 @@ def test_render_and_parse_complex_data(self): error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out)) self.assertEqual(self._complex_data, complex_data_out, error_msg) - def assertXMLContains(self, xml, string): - self.assertTrue(xml.startswith('\n')) - self.assertTrue(xml.endswith('')) + def assertXMLContains(self, xml, string, root_tag='root'): + self.assertTrue(xml.startswith('\n<{0}>'.format(root_tag))) + self.assertTrue(xml.endswith(''.format(root_tag))) self.assertTrue(string in xml, '%r not in %r' % (string, xml))