forked from facebookresearch/AugLy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomposition.py
132 lines (100 loc) · 4.24 KB
/
composition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import random
from typing import Any, Dict, List, Optional, Union
from augly.text.transforms import BaseTransform
"""
Composition Operators:
Compose: identical to the Compose object provided by the torchvision
library, this class provides a similar experience for applying multiple
transformations onto text
OneOf: the OneOf operator takes as input a list of transforms and
may apply (with probability p) one of the transforms in the list.
If a transform is applied, it is selected using the specified
probabilities of the individual transforms.
Example:
>>> Compose([
>>> InsertPunctuationChars(),
>>> ReplaceFunFonts(),
>>> OneOf([
>>> ReplaceSimilarChars(),
>>> SimulateTypos(),
>>> ]),
>>> ])
"""
class BaseComposition(BaseTransform):
def __init__(self, transforms: List[BaseTransform], p: float = 1.0):
"""
@param transforms: a list of transforms
@param p: the probability of the transform being applied; default value is 1.0
"""
for transform in transforms:
assert isinstance(
transform, BaseTransform
), "Expected instances of type 'BaseTransform' for parameter 'transforms'"
super().__init__(p)
self.transforms = transforms
class Compose(BaseComposition):
def __call__(
self,
texts: Union[str, List[str]],
seed: Optional[int] = None,
metadata: Optional[List[Dict[str, Any]]] = None,
) -> Union[str, List[str]]:
"""
Applies the list of transforms in order to the text
@param texts: a string or a list of text documents to be augmented
@param seed: if provided, the random seed will be set to this before calling
the transform
@param metadata: if set to be a list, metadata about the function execution
including its name, the source & dest length, etc. will be appended to
the inputted list. If set to None, no metadata will be appended or returned
@returns: the list of augmented text documents
"""
if seed is not None:
random.seed(seed)
texts = [texts] if isinstance(texts, str) else texts
for transform in self.transforms:
texts = transform(texts, metadata=metadata)
return texts
class OneOf(BaseComposition):
def __init__(self, transforms: List[BaseTransform], p: float = 1.0):
"""
@param transforms: a list of transforms to select from; one of which will
be chosen to be applied to the text
@param p: the probability of the transform being applied; default value is 1.0
"""
super().__init__(transforms, p)
transform_probs = [t.p for t in transforms]
probs_sum = sum(transform_probs)
self.transform_probs = [t / probs_sum for t in transform_probs]
def __call__(
self,
texts: Union[str, List[str]],
force: bool = False,
seed: Optional[int] = None,
metadata: Optional[List[Dict[str, Any]]] = None,
) -> Union[str, List[str]]:
"""
@param texts: a string or a list of text documents to be augmented
@param force: if set to True, the transform will be applied. Otherwise,
application is determined by the probability set
@param seed: if provided, the random seed will be set to this before calling
the transform
@param metadata: if set to be a list, metadata about the function execution
including its name, the source & dest length, etc. will be appended to
the inputted list. If set to None, no metadata will be appended or returned
@returns: the list of augmented text documents
"""
if seed is not None:
random.seed(seed)
texts = [texts] if isinstance(texts, str) else texts
if random.random() > self.p:
return texts
transform = random.choices(self.transforms, self.transform_probs)[0]
return transform(texts, force=True, metadata=metadata)