Skip to content

Commit b4d84aa

Browse files
committed
fix lanpa#5. Unified API
1 parent 3f9532f commit b4d84aa

File tree

4 files changed

+71
-144
lines changed

4 files changed

+71
-144
lines changed

demo_embedding.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from tensorboard import SummaryWriter
88
from datetime import datetime
99
from torch.utils.data import TensorDataset,DataLoader
10-
from tensorboard.embedding import EmbeddingWriter
1110
import os
1211

1312
#EMBEDDING VISUALIZATION FOR A TWO-CLASSES PROBLEM
@@ -52,37 +51,38 @@ def get_data(value,shape):
5251
optimizer = Adam(params=m.parameters())
5352
#settings for train and log
5453
num_epochs = 20
55-
num_batches = len(gen)
5654
embedding_log = 5
57-
#WE NEED A WRITER! BECAUSE TB LOOK FOR IT!
5855
writer_name = datetime.now().strftime('%B%d %H:%M:%S')
5956
writer = SummaryWriter(os.path.join("runs",writer_name))
60-
#our brand new embwriter in the same dir
61-
embedding_writer = EmbeddingWriter(os.path.join("runs",writer_name))
57+
6258
#TRAIN
63-
for i in range(num_epochs):
59+
for epoch in range(num_epochs):
6460
for j,sample in enumerate(gen):
61+
n_iter = (epoch*len(gen))+j
6562
#reset grad
6663
m.zero_grad()
6764
optimizer.zero_grad()
6865
#get batch data
69-
data_batch = Variable(sample[0],requires_grad=True).float()
70-
label_batch = Variable(sample[1],requires_grad=False).long()
66+
data_batch = Variable(sample[0], requires_grad=True).float()
67+
label_batch = Variable(sample[1], requires_grad=False).long()
7168
#FORWARD
7269
out = m(data_batch)
73-
loss_value = loss(out,label_batch)
70+
loss_value = loss(out, label_batch)
7471
#BACKWARD
7572
loss_value.backward()
7673
optimizer.step()
7774
#LOGGING
75+
writer.add_scalar('loss', loss_value.data[0], n_iter)
76+
7877
if j % embedding_log == 0:
7978
print("loss_value:{}".format(loss_value.data[0]))
8079
#we need 3 dimension for tensor to visualize it!
8180
out = torch.cat((out,torch.ones(len(out),1)),1)
8281
#write the embedding for the timestep
83-
embedding_writer.add_embedding(out.data,metadata=label_batch.data,label_img=data_batch.data,timestep=(i*num_batches)+j)
82+
writer.add_embedding(out.data, metadata=label_batch.data, label_img=data_batch.data, global_step=n_iter)
8483

8584
writer.close()
8685

8786
#tensorboard --logdir runs
88-
#you should now see a dropdown list with all the timestep, latest timestep should have a visible separation between the two classes
87+
#you should now see a dropdown list with all the timestep,
88+
# last timestep should have a visible separation between the two classes

docs/tensorboard.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@ tensorboard-pytorch
66
:members:
77

88
.. automethod:: __init__
9-
.. autofunction:: tensorboard.embedding.add_embedding

tensorboard/embedding.py

Lines changed: 8 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,17 @@ def make_sprite(label_img, save_path):
2828
else:
2929
torchvision.utils.save_image(label_img, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)
3030

