Skip to content

Commit 12c8739

Browse files
committed
Update default hyper-parameters
1 parent 707b0e7 commit 12c8739

File tree

1 file changed

+42
-30
lines changed

1 file changed

+42
-30
lines changed

awe/training/params.py

+42-30
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ class AttentionNormalization(str, enum.Enum):
3939
vector = 'vector'
4040
softmax = 'softmax'
4141

42+
def _freeze(value):
43+
"""
44+
Creates field with a mutable default value.
45+
46+
Python doesn't like this being done directly. But `Params` are never
47+
mutated anyway, so this workaround is easiest.
48+
"""
49+
50+
return dataclasses.field(default_factory=lambda: value)
51+
4252
@dataclasses.dataclass
4353
class Params:
4454
"""
@@ -57,14 +67,14 @@ class Params:
5767
Only considered when using the SWDE dataset for now.
5868
"""
5969

60-
label_keys: list[str] = ('name', 'price', 'shortDescription', 'images')
70+
label_keys: list[str] = ()
6171
"""
6272
Set of keys to select from the dataset.
6373
6474
Only considered when using the Apify dataset for now.
6575
"""
6676

67-
train_website_indices: list[int] = (0, 3, 4, 5, 7)
77+
train_website_indices: list[int] = (0, 1, 2, 3, 4)
6878
"""
6979
Indices of websites to put in the training set.
7080
@@ -74,23 +84,23 @@ class Params:
7484
exclude_websites: list[str] = ()
7585
"""Website names to exclude from loading."""
7686

77-
train_subset: Optional[int] = 2000
87+
train_subset: Optional[int] = 100
7888
"""Number of pages per website to use for training."""
7989

80-
val_subset: Optional[int] = 50
90+
val_subset: Optional[int] = 5
8191
"""
8292
Number of pages per website to use for validation (evaluation after each
8393
training epoch).
8494
"""
8595

86-
test_subset: Optional[int] = None
96+
test_subset: Optional[int] = 250
8797
"""
8898
Number of pages per website to use for testing (evaluation of each
8999
cross-validation run).
90100
"""
91101

92102
# Trainer
93-
epochs: int = 5
103+
epochs: int = 3
94104
"""
95105
Number of epochs (passes over all training samples) to train the model for.
96106
"""
@@ -103,13 +113,13 @@ class Params:
103113
restore_num: Optional[int] = None
104114
"""Existing version number to restore."""
105115

106-
batch_size: int = 16
116+
batch_size: int = 32
107117
"""Number of samples to have in a mini-batch during training/evaluation."""
108118

109-
save_every_n_epochs: Optional[int] = 1
119+
save_every_n_epochs: Optional[int] = None
110120
"""How often a checkpoint should be saved."""
111121

112-
save_better_val_loss_checkpoint: bool = True
122+
save_better_val_loss_checkpoint: bool = False
113123
"""
114124
Save checkpoint after each epoch when better validation loss is achieved.
115125
"""
@@ -124,12 +134,12 @@ class Params:
124134
the storage space is not wasted by saving all checkpoints.
125135
"""
126136

127-
log_every_n_steps: int = 10
137+
log_every_n_steps: int = 100
128138
"""
129139
How often to evaluate and write metrics to TensorBoard during training.
130140
"""
131141

132-
eval_every_n_steps: Optional[int] = 50
142+
eval_every_n_steps: Optional[int] = None
133143
"""
134144
How often to execute full evaluation pass on the validation subset during
135145
training.
@@ -150,13 +160,13 @@ class Params:
150160
"""
151161

152162
# Sampling
153-
load_visuals: bool = False
163+
load_visuals: bool = True
154164
"""
155165
When loading HTML for pages, also load JSON visuals, parse visual attributes
156166
and attach them to DOM nodes in memory.
157167
"""
158168

159-
classify_only_text_nodes: bool = False
169+
classify_only_text_nodes: bool = True
160170
"""Sample only text fragments."""
161171

162172
classify_only_variable_nodes: bool = False
@@ -179,10 +189,10 @@ class Params:
179189
validate_data: bool = True
180190
"""Validate sampled page DOMs."""
181191

182-
ignore_invalid_pages: bool = False
192+
ignore_invalid_pages: bool = True
183193
"""Of sampled and validated pages, ignore those that are invalid."""
184194

185-
none_cutoff: Optional[int] = None
195+
none_cutoff: Optional[int] = 30_000
186196
"""
187197
From 0 to 100,000. The higher, the more non-target nodes will be sampled.
188198
"""
@@ -195,16 +205,16 @@ class Params:
195205
"""Number of friends in the friend cycle."""
196206

