Skip to content

Commit

Permalink
[WAYANG-#8] add PyFlatMapOperator and correction of types capture
Browse files Browse the repository at this point in the history
Signed-off-by: bertty <[email protected]>
  • Loading branch information
Bertty Contreras-Rojas authored and berttty committed Apr 8, 2022
1 parent 8b5048f commit d60c170
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 62 deletions.
15 changes: 13 additions & 2 deletions python/src/pywy/dataquanta.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,19 @@ def map(self: "DataQuanta[In]", f: Function) -> "DataQuanta[Out]":
def flatmap(self: "DataQuanta[In]", f: FlatmapFunction) -> "DataQuanta[IterableOut]":
return DataQuanta(self.context, self._connect(FlatmapOperator(f)))

def store_textfile(self: "DataQuanta[In]", path: str):
last: List[SinkOperator] = [cast(SinkOperator, self._connect(TextFileSink(path, self.operator.outputSlot[0])))]
def store_textfile(self: "DataQuanta[In]", path: str, end_line: str = None):
last: List[SinkOperator] = [
cast(
SinkOperator,
self._connect(
TextFileSink(
path,
self.operator.outputSlot[0],
end_line
)
)
)
]
plan = PywyPlan(self.context.plugins, last)

plug = self.context.plugins.pop()
Expand Down
8 changes: 6 additions & 2 deletions python/src/pywy/operators/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def __repr__(self):


class TextFileSink(SinkUnaryOperator):

path: str
end_line: str

def __init__(self, path: str, input_type: GenericTco):
def __init__(self, path: str, input_type: GenericTco, end_line: str = None):
super().__init__('TextFile', input_type)
self.path = path
if input_type != str and end_line is None:
self.end_line = '\n'
else:
self.end_line = end_line

def __str__(self):
return super().__str__()
Expand Down
8 changes: 0 additions & 8 deletions python/src/pywy/operators/unary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from itertools import chain
from pywy.operators.base import PywyOperator
from pywy.types import (
GenericTco,
Expand Down Expand Up @@ -68,13 +67,6 @@ def __init__(self, fm_function: FlatmapFunction):
super().__init__("Flatmap", types[0], types[1])
self.fm_function = fm_function

# TODO remove wrapper
def getWrapper(self):
udf = self.fm_function
def func(iterator):
return chain.from_iterable(map(udf, iterator))
return func

def __str__(self):
return super().__str__()

Expand Down
8 changes: 8 additions & 0 deletions python/src/pywy/platforms/python/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pywy.core import ChannelDescriptor
from pywy.core import Executor
from pywy.core import PywyPlan
from pywy.operators import TextFileSource
from pywy.platforms.python.channels import PY_ITERATOR_CHANNEL_DESCRIPTOR
from pywy.platforms.python.operator.py_execution_operator import PyExecutionOperator

Expand All @@ -17,6 +18,7 @@ def execute(self, plan):

# TODO get this information by a configuration and ideally by the context
descriptor_default: ChannelDescriptor = PY_ITERATOR_CHANNEL_DESCRIPTOR
files_pool = []

def execute(op_current: NodeOperator, op_next: NodeOperator):
if op_current is None:
Expand Down Expand Up @@ -66,4 +68,10 @@ def execute(op_current: NodeOperator, op_next: NodeOperator):

py_next.inputChannel = py_current.outputChannel

if isinstance(py_current, TextFileSource):
files_pool.append(py_current.outputChannel[0].provide_iterable())

graph.traversal(graph.starting_nodes, execute)
# close the files used during the execution
for f in files_pool:
f.close()
1 change: 1 addition & 0 deletions python/src/pywy/platforms/python/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
PYWY_OPERATOR_MAPPINGS.add_mapping(PyTextFileSourceOperator())
PYWY_OPERATOR_MAPPINGS.add_mapping(PyTextFileSinkOperator())
PYWY_OPERATOR_MAPPINGS.add_mapping(PyMapOperator())
PYWY_OPERATOR_MAPPINGS.add_mapping(PyFlatmapOperator())

2 changes: 2 additions & 0 deletions python/src/pywy/platforms/python/operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pywy.platforms.python.operator.py_execution_operator import PyExecutionOperator
from pywy.platforms.python.operator.py_unary_filter import PyFilterOperator
from pywy.platforms.python.operator.py_unary_map import PyMapOperator
from pywy.platforms.python.operator.py_unary_flatmap import PyFlatmapOperator
from pywy.platforms.python.operator.py_source_textfile import PyTextFileSourceOperator
from pywy.platforms.python.operator.py_sink_textfile import PyTextFileSinkOperator

Expand All @@ -10,4 +11,5 @@
PyTextFileSourceOperator,
PyTextFileSinkOperator,
PyMapOperator,
PyFlatmapOperator,
]
8 changes: 5 additions & 3 deletions python/src/pywy/platforms/python/operator/py_sink_textfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@ class PyTextFileSinkOperator(TextFileSink, PyExecutionOperator):
def __init__(self, origin: TextFileSink = None):
path = None if origin is None else origin.path
type_class = None if origin is None else origin.inputSlot[0]
super().__init__(path, type_class)
end_line = None if origin is None else origin.end_line
super().__init__(path, type_class, end_line)

def execute(self, inputs: List[Type[CH_T]], outputs: List[Type[CH_T]]):
self.validate_channels(inputs, outputs)
if isinstance(inputs[0], PyIteratorChannel):
file = open(self.path, 'w')
py_in_iter_channel: PyIteratorChannel = inputs[0]
iterable = py_in_iter_channel.provide_iterable()
if self.inputSlot[0] == str:

if self.inputSlot[0] == str and self.end_line is None:
for element in iterable:
file.write(element)
else:
for element in iterable:
file.write("{}\n".format(str(element)))
file.write("{}{}".format(str(element), self.end_line))
file.close()

else:
Expand Down
49 changes: 49 additions & 0 deletions python/src/pywy/platforms/python/operator/py_unary_flatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from itertools import chain
from typing import Set, List, Type

from pywy.core.channel import CH_T
from pywy.operators.unary import FlatmapOperator
from pywy.platforms.python.operator.py_execution_operator import PyExecutionOperator
from pywy.platforms.python.channels import (
ChannelDescriptor,
PyIteratorChannel,
PY_ITERATOR_CHANNEL_DESCRIPTOR,
PY_CALLABLE_CHANNEL_DESCRIPTOR,
PyCallableChannel
)


class PyFlatmapOperator(FlatmapOperator, PyExecutionOperator):

def __init__(self, origin: FlatmapOperator = None):
fm_function = None if origin is None else origin.fm_function
super().__init__(fm_function)

def execute(self, inputs: List[Type[CH_T]], outputs: List[Type[CH_T]]):
self.validate_channels(inputs, outputs)
udf = self.fm_function
if isinstance(inputs[0], PyIteratorChannel):
py_in_iter_channel: PyIteratorChannel = inputs[0]
py_out_iter_channel: PyIteratorChannel = outputs[0]
py_out_iter_channel.accept_iterable(chain.from_iterable(map(udf, py_in_iter_channel.provide_iterable())))
elif isinstance(inputs[0], PyCallableChannel):
py_in_call_channel: PyCallableChannel = inputs[0]
py_out_call_channel: PyCallableChannel = outputs[0]

def fm_func(iterator):
return chain.from_iterable(map(udf, iterator))

py_out_call_channel.accept_callable(
PyCallableChannel.concatenate(
fm_func,
py_in_call_channel.provide_callable()
)
)
else:
raise Exception("Channel Type does not supported")

def get_input_channeldescriptors(self) -> Set[ChannelDescriptor]:
return {PY_ITERATOR_CHANNEL_DESCRIPTOR, PY_CALLABLE_CHANNEL_DESCRIPTOR}

def get_output_channeldescriptors(self) -> Set[ChannelDescriptor]:
return {PY_ITERATOR_CHANNEL_DESCRIPTOR, PY_CALLABLE_CHANNEL_DESCRIPTOR}
124 changes: 89 additions & 35 deletions python/src/pywy/tests/integration/python_platform_test.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,127 @@
import logging
import os
import unittest
import tempfile
from typing import List
from itertools import chain
from typing import List, Iterable

from pywy.config import RC_TEST_DIR as ROOT
from pywy.dataquanta import WayangContext
from pywy.plugins import PYTHON

logger = logging.getLogger(__name__)


class TestIntegrationPythonPlatform(unittest.TestCase):

file_10e0: str

def setUp(self):
self.file_10e0 = "{}/10e0MB.input".format(ROOT)
pass

def test_grep(self):
@staticmethod
def seed_small_grep(validation_file):
def pre(a: str) -> bool:
return 'six' in a

fd, path_tmp = tempfile.mkstemp()

WayangContext() \
dq = WayangContext() \
.register(PYTHON) \
.textfile(self.file_10e0) \
.filter(pre) \
.store_textfile(path_tmp)

lines_filter: List[str]
with open(self.file_10e0, 'r') as f:
lines_filter = list(filter(pre, f.readlines()))
selectivity = len(list(lines_filter))
.textfile(validation_file) \
.filter(pre)

return dq, path_tmp, pre

def validate_files(self,
validation_file,
outputed_file,
read_and_convert_validation,
read_and_convert_outputed,
delete_outputed=True,
print_variable=False):
lines_filter: List[int]
with open(validation_file, 'r') as f:
lines_filter = list(read_and_convert_validation(f))
selectivity = len(lines_filter)

lines_platform: List[str]
with open(path_tmp, 'r') as fp:
lines_platform = fp.readlines()
lines_platform: List[int]
with open(outputed_file, 'r') as fp:
lines_platform = list(read_and_convert_outputed(fp))
elements = len(lines_platform)
os.remove(path_tmp)

if delete_outputed:
os.remove(outputed_file)

if print_variable:
logger.info(f"{lines_platform=}")
logger.info(f"{lines_filter=}")
logger.info(f"{elements=}")
logger.info(f"{selectivity=}")

self.assertEqual(selectivity, elements)
self.assertEqual(lines_filter, lines_platform)

def test_grep(self):

dq, path_tmp, pre = self.seed_small_grep(self.file_10e0)

dq.store_textfile(path_tmp)

def convert_validation(file):
return filter(pre, file.readlines())

def convert_outputed(file):
return file.readlines()

self.validate_files(
self.file_10e0,
path_tmp,
convert_validation,
convert_outputed
)

def test_dummy_map(self):
def pre(a: str) -> bool:
return 'six' in a

def convert(a: str) -> int:
return len(a)

fd, path_tmp = tempfile.mkstemp()
dq, path_tmp, pre = self.seed_small_grep(self.file_10e0)

WayangContext() \
.register(PYTHON) \
.textfile(self.file_10e0) \
.filter(pre) \
.map(convert) \
dq.map(convert) \
.store_textfile(path_tmp)

lines_filter: List[int]
with open(self.file_10e0, 'r') as f:
lines_filter = list(map(convert, filter(pre, f.readlines())))
selectivity = len(list(lines_filter))
def convert_validation(file):
return map(convert, filter(pre, file.readlines()))

lines_platform: List[int]
with open(path_tmp, 'r') as fp:
lines_platform = list(map(lambda x: int(x), fp.readlines()))
elements = len(lines_platform)
os.remove(path_tmp)
def convert_outputed(file):
return map(lambda x: int(x), file.read().splitlines())

self.assertEqual(selectivity, elements)
self.assertEqual(lines_filter, lines_platform)
self.validate_files(
self.file_10e0,
path_tmp,
convert_validation,
convert_outputed
)

def test_dummy_flatmap(self):
def fm_func(string: str) -> Iterable[str]:
return string.strip().split(" ")

dq, path_tmp, pre = self.seed_small_grep(self.file_10e0)

dq.flatmap(fm_func) \
.store_textfile(path_tmp, '\n')

def convert_validation(file):
return chain.from_iterable(map(fm_func, filter(pre, file.readlines())))

def convert_outputed(file):
return file.read().splitlines()

self.validate_files(
self.file_10e0,
path_tmp,
convert_validation,
convert_outputed
)
26 changes: 16 additions & 10 deletions python/src/pywy/tests/unit/dataquanta/dataquanta_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest
from typing import Tuple, Callable
from typing import Tuple, Callable, Iterable
from unittest.mock import Mock

from pywy.dataquanta import WayangContext
from pywy.dataquanta import DataQuanta
from pywy.exception import PywyException
from pywy.operators import *
from pywy.types import FlatmapFunction


class TestUnitCoreTranslator(unittest.TestCase):
Expand Down Expand Up @@ -110,13 +112,16 @@ def validate_flatmap(self, flatted: DataQuanta, operator: PywyOperator):
def test_flatmap_lambda(self):
(operator, dq) = self.build_seed()
func: Callable = lambda x: x.split(" ")
flatted = dq.flatmap(func)
self.validate_flatmap(flatted, operator)
try:
flatted = dq.flatmap(func)
self.validate_flatmap(flatted, operator)
except PywyException as e:
self.assertTrue("the return for the FlatmapFunction is not Iterable" in str(e))

def test_flatmap_func(self):
(operator, dq) = self.build_seed()

def fmfunc(i: str) -> str:
def fmfunc(i: str) -> Iterable[str]:
for x in range(len(i)):
yield str(x)

Expand All @@ -126,9 +131,10 @@ def fmfunc(i: str) -> str:
def test_flatmap_func_lambda(self):
(operator, dq) = self.build_seed()

def fmfunc(i):
for x in range(len(i)):
yield str(x)

flatted = dq.flatmap(lambda x: fmfunc(x))
self.validate_flatmap(flatted, operator)
try:
fm_func_lambda: Callable[[str], Iterable[str]] = lambda i: [str(x) for x in range(len(i))]
flatted = dq.flatmap(fm_func_lambda)
self.assertRaises("the current implementation does not support lambdas")
# self.validate_flatmap(flatted, operator)
except PywyException as e:
self.assertTrue("the return for the FlatmapFunction is not Iterable" in str(e))
Loading

0 comments on commit d60c170

Please sign in to comment.