1- import os
2- from abc import ABCMeta , abstractmethod
3- from collections import OrderedDict
4- from dataclasses import dataclass , field
51
6- import torch
7- from tqdm import tqdm
8- import yaml
9-
10- from ..models import *
11- from ..foldermanager import FolderManager
12- from ..metrics .base import *
13- from ..utils import DictAccumulator , Accumulator
142from ..enums import TqdmStrategy
15-
16- @dataclass
17- class BaseAttackConfig :
18-
19- # classifier
20- target_name : str
21- eval_name : str
22-
23- # folders
24- ckpt_dir : str
25- result_dir : str
26- dataset_dir : str
27- cache_dir : str
28- defense_ckpt_dir : str = None
29-
30- # dataset
31- dataset_name : str = 'celeba'
32-
33- # misc
34- defense_type : str = 'no_defense'
35- device : str = 'cpu'
36-
37- class BaseAttacker (metaclass = ABCMeta ):
38-
39- def __init__ (self , config : BaseAttackConfig ) -> None :
40- self .config = config
41- tag = self .get_tag ()
42- cache_dir = os .path .join (config .cache_dir , tag )
43- result_dir = os .path .join (config .result_dir , tag )
44-
45- self .folder_manager = FolderManager (config .ckpt_dir , config .dataset_dir , cache_dir , result_dir , config .defense_ckpt_dir , config .defense_type )
46-
47- self .prepare_classifiers ()
48-
49- print ('--------------- config --------------' )
50- print (config )
51- print ('-------------------------------------' )
52-
53- def register_dirs (self , dirs : dict ):
54- for k , v in dirs .items ():
55- os .makedirs (v , exist_ok = True )
56- setattr (self .folder_manager .config , k , v )
57-
58- @abstractmethod
59- def get_tag (self ) -> str :
60- raise NotImplementedError ()
61-
62- @abstractmethod
63- def prepare_attack (self ):
64- raise NotImplementedError ()
65-
66- @abstractmethod
67- def attack_step (self , iden ) -> dict :
68- raise NotImplementedError ()
69-
70- def prepare_classifiers (self ):
71-
72- config = self .config
73- folder_manager = self .folder_manager
74-
75- self .T = get_model (config .target_name , config .dataset_name , device = config .device , defense_type = config .defense_type )
76- folder_manager .load_target_model_state_dict (self .T , config .dataset_name , config .target_name , device = config .device , defense_type = config .defense_type )
77-
78- self .E = get_model (config .eval_name , config .dataset_name , device = config .device )
79- folder_manager .load_target_model_state_dict (self .E , config .dataset_name , config .eval_name , device = config .device )
80-
81- self .T .eval ()
82- self .E .eval ()
83-
84-
85- def attack (self , batch_size : int , target_labels : list ):
86-
87- self .batch_size = batch_size
88- self .target_labels = target_labels
89-
90- self .prepare_attack ()
91-
92- config = self .config
93-
94- print ("=> Begin attacking ..." )
95- # aver_acc, aver_acc5, aver_var, aver_var5 = 0, 0, 0, 0
96-
97- total_num = len (target_labels )
98-
99- accumulator = DictAccumulator ()
100-
101- if total_num > 0 :
102- for idx in range ((total_num - 1 ) // batch_size + 1 ):
103- print ("--------------------- Attack batch [%s]------------------------------" % idx )
104- iden = torch .tensor (
105- target_labels [idx * batch_size : min ((idx + 1 )* batch_size , total_num )], device = config .device , dtype = torch .long
106- )
107-
108- update_dict = self .attack_step (iden )
109-
110- bs = len (iden )
111-
112- accumulator .add (update_dict )
113-
114- for key , val in accumulator .avg ().items ():
115- print (f'average { key } : { val :.6f} ' )
116-
117- def evaluation (self , batch_size , transform = None , knn = True , feature_distance = True , fid = False ):
118- eval_metrics = []
119-
120- if knn :
121- eval_metrics .append (KnnDistanceMetric (self .folder_manager , device = self .config .device , model = self .E ))
122-
123- if feature_distance :
124- eval_metrics .append (FeatureDistanceMetric (self .folder_manager , self .config .device , model = self .E ))
125-
126- if fid :
127- eval_metrics .append (FIDMetric (self .folder_manager , device = self .config .device , model = None ))
128-
129- for metric in eval_metrics :
130- metric : BaseMetric
131- print (f'calculate { metric .get_metric_name ()} ' )
132- metric .evaluation (self .config .dataset_name , batch_size , transform )
133-
134- class BaseSingleLabelAttacker (BaseAttacker ):
135-
136- def __init__ (self , config : BaseAttackConfig ) -> None :
137- super ().__init__ (config )
138-
139- def attack_step (self , target ) -> dict :
140- return super ().attack_step (target )
141-
142- def attack (self , batch_size : int , target_labels : list ):
143-
144- self .batch_size = batch_size
145- self .target_labels = target_labels
146-
147- self .prepare_attack ()
148-
149- config = self .config
150-
151- print ("=> Begin attacking ..." )
152-
153- total_num = len (target_labels )
154-
155- accumulator = DictAccumulator ()
156-
157- if total_num > 0 :
158- for target in target_labels :
159- print (f"--------------------- Attack label [{ target } ]------------------------------" )
160-
161- update_dict = self .attack_step (target )
162-
163-
164- accumulator .add (update_dict )
165-
166- for key , val in accumulator .avg ().items ():
167- print (f'average { key } : { val :.6f} ' )
168-
169- @dataclass
170- class BaseGANTrainArgs :
171-
172- dataset_name : str
173- batch_size : int
174- epoch_num : int
175-
176- dis_gen_update_rate : int = 5
177- tqdm_strategy : TqdmStrategy = TqdmStrategy .ITER
178- defense_type : str = 'no_defense'
179- device : str = 'cpu'
180-
181-
182- class BaseGANTrainer (metaclass = ABCMeta ):
183-
184- def __init__ (self , args : BaseGANTrainArgs , folder_manager : FolderManager , ** kwargs ) -> None :
185- self .args = args
186- self .folder_manager = folder_manager
187-
188- self .method_name = self .get_method_name ()
189- self .tag = self .get_tag ()
190-
191- @abstractmethod
192- def get_method_name (self ) -> str :
193- raise NotImplementedError ()
194-
195- @abstractmethod
196- def get_tag (self ) -> str :
197- raise NotImplementedError ()
198-
199- @abstractmethod
200- def prepare_training (self ):
201- # raise NotImplementedError()
202- self .G = None
203- self .D = None
204-
205- @abstractmethod
206- def get_trainloader (self ) -> DataLoader :
207- raise NotImplementedError ()
208-
209- @abstractmethod
210- def train_gen_step (self , batch ) -> OrderedDict :
211- raise NotImplementedError ()
212-
213- @abstractmethod
214- def train_dis_step (self , batch ) -> OrderedDict :
215- raise NotImplementedError ()
216-
217- def before_train (self ):
218- pass
219-
220- def after_train (self ):
221- pass
222-
223- def before_gen_train_step (self ):
224- # self.model.train()
225- self .G .train ()
226- # self.D.eval()
227-
228- def before_dis_train_step (self ):
229- # self.G.eval()
230- self .D .train ()
231-
232-
233-
234- def save_state_dict (self ):
235- self .folder_manager .save_state_dict (self .G , [self .method_name , f'{ self .tag } _G.pt' ], self .args .defense_type )
236- self .folder_manager .save_state_dict (self .D , [self .method_name , f'{ self .tag } _D.pt' ], self .args .defense_type )
237-
238- def loss_update (self , loss , optimizer ):
239- optimizer .zero_grad ()
240- loss .backward ()
241- optimizer .step ()
242-
243- def _train_loop (self , dataloader : DataLoader , epoch ):
244- self .before_train ()
245-
246- gen_accumulator = DictAccumulator ()
247- dis_accumulator = DictAccumulator ()
248-
249- for iter_time , batch in enumerate (dataloader , start = 1 ):
250- # print(f'len batch: {len(batch)}')
251- self .before_dis_train_step ()
252- dis_ret = self .train_dis_step (batch )
253- dis_accumulator .add (dis_ret )
254-
255- if iter_time % self .args .dis_gen_update_rate == 0 :
256- self .before_gen_train_step ()
257- gen_ret = self .train_gen_step (batch )
258- gen_accumulator .add (gen_ret )
259-
260- gen_avg = gen_accumulator .avg ()
261- dis_avg = dis_accumulator .avg ()
262-
263- print_context = OrderedDict (
264- epoch = epoch ,
265- generator = gen_avg ,
266- discriminator = dis_avg
267- )
268- print (yaml .dump (print_context ))
269-
270- def train (self ):
271-
272- self .prepare_training ()
273-
274- trainloader = self .get_trainloader ()
275-
276- epochs = range (self .args .epoch_num )
277- if self .args .tqdm_strategy == TqdmStrategy .EPOCH :
278- epochs = tqdm (epochs )
279-
280- for epoch in epochs :
281- if self .args .tqdm_strategy == TqdmStrategy .ITER :
282- trainloader = tqdm (trainloader )
283-
284- self ._train_loop (trainloader , epoch )
285-
286- self .save_state_dict ()
3+ from ..trainer import BaseGANTrainArgs , BaseGANTrainer
0 commit comments