197207
# Visual neighbors
198-
visual_neighbors: bool = False
208+
visual_neighbors: bool = True
199209
"""Use visual neighbors as a feature when classifying nodes."""
200210

201-
n_neighbors: int = 4
211+
n_neighbors: int = 10
202212
"""Number of visual neighbors (the closes ones)."""
203213

204214
neighbor_distance: VisualNeighborDistance = VisualNeighborDistance.rect
205215
"""How to determine the closes visual neighbors."""
206216

207-
neighbor_normalize: Optional[AttentionNormalization] = AttentionNormalization.softmax
217+
neighbor_normalize: Optional[AttentionNormalization] = AttentionNormalization.vector
208218
"""
209219
How to normalize neighbor distances before feeding them to the attention
210220
module.
@@ -217,16 +227,16 @@ class Params:
217227
"""
218228

219229
# Ancestor chain
220-
ancestor_chain: bool = False
230+
ancestor_chain: bool = True
221231
"""Use DOM ancestors as a feature when classifying nodes."""
222232

223-
n_ancestors: Optional[int] = 5
233+
n_ancestors: Optional[int] = None
224234
"""`None` to use all ancestors."""
225235

226236
ancestor_lstm_out_dim: int = 10
227237
"""Output dimension of the LSTM aggregating ancestor features."""
228238

229-
ancestor_lstm_args: Optional[dict[str]] = None
239+
ancestor_lstm_args: Optional[dict[str]] = _freeze({'bidirectional': True})
230240
"""
231241
Additional keyword arguments to the LSTM layer aggregating ancestor
232242
features.
@@ -243,7 +253,7 @@ class Params:
243253
"""
244254

245255
# Word vectors
246-
tokenizer_family: TokenizerFamily = TokenizerFamily.custom
256+
tokenizer_family: TokenizerFamily = TokenizerFamily.bert
247257
"""Which tokenizer to use."""
248258

249259
tokenizer_id: str = ''
@@ -266,7 +276,7 @@ class Params:
266276
"""
267277

268278
# HTML attributes
269-
tokenize_node_attrs: list[str] = ()
279+
tokenize_node_attrs: list[str] = ('itemprop',)
270280
"""
271281
DOM attributes to tokenize and use as a feature when classifying nodes and
272282
also as a feature of each ancestor in the ancestor chain.
@@ -278,7 +288,7 @@ class Params:
278288
"""Use the DOM attribute feature only for nodes of the ancestor chain."""
279289

280290
# LSTM
281-
word_vector_function: Optional[str] = 'sum'
291+
word_vector_function: Optional[str] = 'lstm'
282292
"""
283293
How to aggregate word vectors.
284294
@@ -288,7 +298,7 @@ class Params:
288298
lstm_dim: int = 100
289299
"""Output dimension of the LSTM aggregating word vectors."""
290300

291-
lstm_args: Optional[dict[str]] = None
301+
lstm_args: Optional[dict[str]] = _freeze({'bidirectional': True})
292302
"""
293303
Additional keyword arguments to the LSTM layer aggregating word vectors.
294304
"""
@@ -313,7 +323,7 @@ class Params:
313323
"""
314324

315325
# HTML DOM features
316-
tag_name_embedding: bool = False
326+
tag_name_embedding: bool = True
317327
"""
318328
Whether to use HTML tag name as a feature when classifying nodes.
319329
@@ -323,15 +333,17 @@ class Params:
323333
tag_name_embedding_dim: int = 30
324334
"""Dimension of the output vector of HTML tag name embedding."""
325335

326-
position: bool = False
336+
position: bool = True
327337
"""
328338
Whether to use visual position as a feature when classifying nodes.
329339
330340
See `awe.features.dom.Position`.
331341
"""
332342

333343
# Visual features
334-
enabled_visuals: Optional[list[str]] = None
344+
enabled_visuals: Optional[list[str]] = (
345+
"font_size", "font_style", "font_weight", "font_color"
346+
)
335347
"""
336348
Filter visual attributes to only those in this list.
337349
@@ -358,10 +370,10 @@ class Params:
358370
layer_norm: bool = False
359371
"""Use layer normalization in the classification head."""
360372

361-
head_dims: list[int] = (128, 64)
373+
head_dims: list[int] = (100, 10)
362374
"""Dimensions of feed-forward layers in the classification head."""
363375

364-
head_dropout: float = 0.5
376+
head_dropout: float = 0.3
365377
"""Dropout probability in the classification head."""
366378

367379
gradient_clipping: Optional[float] = None

0 commit comments

Comments
 (0)