1818the `dm_control/tutorial.ipynb` file.
1919"""
2020
21- import os .path
22- import tempfile
23- from typing import Callable , Optional , Sequence , Tuple , Union
21+ from typing import Callable , Optional , Tuple , Union
2422
23+ from absl import logging
2524from acme .utils import paths
2625from acme .wrappers import base
2726import dm_env
2827
29- import matplotlib
30- matplotlib .use ('Agg' ) # Switch to headless 'Agg' to inhibit figure rendering.
31- import matplotlib .animation as anim # pylint: disable=g-import-not-at-top
32- import matplotlib .pyplot as plt
33- import numpy as np
34-
35- # Internal imports.
36- # Make sure you have FFMpeg configured.
37-
38- def make_animation (
39- frames : Sequence [np .ndarray ], frame_rate : float ,
40- figsize : Optional [Union [float , Tuple [int , int ]]]) -> anim .Animation :
41- """Generates a matplotlib animation from a stack of frames."""
42-
43- # Set animation characteristics.
44- if figsize is None :
45- height , width , _ = frames [0 ].shape
46- elif isinstance (figsize , tuple ):
47- height , width = figsize
48- else :
49- diagonal = figsize
50- height , width , _ = frames [0 ].shape
51- scale_factor = diagonal / np .sqrt (height ** 2 + width ** 2 )
52- width *= scale_factor
53- height *= scale_factor
54-
55- dpi = 70
56- interval = int (round (1e3 / frame_rate )) # Time (in ms) between frames.
57-
58- # Create and configure the figure.
59- fig , ax = plt .subplots (1 , 1 , figsize = (width / dpi , height / dpi ), dpi = dpi )
60- ax .set_axis_off ()
61- ax .set_aspect ('equal' )
62- ax .set_position ([0 , 0 , 1 , 1 ])
63-
64- # Initialize the first frame.
65- im = ax .imshow (frames [0 ])
66-
67- # Create the function that will modify the frame, creating an animation.
68- def update (frame ):
69- im .set_data (frame )
70- return [im ]
71-
72- return anim .FuncAnimation (
73- fig = fig ,
74- func = update ,
75- frames = frames ,
76- interval = interval ,
77- blit = True ,
78- repeat = False )
79-
8028
8129class VideoWrapper (base .EnvironmentWrapper ):
8230 """Wrapper which creates and records videos from generated observations.
@@ -101,77 +49,24 @@ def __init__(
10149 to_html : bool = True ,
10250 ):
10351 super (VideoWrapper , self ).__init__ (environment )
104- self ._path = process_path (path , 'videos' )
105- self ._filename = filename
106- self ._record_every = record_every
107- self ._frame_rate = frame_rate
108- self ._frames = []
109- self ._counter = 0
110- self ._figsize = figsize
111- self ._to_html = to_html
112-
113- def _render_frame (self , observation ):
114- """Renders a frame from the given environment observation."""
115- return observation
52+ logging .warning (
53+ 'VideoWrapper is deprecated and currently acts as a no-op in order to '
54+ 'avoid using ffmpeg directly. The old behavior can be restored by '
55+ 'replacing the direct call to ffmpeg within matplotlib.'
56+ )
11657
11758 def _write_frames (self ):
118- """Writes frames to video."""
119- if self ._counter % self ._record_every == 0 :
120- animation = make_animation (self ._frames , self ._frame_rate , self ._figsize )
121- path_without_extension = os .path .join (
122- self ._path , f'{ self ._filename } _{ self ._counter :04d} '
123- )
124- if self ._to_html :
125- path = path_without_extension + '.html'
126- video = animation .to_html5_video ()
127- with open (path , 'w' ) as f :
128- f .write (video )
129- else :
130- path = path_without_extension + '.m4v'
131- # Animation.save can save only locally. Save first and copy using
132- # gfile.
133- with tempfile .TemporaryDirectory () as tmp_dir :
134- tmp_path = os .path .join (tmp_dir , 'temp.m4v' )
135- animation .save (tmp_path )
136- with open (path , 'wb' ) as f :
137- with open (tmp_path , 'rb' ) as g :
138- f .write (g .read ())
139-
140- # Clear the frame buffer whether a video was generated or not.
141- self ._frames = []
142-
143- def _append_frame (self , observation ):
144- """Appends a frame to the sequence of frames."""
145- if self ._counter % self ._record_every == 0 :
146- self ._frames .append (self ._render_frame (observation ))
59+ # This is a no-op to preserve existing behavior.
60+ return
14761
14862 def step (self , action ) -> dm_env .TimeStep :
149- timestep = self .environment .step (action )
150- self ._append_frame (timestep .observation )
151- return timestep
63+ return self .environment .step (action )
15264
15365 def reset (self ) -> dm_env .TimeStep :
154- # If the frame buffer is nonempty, flush it and record video
155- if self ._frames :
156- self ._write_frames ()
157- self ._counter += 1
158- timestep = self .environment .reset ()
159- self ._append_frame (timestep .observation )
160- return timestep
161-
162- def make_html_animation (self ):
163- if self ._frames :
164- return make_animation (self ._frames , self ._frame_rate ,
165- self ._figsize ).to_html5_video ()
166- else :
167- raise ValueError ('make_html_animation should be called after running a '
168- 'trajectory and before calling reset().' )
66+ return self .environment .reset ()
16967
17068 def close (self ):
171- if self ._frames :
172- self ._write_frames ()
173- self ._frames = []
174- self .environment .close ()
69+ return self .environment .close ()
17570
17671
17772class MujocoVideoWrapper (VideoWrapper ):
@@ -215,42 +110,3 @@ def __init__(self,
215110 self ._camera_id = camera_id
216111 self ._height = height
217112 self ._width = width
218-
219- def _render_frame (self , unused_observation ):
220- del unused_observation
221-
222- # We've checked above that this attribute should exist. Pytype won't like
223- # it if we just try and do self.environment.physics, so we use the slightly
224- # grosser version below.
225- physics = getattr (self .environment , 'physics' )
226-
227- if self ._camera_id is not None :
228- frame = physics .render (
229- camera_id = self ._camera_id , height = self ._height , width = self ._width )
230- else :
231- # If camera_id is None, we create a minimal canvas that will accommodate
232- # physics.model.ncam frames, and render all of them on a grid.
233- num_cameras = physics .model .ncam
234- num_columns = int (np .ceil (np .sqrt (num_cameras )))
235- num_rows = int (np .ceil (float (num_cameras )/ num_columns ))
236- height = self ._height
237- width = self ._width
238-
239- # Make a black canvas.
240- frame = np .zeros ((num_rows * height , num_columns * width , 3 ), dtype = np .uint8 )
241-
242- for col in range (num_columns ):
243- for row in range (num_rows ):
244-
245- camera_id = row * num_columns + col
246-
247- if camera_id >= num_cameras :
248- break
249-
250- subframe = physics .render (
251- camera_id = camera_id , height = height , width = width )
252-
253- # Place the frame in the appropriate rectangle on the pixel canvas.
254- frame [row * height :(row + 1 )* height , col * width :(col + 1 )* width ] = subframe
255-
256- return frame
0 commit comments