Skip to content

Commit

Permalink
Prompt updates for forward
Browse files Browse the repository at this point in the history
  • Loading branch information
LouisCastricato committed Jan 19, 2022
1 parent 3b60191 commit f701ca8
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 159 deletions.
52 changes: 35 additions & 17 deletions continuations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

def continue_story_svo_customprompt_past_tense(model, story, question, verbose = False, width = 10):
instructions = "The following are results from a children's reasoning test. \nThe child will be read a short story and then answer a question about what happened before the story. \nThe answer cannot contradict the story. \nWe have collected the results below:\n"
story1="Story 1: Jim sat by the swings as Tom slowly approached. Tom gave the book to Jim during recess.\n"
question1="Question: Why did Jim receive the book?\n"
answer1="Answers:\n1. Tom desired to return the book to his friend.\n2. Jim bought the book from Tom.\n3. Tom noticed Jim dropped his book.\n4. Jim needed a book to study for his exam.\n5. Tom wanted to give Jim his favorite book.\n\n"
story2 = "Story 2: Mary was happy to finally cross the street.\n"
question2 = "Question: Why was Mary running?\n"
answer2="Answers:\n1. Mary was running from a monster.\n2. Mary was running a marathon.\n3. Mary wanted to get away from her parents.\n\n"
story1="Story 1: Jane sat by the swings as John slowly approached. John gave the book to Jane during recess.\n"
question1="Question: Why did Jane receive the book?\n"
answer1="Answers:\n1. John desired to return the book to his friend.\n2. Jane bought the book from John.\n3. John noticed Jane dropped his book.\n4. Jane needed a book to study for his exam.\n5. John wanted to give Jane his favorite book.\n\n"
story2 = "Story 2: Jane was happy to finally cross the street.\n"
question2 = "Question: Why was Jane running?\n"
answer2="Answers:\n1. Jane was running from a monster.\n2. Jane was running a marathon.\n3. Jane wanted to get away from her parents.\n\n"
story3 = "Story 3: " + story + "\n"
question3 = "Question: " + question +"\n"
inp = instructions + story1 + question1 + answer1 + story2 + question2 + answer2 + story3 + question3
Expand All @@ -18,7 +18,7 @@ def continue_story_svo_customprompt_past_tense(model, story, question, verbose =

#con = construct(inp, " A.")
#print("construct is ", con, "\n"*2)
out = generate(model, con, max_length=850, horizon=story, horizon_penalty=1.8, beams=5, repetition_penalty=2.8, do_beams=True)
out = generate(model, con, max_length=50, horizon=story, horizon_penalty=1.8, beams=5, repetition_penalty=2.8, do_beams=True)

#return out
out = out.split("Story 3")[1]
Expand Down Expand Up @@ -60,15 +60,23 @@ def continue_story_svo_customprompt_past_tense(model, story, question, verbose =
except:
continue
return responses


continuations_bad_words_future_tense =\
["Why", "Story", "You", "I",
"Think", "Erica", "Answer-", "A-", "A", "1", "2", "3", "4", "Answers",
"\"","\'"]
continuations_bad_words_future_tense = sum(list(map(permute_string, continuations_bad_words_future_tense)), [])
continuations_bad_words_ids_future_tense = list(map(lambda x: tokenizer(x)['input_ids'], continuations_bad_words_future_tense))


def continue_story_svo_customprompt_future_tense(model, story, question, verbose = False, width = 10):
instructions = "The following are results from a children's reasoning test. \nThe child will be read a short story and then answer a question about what happened before the story. \nThe answer cannot contradict the story. \nWe have collected the results below:\n"
story1="Story 1: Jim sat by the swings as Tom slowly approached. Tom gave the book to Jim during recess.\n"
question1="Question: What did Jim do after recieving the book?\n"
answer1="Answers:\n1. Jim jumped for joy!\n2. Jim threw the book away.\n3. Tom noticed that Jim looked unhappy.\n\n"
story2 = "Story 2: Mary was happy to finally cross the street.\n"
question2 = "Question: What happened after Mary crossed the street?\n"
answer2="Answers:\n1. She sighed in relief.\n2. Her smile quickly faded when she remembered she left her wallet at the restaurant.\n3. She looked over her shoulder to make sure she was not being followed.\n\n"
instructions = "Please continue the stories below. \nThe answer cannot contradict the story. \nWe have collected the results below:\n"
story1="Story 1: John looked up from his lap and saw his friend in the distance. Jane sat by the swings as John slowly approached. John gave the book to Jane during recess.\n"
question1="Question: What did Jane do after recieving the book?\n"
answer1="Answers:\n1. Jane jumped for joy!\n2. Jane threw the book away.\n3. John noticed that Jane looked unhappy.\n\n"
story2 = "Story 2: Jane was happy to finally cross the street.\n"
question2 = "Question: What happened after Jane crossed the street?\n"
answer2="Answers:\n1. She sighed in relief and brushed the dust off her pants. \n2. Her smile quickly faded when she remembered she left her wallet at the restaurant.\n3. She looked over her shoulder to make sure she was not being followed.\n\n"
story3 = "Story 3: " + story + "\n"
question3 = "Question: " + question +"\n"
inp = instructions + story1 + question1 + answer1 + story2 + question2 + answer2 + story3 + question3
Expand All @@ -79,7 +87,8 @@ def continue_story_svo_customprompt_future_tense(model, story, question, verbose

#con = construct(inp, " A.")
#print("construct is ", con, "\n"*2)
out = generate(model, con, max_length=850, horizon=story, horizon_penalty=1.8, beams=5, repetition_penalty=1.8, do_beams=True)
out = generate(model, con, max_length=50, horizon=story, horizon_penalty=1.4, beams=5,
repetition_penalty=1.8, do_sample=False, extra_bad_words=continuations_bad_words_ids_future_tense)

#return out
out = out.split("Story 3")[1]
Expand All @@ -105,6 +114,7 @@ def continue_story_svo_customprompt_future_tense(model, story, question, verbose
break
#Take responses that are long enough
responses = list(filter(lambda x: len(x) > 5, responses))

#print("og responses are ", responses)
#Remove first space
for i in range(len(responses)):
Expand All @@ -120,5 +130,13 @@ def continue_story_svo_customprompt_future_tense(model, story, question, verbose
responses[i] = clean_story(responses[i])
except:
continue
return responses
# take only the first sentence

def prune_to_first_k_sentence(story, k = 1):
if k > 1:
return " ".join(sent_tokenize(story)[:k])
else:
return sent_tokenize(story)[0]

return list(map(lambda string: prune_to_first_k_sentence(string, 2), responses))

91 changes: 32 additions & 59 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,6 @@ def hasSVOExpand(sent):
total_sentences_checked = 0
num_stories = 5

#for inp_story_full in tqdm(stories[:num_stories], position = 0, leave = True):
for i in tqdm(r[:num_stories]):
inp_story_full = stories[i]

inp_story = inp_story_full[-1]
if backwards:
q = get_questions_past_tense(model, inp_story)[0]
print(q)
starts = continue_story_svo_customprompt_past_tense(model, inp_story, q)

else:
q = get_questions_future_tense(model, inp_story)[0]
print(q)
starts = continue_story_svo_customprompt_future_tense(model, inp_story, q)



svoability = ['x'for start in starts if hasSVOExpand(start)]

if svoability:
at_least_one_svoable_output += 1
total_svoable_percents += len(svoability)
total_sentences_checked += len(starts)

print("at least one output ", at_least_one_svoable_output/num_stories)
print("total svoable percentages ", total_svoable_percents/total_sentences_checked)


import json
expansionfilenames = [
Expand Down Expand Up @@ -128,6 +101,8 @@ def not_contradict(sentA, sentB):
choice = np.argmax(outputs)
if choice == 0:
return -100
print(outputs[2])
return 100
return outputs[2]

print(not_contradict("The sky is blue.", "The sky is not blue.")) #scores w/ BERT
Expand All @@ -147,7 +122,7 @@ def beam_search(beams, width=20, diversity_width=2, graph=None, reverse_rank=Tru
if backwards:
questions = get_questions_past_tense(model, " ".join(story_sents[:min(len(story_sents),3)]), q_prev)
else:
questions = get_questions_future_tense(model, " ".join(story_sents[min(len(story_sents),3):]), q_prev)
questions = get_questions_future_tense(model, " ".join(story_sents[max(0, len(story_sents) - 3):]), q_prev)

print("accumulated questions are ", questions)
extensions = list()
Expand All @@ -156,7 +131,7 @@ def beam_search(beams, width=20, diversity_width=2, graph=None, reverse_rank=Tru
continuation = continue_story_svo_customprompt_past_tense(model, story, q, width=-1)
else:
continuation = continue_story_svo_customprompt_future_tense(model, story, q, width=-1)

print(continuation)
if q_prev != ['']:
q_cur = [[q] + q_prev]*len(continuation)
else:
Expand All @@ -166,12 +141,12 @@ def beam_search(beams, width=20, diversity_width=2, graph=None, reverse_rank=Tru
if backwards:
implication_story = " ".join(story_sents[0:min(len(story_sents), 1)]) #The first k sentences sliding window
else:
implication_story = " ".join(story_sents[min(len(story_sents), 1):]) #The first k sentences sliding window
#" ".join(story_sents[max(0, len(story_sents) - 3):])
implication_story = " ".join(story_sents[-3:]) #The last k sentences sliding window
extensions = list(filter(lambda x: hasSVOExpand(x[0]), extensions)) #Filter on if there is an SVO tuple
extensions_ranks = list(map(lambda x: not_contradict(x[0], implication_story) + score, extensions)) #Rank. Include prior score via sum. This keeps track of beam scores over time.
extensions_zip = list(filter(lambda x: x[1] > -100, zip(extensions, extensions_ranks))) #Zip
extensions_zip = sorted(extensions_zip, key=lambda x: x[1], reverse=reverse_rank)

extensions_zip = extensions_zip[:min(diversity_width, len(extensions_zip))] #Take top k
if backwards:
extensions = list(map(lambda x: (x[0][0]+" "+story, x[0][1], x[1]), extensions_zip)) #Sort
Expand Down Expand Up @@ -205,41 +180,39 @@ def remove_duplicates(beams):
return to_ret

reverse = True
inp_story="The battle had been raging on for hours. William set his phasers to kill, he knew he had to make amends."
inp_story="The battle had been raging on for hours. John set his phasers to kill, he knew he had to make amends."
f=None

beams, f = beam_search([(inp_story, [""], 0.0)], width=10, diversity_width=2, graph=f, reverse_rank=reverse)
print("beams are ", beams)
print("graph is ", f)

for i in range(3):
beams = remove_duplicates(beams)
print("beams are ", beams)
print("\nSTEP: " + str(i) + "\n\n")
print("\n".join(list(map(lambda x: x[0], beams))))
beams, f = beam_search(beams, width=5, diversity_width=10,graph=f, reverse_rank=reverse)
#beams, f = beam_search([(inp_story, [""], 0.0)], width=10, diversity_width=2, graph=f, reverse_rank=reverse)
#print("beams are ", beams)
#print("graph is ", f)

#for i in range(3):
# beams = remove_duplicates(beams)
# print("beams are ", beams)
# print("\nSTEP: " + str(i) + "\n\n")
# print("\n".join(list(map(lambda x: x[0], beams))))
# beams, f = beam_search(beams, width=5, diversity_width=10,graph=f, reverse_rank=reverse)

def gen_story_from_last_sentence(last_sentence, story_length=10):
reverse = True
inp_story=last_sentence
f=None

beams, f = beam_search([(inp_story, [""], 0.0)], width=10, diversity_width=2, graph=f, reverse_rank=reverse)
def gen_story_from_last_sentence(last_sentence, story_length=3):
reverse = True
inp_story=last_sentence
f=None
beams, f = beam_search([(inp_story, [""], 0.0)], width=10, diversity_width=2, graph=f, reverse_rank=reverse)
print("beams are ", beams)
print("graph is ", f)
for i in range(story_length - 1):
beams = remove_duplicates(beams)
print("beams are ", beams)
print("graph is ", f)

for i in range(story_length - 1):
beams = remove_duplicates(beams)
print("beams are ", beams)
print("\nSTEP: " + str(i) + "\n\n")
print("\n".join(list(map(lambda x: x[0], beams))))
beams, f = beam_search(beams, width=5, diversity_width=10,graph=f, reverse_rank=reverse)

return beams
print("\nSTEP: " + str(i) + "\n\n")
print("\n".join(list(map(lambda x: x[0], beams))))
beams, f = beam_search(beams, width=5, diversity_width=10,graph=f, reverse_rank=reverse)

return beams


story1 = "Hansel's hand still trembles as he pushes open the twice-cooked door. \n The last time he saw the house he was glancing back over his shoulder as he and his sister fled into the trees. \n"
story1 = "John's hand still trembles as he pushes open the twice-cooked door. \n The last time he saw the house he was glancing back over his shoulder as he and his sister fled into the trees. \n"
story2 = "Stella loves to eat peanuts! She goes to the market every day, and is best friends with the owner."
story3 = "There was a princess with long hair locked in a tower far away. \n A prince came to save her by climbing her hair and taking her out of the tower."
story4 = "I woke up realizing that I became a cat. \n I caught a mouse and at the next moment I realized that I'm a human again!"
Expand All @@ -251,7 +224,7 @@ def gen_story_from_last_sentence(last_sentence, story_length=10):

import pickle
for idx, sentence in enumerate(last_sentences):
story_beams = gen_story_from_last_sentence(sentence, story_length = 10)
story_beams = gen_story_from_last_sentence(sentence, story_length = 3)

with open('storypickle' + str(idx) + '.pickle', 'wb') as handle:
pickle.dump(story_beams, handle, protocol=pickle.HIGHEST_PROTOCOL)
Expand Down
Loading

0 comments on commit f701ca8

Please sign in to comment.