Skip to content

Commit 68a2e80

Browse files
author
jamie-mcg
committed
Fixed a couple of bugs.
1 parent 23db98c commit 68a2e80

File tree

3 files changed

+79
-69
lines changed

3 files changed

+79
-69
lines changed

examples/autoencoder.py

Lines changed: 72 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,7 @@ def __getitem__(self, idx):
126126
if torch.is_tensor(idx):
127127
idx = idx.tolist()
128128

129-
return torch.tensor(self._images[idx]), torch.tensor(
130-
self._labels[idx]
131-
)
129+
return torch.tensor(self._images[idx]), torch.tensor(self._labels[idx])
132130

133131

134132
def _read32(bytestream):
@@ -225,19 +223,19 @@ class DataSets(object):
225223

226224
SOURCE_URL = "http://www.cs.toronto.edu/~jmartens/"
227225
TRAIN_IMAGES = "newfaces_rot_single.mat"
228-
226+
229227
local_file = maybe_download(SOURCE_URL, TRAIN_IMAGES, train_dir)
230228
print(f"Data read from {local_file}")
231-
232-
numpy_file = os.path.dirname(local_file) + '/faces.npy'
229+
230+
numpy_file = os.path.dirname(local_file) + "/faces.npy"
233231
if os.path.exists(numpy_file):
234232
images_ = np.load(numpy_file)
235233
else:
236234
import mat4py
237235

238236
images_ = mat4py.loadmat(local_file)
239237
images_ = np.asarray(images_["newfaces_single"])
240-
238+
241239
images_ = np.transpose(images_)
242240
np.save(numpy_file, images_)
243241
print(f"Data saved to {numpy_file}")
@@ -276,10 +274,9 @@ class DataSets(object):
276274
if __name__ == "__main__":
277275

278276
argparser = argparse.ArgumentParser()
279-
argparser.add_argument('--exp', type=str, help='which dataset', default='FACES')
277+
argparser.add_argument("--exp", type=str, help="which dataset", default="FACES")
280278
args = argparser.parse_args()
281279

282-
283280
seed = 13
284281
torch.manual_seed(seed)
285282
torch.backends.cudnn.benchmark = False
@@ -291,7 +288,7 @@ class DataSets(object):
291288
print("device", device)
292289

293290
## Hyperparams
294-
if args.exp == 'FACES':
291+
if args.exp == "FACES":
295292

296293
batch_size = 100
297294
epochs = 5
@@ -304,8 +301,8 @@ class DataSets(object):
304301
damping = 1.0
305302

306303
dataset = read_data_sets("FACES", "../data/", if_autoencoder=True)
307-
308-
if args.exp == 'MNIST':
304+
305+
if args.exp == "MNIST":
309306
batch_size = 100
310307
epochs = 10
311308
eta_adam = 1e-4
@@ -321,19 +318,23 @@ class DataSets(object):
321318
## Dataset
322319
train_dataset = dataset.train
323320
test_dataset = dataset.test
324-
if args.exp == 'FACES':
325-
likelihood = FISH_LIKELIHOODS['fixedgaussian'](sigma=1.0, device=device)
321+
if args.exp == "FACES":
322+
likelihood = FISH_LIKELIHOODS["fixedgaussian"](sigma=1.0, device=device)
323+
326324
def mse(model, data):
327325
data_x, data_y = data
328326
pred_y = model.forward(data_x)
329-
return torch.mean(torch.square(pred_y-data_y))
330-
if args.exp == 'MNIST':
331-
likelihood = FISH_LIKELIHOODS['bernoulli'](device=device)
327+
return torch.mean(torch.square(pred_y - data_y))
328+
329+
if args.exp == "MNIST":
330+
likelihood = FISH_LIKELIHOODS["bernoulli"](device=device)
331+
332332
def mse(model, data):
333333
data_x, data_y = data
334334
pred_y = model.forward(data_x)
335335
pred_y = torch.sigmoid(pred_y)
336-
return torch.mean(torch.square(pred_y-data_y))
336+
return torch.mean(torch.square(pred_y - data_y))
337+
337338
def nll(model, data):
338339
data_x, data_y = data
339340
pred_y = model.forward(data_x)
@@ -344,7 +345,6 @@ def draw(model, data):
344345
pred_y = model.forward(data_x)
345346
return (data_x, likelihood.draw(pred_y))
346347

