1515from tqdm import tqdm
1616
1717
18- def create_labeled_video (video_file ,
19- ts_file ,
20- dlc_file ,
21- out_dir = None ,
22- save_images = False ,
23- cut = (0 , np .Inf ),
24- crop = None ,
25- cmap = 'bmy' ,
26- radius = 3 ,
27- lik_thresh = 0.5 ,
28- write_ts = False ,
29- write_scale = 2 ,
30- display = False ,
31- progress = True ,
32- label = True ):
18+ def create_labeled_video (
19+ data_dir ,
20+ out_dir = None ,
21+ dlc_online = True ,
22+ save_images = False ,
23+ cut = (0 , np .Inf ),
24+ crop = None ,
25+ cmap = "bmy" ,
26+ radius = 3 ,
27+ lik_thresh = 0.5 ,
28+ write_ts = False ,
29+ write_scale = 2 ,
30+ write_pos = "bottom-left" ,
31+ write_ts_offset = 0 ,
32+ display = False ,
33+ progress = True ,
34+ label = True ,
35+ ):
3336 """ Create a labeled video from DeepLabCut-live-GUI recording
3437
3538 Parameters
3639 ----------
37- video_file : str
38- path to video file
39- ts_file : str
40- path to timestamps file
41- dlc_file : str
42- path to DeepLabCut file
40+ data_dir : str
41+ path to data directory
42+ dlc_online : bool, optional
43+ flag indicating dlc keypoints from online tracking, using DeepLabCut-live-GUI, or offline tracking, using :func:`dlclive.benchmark_videos`
4344 out_file : str, optional
4445 path for output file. If None, output file will be "'video_file'_LABELED.avi". by default None. If NOn
4546 save_images : bool, optional
@@ -63,50 +64,86 @@ def create_labeled_video(video_file,
6364 if frames cannot be read from the video file
6465 """
6566
67+ base_dir = os .path .basename (data_dir )
68+ video_file = os .path .normpath (f"{ data_dir } /{ base_dir } _VIDEO.avi" )
69+ ts_file = os .path .normpath (f"{ data_dir } /{ base_dir } _TS.npy" )
70+ dlc_file = (
71+ os .path .normpath (f"{ data_dir } /{ base_dir } _DLC.hdf5" )
72+ if dlc_online
73+ else os .path .normpath (f"{ data_dir } /{ base_dir } _VIDEO_DLCLIVE_POSES.h5" )
74+ )
75+
6676 cap = cv2 .VideoCapture (video_file )
6777 cam_frame_times = np .load (ts_file )
6878 n_frames = cam_frame_times .size
6979
70-
7180 lab = "LABELED" if label else "UNLABELED"
7281 if out_dir :
73- out_file = f"{ out_dir } /{ os .path .splitext (os .path .basename (video_file ))[0 ]} _{ lab } .avi"
74- out_times_file = f"{ out_dir } /{ os .path .splitext (os .path .basename (ts_file ))[0 ]} _{ lab } .npy"
82+ out_file = (
83+ f"{ out_dir } /{ os .path .splitext (os .path .basename (video_file ))[0 ]} _{ lab } .avi"
84+ )
85+ out_times_file = (
86+ f"{ out_dir } /{ os .path .splitext (os .path .basename (ts_file ))[0 ]} _{ lab } .npy"
87+ )
7588 else :
7689 out_file = f"{ os .path .splitext (video_file )[0 ]} _{ lab } .avi"
7790 out_times_file = f"{ os .path .splitext (ts_file )[0 ]} _{ lab } .npy"
7891
7992 os .makedirs (os .path .normpath (os .path .dirname (out_file )), exist_ok = True )
80-
93+
8194 if save_images :
8295 im_dir = os .path .splitext (out_file )[0 ]
8396 os .makedirs (im_dir , exist_ok = True )
8497
85- im_size = (int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH )), int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT )))
98+ im_size = (
99+ int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH )),
100+ int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT )),
101+ )
86102 if crop is not None :
87- crop = np .max (np .vstack ((crop , [0 , im_size [1 ], 0 , im_size [0 ]])), axis = 0 )
88- im_size = (crop [3 ]- crop [2 ], crop [1 ]- crop [0 ])
103+ crop [0 ] = crop [0 ] if crop [0 ] > 0 else 0
104+ crop [1 ] = crop [1 ] if crop [1 ] > 0 else im_size [1 ]
105+ crop [2 ] = crop [2 ] if crop [2 ] > 0 else 0
106+ crop [3 ] = crop [3 ] if crop [3 ] > 0 else im_size [0 ]
107+ im_size = (crop [3 ] - crop [2 ], crop [1 ] - crop [0 ])
89108
90- fourcc = cv2 .VideoWriter_fourcc (* ' DIVX' )
109+ fourcc = cv2 .VideoWriter_fourcc (* " DIVX" )
91110 fps = cap .get (cv2 .CAP_PROP_FPS )
92111 vwriter = cv2 .VideoWriter (out_file , fourcc , fps , im_size )
93112 label_times = []
94-
113+
95114 if write_ts :
96115 ts_font = cv2 .FONT_HERSHEY_PLAIN
97- ts_w = 0 if crop is None else crop [0 ]
98- ts_h = im_size [1 ] if crop is None else crop [1 ]
116+
117+ if "left" in write_pos :
118+ ts_w = 0
119+ else :
120+ ts_w = (
121+ im_size [0 ] if crop is None else (crop [3 ] - crop [2 ]) - (55 * write_scale )
122+ )
123+
124+ if "bottom" in write_pos :
125+ ts_h = im_size [1 ] if crop is None else (crop [1 ] - crop [0 ])
126+ else :
127+ ts_h = 0 if crop is None else crop [0 ] + (12 * write_scale )
128+
99129 ts_coord = (ts_w , ts_h )
100130 ts_color = (255 , 255 , 255 )
101131 ts_size = 2
102132
103133 poses = pd .read_hdf (dlc_file )
104- pose_times = poses ['pose_time' ]
105- poses = poses .melt (id_vars = ['frame_time' , 'pose_time' ])
106- bodyparts = poses ['bodyparts' ].unique ()
134+ if dlc_online :
135+ pose_times = poses ["pose_time" ]
136+ else :
137+ poses ["frame_time" ] = cam_frame_times
138+ poses ["pose_time" ] = cam_frame_times
139+ poses = poses .melt (id_vars = ["frame_time" , "pose_time" ])
140+ bodyparts = poses ["bodyparts" ].unique ()
107141
108142 all_colors = getattr (cc , cmap )
109- colors = [ImageColor .getcolor (c , "RGB" )[::- 1 ] for c in all_colors [::int (len (all_colors )/ bodyparts .size )]]
143+ colors = [
144+ ImageColor .getcolor (c , "RGB" )[::- 1 ]
145+ for c in all_colors [:: int (len (all_colors ) / bodyparts .size )]
146+ ]
110147
111148 ind = 0
112149 vid_time = 0
@@ -116,47 +153,70 @@ def create_labeled_video(video_file,
116153 vid_time = cur_time - cam_frame_times [0 ]
117154 ret , frame = cap .read ()
118155 ind += 1
119-
156+
120157 if not ret :
121- raise Exception (f"Could not read frame = { ind + 1 } at time = { cur_time - cam_frame_times [0 ]} ." )
122-
123-
124- frame_times_sub = cam_frame_times [(cam_frame_times - cam_frame_times [0 ] > cut [0 ]) & (cam_frame_times - cam_frame_times [0 ] < cut [1 ])]
125- iterator = tqdm (range (ind , ind + frame_times_sub .size )) if progress else range (ind , ind + frame_times_sub .size )
158+ raise Exception (
159+ f"Could not read frame = { ind + 1 } at time = { cur_time - cam_frame_times [0 ]} ."
160+ )
161+
162+ frame_times_sub = cam_frame_times [
163+ (cam_frame_times - cam_frame_times [0 ] > cut [0 ])
164+ & (cam_frame_times - cam_frame_times [0 ] < cut [1 ])
165+ ]
166+ iterator = (
167+ tqdm (range (ind , ind + frame_times_sub .size ))
168+ if progress
169+ else range (ind , ind + frame_times_sub .size )
170+ )
126171 this_pose = np .zeros ((bodyparts .size , 3 ))
127172
128173 for i in iterator :
129174
130175 cur_time = cam_frame_times [i ]
131176 vid_time = cur_time - cam_frame_times [0 ]
132177 ret , frame = cap .read ()
133-
178+
134179 if not ret :
135- raise Exception (f"Could not read frame = { i + 1 } at time = { cur_time - cam_frame_times [0 ]} ." )
180+ raise Exception (
181+ f"Could not read frame = { i + 1 } at time = { cur_time - cam_frame_times [0 ]} ."
182+ )
136183
137- poses_before_index = np .where (pose_times < cur_time )[0 ]
138- if poses_before_index .size > 0 :
139- cur_pose_time = pose_times [poses_before_index [- 1 ]]
140- this_pose = poses [poses ['pose_time' ]== cur_pose_time ]
184+ if dlc_online :
185+ poses_before_index = np .where (pose_times < cur_time )[0 ]
186+ if poses_before_index .size > 0 :
187+ cur_pose_time = pose_times [poses_before_index [- 1 ]]
188+ this_pose = poses [poses ["pose_time" ] == cur_pose_time ]
189+ else :
190+ this_pose = poses [poses ["frame_time" ] == cur_time ]
141191
142192 if label :
143193 for j in range (bodyparts .size ):
144- this_bp = this_pose [this_pose ['bodyparts' ] == bodyparts [j ]]['value' ].values
194+ this_bp = this_pose [this_pose ["bodyparts" ] == bodyparts [j ]][
195+ "value"
196+ ].values
145197 if this_bp [2 ] > lik_thresh :
146198 x = int (this_bp [0 ])
147199 y = int (this_bp [1 ])
148200 frame = cv2 .circle (frame , (x , y ), radius , colors [j ], thickness = - 1 )
149-
201+
150202 if crop is not None :
151- frame = frame [crop [0 ]: crop [1 ], crop [2 ]: crop [3 ]]
203+ frame = frame [crop [0 ] : crop [1 ], crop [2 ] : crop [3 ]]
152204
153205 if write_ts :
154- frame = cv2 .putText (frame , f"{ vid_time :0.3f} " , ts_coord , ts_font , write_scale , ts_color , ts_size )
206+ frame = cv2 .putText (
207+ frame ,
208+ f"{ (vid_time - write_ts_offset ):0.3f} " ,
209+ ts_coord ,
210+ ts_font ,
211+ write_scale ,
212+ ts_color ,
213+ ts_size ,
214+ )
155215
156216 if display :
157- cv2 .imshow (' DLC Live Labeled Video' , frame )
217+ cv2 .imshow (" DLC Live Labeled Video" , frame )
158218 cv2 .waitKey (1 )
159-
219+
160220 vwriter .write (frame )
161221 label_times .append (cur_time )
162222 if save_images :
@@ -165,7 +225,7 @@ def create_labeled_video(video_file,
165225
166226 if display :
167227 cv2 .destroyAllWindows ()
168-
228+
169229 vwriter .release ()
170230 np .save (out_times_file , label_times )
171231
@@ -176,37 +236,40 @@ def main():
176236 import os
177237
178238 parser = argparse .ArgumentParser ()
179- parser .add_argument ('file' , type = str )
180- parser .add_argument ('-o' , '--out-dir' , type = str , default = None )
181- parser .add_argument ('-s' , '--save-images' , action = 'store_true' )
182- parser .add_argument ('-u' , '--cut' , nargs = '+' , type = float , default = [0 , np .Inf ])
183- parser .add_argument ('-c' , '--crop' , nargs = '+' , type = int , default = None )
184- parser .add_argument ('-m' , '--cmap' , type = str , default = 'bmy' )
185- parser .add_argument ('-r' , '--radius' , type = int , default = 3 )
186- parser .add_argument ('-l' , '--lik-thresh' , type = float , default = 0.5 )
187- parser .add_argument ('-w' , '--write-ts' , action = 'store_true' )
188- parser .add_argument ('--write-scale' , type = int , default = 2 )
189- parser .add_argument ('-d' , '--display' , action = 'store_true' )
190- parser .add_argument ('--no-progress' , action = 'store_false' )
191- parser .add_argument ('--no-label' , action = 'store_false' )
239+ parser .add_argument ("dir" , type = str )
240+ parser .add_argument ("-o" , "--out-dir" , type = str , default = None )
241+ parser .add_argument ("--dlc-offline" , action = "store_true" )
242+ parser .add_argument ("-s" , "--save-images" , action = "store_true" )
243+ parser .add_argument ("-u" , "--cut" , nargs = "+" , type = float , default = [0 , np .Inf ])
244+ parser .add_argument ("-c" , "--crop" , nargs = "+" , type = int , default = None )
245+ parser .add_argument ("-m" , "--cmap" , type = str , default = "bmy" )
246+ parser .add_argument ("-r" , "--radius" , type = int , default = 3 )
247+ parser .add_argument ("-l" , "--lik-thresh" , type = float , default = 0.5 )
248+ parser .add_argument ("-w" , "--write-ts" , action = "store_true" )
249+ parser .add_argument ("--write-scale" , type = int , default = 2 )
250+ parser .add_argument ("--write-pos" , type = str , default = "bottom-left" )
251+ parser .add_argument ("--write-ts-offset" , type = float , default = 0.0 )
252+ parser .add_argument ("-d" , "--display" , action = "store_true" )
253+ parser .add_argument ("--no-progress" , action = "store_false" )
254+ parser .add_argument ("--no-label" , action = "store_false" )
192255 args = parser .parse_args ()
193256
194- vid_file = os . path . normpath ( f" { args . file } _VIDEO.avi" )
195- ts_file = os . path . normpath ( f" { args .file } _TS.npy" )
196- dlc_file = os . path . normpath ( f" { args .file } _DLC.hdf5" )
197-
198- create_labeled_video ( vid_file ,
199- ts_file ,
200- dlc_file ,
201- out_dir = args .out_dir ,
202- save_images = args .save_images ,
203- cut = tuple ( args .cut ) ,
204- crop = args .crop ,
205- cmap = args .cmap ,
206- radius = args .radius ,
207- lik_thresh = args .lik_thresh ,
208- write_ts = args .write_ts ,
209- write_scale = args .write_scale ,
210- display = args .display ,
211- progress = args . no_progress ,
212- label = args . no_label )
257+ create_labeled_video (
258+ args .dir ,
259+ out_dir = args .out_dir ,
260+ dlc_online = ( not args . dlc_offline ),
261+ save_images = args . save_images ,
262+ cut = tuple ( args . cut ) ,
263+ crop = args . crop ,
264+ cmap = args .cmap ,
265+ radius = args .radius ,
266+ lik_thresh = args .lik_thresh ,
267+ write_ts = args .write_ts ,
268+ write_scale = args .write_scale ,
269+ write_pos = args .write_pos ,
270+ write_ts_offset = args .write_ts_offset ,
271+ display = args .display ,
272+ progress = args .no_progress ,
273+ label = args .no_label ,
274+ )
275+
0 commit comments