31-
def make_pbtxt(save_path, metadata, label_img):
32-
with open(os.path.join(save_path, 'projector_config.pbtxt'), 'w') as f:
31+
def append_pbtxt(metadata, label_img, save_path, global_step):
32+
with open(os.path.join(save_path, 'projector_config.pbtxt'), 'a') as f:
33+
#step = os.path.split(save_path)[-1]
3334
f.write('embeddings {\n')
34-
f.write('tensor_name: "embedding:0"\n')
35-
f.write('tensor_path: "tensors.tsv"\n')
35+
f.write('tensor_name: "embedding:{}"\n'.format(global_step))
36+
f.write('tensor_path: "{}"\n'.format(os.path.join(global_step,"tensors.tsv")))
3637
if metadata is not None:
37-
f.write('metadata_path: "metadata.tsv"\n')
38+
f.write('metadata_path: "{}"\n'.format(os.path.join(global_step,"metadata.tsv")))
3839
if label_img is not None:
3940
f.write('sprite {\n')
40-
f.write('image_path: "sprite.png"\n')
41+
f.write('image_path: "{}"\n'.format(os.path.join(global_step,"sprite.png")))
4142
f.write('single_image_dim: {}\n'.format(label_img.size(3)))
4243
f.write('single_image_dim: {}\n'.format(label_img.size(2)))
4344
f.write('}\n')
@@ -47,129 +48,4 @@ def make_mat(matlist, save_path):
4748
with open(os.path.join(save_path, 'tensors.tsv'), 'w') as f:
4849
for x in matlist:
4950
x = [str(i) for i in x]
50-
f.write('\t'.join(x) + '\n')
51-
52-
def add_embedding(mat, save_path, metadata=None, label_img=None):
53-
"""add embedding
54-
55-
Args:
56-
mat (torch.Tensor): A matrix which each row is the feature vector of the data point
57-
save_path (string): Save path (use ``writer.file_writer.get_logdir()`` to show embedding along with other summaries)
58-
metadata (list): A list of labels, each element will be convert to string
59-
label_img (torch.Tensor): Images correspond to each data point
60-
Shape:
61-
mat: :math:`(N, D)`, where N is number of data and D is feature dimension
62-
63-
label_img: :math:`(N, C, H, W)`
64-
65-
.. note::
66-
~~This function needs tensorflow installed. It invokes tensorflow to dump data. ~~
67-
Therefore I separate it from the SummaryWriter class. Please pass ``writer.file_writer.get_logdir()`` to ``save_path`` to prevent glitches.
68-
69-
If ``save_path`` is different than SummaryWritter's save path, you need to pass the leave directory to tensorboard's logdir argument,
70-
otherwise it cannot display anything. e.g. if ``save_path`` equals 'path/to/embedding',
71-
you need to call 'tensorboard --logdir=path/to/embedding', instead of 'tensorboard --logdir=path'.
72-
73-
74-
Examples::
75-
76-
from tensorboard.embedding import add_embedding
77-
import keyword
78-
import torch
79-
meta = []
80-
while len(meta)<100:
81-
meta = meta+keyword.kwlist # get some strings
82-
meta = meta[:100]
83-
84-
for i, v in enumerate(meta):
85-
meta[i] = v+str(i)
86-
87-
label_img = torch.rand(100, 3, 10, 32)
88-
for i in range(100):
89-
label_img[i]*=i/100.0
90-
91-
add_embedding(torch.randn(100, 5), 'embedding1', metadata=meta, label_img=label_img)
92-
add_embedding(torch.randn(100, 5), 'embedding2', label_img=label_img)
93-
add_embedding(torch.randn(100, 5), 'embedding3', metadata=meta)
94-
"""
95-
try:
96-
os.makedirs(save_path)
97-
except OSError:
98-
print('warning: dir exists')
99-
if metadata is not None:
100-
assert mat.size(0)==len(metadata), '#labels should equal with #data points'
101-
make_tsv(metadata, save_path)
102-
if label_img is not None:
103-
assert mat.size(0)==label_img.size(0), '#images should equal with #data points'
104-
make_sprite(label_img, save_path)
105-
assert mat.dim()==2, 'mat should be 2D, where mat.size(0) is the number of data points'
106-
make_mat(mat.tolist(), save_path)
107-
make_pbtxt(save_path, metadata, label_img)
108-
109-
def append_pbtxt(f, metadata, label_img,path):
110-
111-
f.write('embeddings {\n')
112-
f.write('tensor_name: "{}"\n'.format(os.path.join(path,"embedding")))
113-
f.write('tensor_path: "{}"\n'.format(os.path.join(path,"tensors.tsv")))
114-
if metadata is not None:
115-
f.write('metadata_path: "{}"\n'.format(os.path.join(path,"metadata.tsv")))
116-
if label_img is not None:
117-
f.write('sprite {\n')
118-
f.write('image_path: "{}"\n'.format(os.path.join(path,"sprite.png")))
119-
f.write('single_image_dim: {}\n'.format(label_img.size(3)))
120-
f.write('single_image_dim: {}\n'.format(label_img.size(2)))
121-
f.write('}\n')
122-
f.write('}\n')
123-
124-
125-
class EmbeddingWriter(object):
126-
"""
127-
Class to allow writing embeddings ad defined timestep
128-
129-
"""
130-
def __init__(self,save_path):
131-
"""
132-
133-
:param save_path: should be the same path of you SummaryWriter
134-
"""
135-
self.save_path = save_path
136-
#make dir if needed, it should not
137-
try:
138-
os.makedirs(save_path)
139-
except OSError:
140-
print('warning: dir exists')
141-
#create config file to store all embeddings conf
142-
self.f = open(os.path.join(save_path, 'projector_config.pbtxt'), 'w')
143-
144-
def add_embedding(self,mat, metadata=None, label_img=None,timestep=0):
145-
"""
146-
add an embedding at the defined timestep
147-
148-
:param mat:
149-
:param metadata:
150-
:param label_img:
151-
:param timestep:
152-
:return:
153-
"""
154-
# TODO make doc
155-
#path to the new subdir
156-
timestep_path = "{}".format(timestep)
157-
# TODO should this be handled?
158-
os.makedirs(os.path.join(self.save_path,timestep_path))
159-
#check other info
160-
#save all this metadata in the new subfolder
161-
if metadata is not None:
162-
assert mat.size(0) == len(metadata), '#labels should equal with #data points'
163-
make_tsv(metadata, os.path.join(self.save_path,timestep_path))
164-
if label_img is not None:
165-
assert mat.size(0) == label_img.size(0), '#images should equal with #data points'
166-
make_sprite(label_img, os.path.join(self.save_path,timestep_path))
167-
assert mat.dim() == 2, 'mat should be 2D, where mat.size(0) is the number of data points'
168-
make_mat(mat.tolist(), os.path.join(self.save_path,timestep_path))
169-
#new funcion to append to the config file a new embedding
170-
append_pbtxt(self.f, metadata, label_img,timestep_path)
171-
172-
173-
def __del__(self):
174-
#close the file at the end of the script
175-
self.f.close()
51+
f.write('\t'.join(x) + '\n')

