From fa51a6c00e6e9f5ff8e392eccb2064970bc8ddc2 Mon Sep 17 00:00:00 2001 From: Matteo <43417658+Matteo-Omenetti@users.noreply.github.com> Date: Tue, 21 Jan 2025 15:56:59 +0100 Subject: [PATCH] feat: Code equation model (#71) Signed-off-by: Matteo Omenetti Signed-off-by: Christoph Auer Co-authored-by: Christoph Auer --- demo/demo_code_formula_predictor.py | 111 ++++ .../code_formula_predictor.py | 223 ++++++++ .../code_formula_model/models/sam.py | 514 ++++++++++++++++++ .../code_formula_model/models/sam_opt.py | 237 ++++++++ .../models/sam_opt_image_processor.py | 31 ++ tests/test_code_formula_predictor.py | 131 +++++ tests/test_data/code_formula/gt/code.txt | 27 + tests/test_data/code_formula/gt/formula.txt | 1 + tests/test_data/code_formula/images/code.png | Bin 0 -> 27000 bytes .../test_data/code_formula/images/formula.png | Bin 0 -> 4611 bytes 10 files changed, 1275 insertions(+) create mode 100644 demo/demo_code_formula_predictor.py create mode 100644 docling_ibm_models/code_formula_model/code_formula_predictor.py create mode 100644 docling_ibm_models/code_formula_model/models/sam.py create mode 100644 docling_ibm_models/code_formula_model/models/sam_opt.py create mode 100644 docling_ibm_models/code_formula_model/models/sam_opt_image_processor.py create mode 100644 tests/test_code_formula_predictor.py create mode 100644 tests/test_data/code_formula/gt/code.txt create mode 100644 tests/test_data/code_formula/gt/formula.txt create mode 100644 tests/test_data/code_formula/images/code.png create mode 100644 tests/test_data/code_formula/images/formula.png diff --git a/demo/demo_code_formula_predictor.py b/demo/demo_code_formula_predictor.py new file mode 100644 index 0000000..b6b6579 --- /dev/null +++ b/demo/demo_code_formula_predictor.py @@ -0,0 +1,111 @@ +# +# Copyright IBM Corp. 2024 - 2024 +# SPDX-License-Identifier: MIT +# +import argparse +import logging +import os +import sys +import time +from pathlib import Path + +from huggingface_hub import snapshot_download +from PIL import Image + +from docling_ibm_models.code_formula_model.code_formula_predictor import CodeFormulaPredictor + + +def demo( + logger: logging.Logger, + artifact_path: str, + device: str, + num_threads: int, + image_dir: str, + viz_dir: str, +): + r""" + Apply LayoutPredictor on the input image directory + + If you want to load from PDF: + pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0) + """ + # Create the layout predictor + code_formula_predictor = CodeFormulaPredictor(artifact_path, device=device, num_threads=num_threads) + + image_dir = Path(image_dir) + images = [] + image_names = os.listdir(image_dir) + image_names.sort() + for image_name in image_names: + image = Image.open(image_dir / image_name) + images.append(image) + + t0 = time.perf_counter() + outputs = code_formula_predictor.predict(images, ['code', 'formula'], temperature=0) + total_ms = 1000 * (time.perf_counter() - t0) + avg_ms = (total_ms / len(image_names)) if len(image_names) > 0 else 0 + logger.info( + "For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format( + len(image_names), total_ms, avg_ms + ) + ) + + for i, output in enumerate(outputs): + logger.info(f"\nOutput {i}:\n{output}\n\n") + + +def main(args): + num_threads = int(args.num_threads) if args.num_threads is not None else None + device = args.device.lower() + image_dir = args.image_dir + viz_dir = args.viz_dir + + # Initialize logger + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger("CodeFormulaPredictor") + logger.setLevel(logging.DEBUG) + if not logger.hasHandlers(): + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # Ensure the viz dir + Path(viz_dir).mkdir(parents=True, exist_ok=True) + + # Download models from HF + download_path = snapshot_download(repo_id="ds4sd/CodeFormula", revision="v1.0.0") + + # Test the Code+Equation model + demo(logger, download_path, device, num_threads, image_dir, viz_dir) + + +if __name__ == "__main__": + r""" + python -m demo.demo_code_formula_predictor -i + """ + parser = argparse.ArgumentParser(description="Test the CodeFormulaPredictor") + parser.add_argument( + "-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]" + ) + parser.add_argument( + "-n", "--num_threads", required=False, default=4, help="Number of threads" + ) + parser.add_argument( + "-i", + "--image_dir", + required=True, + help="PNG images input directory", + ) + parser.add_argument( + "-v", + "--viz_dir", + required=False, + default="viz/", + help="Directory to save prediction visualizations", + ) + + args = parser.parse_args() + main(args) diff --git a/docling_ibm_models/code_formula_model/code_formula_predictor.py b/docling_ibm_models/code_formula_model/code_formula_predictor.py new file mode 100644 index 0000000..52849a7 --- /dev/null +++ b/docling_ibm_models/code_formula_model/code_formula_predictor.py @@ -0,0 +1,223 @@ +# +# Copyright IBM Corp. 2024 - 2024 +# SPDX-License-Identifier: MIT +# +import logging +from typing import List, Union + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer + +from docling_ibm_models.code_formula_model.models.sam_opt import SamOPTForCausalLM +from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import ( + SamOptImageProcessor, +) + +_log = logging.getLogger(__name__) + + +class CodeFormulaPredictor: + """ + Code and Formula Predictor using a multi-modal vision-language model. + + This class enables the prediction of code or LaTeX representations + from input images of code snippets or mathematical formulas. + + Attributes + ---------- + _device : str + The device on which the model is loaded (e.g., 'cpu' or 'cuda'). + _num_threads : int + Number of threads used for inference when running on CPU. + _tokenizer : transformers.PreTrainedTokenizer + Tokenizer for processing textual inputs to the model. + _model : transformers.PreTrainedModel + Pretrained multi-modal vision-language model. + _image_processor : transformers.ImageProcessor + Processor for normalizing and preparing input images. + _temperature : float + Sampling temperature for generation; controls randomness in predictions. + """ + + def __init__( + self, + artifacts_path: str, + device: str = "cpu", + num_threads: int = 4, + ): + """ + Initializes the CodeFormulaPredictor with the specified model artifacts. + + Parameters + ---------- + artifacts_path : str + Path to the directory containing the pretrained model files. + device : str, optional + Device to run the inference on ('cpu' or 'cuda'), by default "cpu". + num_threads : int, optional + Number of threads for CPU inference, by default 4. + """ + self._device = device + self._num_threads = num_threads + if device == "cpu": + torch.set_num_threads(self._num_threads) + + self._tokenizer = AutoTokenizer.from_pretrained( + artifacts_path, use_fast=True, padding_side="left" + ) + self._model = SamOPTForCausalLM.from_pretrained(artifacts_path).to(self._device) + self._model.eval() + + self._image_processor = SamOptImageProcessor.from_pretrained(artifacts_path) + + _log.debug("CodeFormulaModel settings: {}".format(self.info())) + + def info(self) -> dict: + """ + Retrieves configuration details of the CodeFormulaPredictor instance. + + Returns + ------- + dict + A dictionary containing configuration details such as the device and + the number of threads used. + """ + info = { + "device": self._device, + "num_threads": self._num_threads, + } + return info + + def _get_prompt(self, label: str) -> str: + """ + Constructs the prompt for the model based on the input label. + + Parameters + ---------- + label : str + The type of input, either 'code' or 'formula'. + + Returns + ------- + str + The constructed prompt including necessary tokens and query. + + Raises + ------ + NotImplementedError + If the label is not 'code' or 'formula'. + """ + if label == "code": + query = "" + elif label == "formula": + query = "" + else: + raise NotImplementedError("Label must be either code or formula") + + prompt = ( + "A chat between a curious user and an artificial intelligence" + " assistant. The assistant gives helpful, detailed, and polite answers to" + " the user's questions. USER:" + ) + prompt += ( + "" + "" * 256 + "" + "\n" + " ASSISTANT:" + "\n" + query + ) + + return prompt + + @torch.inference_mode() + def predict( + self, + images: List[Union[Image.Image, np.ndarray]], + labels: List[str], + temperature: float = 0.1, + ) -> List[str]: + """ + Predicts the textual representation of input images (code or LaTeX). + + Parameters + ---------- + images : List[Union[Image.Image, np.ndarray]] + List of images to be processed, provided as PIL Image objects or numpy arrays. + labels : List[str] + List of labels indicating the type of each image ('code' or 'formula'). + temperature : float, optional + Sampling temperature for generation, by default set to 0.1. + + Returns + ------- + List[str] + List of predicted textual outputs for each input image in the given input + order. + + Raises + ------ + TypeError + If any of the input images is not of a supported type (PIL Image or numpy array). + Excpetion + In case the temperature is an invalid number. + """ + if (type(temperature) != float and type(temperature) != int) or temperature < 0: + raise Exception("Temperature must be a number greater or equal to 0.") + + do_sample = True + if temperature == 0: + do_sample = False + temperature = None + + if len(labels) != len(images): + raise Exception( + "The number of images must be the same as the number of labels." + ) + + images_tmp = [] + for image in images: + if isinstance(image, Image.Image): + image = image.convert("RGB") + elif isinstance(image, np.ndarray): + image = Image.fromarray(image).convert("RGB") + else: + raise TypeError("Not supported input image format") + images_tmp.append(image) + images = images_tmp + + images_tensor = torch.stack([self._image_processor(img) for img in images]).to( + self._device + ) + + prompts = [self._get_prompt(label) for label in labels] + + tokenized = self._tokenizer(prompts, padding=True, return_tensors="pt") + tokenized = {k: v.to(self._device) for k, v in tokenized.items()} + + prompt_ids = tokenized["input_ids"] + attention_mask = tokenized["attention_mask"] + + if self._device == "cpu": + output_ids_list = self._model.generate( + input_ids=prompt_ids, + attention_mask=attention_mask, + images=images_tensor, + do_sample=do_sample, + temperature=temperature, + max_new_tokens=4096 - prompt_ids.shape[1], + use_cache=True, + ) + else: + with torch.autocast(device_type=self._device, dtype=torch.bfloat16): + output_ids_list = self._model.generate( + prompt_ids, + images=images_tensor, + do_sample=do_sample, + temperature=temperature, + max_new_tokens=4096 - prompt_ids.shape[1], + use_cache=True, + ) + + outputs = self._tokenizer.batch_decode( + output_ids_list[:, prompt_ids.shape[1] :], skip_special_tokens=True + ) + + return outputs diff --git a/docling_ibm_models/code_formula_model/models/sam.py b/docling_ibm_models/code_formula_model/models/sam.py new file mode 100644 index 0000000..98179cf --- /dev/null +++ b/docling_ibm_models/code_formula_model/models/sam.py @@ -0,0 +1,514 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This file was originally developed by Meta Platforms, Inc. as part of +# the Segment Anything project (https://github.com/facebookresearch/segment-anything). +# It has been adapted by contributors from the Vary-toy project +# (https://github.com/Ucas-HaoranWei/Vary-toy). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +from typing import Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + self.pos_embed = nn.Parameter( + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) + self.net_3 = nn.Conv2d( + 512, 1024, kernel_size=3, stride=2, padding=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + x = self.net_2(x) + x = self.net_3(x) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos( + attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) + + attn = attn.softmax(dim=-1) + x = ( + (attn @ v) + .view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + x = self.proj(x) + + return x + + +def window_partition( + x: torch.Tensor, window_size: int +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, + window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int], +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +def build_sam_vit_b(checkpoint=None, image_size=1024): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + image_size=image_size, + ) + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + image_size=1024, +): + prompt_embed_dim = 256 + vit_patch_size = 16 + image_encoder = ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + + if checkpoint is not None: + # with open(checkpoint, "rb") as f: + state_dict = torch.load(checkpoint) + + image_encoder.load_state_dict(state_dict, strict=True) + return image_encoder diff --git a/docling_ibm_models/code_formula_model/models/sam_opt.py b/docling_ibm_models/code_formula_model/models/sam_opt.py new file mode 100644 index 0000000..c682f45 --- /dev/null +++ b/docling_ibm_models/code_formula_model/models/sam_opt.py @@ -0,0 +1,237 @@ +# Copyright 2023 Haotian Liu +# +# This file is part of the Vary project, originally located at: +# https://github.com/Ucas-HaoranWei/Vary-toy/blob/main/Vary-master/vary/model/vary_opt.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + OPTConfig, + OPTForCausalLM, + OPTModel, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) + +from docling_ibm_models.code_formula_model.models.sam import build_sam_vit_b + + +class SamOptConfig(OPTConfig): + model_type = "sam_opt" + + def __init__( + self, + sam_image_size=1024, + sam_mm_projector_in=1024, + sam_mm_projector_out=768, + **kwargs, + ): + super().__init__(**kwargs) + self.sam_image_size = sam_image_size + self.sam_mm_projector_in = sam_mm_projector_in + self.sam_mm_projector_out = sam_mm_projector_out + + +class SamOPTModel(OPTModel): + config_class = SamOptConfig + + def __init__(self, config: OPTConfig): + super(SamOPTModel, self).__init__(config) + self.vision_tower = build_sam_vit_b(image_size=config.sam_image_size) + + self.mm_projector = nn.Linear( + config.sam_mm_projector_in, config.sam_mm_projector_out + ) + + def embed_tokens(self, x): + return self.get_input_embeddings()(x) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: torch.FloatTensor = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + vision_tower = getattr(self, "vision_tower", None) + im_start_token = getattr(self.config, "im_start_token", -1) + + if input_ids.shape[1] != 1 or self.training: + with torch.set_grad_enabled(self.training): + image_features = vision_tower(images) + image_features = image_features.flatten(2).permute(0, 2, 1) + image_features = self.mm_projector(image_features) + + new_input_embeds = [] + for cur_input_ids, cur_input_embeds, cur_image_features in zip( + input_ids, inputs_embeds, image_features + ): + image_start_token_position = torch.where( + cur_input_ids == im_start_token + )[0].item() + + cur_image_features = cur_image_features.to( + device=cur_input_embeds.device + ) + num_patches = cur_image_features.shape[0] + cur_input_embeds = torch.cat( + ( + cur_input_embeds[: image_start_token_position + 1], + cur_image_features, + cur_input_embeds[ + image_start_token_position + num_patches + 1 : + ], + ), + dim=0, + ) + + new_input_embeds.append(cur_input_embeds) + + inputs_embeds = torch.stack(new_input_embeds, dim=0) + + return super(SamOPTModel, self).forward( + input_ids=None, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SamOPTForCausalLM(OPTForCausalLM): + config_class = SamOptConfig + + def __init__(self, config): + super(OPTForCausalLM, self).__init__(config) + self.model = SamOPTModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + outputs = self.model( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + images=images, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states).contiguous() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + token_type_ids = kwargs.get("token_type_ids", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "images": kwargs.get("images", None), + } + ) + return model_inputs + + +AutoConfig.register("sam_opt", SamOptConfig) +AutoModelForCausalLM.register(SamOptConfig, SamOPTForCausalLM) diff --git a/docling_ibm_models/code_formula_model/models/sam_opt_image_processor.py b/docling_ibm_models/code_formula_model/models/sam_opt_image_processor.py new file mode 100644 index 0000000..75459aa --- /dev/null +++ b/docling_ibm_models/code_formula_model/models/sam_opt_image_processor.py @@ -0,0 +1,31 @@ +# +# Copyright IBM Corp. 2024 - 2024 +# SPDX-License-Identifier: MIT +# +from PIL import Image +from torchvision.transforms import functional as F +from transformers import AutoImageProcessor +from transformers.image_processing_utils import ImageProcessingMixin + + +class SamOptImageProcessor(ImageProcessingMixin): + + def __init__(self, size=(1024, 1024), mean=None, std=None, **kwargs): + super().__init__(**kwargs) + self.size = size + self.mean = mean + self.std = std + + def __call__(self, image): + if not isinstance(image, Image.Image): + raise ValueError("Input must be a PIL Image") + + image = F.resize(image, self.size) + image = F.to_tensor(image) + + image = F.normalize(image, mean=self.mean, std=self.std) + + return image + + +AutoImageProcessor.register(SamOptImageProcessor, SamOptImageProcessor) diff --git a/tests/test_code_formula_predictor.py b/tests/test_code_formula_predictor.py new file mode 100644 index 0000000..52614f4 --- /dev/null +++ b/tests/test_code_formula_predictor.py @@ -0,0 +1,131 @@ +# +# Copyright IBM Corp. 2024 - 2024 +# SPDX-License-Identifier: MIT +# +import os +import numpy as np +import pytest +from PIL import Image + +from docling_ibm_models.code_formula_model.code_formula_predictor import CodeFormulaPredictor + +from huggingface_hub import snapshot_download + +@pytest.fixture(scope="module") +def init() -> dict: + r""" + Initialize the testing environment + """ + init = { + "num_threads": 1, + "test_imgs": [ + { + "label": "code", + "image_path": "tests/test_data/code_formula/images/code.png", + "gt_path": "tests/test_data/code_formula/gt/code.txt", + }, + { + "label": "formula", + "image_path": "tests/test_data/code_formula/images/formula.png", + "gt_path": "tests/test_data/code_formula/gt/formula.txt", + }, + ], + "info": { + "device": "auto", + "temperature": 0, + }, + } + + # Download models from HF + artifact_path = snapshot_download(repo_id="ds4sd/CodeFormula", revision="v1.0.0") + + init["artifact_path"] = artifact_path + + return init + + +def test_code_formula_predictor(init: dict): + r""" + Unit test for the CodeFormulaPredictor + """ + device = "cpu" + num_threads = 2 + + # Initialize LayoutPredictor + code_formula_predictor = CodeFormulaPredictor( + init["artifact_path"], device=device, num_threads=num_threads + ) + + # Check info + info = code_formula_predictor.info() + assert info["device"] == device, "Wronly set device" + assert info["num_threads"] == num_threads, "Wronly set number of threads" + + # Unsupported input image + is_exception = False + try: + for _ in code_formula_predictor.predict(["wrong"], ['label']): + pass + except TypeError: + is_exception = True + assert is_exception + + # wrong type for temperature + is_exception = False + try: + dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) + for _ in code_formula_predictor.predict([dummy_image], ['label'], "0.1"): + pass + except Exception: + is_exception = True + assert is_exception + + # wrong value for temperature + is_exception = False + try: + dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) + for _ in code_formula_predictor.predict([dummy_image], ['label'], -0.1): + pass + except Exception: + is_exception = True + assert is_exception + + # mistmatched number of images and labels + is_exception = False + try: + dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) + for _ in code_formula_predictor.predict([dummy_image], ['label', 'label']): + pass + except Exception: + is_exception = True + assert is_exception + + # Predict on test images, not batched + temperature = init['info']['temperature'] + for d in init["test_imgs"]: + label = d['label'] + img_path = d['image_path'] + gt_path = d['gt_path'] + + with Image.open(img_path) as img, open(gt_path, 'r') as gt_fp: + gt = gt_fp.read() + + output = code_formula_predictor.predict([img], [label], temperature) + output = output[0] + + assert output == gt + + # Load images as numpy arrays + np_arr = np.asarray(img) + output = code_formula_predictor.predict([np_arr], [label], temperature) + output = output[0] + + assert output == gt + + # Predict on test images, batched + labels = [d['label'] for d in init["test_imgs"]] + images = [Image.open(d['image_path']) for d in init["test_imgs"]] + gts = [open(d['gt_path'], 'r').read() for d in init["test_imgs"]] + + outputs = code_formula_predictor.predict(images, labels, temperature) + assert outputs == gts diff --git a/tests/test_data/code_formula/gt/code.txt b/tests/test_data/code_formula/gt/code.txt new file mode 100644 index 0000000..9de7750 --- /dev/null +++ b/tests/test_data/code_formula/gt/code.txt @@ -0,0 +1,27 @@ +<_C++_> #include +using namespace std; + +int main(){ + + int n; + + while(cin>>n, n){ + int cnt=0; + + n=1000-n; + cnt+=n/500; + n%=500; + cnt+=n/100; + n%=100; + cnt+=n/50; + n%=50; + cnt+=n/10; + n%=10; + cnt+=n/5; + n%=5; + + cout<(85Bn;I5%?cY+fn9QaG2XvxENaxQ#oBxBT64`gzxmDe>)Wq20FIJ8NFD$L0sz3L7vR?- zKo)?4goKQQh=PoajEahahCzUdfsT$riieL)KuJbTMM*|MK|{~ML_^C)M?t|X_?(T4 zn~#r=nn_q(h)0ZrmyhSSL4c^Js2J!NB$${aJkKbe@%%qNesuzHQQ%bI{NaJL061JA zJTCB8FM#^#J`sVxAAtWlfN<~#h)BpNsA%X<9U5@}a6ou?I0Se^Ly0RK0!p8oz#uz!*Z_em}|1O#{l?8lY zr(Re1=J*S+JB59JE!iE+bb)_Ad^@?rbdz>bCedU23m}mYv|9T}5EwE37)bLA@Wt|h z87_!hvi;AI1pYEoeSDqKFTijA0F+G6TX{7W|9e;IYjVlk%dp3v@u!7HfAp8azm5F4 z{jm4@HD{>J>wf|A6eYiZ`vpMz>!ttoNPo!iPa~QAb)>F8h4`nD(Eq%r|1*-r|9MV- z>cszk+NUaJTON_Kwf>fUfwXAkK zahdS}YB|U)fO2dxw(^q@3@-eUZsK@M9CTHkdNM~oK!!R!60?>=)QQ-Z$|MihU7+L1 z?tANmd7~hXOgi{o+w&Iy_W<8?@vE1;h z63;Lh*_qMJ!syR-EG!oH3 zUOup;f*X7zNambDFL3s=`bPrcx3|V{ASO~VEacDj%X@P^6S5Q=FSma7gK{X`!Y-a; z4jhj2!}+=>tBSMfk`kdTdYY--7rN*KE~b_5DZG@`>WFd6K_tlZ`2~1)ad{sjE@+Si zI;}e19&urtSZQY^%_zQ&`AC37lhY!R(eHR>M21u=Q8AC|;b_T~LMu&P(QM}L0@yMO z`{6$`w?aM5?&!>B0cBkAW;<)IUQ3>N=v~{*j(fN&Z}`bLt-Rutsyb!qG0ZwJf-Kz< zDMTtn`W^I~sdA#C(N2k{&c{VzyYvA+H`;K9U!7DUMSNiLV?&y!?n(Z5S;a2^SuOdC zHy_4ep>34a@F|)*bg?lJ|DK#21|vom&FmmgUz1Hu*pH@#@dIn9Z^^R0*K_ zahw#S2HtS!^2PB4Um_%_ya(Lne&b^}Q(ej>2-9hpzFFeiSy6PKLRqk~+oGXOO@NCH zE~GbG)BXGyR(jxSBe4~kM*Im_BOP%ulOBN2X&I&?V>%?FyM~Y+f{cp3Y zSj&fkxfBuLzS@9&cFR=S@J~LB(r$4)9@8rgYo%{_g&0Ii%OpG1 z89biy;Q}?GCHq|;&_#5R&9Be4!rjkN?iM{E)ndx9k zo4`6jtoF^~$p&<4?8U4=(c|an6i1mS2HjVZoav*`Bv2o(rtG>C)#^*Kw*uAmY*iwq zc}ySb1l%{+7vU-L(mDJj8i3IOo1$5q_*L->>&V?oHsv)7Qq{6LBh3Ic;2>~Dax7OL2Z z=IoLvs)l)kV~)t3J9&KKc>6>$L7zG7MLd&}qB}B8uGsT%>?NHAY#wo4sfW2!dytN> z=b<$YPu9m^x?xVPy*&c2^ApxLp{qM@2>LRyesi+6UDZ*omk+{ZYKxh7ug3ZpZ&&qH zaku6>kDQHHJSK1KR=~crV8zYmIB@a1QIoNLghD(dw?rtZmPFTghEFkblHge}G#NLW!@5z~z!;0T zA?k@|WjcGRK&s*bf#SAaVw@tv&1Dtm$uh<3N|?c3HIc|zUO6$5kDHnX+#qBRcYHV=SCPX6{0tQhNUxn(R9m`0z3P}_cW*XXg-jfNYe&(BXvHkga^ve&B zr%iSzKM6*+I;YK}k1a|6{B6AE%aubB@9e#_v_zQ^9>FG|D~YhBGvkpC>N(Qoxs`tJ z{itNurSBQ^q_^Vu)F}pY?SfuD<*HuJc-GA!=1IaFMZt`M9_R4wl>IDjS37BZ$De-2 zQD>MMf)@u?KhHpp2vtYBKPHO;x0*!Ouv8TU+KRg@F^{0N3=OLXTQ=R$O#}bxct0Nd zVJ8ec_6EoJtVpvrQAqtNpC~B8vd^w6L^hV}J4mOGcQcok8I~Ac_raaJzS5!O+0R~AzW~9WlLvEn-;`fNHM8_XI2ljyDLwY1j40*DHRQpdId%4? zW=icl96!Q@Ow*EUvZA*2rjCv-o%gBwryr%3r05fb9WkujzVgaUUZG=Zop0D##F3X3 zY9VeVxAHhYGHIy!_{Ug&{3vYykt2o5fbYr(yzg~;xCNH3vcD1hVij@1l(loKeYMzQ zb}96IU-@Uf`OlBn*o|^eYy0_)N471su}55U1sQupX~|a?bDD;G%7!t9tjGJUrhD@?Jn@l1&Nso&H5+0x4%|t|V8qw4WlgzIZt~vuDicWL-FQ-t{#Llg*NLURA!4umAP%~B zRvOC&vQBk#>&pp!W1SC^)knAQrb*zRZk=%S!nfpL@yL$v$FZfkS8W)nvfyBH4~E&< zM6Uk?(kOwZ6ym?J5xR6dnTF*j(@+6jb*H{N$s4QE0Ll?)099?ZwZ^)%z3A6oR^*&X z!y$)kz2bgzusck%B_R$3Odx*nOoT6FeZHr?MHc!OyHMuqM<3Ovoy4eqS)MKmhmqeR zdg^bL+`s+R?p08KYtX#imrIKY#LWM~ukZ?!NVfYfb-M-z#xOz3QcO;vzXdh_)eyca z&~pzVr$=?V3u0^(73bqOb^Ug8|KWyc^t}c9r?>j9ZbjA2(!0mD(u+!MfTkS5Q`2p0 z=8|pEuIGj|Yjsk;jGF=4RQ@UGEGM?{=ec-_B&De@8MJ9`p{~dA(}eDSd0xS6Q*bB) zW!1X!A#=~gpfuruZB2Q`J`hS;gRE|MJ|}{g@;7K?LY>BtYF=om(BKVUHe{4eJyYxm>;8J$u>kF!W3+Q0J?3XMg%L>>Rk(oOFRgJG+IV4p^qRc` zKze3p;Wc1#REp@L>L?Xy<%+VpNf11wAZkPGRV5hqrV7JfXS+xv%bbq>dPn6;9))E> z4?N!PBb>OZ!I<}@9=62ia%@b4Pc1ITpes@c&LXU`u`3Fsx2@+T$6&#pDnxPr2X@s8JvYScr^nV%JEv;eDA!m5 z0PP1Kz}qq5#p#vGA1@8&PQE*u&pKC?-oN>f8I03ud9I&=s4(J#!L@Zm^&=mvZsQy_ z`OzqHGiID7islGtD^B2_i~4|x?&mmPXUKMzl|E_ZD~mZ)O2bX{4mr=ETdMvwsF_%Z z-KDs9PU?DW_)lMw`r-krG6n7?BN1;}duF5EMqM;I0n7OLf7^OG^aiL%-298P5xC%?= zsI__}neuPIWS}Yn4P8j*;^n6>mu&wF!w)a%hl8;9#?2YR_KDc7=_B-@0(+LWObk(*2)o%=C8h896sw%sEhjy< znwrs=!kske+uodP_mNZ9Pc@e1V??`P>ZO#%gJZIBRgN#u9aq zO+YxzTk0o04+Gxg6lwK1yq_!|lS$w~W(qj~n??aXdX>OdmJOGdUZ~8)nW7R#)g^Js zOK&S4U?a!GA#z?j)m)ceXT&t_rffx$Mt;i}L__V~+7vnr$zIbrV1`TmXzo<;P1!Pl zI16F8aiZ2Z-hxZ|d*VRp0*oQH$a}K=oW%+$S2tx`cso5Q2w%7Ck45zFi|K#j9c})L zH}&07-WX)|_chb!Y0WGH6htutDsG52)(L&;cs_Zt_~Sl&)h{c~6r$r%TF3EXy39v$ zgcR0APNG`U$I{gvjP%o=J^B`R`~t*3p*_FNx38sLVm>H_0!#3OV(utwp`v6?HxB`& zKfRwq4S9Be!P6Ch%mwAe2O!_MAwfcuk@XM39Gz<0vNhnKNjzO z+~$L^JAJ5A%@1ZZYCg7Uevv6>qhr$#PkQG?P@OPCznsKpS6EhNoy&~>da4MVASepp z@yD`HQk9d%>=i^aM72!OiAstgffI=noenGHW%%DtYtqW?~^9tWW?XYmQdH8VBQi=(j8L^gkrzHS&^nx}y#Ck7oi4Z)W5NS&}8 z@=cZfDx#1o3F<=OoFt=SUf1Q8)K7w7ww7%hGp{CZkG5N-Ks&R%$q4f{w4Vnv+&=sC z{ru&Kj)6dImztFL*|G6WJl|ogQCDkR9=KHe;Rg5Ct+$FSLH*2nL8^aVB4sMdaAMo` z*?Jc0c&u(5I`LOlD`_4gXa!Eef zx6|?>;REIrF7(@)MUy%4yg8e96S5G&xTcgus}KAI zpxS?$B?Z3#lhL{kjsL7qf1AfKlTT*L3hVMWpKEG}SBwAM^cU4_O&aSQMdM3y(LqUH=}b(>KQ$ul z{zwZ6n{8oD`GV&A5}^0r+5aq<BaT3XotxMh&wO>3$;=|!F*04sck!W6nci_rts zE6R&*PHs(io_g4!+v5{p$qjvag~tW*IpjdOImQ%SarPxToQ(eUHJ}b*DDb zrK6VY4T4eEqK~6(8qH=|ar$O02-LEPy3Ty+4L=Ho4g7|pmHnP6Cg8TWkh1xhg@lAa zg?Giv;2m1VaSvckIe|q+_DxO6IWp^2UxHW}yJ`IdU9qHap_D(W}=>(l^`kjAxzM5St1G z3fbJ*qN@0m2m!I#R=DAq!HQdYw))p1G}Da%ivyg_ue17jK@MFQ7BYm&VNlm~1@hWs zqu;{Re`^Ek-v91!Og}jsl>xt7slg{J^}B^2RJ*%!ER2mx^Rw31lz-|dD^=kBv6G~5@XG3~g6@O3(@m=H}5Vhva+&NdLHu*r&J>1Y?krYB` zQVhdfCzawA>)9r4;C+p_DXy}+UoA!FYIcph< zBWluHv(p;ip+_<(c6KO5z#(RU;g3{U{?)wfd0n(?)XAK(2OY&RpqX^AlirrJ)ZRa$ zoTwQlQOi(t0dl%7vuTVhG)6+()5&NhGNw2u&+q5_1$Zkd9j`dUH!RLG*Ny%o#v@B8 zl~D`_OX`3}@v?ZzQ+Y8EbS18*5qEdm*>u*R`=~^3w)N1K#nz9Iy%`eBJ%E%_{Sk9{ zcDWtQ6`-J<(5_Me(&(|L!{G9dYBM17xvrhL((9l*GId8MVEGt|Kpx4WChfQ>cC}dy zp7>7!U4H|)ehb+C8RXKY^^I|2%FH!PLU$%1S-a2@5XkAFqlmF-@qL!QKTx|k5=jI( zZ?_>b?z6E#L2wt=-JA**@$`Da2a#S*ge5qTfB$i#{iUAQd5Sxd(bsz&YJMriOfF{kF92U{P zuIJgpl!^)iFM>;m@hU#otexIeMeK0LcBdFmuvXO3a!`qm1X>6%i&+JfWN4w zwW)7_+)VZ3d|;8jiLsvzrwG`}=bJR54{o+uV!Vf8V{8SO)7%Ga>14~W{d1~4u{*Ae z5Z})BDcj6Z*uwCR(iQ06*FCLLudGb2$nI1)HeeP*%b})5;3`Z%y$g*vSWymmpQLZC zm89_GKx%PGBZs0K>x+~|IyVfPFv?FILbTDOW zdj&l3KglZ*CcBWWInNJS|U_oc{T&XGlXUGqxQY!D)2QbY(Uj=PAh z9qEDnoS|A)NyD+ZIoM&k={5$Zf|PvrldxEBN5itn`N>Z2c%Yf%7WMh~2p%36mi%ge zT)Qq&^roU$5U%Ndp2ZBEekM2iV5zaftROX49U%L1jJJ9b^|j+xRK%K&`eWMwkeLOMjRTLR2d_XO_5 zu)$3Lr4C{DUTbw_TaQS}T>d<_Z;0Bx)_c5f4eTO*R?LowGNo?917dnc*DYTfRr9`; zoWG@2Z`^+Ud9k*hbh(Y-M4Gv%W9686hn zW?+**l{Fx;nz4EJ`w#)O6jwzB0RB_?pO~G!Q?+?({78h<9!WV}J$xwkM+8fP#m#81 z{)g+o;iZOu$0#9zFJGvt9l*9Zogp>sVGuomV!`m&&RUcPN;ws-(%aG`UmbM}zE?S= zL*}`9hd#ddL6c%4-EFdir~TdocfwA~Tj^l`GfOZxzC+MH55Y!eulIC>i%Y`A{B8Yd z)#vJWW&8LmE3XU>Fhy0b^VkHxN0S1Wql1-eug59ew0z@@n!M`%=MkU zKJTdKfl)Xb0Ieb)APP)Qa(6h=^*%K*HvWnSnMYI^6X}=?$;&V8@jzI>8KrpZh0?is z4ta(BBkmQBVV4h8Y#`X$1CUD|xq?Tji151Ns58N1ARn`fz#NJ)?PGjQS`F4YwK8lR ztD%RDmEt+1mfMSI-jeI`vq>5ED&C*FgYI`HEZjCKb!U3vrVfuH_7>0|IFoFV5hbp_ zx1*gb#`cgFYM<#er3JW=8=9&kX0omt%McGY>h2k_p?2QJO}Mt_&H%+9-`~Bv*7hNe zXA|%!x*+^)7l+p|)z&35+JEoJ0 zS8ho?F*SD zKbna^`y-F7WAt|``g;Zfi2cOLfPFf$XbTr>p73L8)8!xH?!rpTe7D@^~`zEi#WNAh|$v+5tPQz;_5|3ZXWBhEPDVDDla zztf)%ALvflL%B)oofR^7c*3yOPI3F(f{gJ!^BtomrV3QyE#G zf!^zfC51;!sGWPR7pZM~tD|_MgF< zH1Z|nCrQC4sRBI6ki$Nn&ZP0ej8Sb;#e?~CAT(5kj3f!Ufgnf_x*f5IcVRL4E9{O) z!s7a`fzjc_^0tZC9mzetGo;S)(Ze!CjE2U9rpk~^VKq-<&Xx{ zaqD)8%S)5)e+8HpR4y(XEyEPNT=?9(Sdc-QA-Y;TDB@qGOApDMYsfpYf}DI4o94lO zI^B7)VN%3q!L^M6nR5gSx^wnEygLbWT)B6%Y~=J zGuoV&)3$Xtm+}{&!=PgRe-%2ka1yL$sj91lUg#t%O)-LEZr&eKS)1e1?&R;6E^( z5okC}t*Scmqu?p~>Ev3=GDp6$+jV!mTtX2)oHCgOuysk(V4Mx^E0ItHOVg~aUGf3P zeGTvWE>yGa{fl59+!P6xl^z{Vm`fc#b(-<;Aj>As)i+=YekEFApFQEnl%G--rea@Y zqmmTGwY;$!S>$rdB^Nq2DJB4R!}d+NRK1^N?|0h#d2$uG zF&cQp|Al~$iY>rq+H-(otq%WdyjVkUIf5-Un^etDKB|Jxx>10g2RR7Ok2iFM;}mI` zmF$gg%ep=^gh`muRa77a70gC@9utF3>-M3eE=2K8XfJm5yGtg61#3IEHiZw=KRbY& z5bArx68bJ=^PVqATMcw|9VhBC!;dyZ>goOY~>jBZZ>${b&~Bei`Gn zrCXg6U=Q(aQeK9$E3S`M)@R71zItQ+y7wJ*Z}b<|S)Z6qewY>jSz3w;ygnQq5dTBw z@{miip`kay!7cuaGXh2uTTYjK6xaAl-+H#5?RC23I-9^3n7iLEfCXNV?V97_4(`q_ z>DU`%TS}qthCirs9MY$czcgfNoDa6zy)`wra1GuC*upbai>%oS5dgja2r&O7*uedp z&i(6ExKv1phH3$#==}Y}(|$;$V9#U_O-Sy(lKWMwBJws>W3#+8RilhJ;}nUCW?{I! z$^s8B&Q_8>o*EYDVGroJ@HpwD7G$|bh(Kccx>kOz zH+_;pv|EZ|A*dzg*@S>$A=s1Ak=b9SavL zk&*_gdF(J#ie?@h65aRsI~SL86uLt#_RRvkd~0mpqpW@hl&yILd&F#OU%>-N$r^O2 zSWeZlAf}@XDG<%Cl>%Ig1oHY}d2*Iv-+~pCY+^k;ihl@Pygc8`(rx~*vfOpw&pr4| zi_RRcZJUm{8O>L7F>}rZeSdVlxm}Vvk1C5a1|??n8FyjC`LM9%r%h+x8()%gzC3K$ z-pyUkc6(hz+6Af=mOQ~$HU@Ax{UUjQ5kCOqC2v0WhC%_PL{u zJ%?S^kG+0XmLqep5{ShQW&kbFX?3_&ASqp^>1PXlazXN8;er0YEfdLZblv=G=!x z!$Dgx{vdtg?LJ?)So8QQ9bh|nPiVZ#ZcZLrglWyDS!$;XyJoG^rX1bQ>5@jUWCHAp zSy|HK*8C5#aQ>j#{XvlU9o&O|;z1(ZqYbG#FiejQ_gd0!SEcgwq+z|RkrEpJRAre> z;RVr|q&~7Cb*4*9y(1Eh#G4sSUD+aqdn_ZxV-@Yd|<`NPbplc-l!zwfq zhG*1)JQ!~>ZBm*$5w78{LbHCR;c%wmOH-VptST^5Fixg8w6*x^qC%4?3eli8e&qr}*^Lg_TeD#WSpTBNyqIKCq&r{V8!{>}|gy&Tc>6;q) z=6W7ft{{eZrcWCXqy47q*6`z_fcC@g2;*D{9T@kMpJFSwk!4?#I_de-iWjsUNv)D& zi7Yhe-VDE5tT*!-Tx@TcG8(?+Gt2<7=R0Ytp6tiRLkJJ2-z}utGQ}Ds;35xub67@Gm%E00m5-r)#@e+mdaPRG)F}F%dRQ)M;n1Hd-#F%+;i^P(50XyG7G!Yv+*3cv*^1d7E?!+;9?jIURN4!0*;p-{%8m>xQ=6T&2 z|6mO@_fsQnEh)fc@*{!+gx=fZFQfjNQ~Prn(3E@fETWFttJ`hC*I9uk#}OIsHUbID zB$-l^>0u;&Q?3j&USpyuPn5kH--yMYL$i5qm5uW=cKqS~q0!{zB%#gxLGh4EGPO#g z8qD6gy$zCTE2y5Mg-mQ?Mj#Z4I^2U!xAon6M3xVK`zWNbf#{533nTt|4xNO-7{2$N zaKDt|W~Orl?QA<|*$=QPrlV(zL?>fT_H8&8Iy7~!vsTNknSx6k>j`Fth*S5%-n7;> zhQ8v{lH2L$YvPq7FYiNZD?=6@*r{(9T>mWFb77(`KKpsn_qxc9JOM;;mFb7Pg*;c= zPKj`>XDFVUo_Xn=V94kHR-^z_M@xX9ZL6cz)%^1tOppHhiCrp*x9{Pws|wn>c<2i^ z?%rTblsDv0CRYx|(xx~K#1g0+t75Og5b=4Z8Vg%qIK=egAC~H-rdK(SL`a7zN1!)Q zF4zooGPcIbZwz7M?IjwsueD~46ZJ8gW68jEOJO|){_fbaF&aq6|3-fIzshO+cK%;e zL;owUQSHW2Bl&6je7U4maHT`MB{e?K%Pfo>HWe&-$4V}r`O--Y-@(m+ z*!9zj49c~$)DQ$IfMR!Ij{4pA*E)jH{JvhSq6qa0(LQN`bhif4l5I#dW4#vXIQ zqJ|Ks{D7hsJ4@caVmz%d_wp_ps(cEJcynW?>W`a@^A>Rd%C%O`aBY+l%i0QU)$88+ z7(Os{sKdKn|JC=Nqc3~$3||?di?gb|)6lnSS{liq6)eQ!k$h6I9=J~Bz3xwF2dsfsIT=eOBvvBP-tSpU=e)ewM?JmFFE;0arkv}H0eWzDtk|wdnaGH)m503yNh|@6%nE+Tr zpxGTlDtu zA6hxbh&ZpNH6W(@*aOF_2hN;(t?E$>#@YX52PK|TcVZZEAt)M1_D#}wY_BdFyJdSu$CM@m#Lq3R$C}fPI0d8 zt{x;UMJaA+{KHb}b4zn`>qNuCZVS*h$w+^p7#;zM-n!%9>XIsD%*f%cQqHnQBLCPW zpgfD$%u=15*4~hw#rtFhMVCHF*gP)%aHlO(^xnv`E?03E%XA&>=$##c^WgYsHU|83 zwg$qhZ`W`EsHqG1_u>ZZTHKAa7$4PUDFP=$`T+?~C5pqMCNfjXM#FZ_PFp0tB{3pg zhXZ0{jdB3n{X6-rVk7bDtWwxQ=aE4HG{jP#9soiEFafZ;b4vFO-z3A!mg+AWbdMPK zyWsT|$G-*$SI%f4Sx3aXlpcu0Udm-|Y9v}^WKnm2)W!gI!gr{PgVp}C%wV{x@4pqQ zFkMvY{DqwJyRRcVKqqtqSqyycbX){ADqpb5C=tEBvSEATprJB_!K#;x6~Hd+y(4(c z32n7U`Bs?)URl;jiaAzsz9W`HW`ct0iYjJnTVbm)6dS?p1~55m)Htz+v9S@iR10^| z*T3>{*c)ej>_$n26_^kFE3xvYUE4QNQbrn?5gRBE|QdDInC_eZ*D2j7^cWkHb`dbqRcDG9ze1E^5x* z{9gFynKe+_*lg`#VVtB7%bg~=?+tnGPe+Tnry9_$^Hhd(BF=25{%ZUAH|!H^oGX3b zb(=y@cs_0!NxXDz-x5-J+t3o|gj)X_`|dS=}I$T-umit^0? zI@dJl7w8&An`?>SN*|+yN{$&?9^6kUosZQQH9A>iJHt2AlC8fz2%FPm^vwvAO)*x|yA>g&0@8^*KzP*I8OiiDLqID}|i>RPZU@kXHr$KZ#D zc2-yBV|3XQ4YHB=Gxz8AQJ%79TW25Iz^NZMPqnD5LE3YU4c3is4AMt&mp1%iM9xJyk)|cahlu&a_@d^A;Ej5Ry^AYY7nOD+hNwomL z)921+l>CT(F~?}%@tTe>e1j7c3@hG^G@Qfa9h$iufUs|+p2&H!OyP|nC;)CtgVmR)zHs`;0Xz)1 zBR*DrE#99tDli;n87@!XUI}@LxeX`9MtU8y_WENTlACV3tffz&t;;pDqiquThm5N4 z+J5e38X;HbGeoTzjefb}3mM{_}+be+@AIsYrm;6hS}_ zxNC>gw^*W!W=XOgB8K5L>d~a<@3rA@BLJP-ndlCd!{|W`UKMo`(pu);D-?G#F?re+ z;^PqGgd?S~gwy6Qe6DxQ%*~%D(_f1n|E%gnyZ8klJkb2QsceLNrjlu;UShnYie~&G zkB2V+dj!=u?AZ-U0K9349Kbae0&~d2;2Xy=Rd6a_VQ44DVeZ0wEv2Iw2k#dPrL2Qg&;mwY4Vs6#M)T=1p&i{dz}nh+0!j;B=Z zT{hTBksJ~S5_h=z?==+vTvPFH`mFOO_)zitJ`qnWg~s3AbsS^ARhzV_%0;~xfrZWHj3ZpWB#EW_Eio0BSm}} z5m9%WI}kn=-1T*T;6Z*Y4sNr$$_1g;dmJ@3*IOe_sPkNd5fo_!M68(q3ux(8R@ zsDG?S0omGOevr_XaGvtPXRD&aKN$f@<`6F}{!#Ju2ZQ17OospA_y0}#N$8F_VnaQF zpI8HX9skm3NH%ckUI%g3q}tiZbm{i=efrzi)q|*jB&j}uS~^N(|A`Mu7H;S)1KP(n zCY@BhF)=@RSXstZXPo9=Lwh3;ELQlMz=-^|_p3tw)rLWNM)x|BXs4Ppx zB4A&A;Fr~Su!}+e5O-_Xdn2nslXy9P0eEcF0Ol{<#%3vlpLW*C{WpTbL_v54UFNo) z6D!2+qGpXm$GdH2NdlI106b79$)TzQ^K_bO{TC$f(4Bp@E}T!?O5@w%2~>L6=h>J1 zggaqDZdh5anV%l#KC$T@7@18Dx0}xv>c0AL7Q;nnw{?-8*In}qfU;@d{R{AuyXK_^ zxa1v+t-+k*5yI;h4S`*u@40qxeh-?Ct@4$DE>1qY`*rQSGPKWbDkAoOF7agFk=3vN z1m6cJMcf?u2o!HBu!>oGv`F8oyFFR8$p_&8dnzbZX}PR#E2PBcoLvP;z-~?eo@0ST zuHS78G)~kmUi4~TU`6%13|f&WhrmktCF%Wp6lC)BRN<{%3JjnWPXC7d4F9|7OCwp| z)ahHj`I!4>!SU_EYiS8Vo)}nLp2f4TLDwcP?Cz`!9SmEfb`YHsyWbTKGjfH61p_U= zk=07RW80deo2|+OvmOucf*iZ#hT_*#6rc484U)gJcMEyKvk=1> zV(7$@T)8@Zr>vs|w~I-URh#n~maFL9nc%yI=%0*>gB`O(S6a;Z;g-;IO^JK5Q_XC~<|?+sBTXKtVb;+q5#f!pf_Ln7gjsJzupcWl_p|-=BmY!2Kd*V+ErU@_ zZQJ1g)!vy$L)o`+{FbqV6j`!mG+tZHq=`}?gu$dJ$vk*uH^YQc*6dPQ#ujR1q>N=k zVJ3t~*2z-#B}N)qGDVgw;hFP1r|0RZxA(uE^PczZult<)KIeCzbDiIHu5*2V-|O?k zycIvLuOr$3*XLfr-xMz4OmlZ9TTcnRn|s%xk}4k~$1Ibcfml^69YuMU*%tJ%!^ngi zZG>@G1$O6^bp)+)mfjn(CD5iWFXz`rQvX16aK5p0n$#^_*mku!S+80--)6=lX!aC_ z@JU?w^4%$(UltixM=2Ce1i{K7-#!Lkc?&qCmZuuGwa?X6>=6Qa^~411wwL7d3^UKB zk5XGbFv!-Oou@l`Pj*8@UtmpI;jt|>tSrC2rGSVldomdXTvJ8unV0Wcn1FOfoU2T@ zc8L!4`Vgiy_Q<20G2(Q$cvgxuf0hJ&FF4wI?3pXHLiUM;V3y((e`kyj=%TRj_~pzl z_1&7GK7l@7eBon+S1E>~d}45e<%jW_i-wJ|-NwANB`ygzl>O+0A#2^DyO!v%9lhtL z_bcGqa4%y=JZmHq*<5vJ&nJ(H*&(H=iUc~-cjK z_G@2C^Efqy_J$U4qZ{myxL-0pm!mYkUy%6pc7f~Q5@kscm12H2dDkWd!m;uj3QZ;+ za6b(WJ1dt*jCG}Oi)i!CL^YYjAuv{>Z~2tETYs>cWAz1P@6hIv9(?&4*=@vHDr`vN zMrnysvPEgQ9L^ux^jcmvP>ZwbbC>%Ap5IlxK<>L5B@Cv8g#=qp0PI!AI9?uvAJ$*? ztr?CXp}ac5bKz@>y+-$4;gp^J*TJ<~n&%pa#_|4M^>zc=ez=c5@V-`Obz0QAXKeVc zcXb$nHlSTLA-yqY2Qc`sJn|g}akGd)U%ZsRq@;A$>Ak+CzoVWasVAS#SX)ji3Ntj# z?ilj7gg>!E=pg(n{#J$lP=9yx8g;*iU93VHKBh)0=D&_~Ox%pjPoO#Q#+z+>`pgZF zfDncX51ypPp=K8ri|rm8K%r1(0mByhsOi-k_@*3p0w<$E+cXL9r6ra8148?sbo1|h z6TTlb{X1q0_061swnwol4;S{VFDyP(8>4h5$^{!`ieM$oA&^l_JQ1(~B>a1C_J2z_zDJ$_ovKUH zoi&J>EmNXBl?k2Vo+qaaXr?(zG@^3`2D2hWEo{yB5XVQnndawr8Zb$|c5oWmNPou` zR!Jptx{71ZTUy>^j_KMZY761OGRUTrG=zFC+k>q#R|)`ZJ)em_BuZ6RtTwl1JjwJg z+g_(zp>Ye}bW^>9!x8vlN!(WO>goGmjh$>xFqW6ngH4V1)``L~sUiZC`%%nZGQk_=g}T3AF+$UH$*jDu}BH#@5el_IC*49k1xadYT<$)wt4xgw1Qg($?YE< zVfs8S#qp=}^Q%!NHlf@O-KN4>ksbVHwAGlf{R_gWoDnh+_5WHMYLJqRB?*K}D9S@H zfT4UM&KUiT5}GMT6>@S}Z};dGaFYXRwVj!xM)GqnTUO5y0IIXIRHD(g6tn}{I_=)R z4FDV@^!Bg8C&$_I&2iJp0iCktO z|AA++nwU5;!+DOpQ?#3IJ=2hu*91WoVqc6|#L)^7EnLsX!?h&1I*OPo<`27bx{gxL z)7v4JV30l!Zqr)M*P7KR`;shr-;dGQCHy`S`>yJFer9fVVv;X;6Je?}Cac zE!3QKxRThTWU;XFHj(d2GMP~f7 z;}$S-5IA~0w*!0R*a0?;Vh@j_>=r#K17z)7LzM%b^y3*7iH8P^+t?a#(NukXfUBHA z^vS)=jkR0V$Wq<_+EsX$c^tnudt#NJX_0>9Hh{1vX?^86yZgJW5(@fGNm-?5h)lTC zVte+Y=`x{d#wm1X=&O5bj+96#9fO4TS+V+)P7SE^E5nE6lPu}z@bv)11`x=j*~#^3 zF>ZZDhz%+%I&f|SAiiL+)pg=1Pgs7syBomc%K%bOkL6Arrv!g|ewV2n-0rdgywB7U zd|?`xy>*!>y*96~dYmQuA6=a=0a46selGnpApdhaBex8O>e&J`~Z9WnG96 zVe_#jkCp5}&sb(i{E10XGnA{r~^~ literal 0 HcmV?d00001 diff --git a/tests/test_data/code_formula/images/formula.png b/tests/test_data/code_formula/images/formula.png new file mode 100644 index 0000000000000000000000000000000000000000..02240961d0c694b4e48f6a5fc08d8b96fff9d950 GIT binary patch literal 4611 zcmchaS5(tqw#R=Vfe?C^4w4`!MF=fIkS1L~P&%Os29aJuRf?csC;{nJYLFrdp+x?Q zilHb1p(&wAlP1y>1qH6o%$mDq9%kKpAMQTu`*gl*ul+gaWq*&~AAbZmObm<-01yZS zg#SLkF%8fI7#SF#4D^gpC=>=`WMbiAWnpG!5#Zux=MfPU6BQ8@78XayONmR!!i9yU z)lSPQC@L!}i%Ff+L?O}gO3KJTO+YXhjD?wnpOuv#c~bZ!@_${&T>vK|paLocgCqbt zP7s(AbleSy0RV{p_q2Zo{zrl6zz}){C?kxC`L}_@0nmZKU^)nxo*n}E?T-0<4?sBS zxlStTFmPMCK_votka2_ujBwq@FLA)#To@QBES+lfhc?k1;XX5G)u$<50zD6goj zs;(i{K77*LLTYVm?;yW?)%}{%)7#fi9T^=PpLqXaa&CTMap}|Y%Iezo&hFQ5d*6TT zAN=700pP!5{gdq9xHx}v(Lo?!2=osZhz|E#U``1ANks-O9ZRTNAh!fEj*&-~@Zj+a z7+lF}i`PA9n28UeJSVyR2kmcU|1+@o{}b6i!Ty_T5?}#?et#aA6VL*_A4r$SGX0Ah zeR`-?JkIP8?tV)uO)XR&5WvjsZ8d-P`TYRAoR>oORJ2Igx2tgs#&GJh|XDuW2WArlTGoa zO1%swkHk@8iGYu#J`eDpWUC=|SB1Hh`+9+XYw`8InjY@Ix^wYm`CyKA&LWNzQYG_k zOF1YBZ%=|pDueyhN6?yCL{0*(%!RgoUgG(LD@iT#m6OcjbN3d$FT}X400{5hZ_uQJ zOVF}a94^;8=xhVLwLsu)lzU!(eLF-!(w3;HNI7%p+Lm(iP9L`e#=n4SI`%S_4?!!Y z-F*eDmV7|a1$YmwU3UIaG9wg$xiMn?^WruVFg5yk!Q8+QrfLwQDGk;?dz)O@YI$mo zj-lY&UzeGduF4+2w0H|DxxwRW{SIprQd;8`N#SaXO7;Yug*$7u22^vd&74 zikSdw*!m5?{FyLyd_|$1j{wsPoo%|NmFj*Bbgv&(`~RvQg4)_0Ne?%o_2y^;rsv&=jb-t`hoN`77;e*fChB(+`p#}EV=XO)xB~Y_vrDV>>h(Vz13oTx)7L~@ zKFH3@`a3O$JSv%1#EW*!9lR?eo8KW`Kl|NRuCi!(n5XS)8QZlkp%E$?ht3SG zMhhcf5}))x$=S8YNi$c)dxVSZLkr$bc+SojOHjAty>(~iKWWc~dEcizejxG#=i0eQ zMD6<&jeHYpidMT#({Pwx%bjCRzu)EOH8ZOxeDP+XW*VR=Xzvk!BF7dcc-qSg1+W>b z=#UIu*xx)Pktuw9#UA@47B*%_x_;Fkg=S`8r~1KzLedqP=vK9qGlZ3-0W zR?O+j!n&s{0t2Kw^V+D5m*IQUA#&QY`+GA6wGUUb&_NeDsYX@$vPv(kQaaGSv=W_# zmaS*^C;du;Wt(3sA0$q-QCsB53Qk)OqhcQzguhdYD*~8<*S1dO`$zd7WyrTOTRWqv z#{jMkES!}FNRcjYt3C}Zo3uBkFR~5BIO};m?)-GNz<`PIq+<=@C)m!x=vH&MQATtE z5ySs@#5?UHZp2Q{#5sn9~MUS6vcltR5HPbwL zuy-$9w!YzunAYlz2}v$CrPwf$Wf~aIjZYbAl<6X;I1V&~+n9Ejd^hUMdLt*loC#yk z!}rv}BMiPfM~*jU2i6sbHuzOL7KSt%bc9+y5z4$duK+Z`5%EZdnL;8RX%Nrs;;pb6 z*5s^5oF4E(!t z?BN`%x!lPM@l-wZTLs}Judp6hV~=zbUMTrPsEh zhmfy>zfKq*13fEXo2G9uTt6+xTZ(m=e};3L6PD~sMthxFdomdn#HEiBWSSLch zq)UrwNbp+4#7T+e*L!b7Uzy+{J!`Ho1O<5*E9YafgS%wqwWUP1lvbp{K~Fj@9;ULt zBsbP&2vVuWwwmuxpV8uuv%O)pqk$Sf3YQ`a%py4Q2py=GIqs;QZ82wcTaj%?HD~VD z;am+;s>%4Pbb}0~TMNWnOu*EZO;JcG1e+tPh_JDQDIM;VwH=|&{Dy~IEwwoZrLz2G zpG);7J`0=dWLt)UXL(JvM=QspwQANfCl9kX<*2u$ZXU)ZX}^P_;7yHU-8(89&+oxI zQ61jpBR4q7FL;tUl+ty`wvs>Q6=3tNBoXs|Gn^rH=1PL8>FKl8#%yDrm|7+=j}Dsu zp^|7)th$-O{tdaYJpoC=xU@S|XgUqp6Y2X5(yY_yp4^}$czYuF0x<)j5v&E2-MnYs z{Lwf8OjVI*} z%Q!?F;8&i=Z}rEcxIJH| zK>)#dVMoE8SasK0Hb4_tEOm48k5l2__@YwEG!ULOy?P?ODIeFFVHDl%rD@8#WI^5? zZ)thDzCCZO{D#(@>x&PH^g6WiF4>Lu#6+^>_o$8 zzx+U^`;hyJTerg5U?OaR4v*Jco8PZaraoI$UZ0PXI@4?Li}Bq!FqaHIU*GAPiGJO( zA2P^o|HEYF-~}ecktX9k6U~!mT>haWa?`q(>HN0S;?<4KX#J}NxY$!88#u9MnjwAb z&F#?1$446RN6H|Q+!n@Mh;bykG0C1S0It&|HB|y^CTZ3omtDx|6q^Hxs)A9bgUVtq zYtR)+%35i#Tc2dwdAIKtCXN!el!D=`^sh3ym0xM(TkGWZK3|XRlDdV$LOBdvQHQVC zM^*{op(I(oHPk@&F>0objK^yHbfnQ54776sjsY)v294+F&HLY9-k{9xwFT2sO07nH zLyA=wF8xKQAhqcs8i|7>!^??~kD=%Q8E)9Cx?NLvp=4g=4U%btMi*jXjS>HP2&U zd*B#I!h(ZR8hnTjE~cTpEvqt)EpGCxN9q2#&pm9@mhKe_h%#p zp0D_7L&_(LGgE2GW5|;DzOf4HD|y4QS36((7_ciIS&0<0B+N=o|9xk)J-xtOM>8Ej znwp(c32{*#tlIu$R4VUz@mEz=HJj3NU@>KFEqkUQS1n+=a@TRHJ5{icZ|c4~%!;2Q z!derl*mF69WfSVG7c)n^mZaNHeUq1O?=ZEGkIt+I2`M#0*q?TRdG+RP_B+{eHM*~_ zjaDVSM2DWT=Vd=qw>#pcez4c~;6`}nJ#ER<@CbgqW-i0zD010u{OjJRuY3HU+0&dQ zCL>{U?W=dvDwsYGl;BSYECV7umIy_JCDCqJ_hDKr8&ET$$uu=MUt6;D3>UP}F5a`OIkY4{sEkA29yBptxtbrvT<0OUJu7S<#ZR*79IbQrgqbLuj5qJ zbDB;yornxB7atSX-K;dSym=7!6nhLf-e2T?hh&Y^^fE2)-kEN}*(&se<1BbQ!#XB; zO0H!06NNcPG=s2UeUYoF7-4ksV#yMk-7Oh@v-Hzs1M8# zMJA*2+;1nrnFzV!>=c7jzaO19l$mE#F1p%7sIH!1RFOArZhdE79VZ~dYA-OU#$kR= z+&$-a7PAgL(}g6&GxkalKdg|$O4TiV{hwTjb)=X!-?$s1-=|>+j>CV9WO}X-@LKJ> z_OEQ)b^NZ-drHw`lU!6C6Sx0(jb|Az6mp5)xZcLhU}j(JO@SnqhEIm0BrS$wbv6cT z>q0EG84YwqpJm&cjmRk%UiS4$BM69EiP~g*36t%SHD1iTCD`_-zawv{Wg9uWjoJ00 z;zD8&qkN1xRr@Q;bnt7kG#!80&T@&>>jE2S%`^Z&m$QBp6r@N4<^MkdAAk5SJy=A7 literal 0 HcmV?d00001