2121
2222import os
2323import tarfile
24+ import hashlib
2425
2526# Dependency imports
2627
3839
3940_DAILYMAIL_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs"
4041
42+ # Note: using See et al. (2017) as reference for data generation
43+ # For more info, use the links below
44+
45+ # Train/Dev/Test Splits for summarization data
46+ _TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt"
47+ _DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
48+ _TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"
4149
4250# End-of-sentence marker.
4351EOS = text_encoder .EOS_ID
4452
53+ # Techniques for data prep from See et al. (2017)
54+ dm_single_close_quote = u'\u2019 ' # unicode
55+ dm_double_close_quote = u'\u201d '
56+ END_TOKENS = [u'.' , u'!' , u'?' , u'...' , u"'" , u"`" , u'"' , dm_single_close_quote , dm_double_close_quote , u")" ] # acceptable ways to end a sentence
57+
4558
46- def _maybe_download_corpora (tmp_dir ):
59+ def _maybe_download_corpora (tmp_dir , is_training ):
4760 """Download corpora if necessary and unzip them.
4861
4962 Args:
5063 tmp_dir: directory containing dataset.
5164
5265 Returns:
53- filepath of the downloaded corpus file.
66+ list of all files generated and path to file containing train/dev/test split info .
5467 """
5568 cnn_filename = "cnn_stories.tgz"
5669 cnn_finalpath = os .path .join (tmp_dir , "cnn/stories/" )
@@ -66,29 +79,87 @@ def _maybe_download_corpora(tmp_dir):
6679 tmp_dir , dailymail_filename , _DAILYMAIL_STORIES_DRIVE_URL )
6780 with tarfile .open (dailymail_file , "r:gz" ) as dailymail_tar :
6881 dailymail_tar .extractall (tmp_dir )
69- return [cnn_finalpath , dailymail_finalpath ]
70-
71-
72- def story_generator (tmp_dir ):
73- paths = _maybe_download_corpora (tmp_dir )
74- for path in paths :
75- for story_file in tf .gfile .Glob (path + "*" ):
76- story = u""
77- for line in tf .gfile .Open (story_file , "rb" ):
78- line = unicode (line , "utf-8" ) if six .PY2 else line .decode ("utf-8" )
79- story += line
80- yield story
8182
83+ cnn_files = tf .gfile .Glob (cnn_finalpath + "*" )
84+ dailymail_files = tf .gfile .Glob (dailymail_finalpath + "*" )
85+ all_files = cnn_files + dailymail_files
86+
87+ if is_training :
88+ urls_path = generator_utils .maybe_download (tmp_dir , "all_train.txt" , _TRAIN_URLS )
89+ else :
90+ urls_path = generator_utils .maybe_download (tmp_dir , "all_val.txt" , _DEV_URLS )
91+
92+ return all_files , urls_path
93+
94+ def example_splits (url_file , all_files ):
95+ def generate_hash (inp ):
96+ """Generate a sha1 hash to match the raw url to the filename extracted"""
97+ h = hashlib .sha1 ()
98+ h .update (inp )
99+ return h .hexdigest ()
100+
101+ all_files_map = {f .split ("/" )[- 1 ]:f for f in all_files }
102+
103+ urls = []
104+ for line in tf .gfile .Open (url_file ):
105+ urls .append (line .strip ().encode ('utf-8' ))
106+
107+ filelist = []
108+ for url in urls :
109+ url_hash = generate_hash (url )
110+ filename = url_hash + ".story"
111+ if filename not in all_files_map :
112+ tf .logging .info ("Missing file: %s" % url )
113+ continue
114+ filelist .append (all_files_map [filename ])
115+
116+ tf .logging .info ("Found %d examples" % len (filelist ))
117+
118+ return filelist
119+
120+ def example_generator (tmp_dir , is_training , sum_token ):
121+ def fix_run_on_sents (line ):
122+ if u"@highlight" in line : return line
123+ if line == "" : return line
124+ if line [- 1 ] in END_TOKENS : return line
125+ return line + u"."
126+
127+ all_files , urls_path = _maybe_download_corpora (tmp_dir , is_training )
128+ filelist = example_splits (urls_path , all_files )
129+ story_summary_split_token = u" <summary> " if sum_token else " "
130+
131+ for story_file in filelist :
132+ story = []
133+ summary = []
134+ reading_highlights = False
135+ for line in tf .gfile .Open (story_file , "rb" ):
136+ line = unicode (line .strip (), "utf-8" ) if six .PY2 else line .strip ().decode ("utf-8" )
137+ line = fix_run_on_sents (line )
138+ if line == "" :
139+ continue
140+ elif line .startswith (u"@highlight" ):
141+ if len (story ) == 0 : break # No article text
142+ reading_highlights = True
143+ elif reading_highlights :
144+ summary .append (line )
145+ else :
146+ story .append (line )
147+
148+ if len (story ) == 0 or len (summary ) == 0 :
149+ continue
150+
151+ yield " " .join (story ) + story_summary_split_token + " " .join (summary )
82152
83153def _story_summary_split (story ):
84- end_pos = story .find ("\n \n " ) # Upto first empty line.
85- assert end_pos != - 1
86- return story [:end_pos ], story [end_pos :].strip ()
154+ split_str = u" <summary> "
155+ split_str_len = len (split_str )
156+ split_pos = story .find (split_str )
157+ return story [:split_pos ], story [split_pos + split_str_len :] # story, summary
87158
88159
89160@registry .register_problem
90161class SummarizeCnnDailymail32k (problem .Text2TextProblem ):
91- """Summarize CNN and Daily Mail articles to their first paragraph ."""
162+ """Summarize CNN and Daily Mail articles to their summary highlights ."""
92163
93164 @property
94165 def is_character_level (self ):
@@ -124,14 +195,14 @@ def targeted_vocab_size(self):
124195
125196 @property
126197 def use_train_shards_for_dev (self ):
127- return True
198+ return False
128199
129- def generator (self , data_dir , tmp_dir , _ ):
200+ def generator (self , data_dir , tmp_dir , is_training ):
130201 encoder = generator_utils .get_or_generate_vocab_inner (
131202 data_dir , self .vocab_file , self .targeted_vocab_size ,
132- story_generator (tmp_dir ))
133- for story in story_generator (tmp_dir ):
134- summary , rest = _story_summary_split (story )
203+ example_generator (tmp_dir , is_training , sum_token = False ))
204+ for example in example_generator (tmp_dir , is_training , sum_token = True ):
205+ story , summary = _story_summary_split (example )
135206 encoded_summary = encoder .encode (summary ) + [EOS ]
136- encoded_story = encoder .encode (rest ) + [EOS ]
207+ encoded_story = encoder .encode (story ) + [EOS ]
137208 yield {"inputs" : encoded_story , "targets" : encoded_summary }
0 commit comments