diff --git a/sqla_inspect/base.py b/sqla_inspect/base.py index 5d4ac7d..2c5492c 100644 --- a/sqla_inspect/base.py +++ b/sqla_inspect/base.py @@ -56,28 +56,43 @@ def get_info_field(prop): return column.info -class FormatterRegistry(dict): +class Registry(dict): """ - A registry used to store sqla columns <-> formatters association + A registry used to store sqla columns <-> datas association """ - def add_formatter(self, sqla_col_type, formatter, key_specific=None): + def add_item(self, sqla_col_type, item, key_specific=None): """ - Add a formatter to the registry - if key_specific is provided, this formatter will only be used for some - specific exports + Add an item to the registry """ if key_specific is not None: - self.setdefault(key_specific, {})[sqla_col_type] = formatter + self.setdefault(key_specific, {})[sqla_col_type] = item else: - self[sqla_col_type] = formatter + self[sqla_col_type] = item - def get_formatter(self, sqla_col, key_specific=None): - formatter = None + def get_item(self, sqla_col, key_specific=None): + item = None if key_specific is not None: - formatter = self.get(key_specific, {}).get(sqla_col.__class__) + item = self.get(key_specific, {}).get(sqla_col.__class__) - if formatter is None: - formatter = self.get(sqla_col.__class__) + if item is None: + item = self.get(sqla_col.__class__) - return formatter + return item +class FormatterRegistry(Registry): + """ + Registry specific to formatters + """ + def add_formatter(self, sqla_col_type, formatter, key_specific=None): + """ + Add a formatter to the registry + if key_specific is provided, this formatter will only be used for some + specific exports + """ + self.add_item(sqla_col_type, formatter, key_specific) + + def get_formatter(self, sqla_col, key_specific=None): + """ + Returns a formatter stored in the registry + """ + return self.get_item(sqla_col, key_specific) diff --git a/sqla_inspect/excel.py b/sqla_inspect/excel.py index 3c2dd25..1485d2d 100644 --- a/sqla_inspect/excel.py +++ b/sqla_inspect/excel.py @@ -22,16 +22,21 @@ BaseExporter, SqlaExporter, ) +from sqla_inspect.base import Registry log = logging.getLogger(__name__) + # A, B, C, ..., AA, AB, AC, ..., ZZ ASCII_UPPERCASE = list(ascii_uppercase) + list( ''.join(duple) for duple in itertools.combinations_with_replacement(ascii_uppercase, 2) ) +# To be overriden by end user +FORMAT_REGISTRY = Registry() + class XlsWriter(object): """ @@ -51,10 +56,15 @@ class XlsWriter(object): """ title = u"Export" - def __init__(self, guess_types=True): - self.book = openpyxl.workbook.Workbook(guess_types=guess_types) - self.worksheet = self.book.active - self.worksheet.title = self.title + + def __init__(self, guess_types=True, worksheet=None): + if worksheet is None: + self.book = openpyxl.workbook.Workbook(guess_types=guess_types) + self.worksheet = self.book.active + self.worksheet.title = self.title + else: + self.worksheet = worksheet + self.book = worksheet.parent def save_book(self, f_buf=None): """ @@ -96,6 +106,13 @@ def format_row(self, row): res.append(value) return res + def _populate(self): + """ + Populate headers and rows before writing our book + """ + self._render_headers() + self._render_rows() + def render(self, f_buf=None): """ Definitely render the workbook @@ -103,22 +120,28 @@ def render(self, f_buf=None): :param obj f_buf: A file buffer supporting the write and seek methods """ - self._render_headers() - self._render_rows() + self._populate() return self.save_book(f_buf) - def _render_rows(self): """ Render the rows in the current stylesheet """ _datas = getattr(self, '_datas', ()) + headers = getattr(self, 'headers', ()) for index, row in enumerate(_datas): row_number = index + 2 for col_num, value in enumerate(row): cell = self.worksheet.cell(row=row_number, column=col_num + 1) - cell.value = value + if value is not None: + cell.value = value + else: + cell.value = "" + header = headers[col_num] + format = get_cell_format(header) + if format is not None: + cell.number_format = format def _render_headers(self): """ @@ -131,6 +154,24 @@ def _render_headers(self): cell.value = col['label'] +def get_cell_format(column_dict, key=None): + """ + Return the cell format for the given column + + :param column_dict: The column datas collected during inspection + :param key: The exportation key + """ + format = column_dict.get('format') + prop = column_dict['__col__'] + + if format is None: + if hasattr(prop, 'columns'): + sqla_column = prop.columns[0] + column_type = getattr(sqla_column.type, 'impl', sqla_column.type) + format = FORMAT_REGISTRY.get_item(column_type) + return format + + class SqlaXlsExporter(XlsWriter, SqlaExporter): """ Main class used for exporting datas to the xls format @@ -170,11 +211,66 @@ class SqlaXlsExporter(XlsWriter, SqlaExporter): a.render() """ config_key = 'excel' - def __init__(self, model, guess_types=True): - XlsWriter.__init__(self, guess_types) + + def __init__(self, model, guess_types=True, worksheet=None): + self.guess_types = guess_types + self.is_root = worksheet is None + XlsWriter.__init__(self, guess_types, worksheet) SqlaExporter.__init__(self, model) + def _get_related_exporter(self, related_obj, column): + """ + returns an SqlaXlsExporter for the given related object and stores it in + the column object as a cache + """ + result = column.get('sqla_xls_exporter') + if result is None: + worksheet = self.book.create_sheet( + title=column.get('label', 'default title') + ) + result = column['sqla_xls_exporter'] = SqlaXlsExporter( + related_obj.__class__, + worksheet=worksheet + ) + return result + + def _get_relationship_cell_val(self, obj, column): + """ + Return the value to insert in a relationship cell + Handle the case of complex related datas we want to handle + """ + val = SqlaExporter._get_relationship_cell_val(self, obj, column) + if val == "": + related_key = column.get('related_key', None) + + if column['__col__'].uselist and related_key is None and self.is_root: + + # on récupère les objets liés + key = column['key'] + related_objects = getattr(obj, key, None) + if not related_objects: + return "" + else: + exporter = self._get_related_exporter( + related_objects[0], + column, + ) + for rel_obj in related_objects: + exporter.add_row(rel_obj) + + return val + + def _populate(self): + """ + Enhance the default populate script by handling related elements + """ + XlsWriter._populate(self) + for header in self.headers: + if "sqla_xls_exporter" in header: + header['sqla_xls_exporter']._populate() + + class XlsExporter(XlsWriter, BaseExporter): """ A main xls exportation tool (without sqlalchemy support) diff --git a/sqla_inspect/export.py b/sqla_inspect/export.py index e1dbb38..538847c 100644 --- a/sqla_inspect/export.py +++ b/sqla_inspect/export.py @@ -15,7 +15,7 @@ BaseSqlaInspector, FormatterRegistry, ) -BLACKLISTED_KEYS = () +BLACKLISTED_KEYS = [] # Should be completed (to see how this will be done) @@ -216,7 +216,8 @@ def _collect_relationship(self, main_infos, prop, result): # Maybe with indexes ? ( to see: on row add, append headers on the fly # if needed ) if prop.uselist: - main_infos = {} + # One to many + pass else: if "related_key" in main_infos: self._merge_many_to_one_field(main_infos, prop, result) @@ -256,12 +257,13 @@ def _merge_many_to_one_field(self, main_infos, prop, result): # We first find the related foreignkey to get the good title rel_base = list(prop.local_columns)[0] related_fkey_name = rel_base.name - for val in result: - if val['name'] == related_fkey_name: - title = val['label'] - main_infos['label'] = title - result.remove(val) - break + if not main_infos.get('keep_key', False): + for val in result: + if val['name'] == related_fkey_name: + title = val['label'] + main_infos['label'] = title + result.remove(val) + break return main_infos @@ -317,8 +319,9 @@ def _get_relationship_cell_val(self, obj, column): """ Return the value to insert in a relationship cell """ + val = "" key = column['key'] - related_key = column.get('related_key') + related_key = column.get('related_key', None) related_obj = getattr(obj, key, None) @@ -326,12 +329,13 @@ def _get_relationship_cell_val(self, obj, column): return "" if column['__col__'].uselist: # OneToMany - _vals = [] - for rel_obj in related_obj: - _vals.append( - self._get_formatted_val(rel_obj, related_key, column) - ) - val = '\n'.join(_vals) + if related_key is not None: + _vals = [] + for rel_obj in related_obj: + _vals.append( + self._get_formatted_val(rel_obj, related_key, column) + ) + val = '\n'.join(_vals) else: if related_key is not None: val = self._get_formatted_val(related_obj, related_key, column)