diff --git a/pyndl/__init__.py b/pyndl/__init__.py index 4f6a19a..47f6e07 100644 --- a/pyndl/__init__.py +++ b/pyndl/__init__.py @@ -35,7 +35,7 @@ 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Information Analysis', - ] +] def sysinfo(): @@ -62,9 +62,12 @@ def sysinfo(): if uname.sysname == "Linux": _, *lines = os.popen("free -m").readlines() - for identifier in ["Mem:", "Swap:"]: - memory = [line for line in lines if identifier in line][0] - _, total, used, *_ = memory.split() + for identifier in ("Mem:", "Swap:"): + memory = [line for line in lines if identifier in line] + if len(memory) > 0: + _, total, used, *_ = memory[0].split() + else: + total, used = '?', '?' osinfo += "{} {}MiB/{}MiB\n".format(identifier, used, total) osinfo += "\n" diff --git a/pyndl/activation.py b/pyndl/activation.py index 5f0b52f..97e96e2 100644 --- a/pyndl/activation.py +++ b/pyndl/activation.py @@ -9,15 +9,21 @@ import multiprocessing as mp import ctypes from collections import defaultdict, OrderedDict +from typing import Iterable, List, Dict, Optional, Tuple, Union import numpy as np import xarray as xr from . import io +from .types import AnyWeights, CollectionEvent, AnyEvent, Path, CueCollection, Collection # pylint: disable=W0621 -def activation(events, weights, number_of_threads=1, remove_duplicates=None, ignore_missing_cues=False): +def activation(events: Union[Path, Iterable[AnyEvent]], + weights: AnyWeights, + number_of_threads: int = 1, + remove_duplicates: Optional[bool] = None, + ignore_missing_cues: bool = False) -> Union[xr.DataArray, Dict[str, np.ndarray]]: """ Estimate activations for given events in event file and outcome-cue weights. @@ -58,10 +64,13 @@ def activation(events, weights, number_of_threads=1, remove_duplicates=None, ign returned if weights is instance of dict """ - if isinstance(events, str): - events = io.events_from_file(events) + event_list = [] # type: Iterable[CollectionEvent] + if isinstance(events, Path): + event_list = io.events_from_file(events) + else: + event_list = events - events = (cues for cues, outcomes in events) + cues_gen = (cues for cues, outcomes in event_list) # type: Iterable[CueCollection] if remove_duplicates is None: def check_no_duplicates(cues): if len(cues) != len(set(cues)): @@ -69,9 +78,9 @@ def check_no_duplicates(cues): 'remove_duplicates=True'.format(' '.join(cues))) else: return set(cues) - events = (check_no_duplicates(cues) for cues in events) + cues_gen = (check_no_duplicates(cues) for cues in cues_gen) elif remove_duplicates is True: - events = (set(cues) for cues in events) + cues_gen = (set(cues) for cues in cues_gen) if isinstance(weights, xr.DataArray): cues = weights.coords["cues"].values.tolist() @@ -81,10 +90,10 @@ def check_no_duplicates(cues): cue_map = OrderedDict(((cue, ii) for ii, cue in enumerate(cues))) if ignore_missing_cues: event_cue_indices_list = (tuple(cue_map[cue] for cue in event_cues if cue in cues) - for event_cues in events) + for event_cues in cues_gen) else: event_cue_indices_list = (tuple(cue_map[cue] for cue in event_cues) - for event_cues in events) + for event_cues in cues_gen) # pylint: disable=W0621 activations = _activation_matrix(list(event_cue_indices_list), weights.values, number_of_threads) @@ -95,14 +104,14 @@ def check_no_duplicates(cues): dims=('outcomes', 'events')) elif isinstance(weights, dict): assert number_of_threads == 1, "Estimating activations with multiprocessing is not implemented for dicts." - activations = defaultdict(lambda: np.zeros(len(events))) - events = list(events) + cues_list = list(cues_gen) + activation_dict = defaultdict(lambda: np.zeros(len(cues_list))) # type: Dict[str, np.ndarray] for outcome, cue_dict in weights.items(): - _activations = activations[outcome] - for row, cues in enumerate(events): + _activations = activation_dict[outcome] + for row, cues in enumerate(cues_list): for cue in cues: - _activations[row] += cue_dict[cue] - return activations + _activations[row] += cue_dict[cue] # type: ignore + return activation_dict else: raise ValueError("Weights other than xarray.DataArray or dicts are not supported.") @@ -130,7 +139,8 @@ def _run_mp_activation_matrix(event_index, cue_indices): activations[:, event_index] = weights[:, cue_indices].sum(axis=1) -def _activation_matrix(indices_list, weights, number_of_threads): +def _activation_matrix(indices_list: List[Tuple[int, ...]], + weights: np.ndarray, number_of_threads: int) -> np.ndarray: """ Estimate activation for indices in weights @@ -160,12 +170,13 @@ def _activation_matrix(indices_list, weights, number_of_threads): activations[:, row] = weights[:, event_cues].sum(axis=1) return activations else: - shared_activations = mp.RawArray(ctypes.c_double, int(np.prod(activations_dim))) + # type stubs seem to be incorrect for multiprocessing lib. 2018-05-16 + shared_activations = mp.RawArray(ctypes.c_double, int(np.prod(activations_dim))) # type: ignore weights = np.ascontiguousarray(weights) - shared_weights = mp.sharedctypes.copy(np.ctypeslib.as_ctypes(np.float64(weights))) + shared_weights = mp.sharedctypes.copy(np.ctypeslib.as_ctypes(np.float64(weights))) # type: ignore initargs = (shared_weights, weights.shape, shared_activations, activations_dim) with mp.Pool(number_of_threads, initializer=_init_mp_activation_matrix, initargs=initargs) as pool: pool.starmap(_run_mp_activation_matrix, enumerate(indices_list)) - activations = np.ctypeslib.as_array(shared_activations) + activations = np.ctypeslib.as_array(shared_activations) # type: ignore activations.shape = activations_dim return activations diff --git a/pyndl/corpus.py b/pyndl/corpus.py index 2ba3510..80b72c1 100644 --- a/pyndl/corpus.py +++ b/pyndl/corpus.py @@ -12,6 +12,7 @@ import gzip import multiprocessing import xml.etree.ElementTree +from typing import Iterator __version__ = '0.2.0' @@ -19,7 +20,7 @@ PUNCTUATION = tuple(".,:;?!()[]'") -def _parse_time_string(time_string): +def _parse_time_string(time_string: str) -> float: """ parses string and returns time in seconds. @@ -32,7 +33,7 @@ def _parse_time_string(time_string): float(frames) / FRAMES_PER_SECOND) -def read_clean_gzfile(gz_file_path, *, break_duration=2.0): +def read_clean_gzfile(gz_file_path: str, *, break_duration=2.0) -> Iterator[str]: """ Generator that opens and reads a gunzipped xml subtitle file, while all xml tags and timestamps are removed. @@ -68,8 +69,10 @@ def read_clean_gzfile(gz_file_path, *, break_duration=2.0): text = word_tag.text if text in PUNCTUATION: words.append(text) - else: + elif text is not None: words.extend((' ', text)) + else: + raise ValueError("Text content of word tag is None.") result = ''.join(words) result = result.strip() @@ -112,7 +115,7 @@ class JobParseGz(): """ - def __init__(self, break_duration): + def __init__(self, break_duration: float) -> None: self.break_duration = break_duration def run(self, filename): @@ -126,7 +129,7 @@ def run(self, filename): return (lines, not_found) -def create_corpus_from_gz(directory, outfile, *, n_threads=1, verbose=False): +def create_corpus_from_gz(directory: str, outfile: str, *, n_threads=1, verbose=False): """ Create a corpus file from a set of gunziped (.gz) files in a directory. diff --git a/pyndl/count.py b/pyndl/count.py index 58b21d0..c3bd7e8 100644 --- a/pyndl/count.py +++ b/pyndl/count.py @@ -15,6 +15,7 @@ import itertools import multiprocessing import sys +from typing import Tuple def _job_cues_outcomes(event_file_name, start, step, verbose=False): @@ -45,8 +46,8 @@ def _job_cues_outcomes(event_file_name, start, step, verbose=False): return (nn + 1, cues, outcomes) -def cues_outcomes(event_file_name, - *, number_of_processes=2, verbose=False): +def cues_outcomes(event_file_name: str, + *, number_of_processes=2, verbose=False) -> Tuple[int, Counter, Counter]: """ Counts cues and outcomes in event_file_name using number_of_processes processes. @@ -65,8 +66,8 @@ def cues_outcomes(event_file_name, verbose) for start in range(number_of_processes))) n_events = 0 - cues = Counter() - outcomes = Counter() + cues = Counter() # type: Counter + outcomes = Counter() # type: Counter for nn, cues_process, outcomes_process in results: n_events += nn cues += cues_process @@ -116,8 +117,9 @@ def _job_words_symbols(corpus_file_name, start, step, lower_case=False, return (words, symbols) -def words_symbols(corpus_file_name, - *, number_of_processes=2, lower_case=False, verbose=False): +def words_symbols(corpus_file_name: str, *, + number_of_processes=2, lower_case=False, + verbose=False) -> Tuple[Counter, Counter]: """ Counts words and symbols in corpus_file_name using number_of_processes processes. @@ -136,8 +138,8 @@ def words_symbols(corpus_file_name, verbose) for start in range(number_of_processes))) - words = Counter() - symbols = Counter() + words = Counter() # type: Counter + symbols = Counter() # type: Counter for words_process, symbols_process in results: words += words_process symbols += symbols_process @@ -148,7 +150,7 @@ def words_symbols(corpus_file_name, return words, symbols -def save_counter(counter, filename, *, header='key\tfreq\n'): +def save_counter(counter: Counter, filename: str, *, header='key\tfreq\n') -> None: """ Saves a counter object into a tab delimitered text file. @@ -159,7 +161,7 @@ def save_counter(counter, filename, *, header='key\tfreq\n'): dfile.write('{key}\t{count}\n'.format(key=key, count=count)) -def load_counter(filename): +def load_counter(filename: str) -> Counter: """ Loads a counter out of a tab delimitered text file. @@ -167,7 +169,7 @@ def load_counter(filename): with open(filename, 'rt') as dfile: # skip header dfile.readline() - counter = Counter() + counter = Counter() # type: Counter for line in dfile: key, count = line.strip().split('\t') if key in counter.keys(): diff --git a/pyndl/io.py b/pyndl/io.py index 7170d09..b24a5af 100644 --- a/pyndl/io.py +++ b/pyndl/io.py @@ -9,12 +9,15 @@ """ import gzip -from collections import Iterator, Iterable +from collections import Iterable +from typing import Iterator, List, Optional, Tuple, Union, cast import pandas as pd +from .types import CollectionEvent, StringEvent -def events_from_file(event_path, compression="gzip"): + +def events_from_file(event_path: str, compression: Optional[str] = "gzip") -> Iterator[Tuple[List[str], List[str]]]: """ Yields events for all events in a gzipped event file. @@ -30,8 +33,8 @@ def events_from_file(event_path, compression="gzip"): ------ cues, outcomes : list, list a tuple of two lists containing cues and outcomes - """ + if compression == "gzip": event_file = gzip.open(event_path, 'rt') elif compression is None: @@ -51,8 +54,11 @@ def events_from_file(event_path, compression="gzip"): event_file.close() -def events_to_file(events, file_path, delimiter="\t", compression="gzip", - columns=("cues", "outcomes")): +def events_to_file(events: Union[Iterator[StringEvent], Iterator[CollectionEvent], pd.DataFrame], + file_path: str, + delimiter: str = "\t", + compression: Optional[str] = "gzip", + columns: Tuple[str, str] = ("cues", "outcomes")) -> None: """ Writes events to a file @@ -75,9 +81,11 @@ def events_to_file(events, file_path, delimiter="\t", compression="gzip", """ if isinstance(events, pd.DataFrame): - events = events_from_dataframe(events) + collection_events = events_from_dataframe(events) elif isinstance(events, (Iterator, Iterable)): - events = events_from_list(events) + collection_events = events_from_list(cast(Union[Iterator[StringEvent], + Iterator[CollectionEvent]], + events)) else: raise ValueError("events should either be a pd.DataFrame or an Iterator or an Iterable.") @@ -91,7 +99,7 @@ def events_to_file(events, file_path, delimiter="\t", compression="gzip", try: out_file.write("{}\n".format(delimiter.join(columns))) - for cues, outcomes in events: + for cues, outcomes in collection_events: if isinstance(cues, list) and isinstance(outcomes, list): line = "{}{}{}\n".format("_".join(cues), delimiter, @@ -105,7 +113,8 @@ def events_to_file(events, file_path, delimiter="\t", compression="gzip", out_file.close() -def events_from_dataframe(df, columns=("cues", "outcomes")): +def events_from_dataframe(df: pd.DataFrame, + columns: Tuple[str, str] = ("cues", "outcomes")) -> Iterator[CollectionEvent]: """ Yields events for all events in a pandas dataframe. @@ -130,7 +139,7 @@ def events_from_dataframe(df, columns=("cues", "outcomes")): yield (cues, outcomes) -def events_from_list(lst): +def events_from_list(lst: Union[Iterator[StringEvent], Iterator[CollectionEvent]]) -> Iterator[CollectionEvent]: """ Yields events for all events in a list. diff --git a/pyndl/ndl.py b/pyndl/ndl.py index 6de76f6..af774ca 100644 --- a/pyndl/ndl.py +++ b/pyndl/ndl.py @@ -16,6 +16,14 @@ import threading import time import warnings +from typing import ( + Iterator, + Dict, + List, + Optional, + Tuple, + Union, +) import cython import pandas as pd @@ -27,7 +35,7 @@ from . import preprocess from . import ndl_parallel from . import io - +from . import types warnings.simplefilter('always', DeprecationWarning) @@ -40,11 +48,12 @@ def events_from_file(event_path): return io.events_from_file(event_path) -def ndl(events, alpha, betas, lambda_=1.0, *, +def ndl(events: types.Path, alpha: float, betas: Tuple[float, float], + lambda_=1.0, *, method='openmp', weights=None, number_of_threads=8, len_sublists=10, remove_duplicates=None, verbose=False, temporary_directory=None, - events_per_temporary_file=10000000): + events_per_temporary_file=10000000) -> xr.DataArray: """ Calculate the weights for all_outcomes over all events in event_file given by the files path. @@ -103,11 +112,12 @@ def ndl(events, alpha, betas, lambda_=1.0, *, cpu_time_start = time.process_time() # preprocessing - n_events, cues, outcomes = count.cues_outcomes(events, - number_of_processes=number_of_threads, - verbose=verbose) - cues = list(cues.keys()) - outcomes = list(outcomes.keys()) + n_events, cues_counter, outcomes_counter =\ + count.cues_outcomes(events, + number_of_processes=number_of_threads, + verbose=verbose) + cues = list(cues_counter.keys()) + outcomes = list(outcomes_counter.keys()) cue_map = OrderedDict(((cue, ii) for ii, cue in enumerate(cues))) outcome_map = OrderedDict(((outcome, ii) for ii, outcome in enumerate(outcomes))) @@ -173,7 +183,7 @@ def ndl(events, alpha, betas, lambda_=1.0, *, elif method == 'threading': part_lists = slice_list(all_outcome_indices, len_sublists) - working_queue = Queue(len(part_lists)) + working_queue = Queue(len(part_lists)) # type: Queue threads = [] queue_lock = threading.Lock() @@ -219,11 +229,18 @@ def worker(): return weights -def _attributes(event_path, number_events, alpha, betas, lambda_, cpu_time, - wall_time, function, method=None, attrs=None): +def _attributes(event_path: types.Path, number_events: int, + alpha: Union[float, int, Dict[str, float]], betas: Tuple[float, float], + lambda_: float, cpu_time: float, wall_time: float, + function: str, method=None, attrs=None) -> Dict[str, str]: + if not isinstance(alpha, (float, int)): + alpha_str = 'varying' + else: + alpha_str = str(alpha) + width = max([len(ss) for ss in (event_path, str(number_events), - str(alpha), + alpha_str, str(betas), str(lambda_), function, @@ -235,13 +252,10 @@ def _attributes(event_path, number_events, alpha, betas, lambda_, cpu_time, def _format(value): return '{0: <{width}}'.format(value, width=width) - if not isinstance(alpha, (float, int)): - alpha = 'varying' - new_attrs = {'date': _format(time.strftime("%Y-%m-%d %H:%M:%S")), 'event_path': _format(event_path), 'number_events': _format(number_events), - 'alpha': _format(str(alpha)), + 'alpha': _format(alpha_str), 'betas': _format(str(betas)), 'lambda': _format(str(lambda_)), 'function': _format(function), @@ -283,10 +297,10 @@ class WeightDict(defaultdict): """ # pylint: disable=W0613 - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(lambda: defaultdict(float)) - self._attrs = OrderedDict() + self._attrs = OrderedDict() # type: OrderedDict if 'attrs' in kwargs: self.attrs = kwargs['attrs'] @@ -302,9 +316,11 @@ def attrs(self, attrs): self._attrs = OrderedDict(attrs) -def dict_ndl(events, alphas, betas, lambda_=1.0, *, +def dict_ndl(events: Union[types.Path, Iterator[types.CollectionEvent]], + alphas: Union[float, Dict[str, float]], + betas: Tuple[float, float], lambda_=1.0, *, weights=None, inplace=False, remove_duplicates=None, - make_data_array=False, verbose=False): + make_data_array=False, verbose=False) -> Union[xr.DataArray, WeightDict]: """ Calculate the weights for all_outcomes over all events in event_file. @@ -458,7 +474,7 @@ def dict_ndl(events, alphas, betas, lambda_=1.0, *, return weights -def slice_list(list_, len_sublists): +def slice_list(list_: List[int], len_sublists: int) -> List[List[int]]: r""" Slices a list in sublists with the length len_sublists. diff --git a/pyndl/ndl_parallel.pyi b/pyndl/ndl_parallel.pyi new file mode 100644 index 0000000..26f264f --- /dev/null +++ b/pyndl/ndl_parallel.pyi @@ -0,0 +1,16 @@ + + +def learn_inplace(binary_file_paths, weights, + alpha, beta1, + beta2, lambda_, + all_outcomes, + chunksize, + number_of_threads): + ... + + +def learn_inplace_2(binary_file_paths, weights, + alpha, beta1, + beta2, lambda_, + all_outcomes): + ... diff --git a/pyndl/preprocess.py b/pyndl/preprocess.py index f67924a..4843eec 100644 --- a/pyndl/preprocess.py +++ b/pyndl/preprocess.py @@ -14,10 +14,31 @@ import re import sys import time - - -def bandsample(population, sample_size=50000, *, cutoff=5, seed=None, - verbose=False): +from abc import abstractmethod +from io import TextIOWrapper +from collections import ( + Counter, + OrderedDict, +) + +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + Union, + TypeVar, + Generic +) + +from . import types + + +def bandsample(population: Counter, sample_size=50000, *, cutoff=5, seed=None, + verbose=False) -> Counter: """ Creates a sample of size sample_size out of the population using band sampling. @@ -25,26 +46,26 @@ def bandsample(population, sample_size=50000, *, cutoff=5, seed=None, """ # make a copy of the population # filter all words with freq < cutoff - population = [(word, freq) for word, freq in population.items() if freq >= - cutoff] + population_list = [(word, freq) for word, freq in population.items() + if freq >= cutoff] if seed is not None: raise NotImplementedError("Reproducable bandsamples by seeding are not properly implemented yet.") # shuffle words with same frequency rand = random.Random(seed) - rand.shuffle(population) - population.sort(key=lambda x: x[1]) # lowest -> highest freq + rand.shuffle(population_list) + population_list.sort(key=lambda x: x[1]) # lowest -> highest freq - step = sum(freq for word, freq in population) / sample_size + step = sum(freq for word, freq in population_list) / sample_size if verbose: print("step %.2f" % step) accumulator = 0 index = 0 sample = list() - while 0 <= index < len(population): - word, freq = population[index] + while 0 <= index < len(population_list): + word, freq = population_list[index] accumulator += freq if verbose: print("%s\t%i\t%.2f" % (word, freq, accumulator)) @@ -53,15 +74,15 @@ def bandsample(population, sample_size=50000, *, cutoff=5, seed=None, accumulator -= step if verbose: print("add\t%s\t%.2f" % (word, accumulator)) - del population[index] + del population_list[index] while accumulator >= step and index >= 1: index -= 1 - sample.append(population[index]) + sample.append(population_list[index]) accumulator -= step if verbose: - word, freq = population[index] + word, freq = population_list[index] print(" add\t%s\t%.2f" % (word, accumulator)) - del population[index] + del population_list[index] else: # only add to index if no element was removed # if element was removed, index points at next element already @@ -69,11 +90,13 @@ def bandsample(population, sample_size=50000, *, cutoff=5, seed=None, if verbose and index % 1000 == 0: print(".", end="") sys.stdout.flush() - sample = collections.Counter({key: value for key, value in sample}) - return sample + sample_counter = collections.Counter({key: value for key, value in sample}) + return sample_counter -def ngrams_to_word(occurrences, n_chars, outfile, remove_duplicates=True): +def ngrams_to_word(occurrences: Iterator[types.StringEvent], + n_chars: int, outfile: TextIOWrapper, + remove_duplicates=True) -> None: """ Process the occurrences and write them to outfile. @@ -95,18 +118,23 @@ def ngrams_to_word(occurrences, n_chars, outfile, remove_duplicates=True): else: # take either occurrence = cues + outcomes phrase_string = "#" + re.sub("_", "#", occurrence) + "#" - ngrams = (phrase_string[i:(i + n_chars)] for i in - range(len(phrase_string) - n_chars + 1)) - if not ngrams or not occurrence: + ngrams_it = (phrase_string[i:(i + n_chars)] for i in + range(len(phrase_string) - n_chars + 1)) + if not ngrams_it or not occurrence: continue + ngrams = [] # type: Iterable if remove_duplicates: - ngrams = set(ngrams) + ngrams = set(ngrams_it) occurrence = "_".join(set(occurrence.split("_"))) + else: + ngrams = ngrams_it outfile.write("{}\t{}\n".format("_".join(ngrams), occurrence)) -def process_occurrences(occurrences, outfile, *, - cue_structure="trigrams_to_word", remove_duplicates=True): +def process_occurrences(occurrences: Iterator[types.StringEvent], + outfile: TextIOWrapper, *, + cue_structure="trigrams_to_word", + remove_duplicates=True) -> None: """ Process the occurrences and write them to outfile. @@ -139,8 +167,8 @@ def process_occurrences(occurrences, outfile, *, raise NotImplementedError('cue_structure=%s is not implemented yet.' % cue_structure) -def create_event_file(corpus_file, - event_file, +def create_event_file(corpus_file: types.Path, + event_file: types.Path, symbols="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", *, context_structure="document", @@ -149,7 +177,7 @@ def create_event_file(corpus_file, cue_structure="trigrams_to_word", lower_case=False, remove_duplicates=True, - verbose=False): + verbose=False) -> None: """ Create an text based event file from a corpus file. @@ -304,7 +332,7 @@ def process_context(line): with gzip.open(event_file, "wt") as outfile: outfile.write("cues\toutcomes\n") - words = [] + words = [] # type: List[str] for ii, line in enumerate(corpus): if verbose and ii % 100000 == 0: print(".", end="") @@ -351,7 +379,19 @@ def process_context(line): process_words(words) -class JobFilter(): +class JobFilterBase(): + def process_cues(self, cues): + ... + + def process_outcomes(self, cues): + ... + + +KeepCues = TypeVar('KeepCues', str, Iterable[types.Cue]) +KeepOutcomes = TypeVar('KeepOutcomes', str, Iterable[types.Outcome]) + + +class JobFilter(JobFilterBase, Generic[KeepCues, KeepOutcomes]): # pylint: disable=E0202,missing-docstring """ @@ -363,13 +403,19 @@ class JobFilter(): Using a closure is not possible as it is not pickable / serializable. """ + keep_cues = None # type: KeepCues + keep_outcomes = None # type: KeepOutcomes @staticmethod - def return_empty_string(): + def return_empty_string() -> str: return '' - def __init__(self, keep_cues, keep_outcomes, remove_cues, remove_outcomes, - cue_map, outcome_map): + def __init__(self, keep_cues: KeepCues, + keep_outcomes: KeepOutcomes, + remove_cues: Optional[types.Collection[types.Cue]], + remove_outcomes: Optional[types.Collection[types.Outcome]], + cue_map: Optional[Dict[types.Cue, types.Cue]], + outcome_map: Optional[Dict[types.Outcome, types.Outcome]]) -> None: if ((cue_map is not None and remove_cues is not None) or (cue_map is not None and keep_cues != 'all') or (remove_cues is not None and keep_cues != 'all')): @@ -379,70 +425,68 @@ def __init__(self, keep_cues, keep_outcomes, remove_cues, remove_outcomes, (remove_outcomes is not None and keep_outcomes != 'all')): raise ValueError('You can either specify keep_outcomes, remove_outcomes, or outcome_map.') + # Type checking cannot handle assign to a method. 2018-05-16 if cue_map is not None: self.cue_map = collections.defaultdict(self.return_empty_string, cue_map) - self.process_cues = self.process_cues_map + self.process_cues = self.process_cues_map # type: ignore elif remove_cues is not None: self.remove_cues = set(remove_cues) - self.process_cues = self.process_cues_remove + self.process_cues = self.process_cues_remove # type: ignore elif keep_cues == 'all': self.keep_cues = 'all' - self.process_cues = self.process_cues_all + self.process_cues = self.process_cues_all # type: ignore else: self.keep_cues = keep_cues - self.process_cues = self.process_cues_keep + self.process_cues = self.process_cues_keep # type: ignore + if outcome_map is not None: self.outcome_map = collections.defaultdict(self.return_empty_string, outcome_map) - self.process_outcomes = self.process_outcomes_map + self.process_outcomes = self.process_outcomes_map # type: ignore elif remove_outcomes is not None: self.remove_outcomes = set(remove_outcomes) - self.process_outcomes = self.process_outcomes_remove + self.process_outcomes = self.process_outcomes_remove # type: ignore elif keep_outcomes == 'all': self.keep_outcomes = 'all' - self.process_outcomes = self.process_outcomes_all + self.process_outcomes = self.process_outcomes_all # type: ignore + elif isinstance(keep_outcomes, Iterable): + self.keep_outcomes = set(keep_outcomes) # type: ignore + self.process_outcomes = self.process_outcomes_keep # type: ignore else: - self.keep_outcomes = set(keep_outcomes) - self.process_outcomes = self.process_outcomes_keep + raise NotImplementedError('Unsupported variable combination.') - def process_cues(self, cues): - raise NotImplementedError("Needs to be implemented or assigned by a specific method.") - - def process_cues_map(self, cues): + def process_cues_map(self, cues: types.CueCollection) -> types.CueCollection: cues = [self.cue_map[cue] for cue in cues] return [cue for cue in cues if cue] - def process_cues_remove(self, cues): + def process_cues_remove(self, cues: types.CueCollection) -> types.CueCollection: return [cue for cue in cues if cue not in self.remove_cues] - def process_cues_keep(self, cues): + def process_cues_keep(self, cues: types.CueCollection) -> types.CueCollection: return [cue for cue in cues if cue in self.keep_cues] - def process_cues_all(self, cues): + def process_cues_all(self, cues: types.CueCollection) -> types.CueCollection: return cues - def process_outcomes(self, outcomes): - raise NotImplementedError("Needs to be implemented or assigned by a specific method.") - def process_outcomes_map(self, outcomes): outcomes = [self.outcome_map[outcome] for outcome in outcomes] return [outcome for outcome in outcomes if outcome] - def process_outcomes_remove(self, outcomes): + def process_outcomes_remove(self, outcomes: types.OutcomeCollection) -> types.OutcomeCollection: return [outcome for outcome in outcomes if outcome not in self.remove_outcomes] - def process_outcomes_keep(self, outcomes): + def process_outcomes_keep(self, outcomes: types.OutcomeCollection) -> types.OutcomeCollection: return [outcome for outcome in outcomes if outcome in self.keep_outcomes] - def process_outcomes_all(self, outcomes): + def process_outcomes_all(self, outcomes: types.OutcomeCollection) -> types.OutcomeCollection: return outcomes - def job(self, line): + def job(self, line: str) -> Optional[str]: try: - cues, outcomes = line.strip('\n').split("\t") + cues_str, outcomes_str = line.strip('\n').split("\t") except ValueError: raise ValueError("tabular event file need to have two tab separated columns") - cues = cues.split("_") - outcomes = outcomes.split("_") + cues = cues_str.split("_") + outcomes = outcomes_str.split("_") cues = self.process_cues(cues) outcomes = self.process_outcomes(outcomes) # no cues left? @@ -454,12 +498,13 @@ def job(self, line): return processed_line -def filter_event_file(input_event_file, output_event_file, *, +def filter_event_file(input_event_file: types.Path, + output_event_file: types.Path, *, keep_cues="all", keep_outcomes="all", remove_cues=None, remove_outcomes=None, cue_map=None, outcome_map=None, number_of_processes=1, chunksize=100000, - verbose=False): + verbose=False) -> None: """ Filter an event file by a list or a map of cues and outcomes. @@ -528,7 +573,7 @@ def filter_event_file(input_event_file, output_event_file, *, CURRENT_VERSION = 2048 + 215 -def read_binary_file(binary_file_path): +def read_binary_file(binary_file_path: types.Path) -> Iterator[types.IdCollectionEvent]: with open(binary_file_path, "rb") as binary_file: magic_number = to_integer(binary_file.read(4)) if not magic_number == MAGIC_NUMBER: @@ -550,15 +595,17 @@ def read_binary_file(binary_file_path): yield (cue_ids, outcome_ids) -def to_bytes(int_): +def to_bytes(int_: int) -> bytes: return int_.to_bytes(4, 'little') -def to_integer(byte_): +def to_integer(byte_: bytes) -> int: return int.from_bytes(byte_, "little") -def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicates=None): +def write_events(events: Iterator[types.IdCollectionEvent], + filename: types.Path, *, + start=0, stop=4294967295, remove_duplicates=None) -> int: """ Write out a list of events to a disk file in binary format. @@ -658,7 +705,10 @@ def write_events(events, filename, *, start=0, stop=4294967295, remove_duplicate return n_events -def event_generator(event_file, cue_id_map, outcome_id_map, *, sort_within_event=False): +def event_generator(event_file: types.Path, + cue_id_map: Dict[types.Cue, types.Id], + outcome_id_map: Dict[types.Outcome, types.Id], *, + sort_within_event=False) -> Iterator[types.IdCollectionEvent]: with gzip.open(event_file, "rt") as in_file: # skip header in_file.readline() @@ -683,31 +733,31 @@ def event_generator(event_file, cue_id_map, outcome_id_map, *, sort_within_event def _job_binary_event_file(*, - file_name, - event_file, - cue_id_map, - outcome_id_map, - sort_within_event, - start, - stop, - remove_duplicates): + file_name: types.Path, + event_file: types.Path, + cue_id_map: Dict[types.Cue, types.Id], + outcome_id_map: Dict[types.Outcome, types.Id], + sort_within_event: bool, + start: int, + stop: int, + remove_duplicates: Optional[bool]): # create generator which is not pickable events = event_generator(event_file, cue_id_map, outcome_id_map, sort_within_event=sort_within_event) n_events = write_events(events, file_name, start=start, stop=stop, remove_duplicates=remove_duplicates) return n_events -def create_binary_event_files(event_file, - path_name, - cue_id_map, - outcome_id_map, +def create_binary_event_files(event_file: types.Path, + path_name: types.Path, + cue_id_map: Dict[types.Cue, types.Id], + outcome_id_map: Dict[types.Outcome, types.Id], *, sort_within_event=False, number_of_processes=2, events_per_file=10000000, overwrite=False, - remove_duplicates=None, - verbose=False): + remove_duplicates: Optional[bool] = None, + verbose=False) -> int: """ Creates the binary event files for a tabular cue outcome corpus. diff --git a/pyndl/types.py b/pyndl/types.py new file mode 100644 index 0000000..8d11d2e --- /dev/null +++ b/pyndl/types.py @@ -0,0 +1,31 @@ +from typing import Dict, Iterator, Tuple, TypeVar + +from numpy import ndarray +from xarray.core.dataarray import DataArray + +try: + from typing import Collection +except ImportError: # Python 3.5 fallback + from typing import Union, Sequence, Set + T = TypeVar('T') + + # ignore typing because mypy thinks Collection is already a defined type. + Collection = Union[Sequence[T], Set[T]] # type: ignore + +Path = str +Cue = str +Outcome = str +Id = int + +IdCollection = Collection[Id] +CueCollection = Collection[Cue] +AnyCues = TypeVar('AnyCues', ndarray, CueCollection) +OutcomeCollection = Collection[Outcome] +AnyOutcomes = TypeVar('AnyOutcomes', ndarray, OutcomeCollection) +CollectionEvent = Tuple[CueCollection, OutcomeCollection] +IdCollectionEvent = Tuple[IdCollection, IdCollection] +StringEvent = Tuple[str, str] +AnyEvent = Tuple[AnyCues, AnyOutcomes] +AnyEvents = TypeVar('AnyEvents', Path, Iterator[AnyEvent]) +WeightDict = Dict[str, Dict[str, float]] +AnyWeights = TypeVar('AnyWeights', DataArray, WeightDict) diff --git a/stubs/numpy/__init__.pyi b/stubs/numpy/__init__.pyi index f43aec2..05373f3 100644 --- a/stubs/numpy/__init__.pyi +++ b/stubs/numpy/__init__.pyi @@ -446,6 +446,7 @@ def asanyarray(a: Any, dtype: DtypeType=None, order: str=None) -> ndarray[Any]: def asmatrix(data: Any, dtype: DtypeType=None) -> Any: ... # TODO define matrix def ascontiguousarray(a: Any, dtype: DtypeType=None) -> ndarray[Any]: ... def copy(a: Any, order: str=None) -> ndarray[Any]: ... +def concatenate(a: Sequence[_ArrayLike], axis=0, out: Optional[ndarray[Any]]=None) -> ndarray[Any]: ... def empty(shape: ShapeType, dtype: DtypeType=float, order: str='C') -> ndarray[Any]: ... def empty_like(a: Any, dtype: Any=None, order: str='K', subok: bool=True) -> ndarray[Any]: ... def eye(N: int, M: int=None, k: int=0, dtype: DtypeType=float) -> ndarray[Any]: ... @@ -467,4 +468,6 @@ def loadtxt(fname: Any, dtype: DtypeType=float, comments: Union[str, Sequence[st def ones(shape: ShapeType, dtype: Optional[DtypeType]=..., order: str='C') -> ndarray[Any]: ... def ones_like(a: Any, dtype: Any=None, order: str='K', subok: bool=True) -> ndarray[Any]: ... def zeros(shape: ShapeType, dtype: DtypeType=float, order: str='C') -> ndarray[Any]: ... -def zeros_like(a: Any, dtype: Any=None, order: str='K', subok: bool=True) -> ndarray[Any]: ... \ No newline at end of file +def zeros_like(a: Any, dtype: Any=None, order: str='K', subok: bool=True) -> ndarray[Any]: ... + +__version__: Any = ... \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index cbabdc9..2b9163a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ Configuration for py.test-3. ''' +import pytest def pytest_addoption(parser): @@ -9,4 +10,14 @@ def pytest_addoption(parser): adds custom option to the pytest parser """ parser.addoption("--runslow", action="store_true", - help="run slow tests") + default=False, help="run slow tests") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) diff --git a/tests/test_activation.py b/tests/test_activation.py index b99b363..8136280 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -15,8 +15,6 @@ from pyndl import ndl from pyndl.activation import activation -slow = pytest.mark.skipif(not pytest.config.getoption("--runslow"), # pylint: disable=invalid-name - reason="need --runslow option to run") TEST_ROOT = os.path.join(os.path.pardir, os.path.dirname(__file__)) FILE_PATH_SIMPLE = os.path.join(TEST_ROOT, "resources/event_file_simple.tab.gz") @@ -140,7 +138,7 @@ def test_ignore_missing_cues_dict(): assert np.allclose(reference_activations[outcome], activation_list) -@slow +@pytest.mark.slow def test_activation_matrix_large(): """ Test with a lot of data. Better run only with at least 12GB free RAM. diff --git a/tox.ini b/tox.ini index 54566e3..aa77892 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py{35,36}-test, checkstyle, documentation +envlist = py{35,36}-test, checkstyle, checktypes, documentation [testenv] usedevelop = True @@ -52,7 +52,6 @@ deps = mypy setenv = MYPYPATH=./stubs/ commands = mypy --ignore-missing-imports pyndl -ignore_outcome = True [testenv:documentation] usedevelop = True