@@ -39,6 +39,16 @@ class AttentionNormalization(str, enum.Enum):
39
39
vector = 'vector'
40
40
softmax = 'softmax'
41
41
42
+ def _freeze (value ):
43
+ """
44
+ Creates field with a mutable default value.
45
+
46
+ Python doesn't like this being done directly. But `Params` are never
47
+ mutated anyway, so this workaround is easiest.
48
+ """
49
+
50
+ return dataclasses .field (default_factory = lambda : value )
51
+
42
52
@dataclasses .dataclass
43
53
class Params :
44
54
"""
@@ -57,14 +67,14 @@ class Params:
57
67
Only considered when using the SWDE dataset for now.
58
68
"""
59
69
60
- label_keys : list [str ] = ('name' , 'price' , 'shortDescription' , 'images' )
70
+ label_keys : list [str ] = ()
61
71
"""
62
72
Set of keys to select from the dataset.
63
73
64
74
Only considered when using the Apify dataset for now.
65
75
"""
66
76
67
- train_website_indices : list [int ] = (0 , 3 , 4 , 5 , 7 )
77
+ train_website_indices : list [int ] = (0 , 1 , 2 , 3 , 4 )
68
78
"""
69
79
Indices of websites to put in the training set.
70
80
@@ -74,23 +84,23 @@ class Params:
74
84
exclude_websites : list [str ] = ()
75
85
"""Website names to exclude from loading."""
76
86
77
- train_subset : Optional [int ] = 2000
87
+ train_subset : Optional [int ] = 100
78
88
"""Number of pages per website to use for training."""
79
89
80
- val_subset : Optional [int ] = 50
90
+ val_subset : Optional [int ] = 5
81
91
"""
82
92
Number of pages per website to use for validation (evaluation after each
83
93
training epoch).
84
94
"""
85
95
86
- test_subset : Optional [int ] = None
96
+ test_subset : Optional [int ] = 250
87
97
"""
88
98
Number of pages per website to use for testing (evaluation of each
89
99
cross-validation run).
90
100
"""
91
101
92
102
# Trainer
93
- epochs : int = 5
103
+ epochs : int = 3
94
104
"""
95
105
Number of epochs (passes over all training samples) to train the model for.
96
106
"""
@@ -103,13 +113,13 @@ class Params:
103
113
restore_num : Optional [int ] = None
104
114
"""Existing version number to restore."""
105
115
106
- batch_size : int = 16
116
+ batch_size : int = 32
107
117
"""Number of samples to have in a mini-batch during training/evaluation."""
108
118
109
- save_every_n_epochs : Optional [int ] = 1
119
+ save_every_n_epochs : Optional [int ] = None
110
120
"""How often a checkpoint should be saved."""
111
121
112
- save_better_val_loss_checkpoint : bool = True
122
+ save_better_val_loss_checkpoint : bool = False
113
123
"""
114
124
Save checkpoint after each epoch when better validation loss is achieved.
115
125
"""
@@ -124,12 +134,12 @@ class Params:
124
134
the storage space is not wasted by saving all checkpoints.
125
135
"""
126
136
127
- log_every_n_steps : int = 10
137
+ log_every_n_steps : int = 100
128
138
"""
129
139
How often to evaluate and write metrics to TensorBoard during training.
130
140
"""
131
141
132
- eval_every_n_steps : Optional [int ] = 50
142
+ eval_every_n_steps : Optional [int ] = None
133
143
"""
134
144
How often to execute full evaluation pass on the validation subset during
135
145
training.
@@ -150,13 +160,13 @@ class Params:
150
160
"""
151
161
152
162
# Sampling
153
- load_visuals : bool = False
163
+ load_visuals : bool = True
154
164
"""
155
165
When loading HTML for pages, also load JSON visuals, parse visual attributes
156
166
and attach them to DOM nodes in memory.
157
167
"""
158
168
159
- classify_only_text_nodes : bool = False
169
+ classify_only_text_nodes : bool = True
160
170
"""Sample only text fragments."""
161
171
162
172
classify_only_variable_nodes : bool = False
@@ -179,10 +189,10 @@ class Params:
179
189
validate_data : bool = True
180
190
"""Validate sampled page DOMs."""
181
191
182
- ignore_invalid_pages : bool = False
192
+ ignore_invalid_pages : bool = True
183
193
"""Of sampled and validated pages, ignore those that are invalid."""
184
194
185
- none_cutoff : Optional [int ] = None
195
+ none_cutoff : Optional [int ] = 30_000
186
196
"""
187
197
From 0 to 100,000. The higher, the more non-target nodes will be sampled.
188
198
"""
@@ -195,16 +205,16 @@ class Params:
195
205
"""Number of friends in the friend cycle."""
196
206
197
207
# Visual neighbors
198
- visual_neighbors : bool = False
208
+ visual_neighbors : bool = True
199
209
"""Use visual neighbors as a feature when classifying nodes."""
200
210
201
- n_neighbors : int = 4
211
+ n_neighbors : int = 10
202
212
"""Number of visual neighbors (the closes ones)."""
203
213
204
214
neighbor_distance : VisualNeighborDistance = VisualNeighborDistance .rect
205
215
"""How to determine the closes visual neighbors."""
206
216
207
- neighbor_normalize : Optional [AttentionNormalization ] = AttentionNormalization .softmax
217
+ neighbor_normalize : Optional [AttentionNormalization ] = AttentionNormalization .vector
208
218
"""
209
219
How to normalize neighbor distances before feeding them to the attention
210
220
module.
@@ -217,16 +227,16 @@ class Params:
217
227
"""
218
228
219
229
# Ancestor chain
220
- ancestor_chain : bool = False
230
+ ancestor_chain : bool = True
221
231
"""Use DOM ancestors as a feature when classifying nodes."""
222
232
223
- n_ancestors : Optional [int ] = 5
233
+ n_ancestors : Optional [int ] = None
224
234
"""`None` to use all ancestors."""
225
235
226
236
ancestor_lstm_out_dim : int = 10
227
237
"""Output dimension of the LSTM aggregating ancestor features."""
228
238
229
- ancestor_lstm_args : Optional [dict [str ]] = None
239
+ ancestor_lstm_args : Optional [dict [str ]] = _freeze ({ 'bidirectional' : True })
230
240
"""
231
241
Additional keyword arguments to the LSTM layer aggregating ancestor
232
242
features.
@@ -243,7 +253,7 @@ class Params:
243
253
"""
244
254
245
255
# Word vectors
246
- tokenizer_family : TokenizerFamily = TokenizerFamily .custom
256
+ tokenizer_family : TokenizerFamily = TokenizerFamily .bert
247
257
"""Which tokenizer to use."""
248
258
249
259
tokenizer_id : str = ''
@@ -266,7 +276,7 @@ class Params:
266
276
"""
267
277
268
278
# HTML attributes
269
- tokenize_node_attrs : list [str ] = ()
279
+ tokenize_node_attrs : list [str ] = ('itemprop' , )
270
280
"""
271
281
DOM attributes to tokenize and use as a feature when classifying nodes and
272
282
also as a feature of each ancestor in the ancestor chain.
@@ -278,7 +288,7 @@ class Params:
278
288
"""Use the DOM attribute feature only for nodes of the ancestor chain."""
279
289
280
290
# LSTM
281
- word_vector_function : Optional [str ] = 'sum '
291
+ word_vector_function : Optional [str ] = 'lstm '
282
292
"""
283
293
How to aggregate word vectors.
284
294
@@ -288,7 +298,7 @@ class Params:
288
298
lstm_dim : int = 100
289
299
"""Output dimension of the LSTM aggregating word vectors."""
290
300
291
- lstm_args : Optional [dict [str ]] = None
301
+ lstm_args : Optional [dict [str ]] = _freeze ({ 'bidirectional' : True })
292
302
"""
293
303
Additional keyword arguments to the LSTM layer aggregating word vectors.
294
304
"""
@@ -313,7 +323,7 @@ class Params:
313
323
"""
314
324
315
325
# HTML DOM features
316
- tag_name_embedding : bool = False
326
+ tag_name_embedding : bool = True
317
327
"""
318
328
Whether to use HTML tag name as a feature when classifying nodes.
319
329
@@ -323,15 +333,17 @@ class Params:
323
333
tag_name_embedding_dim : int = 30
324
334
"""Dimension of the output vector of HTML tag name embedding."""
325
335
326
- position : bool = False
336
+ position : bool = True
327
337
"""
328
338
Whether to use visual position as a feature when classifying nodes.
329
339
330
340
See `awe.features.dom.Position`.
331
341
"""
332
342
333
343
# Visual features
334
- enabled_visuals : Optional [list [str ]] = None
344
+ enabled_visuals : Optional [list [str ]] = (
345
+ "font_size" , "font_style" , "font_weight" , "font_color"
346
+ )
335
347
"""
336
348
Filter visual attributes to only those in this list.
337
349
@@ -358,10 +370,10 @@ class Params:
358
370
layer_norm : bool = False
359
371
"""Use layer normalization in the classification head."""
360
372
361
- head_dims : list [int ] = (128 , 64 )
373
+ head_dims : list [int ] = (100 , 10 )
362
374
"""Dimensions of feed-forward layers in the classification head."""
363
375
364
- head_dropout : float = 0.5
376
+ head_dropout : float = 0.3
365
377
"""Dropout probability in the classification head."""
366
378
367
379
gradient_clipping : Optional [float ] = None
0 commit comments