Skip to content

Commit 880fbaf

Browse files
committed
tensorboard-x refactor
1 parent 4bbb992 commit 880fbaf

File tree

8 files changed

+81
-92
lines changed

8 files changed

+81
-92
lines changed

demo.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import numpy as np
44
import torchvision.models as models
55
from torchvision import datasets
6-
from tensorboard import SummaryWriter
6+
from tensorboardX import SummaryWriter
7+
#import skimage
8+
#from skimage import data, io
9+
710
resnet18 = models.resnet18(False)
811
writer = SummaryWriter()
912
sample_rate = 44100
@@ -12,20 +15,22 @@
1215
for n_iter in range(100):
1316
s1 = torch.rand(1) # value to keep
1417
s2 = torch.rand(1)
15-
writer.add_scalar('data/scalar1', s1[0], n_iter) #data grouping by `slash`
16-
writer.add_scalar('data/scalar2', s2[0], n_iter)
18+
writer.add_scalar('data/scalar1', s1[0], n_iter) # data grouping by `slash`
19+
writer.add_scalar('data/scalar2', s2, n_iter) # passing Tensor is OK!
1720
x = torch.rand(32, 3, 64, 64) # output from network
1821
if n_iter%10==0:
1922
x = vutils.make_grid(x, normalize=True, scale_each=True)
20-
writer.add_image('Image', x, n_iter)
23+
writer.add_image('Image', x, n_iter) # Tensor
24+
#writer.add_image('astronaut', skimage.data.astronaut(), n_iter) # numpy
25+
#writer.add_image('imread', skimage.io.imread('screenshots/audio.png'), n_iter) # numpy
2126
x = torch.zeros(sample_rate*2)
2227
for i in range(x.size(0)):
2328
x[i] = np.cos(freqs[n_iter//10]*np.pi*float(i)/float(sample_rate)) # sound amplitude should in [-1, 1]
2429
writer.add_audio('myAudio', x, n_iter)
2530
writer.add_text('Text', 'text logged at step:'+str(n_iter), n_iter)
2631
writer.add_text('another Text', 'another text logged at step:'+str(n_iter), n_iter)
2732
for name, param in resnet18.named_parameters():
28-
writer.add_histogram(name, param.clone().cpu().data.numpy(), n_iter)
33+
writer.add_histogram(name, param, n_iter)
2934

3035
dataset = datasets.MNIST('mnist', train=False, download=True)
3136
images = dataset.test_data[:100].float()

demo_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn.functional as F
44
import os
55
from torch.autograd.variable import Variable
6-
from tensorboard import SummaryWriter
6+
from tensorboardX import SummaryWriter
77
from torch.utils.data import TensorDataset, DataLoader
88

99
#EMBEDDING VISUALIZATION FOR A TWO-CLASSES PROBLEM

demo_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import torch.nn.functional as F
66
import torchvision.models as models
7-
from tensorboard import SummaryWriter
7+
from tensorboardX import SummaryWriter
88

99
class Mnist(nn.Module):
1010
def __init__(self):

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
]
1919

2020
setup(
21-
name='tensorboard-pytorch',
22-
version='0.7.1',
23-
description='Log TensorBoard events with pytorch',
21+
name='tensorboardX',
22+
version='0.7.5',
23+
description='TensorBoardX lets you watch Tensors Flow without Tensorflow',
2424
long_description= history,
2525
author='Tzu-Wei Huang',
2626
author_email='huang.dexter@gmail.com',

tensorboardX/summary.py

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ def scalar(name, scalar, collections=None):
8484
ValueError: If tensor has the wrong shape or type.
8585
"""
8686
name = _clean_tag(name)
87-
if not isinstance(scalar, float):
88-
# try conversion, if failed then need handle by user.
89-
scalar = float(scalar)
87+
scalar = float(scalar)
9088
return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
9189

9290

@@ -182,8 +180,6 @@ def make_image(tensor):
182180
encoded_image_string=image_string)
183181

184182
def audio(tag, tensor, sample_rate=44100):
185-
tensor = tensor.squeeze()
186-
assert tensor.dim()==1, 'input tensor should be 1 dimensional.'
187183
tensor_list = [int(32767.0*x) for x in tensor]
188184
import io
189185
import wave
@@ -212,69 +208,3 @@ def text(tag, text):
212208
tensor = TensorProto(dtype='DT_STRING', string_val=[text.encode(encoding='utf_8')])
213209
return Summary(value=[Summary.Value(node_name=tag, metadata=smd, tensor=tensor)])
214210

215-
216-
'''
217-
def merge(inputs, collections=None, name=None):
218-
# pylint: disable=line-too-long
219-
"""Merges summaries.
220-
This op creates a
221-
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
222-
protocol buffer that contains the union of all the values in the input
223-
summaries.
224-
When the Op is run, it reports an `InvalidArgument` error if multiple values
225-
in the summaries to merge use the same tag.
226-
Args:
227-
inputs: A list of `string` `Tensor` objects containing serialized `Summary`
228-
protocol buffers.
229-
collections: Optional list of graph collections keys. The new summary op is
230-
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
231-
name: A name for the operation (optional).
232-
Returns:
233-
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
234-
buffer resulting from the merging.
235-
"""
236-
# pylint: enable=line-too-long
237-
name = _clean_tag(name)
238-
with _ops.name_scope(name, 'Merge', inputs):
239-
# pylint: disable=protected-access
240-
val = _gen_logging_ops._merge_summary(inputs=inputs, name=name)
241-
_collect(val, collections, [])
242-
return val
243-
244-
245-
def merge_all(key=_ops.GraphKeys.SUMMARIES):
246-
"""Merges all summaries collected in the default graph.
247-
Args:
248-
key: `GraphKey` used to collect the summaries. Defaults to
249-
`GraphKeys.SUMMARIES`.
250-
Returns:
251-
If no summaries were collected, returns None. Otherwise returns a scalar
252-
`Tensor` of type `string` containing the serialized `Summary` protocol
253-
buffer resulting from the merging.
254-
"""
255-
summary_ops = _ops.get_collection(key)
256-
if not summary_ops:
257-
return None
258-
else:
259-
return merge(summary_ops)
260-
261-
262-
def get_summary_description(node_def):
263-
"""Given a TensorSummary node_def, retrieve its SummaryDescription.
264-
When a Summary op is instantiated, a SummaryDescription of associated
265-
metadata is stored in its NodeDef. This method retrieves the description.
266-
Args:
267-
node_def: the node_def_pb2.NodeDef of a TensorSummary op
268-
Returns:
269-
a summary_pb2.SummaryDescription
270-
Raises:
271-
ValueError: if the node is not a summary op.
272-
"""
273-
274-
if node_def.op != 'TensorSummary':
275-
raise ValueError("Can't get_summary_description on %s" % node_def.op)
276-
description_str = _compat.as_str_any(node_def.attr['description'].s)
277-
summary_description = SummaryDescription()
278-
_json_format.Parse(description_str, summary_description)
279-
return summary_description
280-
'''

tensorboardX/writer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .summary import scalar, histogram, image, audio, text
2929
from .graph import graph
3030
from .embedding import make_mat, make_sprite, make_tsv, append_pbtxt
31-
31+
from .x2num import makenp
3232

3333
class SummaryToEventTransformer(object):
3434
"""Abstractly implements the SummaryWriter API.
@@ -248,6 +248,8 @@ def add_scalar(self, tag, scalar_value, global_step=None):
248248
global_step (int): Global step value to record
249249
250250
"""
251+
scalar_value = makenp(scalar_value)
252+
assert(scalar_value.squeeze().ndim==0), 'input of add_scalar should be 0D'
251253
self.file_writer.add_summary(scalar(tag, scalar_value), global_step)
252254

253255
def add_histogram(self, tag, values, global_step=None, bins='tensorflow'):
@@ -262,6 +264,7 @@ def add_histogram(self, tag, values, global_step=None, bins='tensorflow'):
262264
"""
263265
if bins=='tensorflow':
264266
bins = self.default_bins
267+
values = makenp(values)
265268
self.file_writer.add_summary(histogram(tag, values, bins), global_step)
266269

267270
def add_image(self, tag, img_tensor, global_step=None):
@@ -274,6 +277,7 @@ def add_image(self, tag, img_tensor, global_step=None):
274277
Shape:
275278
img_tensor: :math:`(3, H, W)`. Use ``torchvision.utils.make_grid()`` to prepare it is a good idea.
276279
"""
280+
img_tensor = makenp(img_tensor, 'IMG')
277281
self.file_writer.add_summary(image(tag, img_tensor), global_step)
278282
def add_audio(self, tag, snd_tensor, global_step=None):
279283
"""Add audio data to summary.
@@ -286,6 +290,10 @@ def add_audio(self, tag, snd_tensor, global_step=None):
286290
Shape:
287291
snd_tensor: :math:`(1, L)`. The values should between [-1, 1]. The sample rate is currently fixed at 44100 KHz.
288292
"""
293+
snd_tensor = makenp(snd_tensor)
294+
snd_tensor = snd_tensor.squeeze()
295+
assert(snd_tensor.ndim==1), 'input tensor should be 1 dimensional.'
296+
289297
self.file_writer.add_summary(audio(tag, snd_tensor), global_step)
290298
def add_text(self, tag, text_string, global_step=None):
291299
"""Add text data to summary.

tensorboardX/x2num.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,44 @@
1-
def makenp(x, **kwargs):
1+
# DO NOT alter/distruct/free input object !
2+
3+
import numpy as np
4+
5+
def makenp(x, modality=None):
26
# if already numpy, return
3-
pass
7+
if isinstance(x, np.ndarray):
8+
if modality == 'IMG' and x.dtype == np.uint8:
9+
return x.astype(np.float32)/255.0
10+
return x
11+
if isinstance(x, float) or isinstance(x, int):
12+
return np.array([x])
13+
if 'torch' in str(type(x)):
14+
return pytorch_np(x, modality)
415

5-
def pytorch_np():
6-
pass
16+
def pytorch_np(x, modality):
17+
import torch
18+
if isinstance(x, torch.autograd.variable.Variable):
19+
x = x.data
20+
x = x.cpu()
21+
assert isinstance(x, torch.Tensor), 'invalid input type'
22+
if modality == 'IMG':
23+
assert x.dim()<4 and x.dim()>1, 'input tensor should be 3D for color image or 2D for gray image.'
24+
if x.dim()==2:
25+
x = x.unsqueeze(0)
26+
x = x.permute(1,2,0) #CHW to HWC
27+
return x.numpy()
728

8-
def torch_np():
9-
pass
1029

11-
def theano_np():
30+
def theano_np(x):
31+
import theano
1232
pass
1333

14-
def caffe2_np():
34+
def caffe2_np(x):
1535
pass
1636

17-
def mxnet_np():
37+
def mxnet_np(x):
1838
pass
1939

20-
def chainer_np():
40+
def chainer_np(x):
41+
import chainer
42+
#x = chainer.cuda.to_cpu(x.data)
2143
pass
2244

tests/test_pytorch_np.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from tensorboardX import x2num
2+
import torch
3+
import numpy as np
4+
shapes = [(3, 10, 10), (1, ), (1, 2, 3, 4, 5)]
5+
6+
def test_pytorch_np():
7+
8+
for shape in shapes:
9+
# regular tensor
10+
assert(isinstance(x2num.makenp(torch.Tensor(*shape)), np.ndarray))
11+
12+
# CUDA tensor
13+
assert(isinstance(x2num.makenp(torch.Tensor(*shape).cuda()), np.ndarray))
14+
15+
# regular variable
16+
assert(isinstance(x2num.makenp(torch.autograd.variable.Variable(torch.Tensor((*shape)))), np.ndarray))
17+
18+
# CUDA variable
19+
assert(isinstance(x2num.makenp(torch.autograd.variable.Variable(torch.Tensor((*shape))).cuda()), np.ndarray))
20+
21+
# python primitive type
22+
assert(isinstance(x2num.makenp(0), np.ndarray))
23+
assert(isinstance(x2num.makenp(0.1), np.ndarray))
24+

0 commit comments

Comments
 (0)