Skip to content

Commit 91f36dc

Browse files
committed
fix: gan train error
1 parent 1b115f1 commit 91f36dc

1 file changed

Lines changed: 1 addition & 284 deletions

File tree

src/modelinversion/attack/base.py

Lines changed: 1 addition & 284 deletions
Original file line numberDiff line numberDiff line change
@@ -1,286 +1,3 @@
1-
import os
2-
from abc import ABCMeta, abstractmethod
3-
from collections import OrderedDict
4-
from dataclasses import dataclass, field
51

6-
import torch
7-
from tqdm import tqdm
8-
import yaml
9-
10-
from ..models import *
11-
from ..foldermanager import FolderManager
12-
from ..metrics.base import *
13-
from ..utils import DictAccumulator, Accumulator
142
from ..enums import TqdmStrategy
15-
16-
@dataclass
17-
class BaseAttackConfig:
18-
19-
# classifier
20-
target_name: str
21-
eval_name: str
22-
23-
# folders
24-
ckpt_dir: str
25-
result_dir: str
26-
dataset_dir: str
27-
cache_dir: str
28-
defense_ckpt_dir: str = None
29-
30-
# dataset
31-
dataset_name: str = 'celeba'
32-
33-
# misc
34-
defense_type: str = 'no_defense'
35-
device: str = 'cpu'
36-
37-
class BaseAttacker(metaclass=ABCMeta):
38-
39-
def __init__(self, config: BaseAttackConfig) -> None:
40-
self.config = config
41-
tag = self.get_tag()
42-
cache_dir = os.path.join(config.cache_dir, tag)
43-
result_dir = os.path.join(config.result_dir, tag)
44-
45-
self.folder_manager = FolderManager(config.ckpt_dir, config.dataset_dir, cache_dir, result_dir, config.defense_ckpt_dir, config.defense_type)
46-
47-
self.prepare_classifiers()
48-
49-
print('--------------- config --------------')
50-
print(config)
51-
print('-------------------------------------')
52-
53-
def register_dirs(self, dirs: dict):
54-
for k, v in dirs.items():
55-
os.makedirs(v, exist_ok=True)
56-
setattr(self.folder_manager.config, k, v)
57-
58-
@abstractmethod
59-
def get_tag(self) -> str:
60-
raise NotImplementedError()
61-
62-
@abstractmethod
63-
def prepare_attack(self):
64-
raise NotImplementedError()
65-
66-
@abstractmethod
67-
def attack_step(self, iden) -> dict:
68-
raise NotImplementedError()
69-
70-
def prepare_classifiers(self):
71-
72-
config = self.config
73-
folder_manager = self.folder_manager
74-
75-
self.T = get_model(config.target_name, config.dataset_name, device=config.device, defense_type=config.defense_type)
76-
folder_manager.load_target_model_state_dict(self.T, config.dataset_name, config.target_name, device=config.device, defense_type=config.defense_type)
77-
78-
self.E = get_model(config.eval_name, config.dataset_name, device=config.device)
79-
folder_manager.load_target_model_state_dict(self.E, config.dataset_name, config.eval_name, device=config.device)
80-
81-
self.T.eval()
82-
self.E.eval()
83-
84-
85-
def attack(self, batch_size: int, target_labels: list):
86-
87-
self.batch_size = batch_size
88-
self.target_labels = target_labels
89-
90-
self.prepare_attack()
91-
92-
config = self.config
93-
94-
print("=> Begin attacking ...")
95-
# aver_acc, aver_acc5, aver_var, aver_var5 = 0, 0, 0, 0
96-
97-
total_num = len(target_labels)
98-
99-
accumulator = DictAccumulator()
100-
101-
if total_num > 0:
102-
for idx in range((total_num - 1) // batch_size + 1):
103-
print("--------------------- Attack batch [%s]------------------------------" % idx)
104-
iden = torch.tensor(
105-
target_labels[idx * batch_size: min((idx+1)*batch_size, total_num)], device=config.device, dtype=torch.long
106-
)
107-
108-
update_dict = self.attack_step(iden)
109-
110-
bs = len(iden)
111-
112-
accumulator.add(update_dict)
113-
114-
for key, val in accumulator.avg().items():
115-
print(f'average {key}: {val:.6f}')
116-
117-
def evaluation(self, batch_size, transform=None, knn=True, feature_distance=True, fid=False):
118-
eval_metrics = []
119-
120-
if knn:
121-
eval_metrics.append(KnnDistanceMetric(self.folder_manager, device=self.config.device, model=self.E))
122-
123-
if feature_distance:
124-
eval_metrics.append(FeatureDistanceMetric(self.folder_manager, self.config.device, model=self.E))
125-
126-
if fid:
127-
eval_metrics.append(FIDMetric(self.folder_manager, device=self.config.device, model=None))
128-
129-
for metric in eval_metrics:
130-
metric: BaseMetric
131-
print(f'calculate {metric.get_metric_name()}')
132-
metric.evaluation(self.config.dataset_name, batch_size, transform)
133-
134-
class BaseSingleLabelAttacker(BaseAttacker):
135-
136-
def __init__(self, config: BaseAttackConfig) -> None:
137-
super().__init__(config)
138-
139-
def attack_step(self, target) -> dict:
140-
return super().attack_step(target)
141-
142-
def attack(self, batch_size: int, target_labels: list):
143-
144-
self.batch_size = batch_size
145-
self.target_labels = target_labels
146-
147-
self.prepare_attack()
148-
149-
config = self.config
150-
151-
print("=> Begin attacking ...")
152-
153-
total_num = len(target_labels)
154-
155-
accumulator = DictAccumulator()
156-
157-
if total_num > 0:
158-
for target in target_labels:
159-
print(f"--------------------- Attack label [{target}]------------------------------")
160-
161-
update_dict = self.attack_step(target)
162-
163-
164-
accumulator.add(update_dict)
165-
166-
for key, val in accumulator.avg().items():
167-
print(f'average {key}: {val:.6f}')
168-
169-
@dataclass
170-
class BaseGANTrainArgs:
171-
172-
dataset_name: str
173-
batch_size: int
174-
epoch_num: int
175-
176-
dis_gen_update_rate: int = 5
177-
tqdm_strategy: TqdmStrategy = TqdmStrategy.ITER
178-
defense_type: str = 'no_defense'
179-
device: str = 'cpu'
180-
181-
182-
class BaseGANTrainer(metaclass=ABCMeta):
183-
184-
def __init__(self, args: BaseGANTrainArgs, folder_manager: FolderManager, **kwargs) -> None:
185-
self.args = args
186-
self.folder_manager = folder_manager
187-
188-
self.method_name = self.get_method_name()
189-
self.tag = self.get_tag()
190-
191-
@abstractmethod
192-
def get_method_name(self) -> str:
193-
raise NotImplementedError()
194-
195-
@abstractmethod
196-
def get_tag(self) -> str:
197-
raise NotImplementedError()
198-
199-
@abstractmethod
200-
def prepare_training(self):
201-
# raise NotImplementedError()
202-
self.G = None
203-
self.D = None
204-
205-
@abstractmethod
206-
def get_trainloader(self) -> DataLoader:
207-
raise NotImplementedError()
208-
209-
@abstractmethod
210-
def train_gen_step(self, batch) -> OrderedDict:
211-
raise NotImplementedError()
212-
213-
@abstractmethod
214-
def train_dis_step(self, batch) -> OrderedDict:
215-
raise NotImplementedError()
216-
217-
def before_train(self):
218-
pass
219-
220-
def after_train(self):
221-
pass
222-
223-
def before_gen_train_step(self):
224-
# self.model.train()
225-
self.G.train()
226-
# self.D.eval()
227-
228-
def before_dis_train_step(self):
229-
# self.G.eval()
230-
self.D.train()
231-
232-
233-
234-
def save_state_dict(self):
235-
self.folder_manager.save_state_dict(self.G, [self.method_name, f'{self.tag}_G.pt'], self.args.defense_type)
236-
self.folder_manager.save_state_dict(self.D, [self.method_name, f'{self.tag}_D.pt'], self.args.defense_type)
237-
238-
def loss_update(self, loss, optimizer):
239-
optimizer.zero_grad()
240-
loss.backward()
241-
optimizer.step()
242-
243-
def _train_loop(self, dataloader: DataLoader, epoch):
244-
self.before_train()
245-
246-
gen_accumulator = DictAccumulator()
247-
dis_accumulator = DictAccumulator()
248-
249-
for iter_time, batch in enumerate(dataloader, start=1):
250-
# print(f'len batch: {len(batch)}')
251-
self.before_dis_train_step()
252-
dis_ret = self.train_dis_step(batch)
253-
dis_accumulator.add(dis_ret)
254-
255-
if iter_time % self.args.dis_gen_update_rate == 0:
256-
self.before_gen_train_step()
257-
gen_ret = self.train_gen_step(batch)
258-
gen_accumulator.add(gen_ret)
259-
260-
gen_avg = gen_accumulator.avg()
261-
dis_avg = dis_accumulator.avg()
262-
263-
print_context = OrderedDict(
264-
epoch = epoch,
265-
generator = gen_avg,
266-
discriminator = dis_avg
267-
)
268-
print(yaml.dump(print_context))
269-
270-
def train(self):
271-
272-
self.prepare_training()
273-
274-
trainloader = self.get_trainloader()
275-
276-
epochs = range(self.args.epoch_num)
277-
if self.args.tqdm_strategy == TqdmStrategy.EPOCH:
278-
epochs = tqdm(epochs)
279-
280-
for epoch in epochs:
281-
if self.args.tqdm_strategy == TqdmStrategy.ITER:
282-
trainloader = tqdm(trainloader)
283-
284-
self._train_loop(trainloader, epoch)
285-
286-
self.save_state_dict()
3+
from ..trainer import BaseGANTrainArgs, BaseGANTrainer

0 commit comments

Comments
 (0)