Skip to content

Commit

Permalink
factorized the architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
MaloOLIVIER committed Dec 2, 2024
1 parent 88f9df2 commit 3c8a293
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 148 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from hungarian_net.generate_hnet_training_data import load_obj
from generate_hnet_training_data import load_obj


class HungarianDataset(Dataset):
Expand Down Expand Up @@ -101,6 +101,7 @@ def compute_weighted_accuracy(self, n1star, n0star):

return WA

#TODO: factorize HungarianDataset in HungarianDataModule

class HungarianDataModule(LightningDataModule):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.metrics import f1_score
from torch import optim
from torchmetrics import MetricCollection
from hungarian_net.torch_modules.hnet_gru import HNetGRU


class HNetGRULightning(L.LightningModule):
Expand Down Expand Up @@ -238,77 +239,3 @@ def common_logging(

# F1-score
self.log(f"{stage}_f1", f1, on_step=False, on_epoch=True, prog_bar=False)

# Log metrics
# predicted_phonemes = self.get_phonemes_from_logits(outputs["logits"])
# target_phonemes = batch["phonemes_str"]
# metrics_to_log = self.metrics(predicted_phonemes, target_phonemes)
# metrics_to_log = {f"{stage}/{k}": v for k, v in metrics_to_log.items()}

# self.log_dict(dictionary=metrics_to_log, sync_dist=True, prog_bar=True)


class AttentionLayer(nn.Module):
def __init__(self, in_channels, out_channels, key_channels):
super(AttentionLayer, self).__init__()
self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size=1, bias=False)
self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size=1, bias=False)
self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)

def forward(self, x):
Q = self.conv_Q(x)
K = self.conv_K(x)
V = self.conv_V(x)
A = Q.permute(0, 2, 1).matmul(K).softmax(2)
x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
return x

def __repr__(self):
return (
self._get_name()
+ "(in_channels={}, out_channels={}, key_channels={})".format(
self.conv_Q.in_channels,
self.conv_V.out_channels,
self.conv_K.out_channels,
)
)


class HNetGRU(nn.Module):
def __init__(self, max_len=4, hidden_size=128):
super().__init__()
self.nb_gru_layers = 1
self.max_len = max_len
self.gru = nn.GRU(max_len, hidden_size, self.nb_gru_layers, batch_first=True)
self.attn = AttentionLayer(hidden_size, hidden_size, hidden_size)
self.fc1 = nn.Linear(hidden_size, max_len)

def forward(self, query):
# query - batch x seq x feature

out, _ = self.gru(query)
# out - batch x seq x hidden

out = out.permute((0, 2, 1))
# out - batch x hidden x seq

out = self.attn.forward(out)
# out - batch x hidden x seq

out = out.permute((0, 2, 1))
out = torch.tanh(out)
# out - batch x seq x hidden

out = self.fc1(out)
# out - batch x seq x feature

out1 = out.view(out.shape[0], -1)
# out1 - batch x (seq x feature)

out2, _ = torch.max(out, dim=-1)
# out2 - batch x seq x 1

out3, _ = torch.max(out, dim=-2)
# out3 - batch x 1 x feature

return out1.squeeze(), out2.squeeze(), out3.squeeze()
70 changes: 0 additions & 70 deletions hungarian_net/plot_f1_score.py

This file was deleted.

36 changes: 36 additions & 0 deletions hungarian_net/torch_modules/attention_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from functools import partial
from typing import Any, Dict, Tuple

import lightning as L
import torch
import torch.nn as nn
import torchmetrics
from lightning.pytorch.utilities.types import STEP_OUTPUT
from sklearn.metrics import f1_score
from torch import optim
from torchmetrics import MetricCollection

class AttentionLayer(nn.Module):
def __init__(self, in_channels, out_channels, key_channels):
super(AttentionLayer, self).__init__()
self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size=1, bias=False)
self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size=1, bias=False)
self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)

def forward(self, x):
Q = self.conv_Q(x)
K = self.conv_K(x)
V = self.conv_V(x)
A = Q.permute(0, 2, 1).matmul(K).softmax(2)
x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)
return x

def __repr__(self):
return (
self._get_name()
+ "(in_channels={}, out_channels={}, key_channels={})".format(
self.conv_Q.in_channels,
self.conv_V.out_channels,
self.conv_K.out_channels,
)
)
51 changes: 51 additions & 0 deletions hungarian_net/torch_modules/hnet_gru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from functools import partial
from typing import Any, Dict, Tuple

import lightning as L
import torch
import torch.nn as nn
import torchmetrics
from lightning.pytorch.utilities.types import STEP_OUTPUT
from sklearn.metrics import f1_score
from torch import optim
from torchmetrics import MetricCollection
from hungarian_net.torch_modules.attention_layer import AttentionLayer

class HNetGRU(nn.Module):
def __init__(self, max_len=4, hidden_size=128):
super().__init__()
self.nb_gru_layers = 1
self.max_len = max_len
self.gru = nn.GRU(max_len, hidden_size, self.nb_gru_layers, batch_first=True)
self.attn = AttentionLayer(hidden_size, hidden_size, hidden_size)
self.fc1 = nn.Linear(hidden_size, max_len)

def forward(self, query):
# query - batch x seq x feature

out, _ = self.gru(query)
# out - batch x seq x hidden

out = out.permute((0, 2, 1))
# out - batch x hidden x seq

out = self.attn.forward(out)
# out - batch x hidden x seq

out = out.permute((0, 2, 1))
out = torch.tanh(out)
# out - batch x seq x hidden

out = self.fc1(out)
# out - batch x seq x feature

out1 = out.view(out.shape[0], -1)
# out1 - batch x (seq x feature)

out2, _ = torch.max(out, dim=-1)
# out2 - batch x seq x 1

out3, _ = torch.max(out, dim=-2)
# out3 - batch x 1 x feature

return out1.squeeze(), out2.squeeze(), out3.squeeze()
4 changes: 2 additions & 2 deletions hungarian_net/train_hnet.py → run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader

from hungarian_net.dataset import HungarianDataModule, HungarianDataset
from hungarian_net.models import HNetGRULightning
from hungarian_net.lightning_datamodules.hungarian_datamodule import HungarianDataModule, HungarianDataset
from hungarian_net.lightning_modules.hnet_gru_lightning import HNetGRULightning


# @hydra.main(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import matplotlib.pyplot as plot
import torch
from IPython import embed
from train_hnet import HNetGRU, HungarianDataset
from hungarian_net.lightning_datamodules import HungarianDataset
from hungarian_net.torch_modules import HNetGRU

use_cuda = False
max_len = 2
Expand Down

0 comments on commit 3c8a293

Please sign in to comment.