347-
348348
train_loader = torch.utils.data.DataLoader(
349349
train_dataset, batch_size=batch_size, shuffle=True
350350
)
@@ -360,43 +360,42 @@ def draw(model, data):
360360
test_dataset, batch_size=1000, shuffle=False
361361
)
362362

363-
364-
if args.exp == 'FACES':
363+
if args.exp == "FACES":
365364
model = nn.Sequential(
366-
nn.Linear(625, 2000),
367-
nn.ReLU(),
368-
nn.Linear(2000, 1000),
369-
nn.ReLU(),
370-
nn.Linear(1000, 500),
371-
nn.ReLU(),
372-
nn.Linear(500, 30),
373-
nn.Linear(30, 500),
374-
nn.ReLU(),
375-
nn.Linear(500, 1000),
376-
nn.ReLU(),
377-
nn.Linear(1000, 2000),
378-
nn.ReLU(),
379-
nn.Linear(2000, 625),
365+
nn.Linear(625, 2000),
366+
nn.ReLU(),
367+
nn.Linear(2000, 1000),
368+
nn.ReLU(),
369+
nn.Linear(1000, 500),
370+
nn.ReLU(),
371+
nn.Linear(500, 30),
372+
nn.Linear(30, 500),
373+
nn.ReLU(),
374+
nn.Linear(500, 1000),
375+
nn.ReLU(),
376+
nn.Linear(1000, 2000),
377+
nn.ReLU(),
378+
nn.Linear(2000, 625),
380379
).to(device)
381-
382-
if args.exp == 'MNIST':
380+
381+
if args.exp == "MNIST":
383382
model = nn.Sequential(
384-
nn.Linear(784, 1000, dtype=torch.float32),
385-
nn.ReLU(),
386-
nn.Linear(1000, 500, dtype=torch.float32),
387-
nn.ReLU(),
388-
nn.Linear(500, 250, dtype=torch.float32),
389-
nn.ReLU(),
390-
nn.Linear(250, 30, dtype=torch.float32),
391-
nn.Linear(30, 250, dtype=torch.float32),
392-
nn.ReLU(),
393-
nn.Linear(250, 500, dtype=torch.float32),
394-
nn.ReLU(),
395-
nn.Linear(500, 1000, dtype=torch.float32),
396-
nn.ReLU(),
397-
nn.Linear(1000, 784, dtype=torch.float32),
383+
nn.Linear(784, 1000, dtype=torch.float32),
384+
nn.ReLU(),
385+
nn.Linear(1000, 500, dtype=torch.float32),
386+
nn.ReLU(),
387+
nn.Linear(500, 250, dtype=torch.float32),
388+
nn.ReLU(),
389+
nn.Linear(250, 30, dtype=torch.float32),
390+
nn.Linear(30, 250, dtype=torch.float32),
391+
nn.ReLU(),
392+
nn.Linear(250, 500, dtype=torch.float32),
393+
nn.ReLU(),
394+
nn.Linear(500, 1000, dtype=torch.float32),
395+
nn.ReLU(),
396+
nn.Linear(1000, 784, dtype=torch.float32),
398397
).to(device)
399-
398+
400399
model_adam = copy.deepcopy(model)
401400

402401
print("lr fl={}, lr sgd={}, lr aux={}".format(eta_fl, eta_sgd, aux_eta))
@@ -417,10 +416,15 @@ def draw(model, data):
417416
damping=damping,
418417
pre_aux_training=0,
419418
sgd_lr=eta_sgd,
420-
initialization='normal',
419+
initialization="normal",
421420
device=device,
422421
)
423422

423+
print(opt.__dict__["fish_lr"])
424+
print(opt.__dict__["beta"])
425+
print(opt.__dict__["aux_lr"])
426+
print(opt.__dict__["damping"])
427+
print(opt.__dict__["sgd_lr"])
424428

