37
37
" import torch.optim as optim\n " ,
38
38
" import numpy as np\n " ,
39
39
" import urllib.request\n " ,
40
- " import pandas as pd # noqa\n " ,
40
+ " import pandas as pd # noqa\n " ,
41
41
" from torch.utils.data import Dataset, DataLoader"
42
42
]
43
43
},
61
61
" def get_substrings(in_str):\n " ,
62
62
" # add the stop character to the end of the name, then generate all the partial names\n " ,
63
63
" in_str = in_str + \" +\"\n " ,
64
- " res = [in_str[0: j] for j in range(1, len(in_str) + 1)]\n " ,
64
+ " res = [in_str[0:j] for j in range(1, len(in_str) + 1)]\n " ,
65
65
" return res\n " ,
66
+ " \n " ,
66
67
" pet_names_expanded = [get_substrings(name) for name in pet_names]\n " ,
67
68
" pet_names_expanded = [item for sublist in pet_names_expanded for item in sublist]\n " ,
68
69
" pet_names_characters = [list(name) for name in pet_names_expanded]\n " ,
69
- " pet_names_padded = [name[-(str_len + 1):] for name in pet_names_characters]\n " ,
70
- " pet_names_padded = [list((str_len + 1 - len(characters)) * \" *\" ) + characters for characters in pet_names_padded]\n " ,
70
+ " pet_names_padded = [name[-(str_len + 1) :] for name in pet_names_characters]\n " ,
71
+ " pet_names_padded = [\n " ,
72
+ " list((str_len + 1 - len(characters)) * \" *\" ) + characters for characters in pet_names_padded\n " ,
73
+ " ]\n " ,
71
74
" pet_names_numeric = [[characters.index(char) for char in name] for name in pet_names_padded]\n " ,
72
75
" \n " ,
73
76
" # the final x and y data to use for training the model. Note that the x data needs to be one-hot encoded\n " ,
180
183
" \n " ,
181
184
" for epoch in range(num_epochs):\n " ,
182
185
" # the logger here logs to the dask log of each work, for easy debugging\n " ,
183
- " logger.info(f\" Worker {worker_rank} - {datetime.datetime.now().isoformat()} - Beginning epoch {epoch}\" )\n " ,
186
+ " logger.info(\n " ,
187
+ " f\" Worker {worker_rank} - {datetime.datetime.now().isoformat()} - Beginning epoch {epoch}\"\n " ,
188
+ " )\n " ,
184
189
" \n " ,
185
190
" # this ensures the data is reshuffled each epoch\n " ,
186
191
" sampler.set_epoch(epoch)\n " ,
195
200
" loss.backward()\n " ,
196
201
" optimizer.step()\n " ,
197
202
" \n " ,
198
- " logger.info(f\" Worker {worker_rank} - {datetime.datetime.now().isoformat()} - epoch {epoch} - batch {i} - batch complete - loss {loss.item()}\" )\n " ,
203
+ " logger.info(\n " ,
204
+ " f\" Worker {worker_rank} - {datetime.datetime.now().isoformat()} - epoch {epoch} - batch {i} - batch complete - loss {loss.item()}\"\n " ,
205
+ " )\n " ,
199
206
" \n " ,
200
207
" # the first rh call saves a json file with the loss from the worker at the end of the epoch\n " ,
201
208
" rh.submit_result(\n " ,
202
209
" f\" logs/data_{worker_rank}_{epoch}.json\" ,\n " ,
203
- " json.dumps({'loss': loss.item(),\n " ,
204
- " 'time': datetime.datetime.now().isoformat(),\n " ,
205
- " 'epoch': epoch,\n " ,
206
- " 'worker': worker_rank})\n " ,
210
+ " json.dumps(\n " ,
211
+ " {\n " ,
212
+ " \" loss\" : loss.item(),\n " ,
213
+ " \" time\" : datetime.datetime.now().isoformat(),\n " ,
214
+ " \" epoch\" : epoch,\n " ,
215
+ " \" worker\" : worker_rank,\n " ,
216
+ " }\n " ,
217
+ " ),\n " ,
207
218
" )\n " ,
208
219
" # this saves the model. We only need to do it for one worker (so we picked worker 0)\n " ,
209
220
" if worker_rank == 0:\n " ,
283
294
" def generate_name(model, characters, str_len):\n " ,
284
295
" in_progress_name = []\n " ,
285
296
" next_letter = \"\"\n " ,
286
- " while( not next_letter == \" +\" and len(in_progress_name) < 30) :\n " ,
297
+ " while not next_letter == \" +\" and len(in_progress_name) < 30:\n " ,
287
298
" # prep the data to run in the model again\n " ,
288
299
" in_progress_name_padded = in_progress_name[-str_len:]\n " ,
289
- " in_progress_name_padded = list((str_len - len(in_progress_name_padded)) * \" *\" ) + in_progress_name_padded\n " ,
300
+ " in_progress_name_padded = (\n " ,
301
+ " list((str_len - len(in_progress_name_padded)) * \" *\" ) + in_progress_name_padded\n " ,
302
+ " )\n " ,
290
303
" in_progress_name_numeric = [characters.index(char) for char in in_progress_name_padded]\n " ,
291
304
" in_progress_name_tensor = torch.tensor(in_progress_name_numeric)\n " ,
292
- " in_progress_name_tensor = torch.nn.functional.one_hot(in_progress_name_tensor, num_classes=len(characters)).float()\n " ,
305
+ " in_progress_name_tensor = torch.nn.functional.one_hot(\n " ,
306
+ " in_progress_name_tensor, num_classes=len(characters)\n " ,
307
+ " ).float()\n " ,
293
308
" in_progress_name_tensor = torch.unsqueeze(in_progress_name_tensor, 0)\n " ,
294
309
" \n " ,
295
310
" # get the probabilities of each possible next character by running the model\n " ,
296
311
" with torch.no_grad():\n " ,
297
312
" next_letter_probabilities = model(in_progress_name_tensor)\n " ,
298
313
" \n " ,
299
314
" next_letter_probabilities = next_letter_probabilities[0, -1, :]\n " ,
300
- " next_letter_probabilities = torch.nn.functional.softmax(next_letter_probabilities, dim=0).detach().cpu().numpy()\n " ,
315
+ " next_letter_probabilities = (\n " ,
316
+ " torch.nn.functional.softmax(next_letter_probabilities, dim=0).detach().cpu().numpy()\n " ,
317
+ " )\n " ,
301
318
" next_letter_probabilities = next_letter_probabilities[1:]\n " ,
302
- " next_letter_probabilities = [p / sum(next_letter_probabilities) for p in next_letter_probabilities]\n " ,
319
+ " next_letter_probabilities = [\n " ,
320
+ " p / sum(next_letter_probabilities) for p in next_letter_probabilities\n " ,
321
+ " ]\n " ,
303
322
" \n " ,
304
323
" # determine what the actual letter is\n " ,
305
- " next_letter = characters[np.random.choice(len(characters) - 1, p=next_letter_probabilities) + 1]\n " ,
306
- " if(next_letter != \" +\" ):\n " ,
324
+ " next_letter = characters[\n " ,
325
+ " np.random.choice(len(characters) - 1, p=next_letter_probabilities) + 1\n " ,
326
+ " ]\n " ,
327
+ " if next_letter != \" +\" :\n " ,
307
328
" # if the next character isn't stop add the latest generated character to the name and continue\n " ,
308
329
" in_progress_name.append(next_letter)\n " ,
309
330
" # turn the list of characters into a single string\n " ,
390
411
},
391
412
"nbformat" : 4 ,
392
413
"nbformat_minor" : 4
393
- }
414
+ }
0 commit comments