diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..5391d87
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,138 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..33b87bb
--- /dev/null
+++ b/README.md
@@ -0,0 +1,91 @@
+# AutoAttend: Automated Attention Representation Search
+
+_code implementation of paper [AutoAttend: Automated Attention Representation Search](http://proceedings.mlr.press/v139/guan21a.html)._
+
+Authors: [Chaoyu Guan](https://github.com/Frozenmad), Xin Wang, Wenwu Zhu
+
+## Brief Introduction
+
+
+
+data:image/s3,"s3://crabby-images/38599/385996b744919b87be8ef68e6b5ba096269d2a5e" alt=""
+
+
+We design an automated framework searching for the best self-attention models for given tasks. In which, we leverage functional layers to describe models with self-attention mechanisms, and propose context-aware parameter sharing to build and train the supernet, so that it can consider the specialty and functionality of parameters for different functional layers and inputs. More detailed algorithms can be found in our [paper](http://proceedings.mlr.press/v139/guan21a.html).
+
+## Usage
+
+### 1. Prepare the datasets
+
+First, please prepare the [SST dataset](https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip) and [pretrained glove embeddings](https://nlp.stanford.edu/data/glove.840B.300d.zip) under `./data` folder. The organization should be in following format:
+
+```
++ data
+ - glove.840B.300d.txt
+ + sst
+ + trees
+ - train.txt
+ - dev.txt
+ - test.text
+```
+
+### 2. Preprocess the datasets
+
+Run following command to prepare off-the-shelf datasets for speed up
+```
+python -m task.dataset.sst5
+```
+
+### 3. Train the supernet
+
+Train the whole supernet running the following command:
+```
+python -m task.text_classification.train_supernet --epoch x
+```
+You can set the epoch number as you wish. In our paper, it is set to 10.
+
+### 4. Search for the best architectures
+
+Get the best architectures using evolution algorithms running:
+```
+python -m task.text_classification.search_supernet --model_path ./searched/model_epoch_x.full --devices 0 1 2 3 0 1 2 3
+```
+Where x is the epoch number. In our case it is 10. You can set the device number you want to use, and you can pass repeated device number. The code will run in multiprocessing way according to the device number you passed.
+
+### 5. Retrain the architectures
+
+Re-evaluate the searched models using following cmd:
+```
+python -m task.text_classification.retrain --arch "xxxx"
+```
+
+The `xxxx` should be replaced with the architectures searched. The best architecture we find is
+```
+python -m task.text_classification.retrain --arch "[[4, 0, 1, -1, 0], [0, 0, 3, -1, 0], [1, 0, 3, 1, 1], [1, 1, 3, 1, 1], [0, 1, 4, -1, 0], [2, 5, 1, 3, 1], [3, 1, 2, 1, 1], [3, 1, 3, 2, 1], [1, 1, 1, -1, 0], [1, 1, 2, 1, 1], [1, 2, 3, 4, 1], [3, 6, 4, 0, 1], [3, 2, 0, -1, 0], [1, 0, 0, -1, 0], [1, 4, 4, 0, 1], [3, 15, 2, 1, 1], [3, 10, 1, -1, 0], [1, 14, 2, 3, 1], [3, 18, 1, 2, 1], [4, 9, 0, -1, 0], [1, 16, 1, -1, 0], [0, 12, 0, -1, 0], [4, 3, 3, 2, 1], [2, 0, 4, -1, 0]]"
+```
+
+You should derive a mean test accuracy around `0.5371`
+
+## Codes for Graph
+
+The codes for searching graph models will be published at our [AutoGL library](https://github.com/THUMNLab/AutoGL) soon!
+
+## Cite Us
+
+If you find our work helpful, please cite our paper as following:
+
+```
+@InProceedings{guan21autoattend,
+ title = {AutoAttend: Automated Attention Representation Search},
+ author = {Guan, Chaoyu and Wang, Xin and Zhu, Wenwu},
+ booktitle = {Proceedings of the 38th International Conference on Machine Learning},
+ pages = {3864--3874},
+ year = {2021},
+ editor = {Meila, Marina and Zhang, Tong},
+ volume = {139},
+ series = {Proceedings of Machine Learning Research},
+ month = {18--24 Jul},
+ publisher = {PMLR},
+ pdf = {http://proceedings.mlr.press/v139/guan21a/guan21a.pdf},
+ url = {http://proceedings.mlr.press/v139/guan21a.html}
+```
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/model/ops.py b/model/ops.py
new file mode 100644
index 0000000..c0b2379
--- /dev/null
+++ b/model/ops.py
@@ -0,0 +1,386 @@
+'''
+ops interface, borrowed from TextNAS
+'''
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+def get_length(mask):
+ length = torch.sum(mask, 1)
+ length = length.long()
+ return length
+
+INF = 1E10
+
+class Mask(nn.Module):
+ def forward(self, seq, mask):
+ # seq: (N, C, L)
+ # mask: (N, L)
+ seq_mask = torch.unsqueeze(mask, 2)
+ seq_mask = torch.transpose(seq_mask.repeat(1, 1, seq.size()[1]), 1, 2)
+ return seq.where(torch.eq(seq_mask, 1), torch.zeros_like(seq))
+
+
+class BatchNorm(nn.Module):
+ def __init__(self, num_features, pre_mask, post_mask, eps=1e-5, decay=0.9, affine=True):
+ super(BatchNorm, self).__init__()
+ self.mask_opt = Mask()
+ self.pre_mask = pre_mask
+ self.post_mask = post_mask
+ self.bn = nn.BatchNorm1d(num_features, eps=eps, momentum=1.0 - decay, affine=affine)
+
+ def forward(self, seq, mask):
+ if self.pre_mask:
+ seq = self.mask_opt(seq, mask)
+ seq = self.bn(seq)
+ if self.post_mask:
+ seq = self.mask_opt(seq, mask)
+ return seq
+
+
+class ConvBN(nn.Module):
+ def __init__(self, kernal_size, in_channels, out_channels, cnn_keep_prob,
+ pre_mask, post_mask, with_bn=True, with_relu=True, with_pre_norm=False):
+ super(ConvBN, self).__init__()
+ self.mask_opt = Mask()
+ self.pre_mask = pre_mask
+ self.post_mask = post_mask
+ self.with_bn = with_bn
+ self.with_relu = with_relu
+ self.with_pre_norm = with_pre_norm
+ self.conv = nn.Conv1d(in_channels, out_channels, kernal_size, 1, bias=True, padding=(kernal_size - 1) // 2)
+ self.dropout = nn.Dropout(p=(1 - cnn_keep_prob))
+
+ if with_bn:
+ self.bn = BatchNorm(out_channels, not post_mask, True)
+
+ if with_relu:
+ self.relu = nn.ReLU()
+
+ if with_pre_norm:
+ self.layerNorm = nn.LayerNorm(in_channels)
+
+ def forward(self, seq, mask):
+ if self.with_pre_norm:
+ seq = self.layerNorm(seq.transpose(1,2)).transpose(1,2)
+ if self.pre_mask:
+ seq = self.mask_opt(seq, mask)
+ seq = self.conv(seq)
+ if self.post_mask:
+ seq = self.mask_opt(seq, mask)
+ if self.with_bn:
+ seq = self.bn(seq, mask)
+ if self.with_relu:
+ seq = self.relu(seq)
+ seq = self.dropout(seq)
+ return seq
+
+
+class AvgPool(nn.Module):
+ def __init__(self, kernal_size, pre_mask, post_mask, with_pre_norm=False, dim=None):
+ super(AvgPool, self).__init__()
+ self.avg_pool = nn.AvgPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
+ self.pre_mask = pre_mask
+ self.post_mask = post_mask
+ self.mask_opt = Mask()
+ self.with_pre_norm = with_pre_norm
+ if self.with_pre_norm:
+ self.layerNorm = nn.LayerNorm(dim)
+
+ def forward(self, seq, mask):
+ if self.with_pre_norm:
+ seq = self.layerNorm(seq.transpose(1,2)).transpose(1,2)
+ if self.pre_mask:
+ seq = self.mask_opt(seq, mask)
+ seq = self.avg_pool(seq)
+ if self.post_mask:
+ seq = self.mask_opt(seq, mask)
+ return seq
+
+
+class MaxPool(nn.Module):
+ def __init__(self, kernal_size, pre_mask, post_mask, with_pre_norm=False, dim=None):
+ super(MaxPool, self).__init__()
+ self.max_pool = nn.MaxPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
+ self.pre_mask = pre_mask
+ self.post_mask = post_mask
+ self.mask_opt = Mask()
+ self.with_pre_norm = with_pre_norm
+ if self.with_pre_norm:
+ self.layerNorm = nn.LayerNorm(dim)
+
+ def forward(self, seq, mask):
+ if self.with_pre_norm:
+ seq = self.layerNorm(seq.transpose(1,2)).transpose(1,2)
+ if self.pre_mask:
+ seq = self.mask_opt(seq, mask)
+ seq = seq.contiguous()
+ seq = self.max_pool(seq)
+ if self.post_mask:
+ seq = self.mask_opt(seq, mask)
+ return seq
+
+class Attention(nn.Module):
+ def __init__(self, num_units, num_heads, keep_prob, is_mask, with_bn=True, with_pre_norm=False):
+ super(Attention, self).__init__()
+ self.num_heads = num_heads
+ self.keep_prob = keep_prob
+
+ self.linear_q = nn.Linear(num_units, num_units, bias=False)
+ self.linear_k = nn.Linear(num_units, num_units, bias=False)
+ self.linear_v = nn.Linear(num_units, num_units, bias=False)
+
+ self.o_net = nn.Linear(num_units, num_units, bias=False)
+
+ self.with_bn = with_bn
+ self.with_pre_norm = with_pre_norm
+ if self.with_bn:
+ self.bn = BatchNorm(num_units, True, is_mask)
+ if self.with_pre_norm:
+ self.layerNorm = nn.LayerNorm(num_units)
+ self.dropout = nn.Dropout(p=1 - self.keep_prob)
+
+ def forward(self, seq, mask):
+ if self.with_pre_norm:
+ seq = self.layerNorm(seq.transpose(1, 2)).transpose(1, 2)
+ in_c = seq.size()[1]
+ seq = torch.transpose(seq, 1, 2) # (N, L, C)
+ queries = seq
+ keys = seq
+ num_heads = self.num_heads
+
+ # T_q = T_k = L
+ Q = self.linear_q(seq) # (N, T_q, C)
+ K = self.linear_k(seq) # (N, T_k, C)
+ V = self.linear_v(seq) # (N, T_k, C)
+
+ # Split and concat
+ Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h)
+ K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
+ V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
+
+ # Multiplication
+ outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k)
+ # Scale
+ outputs = outputs / (K_.size()[-1] ** 0.5)
+ # Key Masking
+ key_masks = mask.repeat(num_heads, 1) # (h*N, T_k)
+ key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k)
+ key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k)
+
+ paddings = torch.ones_like(outputs) * (-INF) # extremely small value
+ outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs)
+
+ query_masks = mask.repeat(num_heads, 1) # (h*N, T_q)
+ query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1)
+ query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k)
+
+ att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k)
+ att_scores = self.dropout(att_scores)
+
+ # Weighted sum
+ x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h)
+ # Restore shape
+ x_outputs = torch.cat(
+ torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0),
+ dim=2) # (N, T_q, C)
+
+ # transform for the output
+ x_outputs = self.o_net(x_outputs)
+
+ x = torch.transpose(x_outputs, 1, 2) # (N, C, L)
+ if self.with_bn:
+ x = self.bn(x, mask)
+
+ return x
+
+class Attention_old(nn.Module):
+ def __init__(self, num_units, num_heads, keep_prob, is_mask):
+ super(Attention, self).__init__()
+ self.num_heads = num_heads
+ self.keep_prob = keep_prob
+
+ self.linear_q = nn.Linear(num_units, num_units)
+ self.linear_k = nn.Linear(num_units, num_units)
+ self.linear_v = nn.Linear(num_units, num_units)
+
+ self.bn = BatchNorm(num_units, True, is_mask)
+ self.dropout = nn.Dropout(p=1 - self.keep_prob)
+
+ def forward(self, seq, mask):
+ in_c = seq.size()[1]
+ seq = torch.transpose(seq, 1, 2) # (N, L, C)
+ queries = seq
+ keys = seq
+ num_heads = self.num_heads
+
+ # T_q = T_k = L
+ Q = F.relu(self.linear_q(seq)) # (N, T_q, C)
+ K = F.relu(self.linear_k(seq)) # (N, T_k, C)
+ V = F.relu(self.linear_v(seq)) # (N, T_k, C)
+
+ # Split and concat
+ Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h)
+ K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
+ V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
+
+ # Multiplication
+ outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k)
+ # Scale
+ outputs = outputs / (K_.size()[-1] ** 0.5)
+ # Key Masking
+ key_masks = mask.repeat(num_heads, 1) # (h*N, T_k)
+ key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k)
+ key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k)
+
+ paddings = torch.ones_like(outputs) * (-INF) # extremely small value
+ outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs)
+
+ query_masks = mask.repeat(num_heads, 1) # (h*N, T_q)
+ query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1)
+ query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k)
+
+ att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k)
+ att_scores = self.dropout(att_scores)
+
+ # Weighted sum
+ x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h)
+ # Restore shape
+ x_outputs = torch.cat(
+ torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0),
+ dim=2) # (N, T_q, C)
+
+ x = torch.transpose(x_outputs, 1, 2) # (N, C, L)
+ x = self.bn(x, mask)
+
+ return x
+
+
+class RNN(nn.Module):
+ def __init__(self, hidden_size, output_keep_prob, with_pre_norm):
+ super(RNN, self).__init__()
+ self.hidden_size = hidden_size
+ self.bid_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
+ self.output_keep_prob = output_keep_prob
+
+ self.out_dropout = nn.Dropout(p=(1 - self.output_keep_prob))
+
+ self.with_pre_norm = with_pre_norm
+ if self.with_pre_norm:
+ self.layerNorm = nn.LayerNorm(hidden_size)
+
+ def forward(self, seq, mask):
+ # seq: (N, C, L)
+ # mask: (N, L)
+ if self.with_pre_norm:
+ seq = self.layerNorm(seq.transpose(1, 2)).transpose(1, 2)
+ max_len = seq.size()[2]
+ length = get_length(mask).cpu()
+ seq = torch.transpose(seq, 1, 2) # to (N, L, C)
+ packed_seq = nn.utils.rnn.pack_padded_sequence(seq, length, batch_first=True,
+ enforce_sorted=False)
+ outputs, _ = self.bid_rnn(packed_seq)
+ outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True,
+ total_length=max_len)[0]
+ outputs = outputs.view(-1, max_len, 2, self.hidden_size).sum(2) # (N, L, C)
+ outputs = self.out_dropout(outputs) # output dropout
+ return torch.transpose(outputs, 1, 2) # back to: (N, C, L)
+
+
+class LinearCombine(nn.Module):
+ def __init__(self, layers_num, trainable=True, input_aware=False, word_level=False):
+ super(LinearCombine, self).__init__()
+ self.input_aware = input_aware
+ self.word_level = word_level
+
+ if input_aware:
+ raise NotImplementedError("Input aware is not supported.")
+ self.w = nn.Parameter(torch.full((layers_num, 1, 1, 1), 1.0 / layers_num),
+ requires_grad=trainable)
+
+ def forward(self, seq):
+ nw = F.softmax(self.w, dim=0)
+ seq = torch.mul(seq, nw)
+ seq = torch.sum(seq, dim=0)
+ return seq
+
+def get_param_map(dim = 256, att_head = 8, with_pre_norm = True, with_bn = False, print_func = lambda x: None):
+
+ drop_prob = 0.1
+ def conv_shortcut(kernel_size):
+ return ConvBN(kernel_size, dim, dim, 1 - drop_prob, False, True, with_pre_norm=with_pre_norm, with_bn=with_bn)
+
+ def get_edge_module(index):
+ if 0 == index:
+ return conv_shortcut(1)
+ if 1 == index:
+ return Attention(dim, att_head, 1 - drop_prob, True, with_bn=with_bn, with_pre_norm=with_pre_norm)
+ if 4 == index:
+ return conv_shortcut(3)
+ if 5 == index:
+ return conv_shortcut(5)
+ if 6 == index:
+ return conv_shortcut(7)
+ if 7 == index:
+ return AvgPool(3, False, True, with_pre_norm=with_pre_norm, dim=dim)
+ if 8 == index:
+ return MaxPool(3, False, True, with_pre_norm=with_pre_norm, dim=dim)
+ if 9 == index:
+ return RNN(dim, 1 - drop_prob, with_pre_norm=with_pre_norm)
+
+ param = {}
+
+ a = get_edge_module(0)
+ print_func("conv - 1: param: %d" % (sum([x.nelement() for x in a.parameters()])))
+ param[0] = sum([x.nelement() for x in a.parameters()])
+ a = get_edge_module(4)
+ print_func("conv - 3: param: %d" % (sum([x.nelement() for x in a.parameters()])))
+ param[4] = sum([x.nelement() for x in a.parameters()])
+ a = get_edge_module(5)
+ print_func("conv - 5: param: %d" % (sum([x.nelement() for x in a.parameters()])))
+ param[5] = sum([x.nelement() for x in a.parameters()])
+ a = get_edge_module(6)
+ print_func("conv - 7: param: %d" % (sum([x.nelement() for x in a.parameters()])))
+ param[6] = sum([x.nelement() for x in a.parameters()])
+ a = get_edge_module(1)
+ print_func("attn : param: %d" % (sum([x.nelement() for x in a.parameters()])))
+ param[1] = sum([x.nelement() for x in a.parameters()])
+ a = get_edge_module(9)
+ print_func("rnn : param: %d" % (sum([x.nelement() for x in a.parameters()])))
+ param[9] = sum([x.nelement() for x in a.parameters()])
+ a = get_edge_module(7)
+ print_func("avg pool: param: %d" % (sum([x.nelement() for x in a.parameters()])))
+ param[7] = sum([x.nelement() for x in a.parameters()])
+ a = get_edge_module(8)
+ print_func("max pool: param: %d" % (sum([x.nelement() for x in a.parameters()])))
+ param[8] = sum([x.nelement() for x in a.parameters()])
+
+ return param
+
+class Zero(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.dummy = nn.Parameter(torch.zeros(1), requires_grad=False)
+
+ def forward(self, x, *args, **kwargs):
+ return self.dummy * x
+
+class Identity(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.dummy = nn.Parameter(torch.ones(1), requires_grad=False)
+
+ def forward(self, x, *args, **kwargs):
+ return self.dummy * x
+
+OPS = {
+ 'ZERO': lambda dim, dropout, act=None, norm=None, pre=True: Zero(),
+ 'IDEN': lambda dim, dropout, act=None, norm=None, pre=True: Identity(),
+ 'CONV1': lambda dim, dropout, act=None, norm=None, pre=True: ConvBN(1, dim, dim, 1 - dropout, False, True, False, act==nn.ReLU, norm=='ln'),
+ 'CONV3': lambda dim, dropout, act=None, norm=None, pre=True: ConvBN(3, dim, dim, 1 - dropout, False, True, False, act==nn.ReLU, norm=='ln'),
+ 'MAX': lambda dim, dropout, act=None, norm=None, pre=True: MaxPool(3, False, True, norm == 'ln', dim),
+ 'GRU': lambda dim, dropout, act=None, norm=None, pre=True: RNN(dim, 1 - dropout, norm=='ln')
+}
+
+PRIMITIVES = list(OPS.keys())
diff --git a/model/supernet.py b/model/supernet.py
new file mode 100644
index 0000000..6461163
--- /dev/null
+++ b/model/supernet.py
@@ -0,0 +1,365 @@
+'''
+supernet for sentence encoder
+'''
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple
+from .ops import PRIMITIVES, OPS, ConvBN
+
+class MultiheadAttention(nn.Module):
+
+ def __init__(self, embed_dim, num_heads, dropout=0., add_zero_attn=False,
+ batch_first=False, device=None, dtype=None) -> None:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.kdim = embed_dim
+ self.vdim = embed_dim
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.batch_first = batch_first
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+
+ I = torch.eye(embed_dim)
+ I3 = torch.cat([torch.eye(embed_dim), torch.eye(embed_dim), torch.eye(embed_dim)], dim=0)
+
+ self.in_proj_weight = nn.Parameter(I3, requires_grad=False)
+ self.register_parameter('q_proj_weight', None)
+ self.register_parameter('k_proj_weight', None)
+ self.register_parameter('v_proj_weight', None)
+
+ self.register_parameter('in_proj_bias', None)
+ self.out_proj_weight = nn.Parameter(I, requires_grad=False)
+ self.out_proj_bias = nn.Parameter(torch.zeros(embed_dim), requires_grad=False)
+
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ def __setstate__(self, state):
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
+ if '_qkv_same_embed_dim' not in state:
+ state['_qkv_same_embed_dim'] = True
+
+ super(MultiheadAttention, self).__setstate__(state)
+
+ def forward(self, query, key, value, key_padding_mask = None, attn_mask = None):
+ # B D S -> S B D
+ query, key, value = [x.permute(2, 0, 1) for x in (query, key, value)]
+
+ attn_output, _ = F.multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj_weight, self.out_proj_bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=False,
+ attn_mask=attn_mask)
+ return attn_output.permute(1, 2, 0)
+
+def attention_origin(query, key, value, mask, head=8):
+ b, d, s = key.size()
+ d_h = d // head
+
+ # query, key, value. B S D
+ x1 = query.reshape(b * head, d_h, s)
+ x2 = key.reshape(b * head, d_h, s)
+ x3 = value.reshape(b * head, d_h, s)
+ # => B*H D_H S
+
+ attn = x2.permute(0, 2, 1) @ x1 # s_k x s_q
+ attn = attn / np.sqrt(d_h)
+ attn = attn.masked_fill(~mask[:,None,:,None].repeat(1, head, 1, 1).reshape(b * head, s, 1), -np.inf)
+ # attn = F.dropout(attn, p=self.dropout, training=self.training)
+ attn = F.softmax(attn, dim=1)
+
+ x = x3 @ attn
+ x = x.reshape(b, d, s)
+ return x
+
+class AggLayer(nn.Module):
+ def __init__(self, dim, head, dropout=0.0, norm='ln', aug_dropouts=[0.0, 0.1]):
+ super().__init__()
+ self.dim = dim
+ self.head = head
+ self.dropout = dropout
+ self.aug_dropouts = aug_dropouts
+ self.norm = norm
+ if self.norm == 'ln':
+ self.ln = nn.LayerNorm(self.dim)
+ elif self.norm == 'bn':
+ self.bn = nn.BatchNorm1d(self.dim)
+ # self.ln = nn.LayerNorm(self.dim)
+ self.attention = MultiheadAttention(dim, head, dropout=self.aug_dropouts[0])
+
+ def forward(self, x1, x2, x3, mask, type=0):
+ if type == 0:
+ x = x1 + x2
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ else:
+ x = self.attention(x1, x2, x3, ~mask)
+ '''
+ b, d, s = x1.size()
+ d_h = d // self.head
+
+ # query, key, value. B S D
+ x1 = x1.reshape(b * self.head, d_h, s)
+ x2 = x2.reshape(b * self.head, d_h, s)
+ x3 = x3.reshape(b * self.head, d_h, s)
+ # => B*H D_H S
+
+ attn = x2.permute(0, 2, 1) @ x1 # s_k x s_q
+ attn = attn / np.sqrt(d_h)
+ attn = attn.masked_fill(~mask[:,None,:,None].repeat(1, self.head, 1, 1).reshape(b * self.head, s, 1), -np.inf)
+ # attn = F.dropout(attn, p=self.dropout, training=self.training)
+ attn = F.softmax(attn, dim=1)
+
+ x = x3 @ attn
+ x = x.reshape(b, d, s)
+ '''
+ x = F.dropout(x, p=self.aug_dropouts[1], training=self.training)
+
+ if self.norm == 'bn':
+ x = self.bn(x)
+ elif self.norm == 'ln':
+ x = self.ln(x.permute(0, 2, 1)).permute(0, 2, 1)
+
+ #x = x.permute(0, 2, 1)
+ #x = self.ln(x)
+ #x = x.permute(0, 2, 1)
+ return x
+
+class Encoder(nn.Module):
+ def __init__(self, dim, head, layer, edgeops, nodeops, arch=None, dropout=0.1, context='fc', act='nn.ReLU', norm='ln', pre=True, aug_dropouts=[0.0, 0.1]):
+ super().__init__()
+ self.arch = arch
+ self.edgeops = edgeops
+ self.nodeops = nodeops
+ self.layer = layer if arch is None else len(arch)
+ self.dim = dim
+ self.head = head
+ self.dropout = dropout
+ self.context = context
+
+ op_map = {}
+
+ def O(idx):
+ return OPS[PRIMITIVES[idx]](self.dim, self.dropout, act=eval(act), norm=norm, pre=pre)
+
+ if arch is None:
+ # include all edges according to contexts
+ for i in range(1, layer + 1):
+ for j in range(i):
+ for op in edgeops:
+ for ftype in range(2):
+ for ttype in range(4):
+ op_map[self._get_path_name(j, i, op, ftype, ttype)] = O(op)
+ # agg layer
+ op_map[f'layer-{i}'] = AggLayer(self.dim, self.head, self.dropout, norm, aug_dropouts=aug_dropouts)
+
+ else:
+ for i, a in enumerate(arch):
+ o1, node, o2, o3, n = a
+ cur_id = i + 1
+ ftype_prev = None if node <= 0 else arch[node - 1][-1]
+ ftype_i = None if i == 0 else arch[i - 1][-1]
+ if n == 0:
+ op_map[self._get_path_name(i, cur_id, o1, ftype_i, 0)] = O(o1)
+ if node >= 0:
+ op_map[self._get_path_name(node, cur_id, o2, ftype_prev, 0)] = O(o2)
+ else:
+ op_map[self._get_path_name(i, cur_id, o1, ftype_i, 1)] = O(o1)
+ op_map[self._get_path_name(node, cur_id, o2, ftype_prev, 2)] = O(o2)
+ op_map[self._get_path_name(node, cur_id, o3, ftype_prev, 3)] = O(o3)
+
+ # agg layer
+ op_map[f'layer-{cur_id}'] = AggLayer(self.dim, self.head, self.dropout, norm, aug_dropouts=aug_dropouts)
+
+ self.op_map = nn.ModuleDict(op_map)
+
+ def _get_path_name(self, fid, tid, op, ftype=None, ttype=None):
+ if self.context == 'fc':
+ if fid == 0:
+ return f'0-{tid}-{op}-{ttype}'
+ return f'{fid}-{tid}-{op}-{ftype}-{ttype}'
+ elif self.context == 'tc':
+ return f'{fid}-{tid}-{op}-{ttype}'
+ elif self.context == 'sc':
+ if fid == 0:
+ return f'0-{tid}-{op}'
+ return f'{fid}-{tid}-{op}-{ftype}'
+ return f'{fid}-{tid}-{op}'
+
+ def get_path_parameters_name(self, arch):
+ key_set = set()
+ for i, a in enumerate(arch):
+ o1, prev, o2, o3, n = a
+ f_prev = None if prev <= 0 else arch[prev-1][-1]
+ f_i = None if i == 0 else arch[i - 1][-1]
+ if n == 0:
+ key_set.add(self._get_path_name(i, i + 1, o1, f_i, 0))
+ if prev >= 0:
+ key_set.add(self._get_path_name(prev, i + 1, o2, f_prev, 0))
+ else:
+ key_set.add(self._get_path_name(i, i + 1, o1, f_i, 1))
+ key_set.add(self._get_path_name(prev, i + 1, o2, f_prev, 2))
+ key_set.add(self._get_path_name(prev, i + 1, o3, f_prev, 3))
+ return key_set
+
+ def get_parameters_by_name(self, names):
+ param_list = []
+ for name in names:
+ param_list.extend(list(self.op_map[name].parameters()))
+ return param_list
+
+ def get_path_parameters(self, arch):
+ param_list = []
+ key_list = []
+ def add_param(path):
+ if path not in key_list:
+ key_list.append(path)
+ param_list.extend(list(self.op_map[path].parameters()))
+
+ for i, a in enumerate(arch):
+ o1, prev, o2, o3, n = a
+ ftype = None if prev <= 0 else arch[prev - 1][-1]
+ ftype_prev = None if prev <= 0 else arch[prev - 1][-1]
+ ftype_i = None if i == 0 else arch[i - 1][-1]
+ if n == 0:
+ add_param(self._get_path_name(i, i + 1, o1, ftype_i, 0))
+ if prev >= 0:
+ add_param(self._get_path_name(prev, i + 1, o2, ftype_prev, 0))
+ elif n == 1:
+ add_param(self._get_path_name(i, i + 1, o1, ftype_i, 1))
+ add_param(self._get_path_name(prev, i + 1, o2, ftype_prev, 2))
+ add_param(self._get_path_name(prev, i + 1, o3, ftype_prev, 3))
+
+ return param_list
+
+ def forward(self, x, mask, arch=None):
+
+ if arch is None or self.arch is not None:
+ arch = self.arch
+
+ x_list = [x] + [torch.zeros_like(x) for _ in range(len(arch))]
+
+ for i, a in enumerate(arch):
+ cur_idx = i + 1
+ o1, prev, o2, o3, n = a
+ ftype_prev = None if prev <= 0 else arch[prev - 1][-1]
+ ftype_i = None if i == 0 else arch[i - 1][-1]
+ if n == 0:
+ inp1 = x_list[i]
+ # inp1 = F.layer_norm(inp1, self.dim)
+ # inp1 = F.dropout(inp1, p=self.dropout, training=self.training)
+ feat1 = self.op_map[self._get_path_name(i, i + 1, o1, ftype_i, 0)](inp1, mask=mask)
+ if prev >= 0:
+ inp2 = x_list[prev]
+ # inp2 = F.layer_norm(inp2, self.dim)
+ # inp2 = F.dropout(inp2, p=self.dropout, training=self.training)
+ feat2 = self.op_map[self._get_path_name(prev, i + 1, o2, ftype_prev, 0)](inp2, mask=mask)
+ else:
+ feat2 = 0
+ feat3 = 0
+
+ else:
+ inp1 = x_list[i]
+ # inp1 = F.layer_norm(inp1, self.dim)
+ # inp1 = F.dropout(inp1, p=self.dropout, training=self.training)
+ inp2 = x_list[prev]
+ # inp2 = F.layer_norm(inp2, self.dim)
+ # inp2 = F.dropout(inp2, p=self.dropout, training=self.training)
+ inp3 = x_list[prev]
+ # inp3 = F.layer_norm(inp3, self.dim)
+ # inp3 = F.dropout(inp3, p=self.dropout, training=self.training)
+ feat1 = self.op_map[self._get_path_name(i, i + 1, o1, ftype_i, 1)](inp1, mask=mask)
+ feat2 = self.op_map[self._get_path_name(prev, i + 1, o2, ftype_prev, 2)](inp2, mask=mask)
+ feat3 = self.op_map[self._get_path_name(prev, i + 1, o3, ftype_prev, 3)](inp3, mask=mask)
+
+ x = self.op_map[f'layer-{cur_idx}'](feat1, feat2, feat3, mask, n)
+
+ x_list[cur_idx] = x
+
+ out = x_list[-1]
+ # out = F.layer_norm(out, self.dim)
+ # out = F.dropout(out, p=self.dropout, training=self.training)
+ return out
+
+class TextClassifier(nn.Module):
+ def __init__(self, embeddings, dim, head, nclass, layer=None, edgeops=None, nodeops=None, arch=None, dropout=0.0, context='fc', act='nn.ReLU', norm='ln', pre=True, freeze=True, pad_idx=0, aug_dropouts=[0.0, 0.1]) -> None:
+ super().__init__()
+ self.emb = nn.Embedding.from_pretrained(embeddings, padding_idx=pad_idx, freeze=freeze)
+ # self.stem = nn.Linear(embeddings.size(1), dim)
+ self.stem = ConvBN(1, embeddings.size(1), dim, 1 - dropout, False, True, False, act=='nn.ReLU', norm=='ln')
+ self.core = Encoder(dim, head, layer, edgeops, nodeops, arch=arch, dropout=dropout, context=context, act=act, norm=norm, pre=pre, aug_dropouts=aug_dropouts)
+ self.classifier = nn.Linear(dim, nclass)
+ nn.init.normal_(self.classifier.weight, 0, 0.02)
+ nn.init.constant_(self.classifier.bias, 0)
+ self.dropout = dropout
+
+ def forward(self, x, mask, arch=None, sliding=False, window_size=64, stride=32):
+ if not sliding:
+ x = self.emb(x)
+ x = x.permute(0, 2, 1)
+ # x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.stem(x, mask=mask)
+ # x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.core(x, mask, arch)
+ # x = F.dropout(x, p=self.dropout, training=self.training)
+ # max-pooling
+ x = x.masked_fill(~mask[:,None,:], -np.inf)
+ x = x.max(dim=2)[0]
+ x = self.classifier(x)
+ else:
+ x = self.emb(x)
+ x = x.permute(0, 2, 1)
+ b, d, s = x.size()
+ maxs = []
+ begin_idx = 0
+ while begin_idx < s:
+ x_sub = x[:,:,begin_idx:begin_idx + window_size]
+ masks = mask[:, begin_idx: begin_idx + window_size]
+ lengths = masks.sum(dim=1)
+ length_mask = lengths > 0
+ x_sub = x_sub[length_mask]
+ x_sub = self.stem(x_sub, mask=masks[length_mask])
+ x_sub = self.core(x_sub, masks[length_mask], arch)
+ x_sub = x_sub.masked_fill(~masks[length_mask,None,:], -np.inf)
+ x_sub = x_sub.max(dim=2)[0]
+ x_ful = torch.zeros(b, x_sub.size(1)).to(x_sub)
+ x_ful[length_mask] = x_sub
+ x_ful[~length_mask] = -np.inf
+ maxs.append(x_ful[:,:,None])
+ begin_idx += stride
+ x = torch.cat(maxs, dim=2).max(dim=2)[0]
+ x = self.classifier(x)
+ return x
+
+ def get_path_parameters(self, arch):
+ return self.core.get_path_parameters(arch)
+
+ def get_path_parameters_name(self, arch):
+ return self.core.get_path_parameters_name(arch)
+
+ def get_parameters_by_name(self, names):
+ return self.core.get_parameters_by_name(names)
+
+ @classmethod
+ def add_args(cls, parser):
+ parser.add_argument('--dim', type=int, default=64, help='dimension of model')
+ parser.add_argument('--head', type=int, default=4, help='attn head number of model')
+ parser.add_argument('--dropout', type=float, default=0.7, help='dropout of model')
+ parser.add_argument('--context', type=str, default='fc', choices=['fc', 'sc', 'tc', 'nc'], help='context constraint')
+ return parser
+
+ @classmethod
+ def build_from_args(cls, embeddings, nclass, args, edgeops=None, nodeops=None, layer=None, arch=None):
+ return cls(embeddings, args.dim, args.head, nclass, layer, edgeops, nodeops, arch, args.dropout, args.context)
diff --git a/model/utils.py b/model/utils.py
new file mode 100644
index 0000000..622ad75
--- /dev/null
+++ b/model/utils.py
@@ -0,0 +1,140 @@
+from copy import deepcopy
+import random
+
+def remove(source, element):
+ while element in source:
+ source.remove(element)
+ return source
+
+def check_valid(arch, primitives):
+
+ zero = primitives.index('ZERO')
+
+ for i, a in enumerate(arch):
+ o1, prev, o2, o3, n = a
+ if prev == -1:
+ return False
+ if n == 1 and (o2 == zero or o3 == zero):
+ return False
+ if o1 == zero:
+ return False
+ if (o2 == -1 or o3 == -1) and n == 1:
+ return False
+
+ return True
+
+def sample_valid_arch(node_num, edgeop, nodeop, primitives):
+ arch = []
+ zero = primitives.index('ZERO')
+ for i in range(node_num):
+ edges = deepcopy(edgeop)
+ idx_pool = list(range(i + 1))
+ n = random.choice(nodeop)
+ prev = random.choice(idx_pool)
+ edge_no_zero = deepcopy(edgeop)
+ remove(edge_no_zero, zero)
+
+ o1 = random.choice(edge_no_zero)
+ if n == 1:
+ o2 = random.choice(edge_no_zero)
+ o3 = random.choice(edge_no_zero)
+ else:
+ o2 = random.choice(edgeop)
+ o3 = -1
+ arch.append([o1, prev, o2, o3, n])
+ if check_valid(arch, primitives):
+ return arch
+ return sample_valid_arch(node_num, edgeop, nodeop, primitives)
+
+def sample_valid_archs(node_num, edgeop, nodeop, number, primitives):
+ total_arch = []
+ while len(total_arch) < number:
+ arch = sample_valid_arch(node_num, edgeop, nodeop, primitives)
+ if arch not in total_arch:
+ total_arch.append(arch)
+ return total_arch
+
+def get_edge_node_op(primitives, space=0):
+ # no att
+ if space == 0:
+ return list(range(len(primitives))), [0]
+ # att
+ elif space == 1:
+ return list(range(len(primitives))), [0, 1]
+
+def reduce(arch):
+ new_arch = deepcopy(arch)
+ for i in range(len(arch)):
+ if new_arch[i][-1] == 0:
+ new_arch[i][-2] = -1
+ return new_arch
+
+def mutate_arch(arch, edgeops, nodeops, primitives ,ratio=0.05):
+ new_arch = deepcopy(arch)
+ zero = primitives.index('ZERO')
+ iden = primitives.index('IDEN')
+ edge_no_zero = deepcopy(edgeops)
+ remove(edge_no_zero, zero)
+
+ for i in range(len(arch)):
+ o1, prev, o2, o3, n = new_arch[i]
+ mutate_all = False
+
+ # mutate node
+ if random.random() < ratio and len(nodeops) > 1:
+ n = 1 - n
+ mutate_all = True
+
+ # mutate prev
+ if mutate_all or random.random() < ratio:
+ prev = random.choice(list(range(i + 1)))
+
+ # mutate o1
+ if mutate_all or random.random() < ratio:
+ o1 = random.choice([x for x in edgeops if x != zero])
+
+ # mutate o2
+ if mutate_all or random.random() < ratio:
+ o2 = random.choice(edgeops if n == 0 else edge_no_zero)
+
+ # mutate o3
+ if mutate_all or random.random() < ratio:
+ o3 = -1 if n == 0 else random.choice(edge_no_zero)
+
+ new_arch[i] = [o1, prev, o2, o3, n]
+
+ if new_arch == reduce(arch) or not check_valid(new_arch, primitives):
+ return mutate_arch(arch, edgeops, nodeops, primitives, ratio)
+ return new_arch
+
+if __name__ == '__main__':
+ import os
+ import importlib
+ import argparse
+ from tqdm import tqdm
+
+ import torch
+
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--space', type=int, default=1, help='search space')
+ parser.add_argument('--number', type=int, help='number of archs to spawn')
+ parser.add_argument('--unique', action='store_true', help='whether spawned arch can overlap')
+ parser.add_argument('--repeat', type=int, default=1, help='repeat how many time of sampled arch')
+ parser.add_argument('--path', type=str, help='path to save spawned archs')
+ parser.add_argument('--node', type=int, default=2, help='number of layers')
+ parser.add_argument('--primitive', type=str, default='model.sent.arch', help='primitive archs pool')
+
+ args = parser.parse_args()
+ primitives = getattr(importlib.import_module(args.primitive), 'PRIMITIVES')
+
+ edgeop, nodeop = get_edge_node_op(primitives, args.space)
+ if args.unique:
+ arch = sample_valid_archs(args.node, edgeop, nodeop, args.number, primitives)
+ else:
+ arch = [sample_valid_arch(args.node, edgeop, nodeop, primitives) for _ in tqdm(range(args.number))]
+
+ arch *= args.repeat
+
+ # make dirs
+ os.makedirs(os.path.dirname(args.path), exist_ok=True)
+ torch.save(arch, args.path)
diff --git a/resources/workflow.png b/resources/workflow.png
new file mode 100644
index 0000000..3432f46
Binary files /dev/null and b/resources/workflow.png differ
diff --git a/task/__init__.py b/task/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/task/dataset/__init__.py b/task/dataset/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/task/dataset/sst5.py b/task/dataset/sst5.py
new file mode 100644
index 0000000..df05bdc
--- /dev/null
+++ b/task/dataset/sst5.py
@@ -0,0 +1,379 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import logging
+import os
+import pickle
+from collections import Counter
+
+import numpy as np
+import torch
+from torch.utils import data
+
+logger = logging.getLogger()
+
+
+class PTBTree:
+ WORD_TO_WORD_MAPPING = {
+ "{": "-LCB-",
+ "}": "-RCB-"
+ }
+
+ def __init__(self):
+ self.subtrees = []
+ self.word = None
+ self.label = ""
+ self.parent = None
+ self.span = (-1, -1)
+ self.word_vector = None # HOS, store dx1 RNN word vector
+ self.prediction = None # HOS, store Kx1 prediction vector
+
+ def is_leaf(self):
+ return len(self.subtrees) == 0
+
+ def set_by_text(self, text, pos=0, left=0):
+ depth = 0
+ right = left
+ for i in range(pos + 1, len(text)):
+ char = text[i]
+ # update the depth
+ if char == "(":
+ depth += 1
+ if depth == 1:
+ subtree = PTBTree()
+ subtree.parent = self
+ subtree.set_by_text(text, i, right)
+ right = subtree.span[1]
+ self.span = (left, right)
+ self.subtrees.append(subtree)
+ elif char == ")":
+ depth -= 1
+ if len(self.subtrees) == 0:
+ pos = i
+ for j in range(i, 0, -1):
+ if text[j] == " ":
+ pos = j
+ break
+ self.word = text[pos + 1:i]
+ self.span = (left, left + 1)
+
+ # we've reached the end of the category that is the root of this subtree
+ if depth == 0 and char == " " and self.label == "":
+ self.label = text[pos + 1:i]
+ # we've reached the end of the scope for this bracket
+ if depth < 0:
+ break
+
+ # Fix some issues with variation in output, and one error in the treebank
+ # for a word with a punctuation POS
+ self.standardise_node()
+
+ def standardise_node(self):
+ if self.word in self.WORD_TO_WORD_MAPPING:
+ self.word = self.WORD_TO_WORD_MAPPING[self.word]
+
+ def __repr__(self, single_line=True, depth=0):
+ ans = ""
+ if not single_line and depth > 0:
+ ans = "\n" + depth * "\t"
+ ans += "(" + self.label
+ if self.word is not None:
+ ans += " " + self.word
+ for subtree in self.subtrees:
+ if single_line:
+ ans += " "
+ ans += subtree.__repr__(single_line, depth + 1)
+ ans += ")"
+ return ans
+
+
+def read_tree(source):
+ cur_text = []
+ depth = 0
+ while True:
+ line = source.readline()
+ # Check if we are out of input
+ if line == "":
+ return None
+ # strip whitespace and only use if this contains something
+ line = line.strip()
+ if line == "":
+ continue
+ cur_text.append(line)
+ # Update depth
+ for char in line:
+ if char == "(":
+ depth += 1
+ elif char == ")":
+ depth -= 1
+ # At depth 0 we have a complete tree
+ if depth == 0:
+ tree = PTBTree()
+ tree.set_by_text(" ".join(cur_text))
+ return tree
+ return None
+
+
+def read_trees(source, max_sents=-1):
+ with open(source, encoding='utf8') as fp:
+ trees = []
+ while True:
+ tree = read_tree(fp)
+ if tree is None:
+ break
+ trees.append(tree)
+ if len(trees) >= max_sents > 0:
+ break
+ return trees
+
+
+class SSTDataset(data.Dataset):
+ def __init__(self, sents, mask, labels):
+ self.sents = sents
+ self.labels = labels
+ self.mask = mask
+
+ def __getitem__(self, index):
+ return self.sents[index], self.labels[index], self.mask[index]
+
+ def __len__(self):
+ return len(self.sents)
+
+
+def sst_get_id_input(content, word_id_dict, max_input_length):
+ words = content.split(" ")
+ sentence = [word_id_dict[""]] * max_input_length
+ mask = [0] * max_input_length
+ unknown = word_id_dict[""]
+ for i, word in enumerate(words[:max_input_length]):
+ sentence[i] = word_id_dict.get(word, unknown)
+ mask[i] = 1
+ return sentence, mask
+
+
+def sst_get_phrases(trees, sample_ratio=1.0, is_binary=False, only_sentence=False):
+ all_phrases = []
+ for tree in trees:
+ if only_sentence:
+ sentence = get_sentence_by_tree(tree)
+ label = int(tree.label)
+ pair = (sentence, label)
+ all_phrases.append(pair)
+ else:
+ phrases = get_phrases_by_tree(tree)
+ sentence = get_sentence_by_tree(tree)
+ pair = (sentence, int(tree.label))
+ all_phrases.append(pair)
+ all_phrases += phrases
+ if sample_ratio < 1.:
+ np.random.shuffle(all_phrases)
+ result_phrases = []
+ for pair in all_phrases:
+ if is_binary:
+ phrase, label = pair
+ if label <= 1:
+ pair = (phrase, 0)
+ elif label >= 3:
+ pair = (phrase, 1)
+ else:
+ continue
+ if sample_ratio == 1.:
+ result_phrases.append(pair)
+ else:
+ rand_portion = np.random.random()
+ if rand_portion < sample_ratio:
+ result_phrases.append(pair)
+ return result_phrases
+
+
+def get_phrases_by_tree(tree):
+ phrases = []
+ if tree is None:
+ return phrases
+ if tree.is_leaf():
+ pair = (tree.word, int(tree.label))
+ phrases.append(pair)
+ return phrases
+ left_child_phrases = get_phrases_by_tree(tree.subtrees[0])
+ right_child_phrases = get_phrases_by_tree(tree.subtrees[1])
+ phrases.extend(left_child_phrases)
+ phrases.extend(right_child_phrases)
+ sentence = get_sentence_by_tree(tree)
+ pair = (sentence, int(tree.label))
+ phrases.append(pair)
+ return phrases
+
+
+def get_sentence_by_tree(tree):
+ if tree is None:
+ return ""
+ if tree.is_leaf():
+ return tree.word
+ left_sentence = get_sentence_by_tree(tree.subtrees[0])
+ right_sentence = get_sentence_by_tree(tree.subtrees[1])
+ sentence = left_sentence + " " + right_sentence
+ return sentence.strip()
+
+
+def get_word_id_dict(word_num_dict, word_id_dict, min_count):
+ z = [k for k in sorted(word_num_dict.keys())]
+ for word in z:
+ count = word_num_dict[word]
+ if count >= min_count:
+ index = len(word_id_dict)
+ if word not in word_id_dict:
+ word_id_dict[word] = index
+ return word_id_dict
+
+
+def load_word_num_dict(phrases, word_num_dict):
+ for sentence, _ in phrases:
+ words = sentence.split(" ")
+ for cur_word in words:
+ word = cur_word.strip()
+ word_num_dict[word] += 1
+ return word_num_dict
+
+
+def init_trainable_embedding(embedding_path, word_id_dict, embed_dim=300):
+ word_embed_model = load_glove_model(embedding_path, embed_dim)
+ assert word_embed_model["pool"].shape[1] == embed_dim
+ embedding = np.random.random([len(word_id_dict), embed_dim]).astype(np.float32) / 2.0 - 0.25
+ embedding[0] = np.zeros(embed_dim) # PAD
+ embedding[1] = (np.random.rand(embed_dim) - 0.5) / 2 # UNK
+ for word in sorted(word_id_dict.keys()):
+ idx = word_id_dict[word]
+ if idx == 0 or idx == 1:
+ continue
+ if word in word_embed_model["mapping"]:
+ embedding[idx] = word_embed_model["pool"][word_embed_model["mapping"][word]]
+ else:
+ embedding[idx] = np.random.rand(embed_dim) / 2.0 - 0.25
+ return embedding
+
+
+def sst_get_trainable_data(phrases, word_id_dict, max_input_length):
+ texts, labels, mask = [], [], []
+
+ for phrase, label in phrases:
+ if not phrase.split():
+ continue
+ phrase_split, mask_split = sst_get_id_input(phrase, word_id_dict, max_input_length)
+ texts.append(phrase_split)
+ labels.append(int(label))
+ mask.append(mask_split) # field_input is mask
+ labels = np.array(labels, dtype=np.long)
+ texts = np.reshape(texts, [-1, max_input_length]).astype(np.long)
+ mask = np.reshape(mask, [-1, max_input_length]).astype(bool)
+
+ return SSTDataset(texts, mask, labels)
+
+
+def load_glove_model(filename, embed_dim):
+ if os.path.exists(filename + ".cache"):
+ logger.info("Found cache. Loading...")
+ with open(filename + ".cache", "rb") as fp:
+ return pickle.load(fp)
+ embedding = {"mapping": dict(), "pool": []}
+ with open(filename) as f:
+ for i, line in enumerate(f):
+ line = line.rstrip("\n")
+ vocab_word, *vec = line.rsplit(" ", maxsplit=embed_dim)
+ assert len(vec) == 300, "Unexpected line: '%s'" % line
+ embedding["pool"].append(np.array(list(map(float, vec)), dtype=np.float32))
+ embedding["mapping"][vocab_word] = i
+ embedding["pool"] = np.stack(embedding["pool"])
+ with open(filename + ".cache", "wb") as fp:
+ pickle.dump(embedding, fp)
+ return embedding
+
+
+def read_data_sst(data_path, max_input_length=64, min_count=1, train_with_valid=False,
+ train_ratio=1., valid_ratio=1., is_binary=False, only_sentence=False):
+ word_id_dict = dict()
+ word_num_dict = Counter()
+
+ sst_path = os.path.join(data_path, "sst")
+ logger.info("Reading SST data...")
+ train_file_name = os.path.join(sst_path, "trees", "train.txt")
+ valid_file_name = os.path.join(sst_path, "trees", "dev.txt")
+ test_file_name = os.path.join(sst_path, "trees", "test.txt")
+ train_trees = read_trees(train_file_name)
+ train_phrases = sst_get_phrases(train_trees, train_ratio, is_binary, only_sentence)
+ logger.info("Finish load train phrases.")
+ valid_trees = read_trees(valid_file_name)
+ valid_phrases = sst_get_phrases(valid_trees, valid_ratio, is_binary, only_sentence)
+ logger.info("Finish load valid phrases.")
+ if train_with_valid:
+ train_phrases += valid_phrases
+ test_trees = read_trees(test_file_name)
+ test_phrases = sst_get_phrases(test_trees, valid_ratio, is_binary, only_sentence=True)
+ logger.info("Finish load test phrases.")
+
+ # get word_id_dict
+ word_id_dict[""] = 0
+ word_id_dict[""] = 1
+ load_word_num_dict(train_phrases, word_num_dict)
+ logger.info("Finish load train words: %d.", len(word_num_dict))
+ load_word_num_dict(valid_phrases, word_num_dict)
+ load_word_num_dict(test_phrases, word_num_dict)
+ logger.info("Finish load valid+test words: %d.", len(word_num_dict))
+ word_id_dict = get_word_id_dict(word_num_dict, word_id_dict, min_count)
+ logger.info("After trim vocab length: %d.", len(word_id_dict))
+
+ logger.info("Loading embedding...")
+ embedding = init_trainable_embedding(os.path.join(data_path, "glove.840B.300d.txt"), word_id_dict)
+ logger.info("Finish initialize word embedding.")
+
+ dataset_train = sst_get_trainable_data(train_phrases, word_id_dict, max_input_length)
+ logger.info("Loaded %d training samples.", len(dataset_train))
+ dataset_valid = sst_get_trainable_data(valid_phrases, word_id_dict, max_input_length)
+ logger.info("Loaded %d validation samples.", len(dataset_valid))
+ dataset_test = sst_get_trainable_data(test_phrases, word_id_dict, max_input_length)
+ logger.info("Loaded %d test samples.", len(dataset_test))
+
+ return dataset_train, dataset_valid, dataset_test, torch.from_numpy(embedding)
+
+def convert_to_file(data_path, train_with_valid=False, train_ratio=1., valid_ratio=1., is_binary=False, only_sentence=False, folder=None):
+ if folder:
+ os.makedirs(folder, exist_ok=True)
+
+ sst_path = os.path.join(data_path, "sst")
+ logger.info("Reading SST data...")
+ train_file_name = os.path.join(sst_path, "trees", "train.txt")
+ valid_file_name = os.path.join(sst_path, "trees", "dev.txt")
+ test_file_name = os.path.join(sst_path, "trees", "test.txt")
+ train_trees = read_trees(train_file_name)
+ train_phrases = sst_get_phrases(train_trees, train_ratio, is_binary, only_sentence)
+ logger.info("Finish load train phrases.")
+ valid_trees = read_trees(valid_file_name)
+ valid_phrases = sst_get_phrases(valid_trees, valid_ratio, is_binary, only_sentence)
+ logger.info("Finish load valid phrases.")
+ if train_with_valid:
+ train_phrases += valid_phrases
+ test_trees = read_trees(test_file_name)
+ test_phrases = sst_get_phrases(test_trees, valid_ratio, is_binary, only_sentence=True)
+ logger.info("Finish load test phrases.")
+
+ def write_to_file(phrases, path):
+ with open(path, 'w') as f:
+ for sent, label in phrases:
+ if not sent.split(): continue
+ f.write('__label__{} {}\n'.format(label, sent))
+
+ write_to_file(train_phrases, os.path.join(folder or sst_path, 'fast-train.txt'))
+ if not train_with_valid:
+ write_to_file(valid_phrases, os.path.join(folder or sst_path, 'fast-val.txt'))
+ write_to_file(test_phrases, os.path.join(folder or sst_path, 'fast-test.txt'))
+
+if __name__ == '__main__':
+ # generate intervals
+ if os.path.exists('data/sst/train.pt'):
+ print('detect the dataset is already tranformed, please check path: data/sst/train.pt')
+ exit(0)
+ train, val, test, embedding = read_data_sst('data')
+ import torch
+ torch.save(train, 'data/sst/train.pt')
+ torch.save(val, 'data/sst/valid.pt')
+ torch.save(test, 'data/sst/test.pt')
+ torch.save(embedding, 'data/sst/embedding.pt')
diff --git a/task/text_classification/__init__.py b/task/text_classification/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/task/text_classification/retrain.py b/task/text_classification/retrain.py
new file mode 100644
index 0000000..45dbcce
--- /dev/null
+++ b/task/text_classification/retrain.py
@@ -0,0 +1,231 @@
+'''
+retrain models on sst5
+'''
+
+import numpy as np
+import os
+import yaml
+import argparse
+from tqdm import tqdm
+
+import torch
+import torch.optim
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from model.supernet import TextClassifier
+from task.utils import logger, set_seed
+from .search_supernet import eval_model
+
+LOGGER = logger.get_logger('sst5-retrain')
+
+def retrain_arch_mini_batch(dim, head, arch, context, lr, wd, dr, repeat, device, path, epoch, eval_iter=-1, patience=-1, gradient_clip=1.0, dropout_attn=0.0, dropout_aggr=0.1):
+ log = logger.get_logger('sst5-retrain')
+ if path:
+ log.set_path(path.replace('.pt', '.log'))
+ train_dataset = torch.load('data/sst/train.pt')
+ valid_dataset = torch.load('data/sst/valid.pt')
+ test_dataset = torch.load('data/sst/test.pt')
+ embedding = torch.load('data/sst/embedding.pt')
+ train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
+ valid_dataloader = DataLoader(valid_dataset, batch_size=512, shuffle=False)
+ test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=False)
+ vals, tests = [], []
+ for r in range(repeat):
+ model = TextClassifier(embedding, dim, head, 5, arch=arch, dropout=dr, context=context, aug_dropouts=[dropout_attn, dropout_aggr]).to(device)
+ val_acc = []
+ test_acc = []
+ losses = []
+ opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
+ iter = 0
+ for e in range(epoch):
+ with tqdm(train_dataloader) as t:
+ for batch in t:
+ iter += 1
+ model.train()
+ opt.zero_grad()
+ logit = model(batch[0].to(device), batch[2].to(device))
+ loss = F.cross_entropy(logit, batch[1].to(device).long())
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clip)
+ opt.step()
+ losses = (losses + [loss.item()])[-100:]
+ t.set_postfix(loss='%.4f' % (np.mean(losses)), val=0.0 if val_acc == [] else '%.4f' % max(val_acc), test=0.0 if test_acc == [] else '%.4f' % max(test_acc))
+
+ if eval_iter >= 2 and iter % eval_iter == 0:
+ model.eval()
+ val = eval_model(model, None, valid_dataloader, device)
+ test = eval_model(model, None, test_dataloader, device)
+ val_acc.append(val)
+ test_acc.append(test)
+ # judge whether patience run out
+ if patience > 0:
+ idx_max = np.argmax(val_acc)
+ log.info('epoch', e, 'iter', iter, 'val / test', val, test, 'patience', len(val_acc) - idx_max)
+ if len(val_acc) - idx_max > patience:
+ break
+ else:
+ log.info('epoch', e, 'iter', iter, 'val / test', val, test)
+
+
+ if eval_iter <= 0:
+ model.eval()
+ val = eval_model(model, None, valid_dataloader, device)
+ test = eval_model(model, None, test_dataloader, device)
+ val_acc.append(val)
+ test_acc.append(test)
+ # judge whether patience run out
+ if patience > 0:
+ idx_max = np.argmax(val_acc)
+ log.info('epoch', e, 'val / test', val, test, 'patience', len(val_acc) - idx_max)
+ if len(val_acc) - idx_max > patience:
+ break
+ else:
+ log.info('epoch', e, 'val / test', val, test)
+
+ # record this repeat
+ vals.append(val_acc)
+ tests.append(test_acc)
+
+ info = {
+ 'mode': 'mini',
+ 'arch': arch,
+ 'val': vals,
+ 'test': tests,
+ 'config': {
+ 'dim': dim,
+ 'head': head,
+ 'context': context,
+ 'lr': lr,
+ 'wd': wd,
+ 'dr': dr,
+ 'eval_iter': eval_iter,
+ 'patience': patience,
+ 'repeat': repeat,
+ 'device': device
+ }
+ }
+
+ if path:
+ torch.save(info, path)
+
+ return info
+
+def retrain_arch(dim, head, arch, context, lr, wd, dr, repeat, device, path, epoch, eval_iter=-1, patience=-1, gradient_clip=1.0, dropout_attn=0.0, dropout_aggr=0.1):
+ train_dataset = torch.load('data/sst/train.pt')
+ valid_dataset = torch.load('data/sst/valid.pt')
+ test_dataset = torch.load('data/sst/test.pt')
+ embedding = torch.load('data/sst/embedding.pt')
+ train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
+ valid_dataloader = DataLoader(valid_dataset, batch_size=512, shuffle=False)
+ test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=False)
+ vals, tests = [], []
+ for r in range(repeat):
+ model = TextClassifier(embedding, dim, head, 5, arch=arch, dropout=dr, context=context, aug_dropouts=[dropout_attn, dropout_aggr]).to(device)
+ val_acc = []
+ test_acc = []
+ losses = []
+ opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
+ for e in range(epoch):
+ model.train()
+ with tqdm(train_dataloader) as t:
+ for batch in t:
+ opt.zero_grad()
+ logit = model(batch[0].to(device), batch[2].to(device))
+ loss = F.cross_entropy(logit, batch[1].to(device).long())
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clip)
+ opt.step()
+ losses = (losses + [loss.item()])[-100:]
+ t.set_postfix(loss='%.4f' % (np.mean(losses)))
+
+ model.eval()
+ val = eval_model(model, None, valid_dataloader, device)
+ test = eval_model(model, None, test_dataloader, device)
+ val_acc.append(val)
+ test_acc.append(test)
+ print('epoch', e, 'val:', val, 'test:', test)
+
+ # record this repeat
+ vals.append(val_acc)
+ tests.append(test_acc)
+
+ info = {
+ 'arch': arch,
+ 'val': vals,
+ 'test': tests,
+ 'config': {
+ 'dim': dim,
+ 'head': head,
+ 'context': context,
+ 'lr': lr,
+ 'wd': wd,
+ 'dr': dr,
+ 'repeat': repeat,
+ 'device': device
+ }
+ }
+
+ if path:
+ torch.save(info, path)
+
+ return info
+
+if __name__ == '__main__':
+
+ logger.LEVEL = logger.DEBUG
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--dim', type=int, default=256, help='dimension')
+ parser.add_argument('--head', type=int, default=8, help='default attn head number')
+ parser.add_argument('--arch', type=str, help='architectures (str)')
+ parser.add_argument('--context', choices=['sc', 'fc', 'tc', 'nc'], default='fc')
+
+ parser.add_argument('--epoch', type=int, help='total num of epoch to train the supernet', default=10)
+ parser.add_argument('--lr', type=float, help='learning rate', default=5e-4)
+ parser.add_argument('--weight_decay', type=float, help='weight decay', default=0.0)
+ parser.add_argument('--dropout', type=float, help='dropout', default=0.1)
+ parser.add_argument('--repeat', type=int, default=5)
+
+ parser.add_argument('--path', type=str, help='path to save logs & models', default='logs')
+
+ parser.add_argument('--seed', type=int, help='random seed', default=2021)
+ parser.add_argument('--no_mini', action='store_true', help='wether to disable mini-batch mode')
+ parser.add_argument('--patience', type=int, help='patience', default=-1)
+ parser.add_argument('--eval_iter', type=int, help='eval iteration', default=100)
+ parser.add_argument('--device', type=int, nargs='+', default=[0], help='the main progress')
+ parser.add_argument('--gradient_clip', type=float, default=1.0, help='gradient clip')
+ parser.add_argument('--dropout_attn', type=float, default=0.1, help='dropout applied to attention mask')
+ parser.add_argument('--dropout_aggr', type=float, default=0.1, help='dropout applied after attention aggregation')
+
+ args = parser.parse_args()
+ args.mini = not args.no_mini
+
+ set_seed(args.seed)
+
+ if args.path is not None:
+ os.makedirs(args.path, exist_ok=True)
+ LOGGER.set_path(os.path.join(args.path, 'log.txt'))
+
+ LOGGER.info('hyper parameters')
+ for k,v in args.__dict__.items():
+ LOGGER.info('{} - {}'.format(k, v))
+ LOGGER.info('end of hyper parameters')
+
+ input('hyperparameter confirmed')
+ func = retrain_arch_mini_batch if args.mini else retrain_arch
+
+ device = 'cuda:{}'.format(args.device[0])
+ if args.path is None: path_model = None
+ else:
+ path_model = os.path.join(
+ args.path,
+ f'{args.dim}_{args.head}_{args.epoch}_{args.lr}_{args.weight_decay}_{args.dropout}_{args.gradient_clip}_{args.dropout_attn}_{args.dropout_aggr}.pt'
+ )
+
+ result = func(
+ args.dim, args.head, eval(args.arch), args.context, args.lr, args.weight_decay, args.dropout, args.repeat, device,
+ path_model, args.epoch, args.eval_iter, args.patience, args.gradient_clip, args.dropout_attn, args.dropout_aggr,
+ )
+
+ print(result)
+ print('final test', sum(result['test']) / args.repeat)
diff --git a/task/text_classification/search_supernet.py b/task/text_classification/search_supernet.py
new file mode 100644
index 0000000..3c43df0
--- /dev/null
+++ b/task/text_classification/search_supernet.py
@@ -0,0 +1,144 @@
+'''
+search for best models given supernet using evolution algorithms
+'''
+
+import os
+
+import numpy as np
+from tqdm import tqdm
+from queue import Queue
+import threading
+
+import torch
+from torch.utils.data import DataLoader
+
+from model.utils import get_edge_node_op, sample_valid_archs, mutate_arch
+from model.supernet import PRIMITIVES
+from ..utils import run_exps
+
+def wrap_queue(lists):
+ queue = Queue()
+ for ele in lists: queue.put(ele)
+ queue.put(False)
+ return queue
+
+def eval_model(model, arch, loaders, device):
+ model.eval()
+ gt = []
+ pred = []
+ with torch.no_grad():
+ for batch in loaders:
+ logit = model(batch[0].to(device), batch[2].to(device), arch)
+ pred.extend(logit.argmax(1).detach().cpu().tolist())
+ gt.extend(batch[1].cpu().tolist())
+ return (np.array(gt) == np.array(pred)).mean()
+
+
+def eval_param_archs(model_path, archs, device, mask='val', output=None):
+ if mask == 'val':
+ loader = torch.load('data/sst/valid.pt')
+ else:
+ loader = torch.load('data/sst/test.pt')
+ loader = DataLoader(loader, batch_size=512, shuffle=False)
+
+ perf = []
+ # print('load model')
+ model = torch.load(model_path, map_location=device)
+
+ # print('testing valid scores')
+ for a in archs:
+ score = eval_model(model, a, loader, device)
+ perf.append([a, score])
+
+ return {
+ 'perf': perf,
+ 'pid': os.getpid(),
+ }
+
+if __name__ == '__main__':
+ import torch.multiprocessing as mp
+ mp.set_start_method('spawn')
+
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model_path', type=str)
+ parser.add_argument('--init_pop', type=int, default=500)
+ parser.add_argument('--mutate_number', type=int, default=100)
+ parser.add_argument('--mutate_epoch', type=int, default=5)
+ parser.add_argument('--devices', type=int, nargs='+')
+ parser.add_argument('--output', type=str)
+ parser.add_argument('--parallel', action='store_true')
+ parser.add_argument('--layer', type=int, default=24)
+ parser.add_argument('--space', type=int, default=6)
+ parser.add_argument('--chunk', type=int, default=25)
+
+ args = parser.parse_args()
+
+ if args.output:
+ os.makedirs(args.output, exist_ok=True)
+
+ if os.path.isdir(args.model_path):
+ name = os.listdir(args.model_path)
+ name = [os.path.join(args.model_path, n) for n in name if n.startswith('model_') and n.endswith('.full')]
+ else:
+ name = [args.model_path]
+
+ arch2performance = {}
+ devices = Queue()
+ for device in args.devices: devices.put(f'cuda:{device}')
+ rlock = threading.Lock()
+
+ progress = tqdm(total=args.init_pop + args.mutate_epoch * args.mutate_number)
+
+ progress.set_description('initial')
+
+ # init population
+ def process_result(res):
+ if res is None: return
+ os.system(f'kill -9 {res["pid"]}')
+ rlock.acquire()
+ for line in res['perf']:
+ arch2performance[str(line[0])] = line[1]
+ progress.update(1)
+ rlock.release()
+
+ edgeop, nodeop = get_edge_node_op(PRIMITIVES, args.space)
+ archs = sample_valid_archs(args.layer, edgeop, nodeop, args.init_pop, PRIMITIVES)
+
+ archs_passed = [[]]
+ for a in archs:
+ if len(archs_passed[-1]) == args.chunk:
+ archs_passed.append([])
+ archs_passed[-1].append(a)
+
+ run_exps(devices, wrap_queue([{
+ 'func': eval_param_archs,
+ 'kwargs': dict(model_path=args.model_path, archs=a),
+ 'callback': process_result
+ } for a in archs_passed]))
+
+ for i in range(args.mutate_epoch):
+ progress.set_description(f'epoch: {i}')
+
+ # mutate architectures
+ current_archs = list(arch2performance.items())
+ current_archs = sorted(current_archs, key=lambda x:-x[1])
+ mutated = current_archs[:args.mutate_number]
+ arch_new = [[]]
+ for arch in mutated:
+ arch = eval(arch[0])
+ a = mutate_arch(arch, edgeop, nodeop, PRIMITIVES)
+ while str(a) in arch2performance: a = mutate_arch(arch, edgeop, nodeop, PRIMITIVES)
+ if len(arch_new[-1]) == args.chunk: arch_new.append([])
+ arch_new[-1].append(a)
+
+ # run jobs
+ run_exps(devices, wrap_queue([{
+ 'func': eval_param_archs,
+ 'kwargs': dict(model_path=args.model_path, archs=a),
+ 'callback': process_result
+ } for a in arch_new]))
+
+ # derive final lists
+ archs = sorted(list(arch2performance.items()), key=lambda x:-x[1])
+ torch.save([[eval(x[0]), x[1]] for x in archs], os.path.join(args.output, 'performance.dict'))
diff --git a/task/text_classification/train_supernet.py b/task/text_classification/train_supernet.py
new file mode 100644
index 0000000..e188893
--- /dev/null
+++ b/task/text_classification/train_supernet.py
@@ -0,0 +1,107 @@
+'''
+train the supernet
+'''
+
+import numpy as np
+from model.ops import PRIMITIVES
+import os
+import yaml
+import argparse
+from tqdm import tqdm
+
+import torch
+import torch.optim
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from model.utils import get_edge_node_op, sample_valid_archs
+from model.supernet import TextClassifier
+from task.utils import logger, set_seed
+
+LOGGER = logger.get_logger('sst5-search')
+
+if __name__ == '__main__':
+
+ logger.LEVEL = logger.INFO
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--dim', type=int, help='dimension', default=64)
+ parser.add_argument('--head', type=int, help='default attn head number', default=8)
+ parser.add_argument('--layer', type=int, help='total layer number', default=24)
+ parser.add_argument('--space', type=int, help='search space type', default=1)
+ parser.add_argument('--context', choices=['sc', 'fc', 'tc', 'nc'], default='fc')
+ parser.add_argument('--dataset', type=str, help='path to dataset', default='./data')
+
+ parser.add_argument('--arch_batch_size', type=int, help='batch size to train the base', default=16)
+ parser.add_argument('--epoch', type=int, help='total num of epoch to train the supernet', default=10)
+ parser.add_argument('--lr', type=float, help='learning rate', default=5e-4)
+ parser.add_argument('--weight_decay', type=float, help='weight decay', default=0)
+ parser.add_argument('--dropout', type=float, help='dropout', default=0.1)
+ parser.add_argument('--gradient_clip', type=float, help='gradient clip', default=5.0)
+
+ parser.add_argument('--path', type=str, help='path to save logs & models', default='./searched')
+
+ parser.add_argument('--seed', type=int, help='random seed', default=2021)
+ parser.add_argument('--device', type=int, default=0, help='the main progress')
+
+ args = parser.parse_args()
+
+ set_seed(args.seed)
+
+ if args.path is not None:
+ os.makedirs(args.path, exist_ok=True)
+ LOGGER.set_path(os.path.join(args.path, 'log.txt'))
+
+ LOGGER.info('hyper parameters')
+ for k,v in args.__dict__.items():
+ LOGGER.info('{} - {}'.format(k, v))
+ LOGGER.info('end of hyper parameters')
+
+ input('hyperparameter confirmed')
+
+ LOGGER.info('load dataset...')
+ edgeop, nodeop = get_edge_node_op(PRIMITIVES, args.space)
+ train_dataset = torch.load(os.path.join(args.dataset, 'sst/train.pt'))
+ valid_dataset = torch.load(os.path.join(args.dataset, 'sst/valid.pt'))
+ test_dataset = torch.load(os.path.join(args.dataset, 'sst/test.pt'))
+ embedding = torch.load(os.path.join(args.dataset, 'sst/embedding.pt'))
+ train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)
+ device = torch.device('cuda:{}'.format(args.device))
+
+ LOGGER.info('build model...')
+ model = TextClassifier(
+ embeddings=embedding,
+ dim=args.dim,
+ head=args.head,
+ nclass=5,
+ layer=args.layer,
+ edgeops=edgeop,
+ nodeops=nodeop,
+ dropout=args.dropout,
+ context=args.context, aug_dropouts=[0.1, 0.1]).to(device)
+
+ LOGGER.info('train supernet...')
+ model.train()
+ eval_id = 0
+ time_list = []
+ loss_item = []
+
+ opt = torch.optim.Adam(model.parameters(), lr=args.lr / args.arch_batch_size, weight_decay=args.weight_decay)
+
+ for idx in range(args.epoch):
+ t = tqdm(train_dataloader)
+ for batch in t:
+ opt.zero_grad()
+
+ for arch in sample_valid_archs(args.layer, edgeop, nodeop, args.arch_batch_size, PRIMITIVES):
+ logit = model(batch[0].to(device), batch[2].to(device), arch)
+ loss = F.cross_entropy(logit, batch[1].long().to(device))
+ loss.backward()
+ loss_item.append(loss.item())
+ loss_item = loss_item[-100:]
+
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clip)
+
+ opt.step()
+ t.set_postfix(loss=round(np.mean(loss_item), 4))
+ if args.path:
+ torch.save(model, os.path.join(args.path, 'model_epoch_{}.full'.format(idx)))
diff --git a/task/utils/__init__.py b/task/utils/__init__.py
new file mode 100644
index 0000000..ac431fa
--- /dev/null
+++ b/task/utils/__init__.py
@@ -0,0 +1,65 @@
+import torch
+import torch.cuda
+import numpy as np
+import random
+from .logger import get_logger
+import torch.multiprocessing as mp
+import threading
+import time
+
+def run_exps(devices, jobs, block=True, interval=0.1):
+ '''
+ jobs:
+ - func
+ - kwargs
+ - callback
+ '''
+ job_queue = jobs
+ dev_queue = devices
+
+ pool = mp.Pool()
+
+ def gen_callback(dev, calls):
+ def callback(res):
+ dev_queue.put(dev)
+ return calls(res)
+ return callback
+
+ def gen_err_callback(dev, func):
+ def callback(err):
+ dev_queue.put(dev)
+ return func(err)
+ return callback
+
+ def run_job():
+ while True:
+ job = job_queue.get()
+ if job == False:
+ break
+ dev = dev_queue.get()
+ pool.apply_async(
+ job['func'],
+ kwds={**job['kwargs'], 'device': dev},
+ callback=gen_callback(dev, job['callback']),
+ error_callback=gen_err_callback(dev, job['err'] if 'err' in job else lambda x: x)
+ )
+ time.sleep(interval)
+
+ th = threading.Thread(target=run_job)
+ th.daemon = True
+ th.start()
+ if block:
+ th.join()
+ pool.close()
+ pool.join()
+ return th, pool
+
+def set_seed(seed=2020):
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+
+ np.random.seed(seed)
+ random.seed(seed)
+
+__all__ = ['set_seed', 'get_logger', 'run_exps']
diff --git a/task/utils/logger.py b/task/utils/logger.py
new file mode 100644
index 0000000..9823924
--- /dev/null
+++ b/task/utils/logger.py
@@ -0,0 +1,52 @@
+import os
+from time import gmtime, strftime
+
+DEBUG=0
+INFO=1
+WARN=2
+ERROR=3
+
+LEVEL = ERROR
+
+_idx2str = ['D', 'I', 'W', 'E']
+
+get_logger = lambda x:Logger(x)
+
+class Logger():
+ def __init__(self, name='') -> None:
+ self.name = name
+ if self.name != '':
+ self.name = '[' + self.name + ']'
+ self.path = None
+
+ self.debug = self._generate_print_func(DEBUG)
+ self.info = self._generate_print_func(INFO)
+ self.warn = self._generate_print_func(WARN)
+ self.error = self._generate_print_func(ERROR)
+
+ def to_json(self):
+ return {
+ 'path': self.path,
+ 'name': self.name
+ }
+
+ @classmethod
+ def from_json(cls, js):
+ a = cls(js['name'])
+ a.set_path(js['path'])
+ return a
+
+ def set_path(self, path):
+ self.path = path
+
+ def _generate_print_func(self, level=DEBUG):
+ def prin(*args, end='\n'):
+ strs = ' '.join([str(a) for a in args])
+ str_time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
+ if level >= LEVEL:
+ print('[' + _idx2str[level] + '][' + str_time + ']' + self.name, strs, end=end)
+ if self.path is not None:
+ open(self.path, 'a').write(
+ '[' + _idx2str[level] + '][' + str_time + ']' + self.name + strs + end
+ )
+ return prin