1
+ # _*_coding:utf-8 _*_
2
+ # Author : Tao
3
+ """
4
+ 用于加载tvsum anno数据
5
+ """
6
+
7
+ import os
8
+ import cv2
9
+ import csv
10
+ import numpy as np
11
+
12
+
13
+ def load_tvsum (data_root_path , scale = False , ):
14
+ """
15
+ :param data_root_path: tvsum data path
16
+ :param scale: scale to 0.0 ~ 1.0 or not
17
+ :return: dic
18
+ """
19
+ anno_file = os .path .join (data_root_path , 'data/ydata-tvsum50-anno.tsv' )
20
+ video_path = os .path .join (data_root_path , 'video' )
21
+
22
+ video_name , fps , nframes , user_score , avg_score = [], [], [], [], []
23
+
24
+ with open (anno_file ) as fd :
25
+ rd = csv .reader (fd , delimiter = "\t " , quotechar = '"' )
26
+ for us in rd : # user summary
27
+ if us [0 ] not in video_name :
28
+ video_name .append (us [0 ])
29
+ vidx = video_name .index (us [0 ])
30
+ capture = cv2 .VideoCapture (
31
+ os .path .join (video_path , us [0 ] + '.mp4' ))
32
+
33
+ # get fps
34
+ fps .append (int (capture .get (cv2 .CAP_PROP_FPS )))
35
+ # get frames
36
+ nframes .append (int (capture .get (cv2 .CAP_PROP_FRAME_COUNT )))
37
+ user_score .append ([])
38
+ user_score [vidx ].append (np .asarray (us [2 ].split (',' )).astype (float ))
39
+
40
+ for vidx in range (len (video_name )): # video key
41
+ avg_score .append (np .asarray (user_score [vidx ]).mean (axis = 0 ))
42
+ user_score [vidx ] = np .asarray (user_score [vidx ])
43
+
44
+ # scale to 0.0 ~ 1.0
45
+ if scale :
46
+ for i in range (len (avg_score )):
47
+ max_score = max (avg_score [i ])
48
+ min_score = min (avg_score [i ])
49
+ avg_s = (avg_score [i ] - min_score ) / (max_score - min_score )
50
+ avg_score [i ] = avg_s
51
+
52
+ dic = {}
53
+ for i , video in enumerate (video_name ):
54
+ dic [video ] = {'fps' : fps [i ],
55
+ 'frames' : nframes [i ],
56
+ 'user_score' : user_score [i ],
57
+ 'avg_score' : avg_score [i ]
58
+ }
59
+ return dic
60
+
61
+
62
+ def knapsack_dp (values , weights , n_items , capacity , return_all = False ):
63
+ check_inputs (values , weights , n_items , capacity )
64
+
65
+ table = np .zeros ((n_items + 1 , capacity + 1 ), dtype = np .float32 )
66
+ keep = np .zeros ((n_items + 1 , capacity + 1 ), dtype = np .float32 )
67
+
68
+ for i in range (1 , n_items + 1 ):
69
+ for w in range (0 , capacity + 1 ):
70
+ wi = weights [i - 1 ] # weight of current item
71
+ vi = values [i - 1 ] # value of current item
72
+ if (wi <= w ) and (vi + table [i - 1 , w - wi ] > table [i - 1 , w ]):
73
+ table [i , w ] = vi + table [i - 1 , w - wi ]
74
+ keep [i , w ] = 1
75
+ else :
76
+ table [i , w ] = table [i - 1 , w ]
77
+
78
+ picks = []
79
+ K = capacity
80
+
81
+ for i in range (n_items , 0 , - 1 ):
82
+ if keep [i , K ] == 1 :
83
+ picks .append (i )
84
+ K -= weights [i - 1 ]
85
+
86
+ picks .sort ()
87
+ picks = [x - 1 for x in picks ] # change to 0-index
88
+
89
+ if return_all :
90
+ max_val = table [n_items , capacity ]
91
+ return picks , max_val
92
+ return picks
93
+
94
+
95
+ def check_inputs (values , weights , n_items , capacity ):
96
+ # check variable type
97
+ assert (isinstance (values , list ))
98
+ assert (isinstance (weights , list ))
99
+ assert (isinstance (n_items , int ))
100
+ assert (isinstance (capacity , int ))
101
+ # check value type
102
+ assert (all (isinstance (val , int ) or isinstance (val , float ) for val in
103
+ values ))
104
+ assert (all (isinstance (val , int ) for val in weights ))
105
+ # check validity of value
106
+ assert (all (val >= 0 for val in weights ))
107
+ assert (n_items > 0 )
108
+ assert (capacity > 0 )
109
+
110
+
111
+ def get_summary (score , sum_rate ):
112
+ """
113
+ :param score: score list
114
+ :param sum_rate: summary rate
115
+ :return: summary mask, one hot
116
+ """
117
+ clip_scores = [x * 1000 for x in score ] # up scale
118
+ clip_scores = [int (round (x )) for x in clip_scores ]
119
+
120
+ n = len (clip_scores ) # 总帧数
121
+ W = int (n * sum_rate ) # summary帧总数
122
+ val = clip_scores #
123
+ wt = [1 for x in range (n )] # 全1
124
+
125
+ sum_ = knapsack_dp (val , wt , n , W )
126
+
127
+ summary = np .zeros ((1 ), dtype = np .float32 ) # this element should be deleted
128
+ for seg_idx in range (n ):
129
+ nf = wt [seg_idx ]
130
+ if seg_idx in sum_ :
131
+ tmp = np .ones ((nf ), dtype = np .float32 )
132
+ else :
133
+ tmp = np .zeros ((nf ), dtype = np .float32 )
134
+ summary = np .concatenate ((summary , tmp ))
135
+ summary = list (summary )
136
+
137
+ del summary [0 ]
138
+
139
+ return summary
0 commit comments