-
Notifications
You must be signed in to change notification settings - Fork 72
Expand file tree
/
Copy pathminimal_example.py
More file actions
77 lines (62 loc) · 2.44 KB
/
minimal_example.py
File metadata and controls
77 lines (62 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Example script to run attacks in this repository directly without simulation.
This can be useful if you want to check a model architecture and model gradients computed/defended in some shape or form
against some of the attacks implemented in this repository, without implementing your model into the simulation.
All caveats apply. Make sure not to leak any unexpected information.
"""
import torch
import torchvision
import breaching
from omegaconf import DictConfig
data_cfg_default = DictConfig({
"modality": "vision",
"task": "classification",
"size": 1_281_167,
"classes": 1000,
"shape": [3, 224, 224],
"normalize": True,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
})
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=data_cfg_default.mean, std=data_cfg_default.std),
]
)
def main():
setup = dict(device=torch.device("cpu"), dtype=torch.float)
# This could be your model:
model = torchvision.models.resnet152(pretrained=True)
model.eval()
loss_fn = torch.nn.CrossEntropyLoss()
# And your dataset:
dataset = torchvision.datasets.ImageNet(root="~/data/imagenet", split="val", transform=transforms)
datapoint, label = dataset[1200] # This is the owl, just for the sake of this experiment
labels = torch.as_tensor(label)[None, ...]
# This is the attacker:
cfg_attack = breaching.get_attack_config("invertinggradients")
attacker = breaching.attacks.prepare_attack(model, loss_fn, cfg_attack, setup)
# ## Simulate an attacked FL protocol
# Server-side computation:
server_payload = [
dict(
parameters=[p for p in model.parameters()], buffers=[b for b in model.buffers()], metadata=data_cfg_default
)
]
# User-side computation:
loss = loss_fn(model(datapoint[None, ...]), labels)
shared_data = [
dict(
gradients=torch.autograd.grad(loss, model.parameters()),
buffers=None,
metadata=dict(num_data_points=1, labels=labels, local_hyperparams=None,),
)
]
# Attack:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, {}, dryrun=False)
# Do some processing of your choice here. Maybe save the output image?
if __name__ == "__main__":
main()