Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,20 @@ def __init__(self):
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run a forward pass and return class log-probabilities."""
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = F.relu(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
return F.log_softmax(self.fc2(x), dim=1)


def train(args, model, device, train_loader, optimizer, epoch):
def train(args, model: nn.Module, device, train_loader, optimizer, epoch: int) -> None:
"""Train for one epoch."""
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
Expand All @@ -50,7 +47,8 @@ def train(args, model, device, train_loader, optimizer, epoch):
break


def test(model, device, test_loader):
def test(model: nn.Module, device, test_loader) -> None:
"""Evaluate model on the test set."""
model.eval()
test_loss = 0
correct = 0
Expand All @@ -69,7 +67,12 @@ def test(model, device, test_loader):
100. * correct / len(test_loader.dataset)))


def main():
def _mps_available() -> bool:
"""Return True when MPS backend is available."""
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()


def main() -> None:
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
Expand All @@ -94,12 +97,22 @@ def main():
help='For Saving the current Model')
args = parser.parse_args()

use_accel = not args.no_accel and torch.accelerator.is_available()
use_accel = not args.no_accel and (
torch.accelerator.is_available() or torch.cuda.is_available() or _mps_available()
)

torch.manual_seed(args.seed)

if use_accel:
device = torch.accelerator.current_accelerator()
if torch.accelerator.is_available():
device = torch.accelerator.current_accelerator()
elif torch.cuda.is_available():
device = torch.device("cuda")
elif _mps_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
use_accel = False
else:
device = torch.device("cpu")

Expand Down