tensorboard/writer.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .event_file_writer import EventFileWriter
2828
from .summary import scalar, histogram, image, audio, text
2929
from .graph import graph
30+
from .embedding import make_mat, make_sprite, make_tsv, append_pbtxt
3031

3132

3233
class SummaryToEventTransformer(object):
@@ -329,6 +330,57 @@ def add_graph(self, model, lastVar):
329330
return
330331
self.file_writer.add_graph(graph(model, lastVar))
331332

333+
def add_embedding(self, mat, metadata=None, label_img=None, global_step=None):
334+
"""add embedding
335+
336+
Args:
337+
mat (torch.Tensor): A matrix which each row is the feature vector of the data point
338+
metadata (list): A list of labels, each element will be convert to string
339+
label_img (torch.Tensor): Images correspond to each data point
340+
global_step (int): Global step value to record
341+
Shape:
342+
mat: :math:`(N, D)`, where N is number of data and D is feature dimension
343+
344+
label_img: :math:`(N, C, H, W)`
345+
346+
Examples::
347+
348+
import keyword
349+
import torch
350+
meta = []
351+
while len(meta)<100:
352+
meta = meta+keyword.kwlist # get some strings
353+
meta = meta[:100]
354+
355+
for i, v in enumerate(meta):
356+
meta[i] = v+str(i)
357+
358+
label_img = torch.rand(100, 3, 10, 32)
359+
for i in range(100):
360+
label_img[i]*=i/100.0
361+
362+
writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
363+
writer.add_embedding(torch.randn(100, 5), label_img=label_img)
364+
writer.add_embedding(torch.randn(100, 5), metadata=meta)
365+
"""
366+
if global_step == None:
367+
global_step = 0
368+
# clear pbtxt?
369+
save_path = os.path.join(self.file_writer.get_logdir(), str(global_step).zfill(5))
370+
try:
371+
os.makedirs(save_path)
372+
except OSError:
373+
print('warning: Embedding dir exists, did you set global_step for add_embedding()?')
374+
if metadata is not None:
375+
assert mat.size(0) == len(metadata), '#labels should equal with #data points'
376+
make_tsv(metadata, save_path)
377+
if label_img is not None:
378+
assert mat.size(0) == label_img.size(0), '#images should equal with #data points'
379+
make_sprite(label_img, save_path)
380+
assert mat.dim() == 2, 'mat should be 2D, where mat.size(0) is the number of data points'
381+
make_mat(mat.tolist(), save_path)
382+
#new funcion to append to the config file a new embedding
383+
append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), str(global_step).zfill(5))
332384

333385
def close(self):
334386
self.file_writer.flush()

0 commit comments

Comments
 (0)