|
28 | 28 | "metadata": {}, |
29 | 29 | "outputs": [], |
30 | 30 | "source": [ |
31 | | - "import pandas as pd\n", |
32 | | - "import uuid\n", |
| 31 | + "import uuid # noqa\n", |
33 | 32 | "import datetime\n", |
34 | | - "import pickle\n", |
| 33 | + "import pickle # noqa\n", |
35 | 34 | "import json\n", |
36 | | - "import torch\n", |
37 | | - "import math\n", |
| 35 | + "import torch # noqa\n", |
38 | 36 | "import torch.nn as nn\n", |
39 | 37 | "import torch.optim as optim\n", |
40 | 38 | "import numpy as np\n", |
41 | 39 | "import urllib.request\n", |
42 | | - "import time\n", |
| 40 | + "import pandas as pd # noqa\n", |
43 | 41 | "from torch.utils.data import Dataset, DataLoader" |
44 | 42 | ] |
45 | 43 | }, |
|
65 | 63 | "characters = list(\"*+abcdefghijklmnopqrstuvwxyz-. \")\n", |
66 | 64 | "str_len = 8\n", |
67 | 65 | "\n", |
| 66 | + "\n", |
68 | 67 | "def format_training_data(pet_names, device=None):\n", |
69 | 68 | " def get_substrings(in_str):\n", |
70 | 69 | " # add the stop character to the end of the name, then generate all the partial names\n", |
71 | 70 | " in_str = in_str + \"+\"\n", |
72 | 71 | " res = [in_str[0: j] for j in range(1, len(in_str) + 1)]\n", |
73 | 72 | " return res\n", |
74 | | - " pet_names_expanded = [get_substrings(name) for name in pet_names]\n", |
| 73 | + " pet_names_expanded = [get_substrings(name) for name in pet_names]\n", |
75 | 74 | " pet_names_expanded = [item for sublist in pet_names_expanded for item in sublist]\n", |
76 | 75 | " pet_names_characters = [list(name) for name in pet_names_expanded]\n", |
77 | 76 | " pet_names_padded = [name[-(str_len + 1):] for name in pet_names_characters]\n", |
78 | | - " pet_names_padded = [list((str_len + 1- len(characters)) * \"*\") + characters for characters in pet_names_padded]\n", |
| 77 | + " pet_names_padded = [list((str_len + 1 - len(characters)) * \"*\") + characters for characters in pet_names_padded]\n", |
79 | 78 | " pet_names_numeric = [[characters.index(char) for char in name] for name in pet_names_padded]\n", |
80 | 79 | "\n", |
81 | 80 | " # the final x and y data to use for training the model. Note that the x data needs to be one-hot encoded\n", |
82 | 81 | " if device is None:\n", |
83 | 82 | " y = torch.tensor([name[1:] for name in pet_names_numeric])\n", |
84 | 83 | " x = torch.tensor([name[:-1] for name in pet_names_numeric])\n", |
85 | 84 | " else:\n", |
86 | | - " y = torch.tensor([name[1:] for name in pet_names_numeric], device = device)\n", |
87 | | - " x = torch.tensor([name[:-1] for name in pet_names_numeric], device = device)\n", |
88 | | - " x = torch.nn.functional.one_hot(x, num_classes = len(characters)).float()\n", |
| 85 | + " y = torch.tensor([name[1:] for name in pet_names_numeric], device=device)\n", |
| 86 | + " x = torch.tensor([name[:-1] for name in pet_names_numeric], device=device)\n", |
| 87 | + " x = torch.nn.functional.one_hot(x, num_classes=len(characters)).float()\n", |
89 | 88 | " return x, y\n", |
90 | 89 | "\n", |
| 90 | + "\n", |
91 | 91 | "class OurDataset(Dataset):\n", |
92 | 92 | " def __init__(self, pet_names, device=None):\n", |
93 | 93 | " self.x, self.y = format_training_data(pet_names, device)\n", |
94 | 94 | " self.permute()\n", |
95 | | - " \n", |
| 95 | + "\n", |
96 | 96 | " def __getitem__(self, idx):\n", |
97 | 97 | " idx = self.permutation[idx]\n", |
98 | 98 | " return self.x[idx], self.y[idx]\n", |
99 | | - " \n", |
| 99 | + "\n", |
100 | 100 | " def __len__(self):\n", |
101 | 101 | " return len(self.x)\n", |
102 | | - " \n", |
| 102 | + "\n", |
103 | 103 | " def permute(self):\n", |
104 | 104 | " self.permutation = torch.randperm(len(self.x))\n", |
105 | 105 | "\n", |
| 106 | + "\n", |
106 | 107 | "class Model(nn.Module):\n", |
107 | 108 | " def __init__(self):\n", |
108 | 109 | " super(Model, self).__init__()\n", |
|
115 | 116 | " dropout=0.1,\n", |
116 | 117 | " )\n", |
117 | 118 | " self.fc = nn.Linear(self.lstm_size, len(characters))\n", |
| 119 | + "\n", |
118 | 120 | " def forward(self, x):\n", |
119 | 121 | " output, state = self.lstm(x)\n", |
120 | 122 | " logits = self.fc(output)\n", |
|
138 | 140 | "def train():\n", |
139 | 141 | " device = torch.device(0)\n", |
140 | 142 | "\n", |
141 | | - " dataset = OurDataset(pet_names, device = device)\n", |
142 | | - " loader = DataLoader(dataset, batch_size=batch_size,shuffle=True, num_workers=0)\n", |
143 | | - " \n", |
| 143 | + " dataset = OurDataset(pet_names, device=device)\n", |
| 144 | + " loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n", |
| 145 | + "\n", |
144 | 146 | " model = Model()\n", |
145 | 147 | " model = model.to(device)\n", |
146 | | - " \n", |
| 148 | + "\n", |
147 | 149 | " criterion = nn.CrossEntropyLoss()\n", |
148 | 150 | " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", |
149 | | - " \n", |
| 151 | + "\n", |
150 | 152 | " for epoch in range(num_epochs):\n", |
151 | 153 | " dataset.permute()\n", |
152 | 154 | " for i, (batch_x, batch_y) in enumerate(loader):\n", |
153 | 155 | " optimizer.zero_grad()\n", |
154 | 156 | " batch_y_pred = model(batch_x)\n", |
155 | | - " \n", |
| 157 | + "\n", |
156 | 158 | " loss = criterion(batch_y_pred.transpose(1, 2), batch_y)\n", |
157 | 159 | " loss.backward()\n", |
158 | 160 | " optimizer.step()\n", |
|
206 | 208 | " in_progress_name_padded = in_progress_name[-str_len:]\n", |
207 | 209 | " in_progress_name_padded = list((str_len - len(in_progress_name_padded)) * \"*\") + in_progress_name_padded\n", |
208 | 210 | " in_progress_name_numeric = [characters.index(char) for char in in_progress_name_padded]\n", |
209 | | - " in_progress_name_tensor = torch.tensor(in_progress_name_numeric, device = device)\n", |
210 | | - " in_progress_name_tensor = torch.nn.functional.one_hot(in_progress_name_tensor, num_classes = len(characters)).float()\n", |
| 211 | + " in_progress_name_tensor = torch.tensor(in_progress_name_numeric, device=device)\n", |
| 212 | + " in_progress_name_tensor = torch.nn.functional.one_hot(in_progress_name_tensor, num_classes=len(characters)).float()\n", |
211 | 213 | " in_progress_name_tensor = torch.unsqueeze(in_progress_name_tensor, 0)\n", |
212 | | - " \n", |
| 214 | + "\n", |
213 | 215 | " # get the probabilities of each possible next character by running the model\n", |
214 | 216 | " with torch.no_grad():\n", |
215 | 217 | " next_letter_probabilities = model(in_progress_name_tensor)\n", |
216 | | - " \n", |
217 | | - " next_letter_probabilities = next_letter_probabilities[0,-1,:]\n", |
| 218 | + "\n", |
| 219 | + " next_letter_probabilities = next_letter_probabilities[0, -1, :]\n", |
218 | 220 | " next_letter_probabilities = torch.nn.functional.softmax(next_letter_probabilities, dim=0).detach().cpu().numpy()\n", |
219 | 221 | " next_letter_probabilities = next_letter_probabilities[1:]\n", |
220 | | - " next_letter_probabilities = [p/sum(next_letter_probabilities) for p in next_letter_probabilities]\n", |
221 | | - " \n", |
| 222 | + " next_letter_probabilities = [p / sum(next_letter_probabilities) for p in next_letter_probabilities]\n", |
| 223 | + "\n", |
222 | 224 | " # determine what the actual letter is\n", |
223 | | - " next_letter = characters[np.random.choice(len(characters)-1, p=next_letter_probabilities) + 1]\n", |
| 225 | + " next_letter = characters[np.random.choice(len(characters) - 1, p=next_letter_probabilities) + 1]\n", |
224 | 226 | " if(next_letter != \"+\"):\n", |
225 | 227 | " # if the next character isn't stop add the latest generated character to the name and continue\n", |
226 | 228 | " in_progress_name.append(next_letter)\n", |
|
243 | 245 | "outputs": [], |
244 | 246 | "source": [ |
245 | 247 | "# Generate 50 names then filter out existing ones\n", |
246 | | - "generated_names = [generate_name(model, characters, str_len) for i in range(0,50)]\n", |
| 248 | + "generated_names = [generate_name(model, characters, str_len) for i in range(0, 50)]\n", |
247 | 249 | "generated_names = [name for name in generated_names if name not in pet_names]\n", |
248 | 250 | "print(generated_names)" |
249 | 251 | ] |
|
0 commit comments