425429
FL_time = []
426430
LOSS = []
@@ -429,7 +433,7 @@ def draw(model, data):
429433
iteration = 0
430434
for e in range(1, epochs + 1):
431435
print("######## EPOCH", e)
432-
for n, (batch_data, batch_labels) in enumerate(train_loader):
436+
for n, (batch_data, batch_labels) in enumerate(train_loader, start=1):
433437
iteration += 1
434438
batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
435439
opt.zero_grad()
@@ -440,16 +444,18 @@ def draw(model, data):
440444
if n % 50 == 0:
441445
FL_time.append(time.time() - st)
442446
LOSS.append(loss.detach().cpu().numpy())
443-
447+
444448
test_batch_data, test_batch_labels = next(iter(test_loader))
445-
test_batch_data, test_batch_labels = test_batch_data.to(device), test_batch_labels.to(device)
449+
test_batch_data, test_batch_labels = test_batch_data.to(
450+
device
451+
), test_batch_labels.to(device)
446452
test_loss = mse(opt.model, (test_batch_data, test_batch_labels))
447-
453+
448454
TEST_LOSS.append(test_loss.detach().cpu().numpy())
449455

450456
print(n, LOSS[-1], TEST_LOSS[-1])
451-
452-
fig, axs = plt.subplots(1,2, figsize=(10,5))
457+
458+
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
453459
axs[0].plot(FL_time, LOSS, label="Fishleg") # color=colors_group[i])
454460
axs[1].plot(
455461
FL_time, TEST_LOSS, label="Fishleg"
@@ -478,21 +484,23 @@ def draw(model, data):
478484
opt.step()
479485

480486
if n % 50 == 0:
481-
FL_time.append(time.time()-st)
487+
FL_time.append(time.time() - st)
482488
LOSS.append(loss.detach().cpu().numpy())
483489
test_batch_data, test_batch_labels = next(iter(test_loader_adam))
484-
test_batch_data, test_batch_labels = test_batch_data.to(device), test_batch_labels.to(device)
490+
test_batch_data, test_batch_labels = test_batch_data.to(
491+
device
492+
), test_batch_labels.to(device)
485493
test_loss = mse(model_adam, (test_batch_data, test_batch_labels))
486494
TEST_LOSS.append(test_loss.detach().cpu().numpy())
487495

488496
print(n, LOSS[-1], TEST_LOSS[-1])
489497

490498
axs[0].plot(FL_time, LOSS, label="Adam")
491499
axs[1].plot(FL_time, TEST_LOSS, label="Adam")
492-
500+
493501
axs[0].legend()
494502
axs[1].legend()
495503

496-
axs[0].set_title('Training Loss')
497-
axs[1].set_title('Test MSE')
504+
axs[0].set_title("Training Loss")
505+
axs[1].set_title("Test MSE")
498506
fig.savefig("result/result.png", dpi=300)

src/optim/FishLeg/fishleg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def update_aux(self) -> None:
304304
self.aux_opt.zero_grad()
305305
with torch.no_grad():
306306
self.store_g = False
307-
samples = self.draw(self.model, data[0])
307+
samples = self.draw(self.model, data)
308308
self.store_g = True
309309

310310
g2 = 0.0
@@ -315,13 +315,14 @@ def update_aux(self) -> None:
315315
else:
316316
grad_norm = [0 * p.grad.data for p in group["params"]]
317317

318-
qg = group["Qv"](grad_norm)
318+
g_norm = torch.sqrt(g2)
319+
# print(g_norm)
319320

320321
self.zero_grad()
321322
# How to better implement this?
322323
# The hook is not updated here, locally, only the gradient to the parameters g.grad is updated
323324
self.store_g = False
324-
self.nll(self.model, samples, data[1]).backward()
325+
self.nll(self.model, samples).backward()
325326
self.store_g = True
326327

327328
gm_norm = 0.0

src/optim/FishLeg/fishleg_layers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def diagQ(self) -> Tensor:
146146
the Kronecker product.
147147
148148
"""
149-
L = torch.sqrt(self.fishleg_aux["scale"]) * self.fishleg_aux["L"]
150-
R = torch.sqrt(self.fishleg_aux["scale"]) * self.fishleg_aux["R"]
149+
L = self.fishleg_aux["L"]
150+
R = self.fishleg_aux["R"]
151+
print(L)
151152
return torch.kron(torch.sum(R * R, axis=1), torch.sum(L * L, axis=1))

0 commit comments

Comments
 (0)