Skip to content

Commit 63d6112

Browse files
committed
fix: attacker
1 parent 91f36dc commit 63d6112

2 files changed

Lines changed: 166 additions & 1 deletion

File tree

src/modelinversion/attack/Lokt/__init__.py

Whitespace-only changes.

src/modelinversion/attack/base.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,168 @@
1+
import os
2+
from abc import ABCMeta, abstractmethod
3+
from collections import OrderedDict
4+
from dataclasses import dataclass, field
15

6+
import torch
7+
from tqdm import tqdm
8+
import yaml
9+
10+
from ..models import *
11+
from ..foldermanager import FolderManager
12+
from ..metrics.base import *
13+
from ..utils import DictAccumulator, Accumulator
214
from ..enums import TqdmStrategy
3-
from ..trainer import BaseGANTrainArgs, BaseGANTrainer
15+
from ..trainer import BaseGANTrainArgs, BaseGANTrainer
16+
17+
@dataclass
18+
class BaseAttackConfig:
19+
20+
# classifier
21+
target_name: str
22+
eval_name: str
23+
24+
# folders
25+
ckpt_dir: str
26+
result_dir: str
27+
dataset_dir: str
28+
cache_dir: str
29+
defense_ckpt_dir: str = None
30+
31+
# dataset
32+
dataset_name: str = 'celeba'
33+
34+
# misc
35+
defense_type: str = 'no_defense'
36+
device: str = 'cpu'
37+
38+
class BaseAttacker(metaclass=ABCMeta):
39+
40+
def __init__(self, config: BaseAttackConfig) -> None:
41+
self.config = config
42+
tag = self.get_tag()
43+
cache_dir = os.path.join(config.cache_dir, tag)
44+
result_dir = os.path.join(config.result_dir, tag)
45+
46+
self.folder_manager = FolderManager(config.ckpt_dir, config.dataset_dir, cache_dir, result_dir, config.defense_ckpt_dir, config.defense_type)
47+
48+
self.prepare_classifiers()
49+
50+
print('--------------- config --------------')
51+
print(config)
52+
print('-------------------------------------')
53+
54+
def register_dirs(self, dirs: dict):
55+
for k, v in dirs.items():
56+
os.makedirs(v, exist_ok=True)
57+
setattr(self.folder_manager.config, k, v)
58+
59+
@abstractmethod
60+
def get_tag(self) -> str:
61+
raise NotImplementedError()
62+
63+
@abstractmethod
64+
def prepare_attack(self):
65+
raise NotImplementedError()
66+
67+
@abstractmethod
68+
def attack_step(self, iden) -> dict:
69+
raise NotImplementedError()
70+
71+
def prepare_classifiers(self):
72+
73+
config = self.config
74+
folder_manager = self.folder_manager
75+
76+
self.T = get_model(config.target_name, config.dataset_name, device=config.device, defense_type=config.defense_type)
77+
folder_manager.load_target_model_state_dict(self.T, config.dataset_name, config.target_name, device=config.device, defense_type=config.defense_type)
78+
79+
self.E = get_model(config.eval_name, config.dataset_name, device=config.device)
80+
folder_manager.load_target_model_state_dict(self.E, config.dataset_name, config.eval_name, device=config.device)
81+
82+
self.T.eval()
83+
self.E.eval()
84+
85+
86+
def attack(self, batch_size: int, target_labels: list):
87+
88+
self.batch_size = batch_size
89+
self.target_labels = target_labels
90+
91+
self.prepare_attack()
92+
93+
config = self.config
94+
95+
print("=> Begin attacking ...")
96+
# aver_acc, aver_acc5, aver_var, aver_var5 = 0, 0, 0, 0
97+
98+
total_num = len(target_labels)
99+
100+
accumulator = DictAccumulator()
101+
102+
if total_num > 0:
103+
for idx in range((total_num - 1) // batch_size + 1):
104+
print("--------------------- Attack batch [%s]------------------------------" % idx)
105+
iden = torch.tensor(
106+
target_labels[idx * batch_size: min((idx+1)*batch_size, total_num)], device=config.device, dtype=torch.long
107+
)
108+
109+
update_dict = self.attack_step(iden)
110+
111+
bs = len(iden)
112+
113+
accumulator.add(update_dict)
114+
115+
for key, val in accumulator.avg().items():
116+
print(f'average {key}: {val:.6f}')
117+
118+
def evaluation(self, batch_size, transform=None, knn=True, feature_distance=True, fid=False):
119+
eval_metrics = []
120+
121+
if knn:
122+
eval_metrics.append(KnnDistanceMetric(self.folder_manager, device=self.config.device, model=self.E))
123+
124+
if feature_distance:
125+
eval_metrics.append(FeatureDistanceMetric(self.folder_manager, self.config.device, model=self.E))
126+
127+
if fid:
128+
eval_metrics.append(FIDMetric(self.folder_manager, device=self.config.device, model=None))
129+
130+
for metric in eval_metrics:
131+
metric: BaseMetric
132+
print(f'calculate {metric.get_metric_name()}')
133+
metric.evaluation(self.config.dataset_name, batch_size, transform)
134+
135+
class BaseSingleLabelAttacker(BaseAttacker):
136+
137+
def __init__(self, config: BaseAttackConfig) -> None:
138+
super().__init__(config)
139+
140+
def attack_step(self, target) -> dict:
141+
return super().attack_step(target)
142+
143+
def attack(self, batch_size: int, target_labels: list):
144+
145+
self.batch_size = batch_size
146+
self.target_labels = target_labels
147+
148+
self.prepare_attack()
149+
150+
config = self.config
151+
152+
print("=> Begin attacking ...")
153+
154+
total_num = len(target_labels)
155+
156+
accumulator = DictAccumulator()
157+
158+
if total_num > 0:
159+
for target in target_labels:
160+
print(f"--------------------- Attack label [{target}]------------------------------")
161+
162+
update_dict = self.attack_step(target)
163+
164+
165+
accumulator.add(update_dict)
166+
167+
for key, val in accumulator.avg().items():
168+
print(f'average {key}: {val:.6f}')

0 commit comments

Comments
 (0)