1+ import os
2+ from abc import ABCMeta , abstractmethod
3+ from collections import OrderedDict
4+ from dataclasses import dataclass , field
15
6+ import torch
7+ from tqdm import tqdm
8+ import yaml
9+
10+ from ..models import *
11+ from ..foldermanager import FolderManager
12+ from ..metrics .base import *
13+ from ..utils import DictAccumulator , Accumulator
214from ..enums import TqdmStrategy
3- from ..trainer import BaseGANTrainArgs , BaseGANTrainer
15+ from ..trainer import BaseGANTrainArgs , BaseGANTrainer
16+
17+ @dataclass
18+ class BaseAttackConfig :
19+
20+ # classifier
21+ target_name : str
22+ eval_name : str
23+
24+ # folders
25+ ckpt_dir : str
26+ result_dir : str
27+ dataset_dir : str
28+ cache_dir : str
29+ defense_ckpt_dir : str = None
30+
31+ # dataset
32+ dataset_name : str = 'celeba'
33+
34+ # misc
35+ defense_type : str = 'no_defense'
36+ device : str = 'cpu'
37+
38+ class BaseAttacker (metaclass = ABCMeta ):
39+
40+ def __init__ (self , config : BaseAttackConfig ) -> None :
41+ self .config = config
42+ tag = self .get_tag ()
43+ cache_dir = os .path .join (config .cache_dir , tag )
44+ result_dir = os .path .join (config .result_dir , tag )
45+
46+ self .folder_manager = FolderManager (config .ckpt_dir , config .dataset_dir , cache_dir , result_dir , config .defense_ckpt_dir , config .defense_type )
47+
48+ self .prepare_classifiers ()
49+
50+ print ('--------------- config --------------' )
51+ print (config )
52+ print ('-------------------------------------' )
53+
54+ def register_dirs (self , dirs : dict ):
55+ for k , v in dirs .items ():
56+ os .makedirs (v , exist_ok = True )
57+ setattr (self .folder_manager .config , k , v )
58+
59+ @abstractmethod
60+ def get_tag (self ) -> str :
61+ raise NotImplementedError ()
62+
63+ @abstractmethod
64+ def prepare_attack (self ):
65+ raise NotImplementedError ()
66+
67+ @abstractmethod
68+ def attack_step (self , iden ) -> dict :
69+ raise NotImplementedError ()
70+
71+ def prepare_classifiers (self ):
72+
73+ config = self .config
74+ folder_manager = self .folder_manager
75+
76+ self .T = get_model (config .target_name , config .dataset_name , device = config .device , defense_type = config .defense_type )
77+ folder_manager .load_target_model_state_dict (self .T , config .dataset_name , config .target_name , device = config .device , defense_type = config .defense_type )
78+
79+ self .E = get_model (config .eval_name , config .dataset_name , device = config .device )
80+ folder_manager .load_target_model_state_dict (self .E , config .dataset_name , config .eval_name , device = config .device )
81+
82+ self .T .eval ()
83+ self .E .eval ()
84+
85+
86+ def attack (self , batch_size : int , target_labels : list ):
87+
88+ self .batch_size = batch_size
89+ self .target_labels = target_labels
90+
91+ self .prepare_attack ()
92+
93+ config = self .config
94+
95+ print ("=> Begin attacking ..." )
96+ # aver_acc, aver_acc5, aver_var, aver_var5 = 0, 0, 0, 0
97+
98+ total_num = len (target_labels )
99+
100+ accumulator = DictAccumulator ()
101+
102+ if total_num > 0 :
103+ for idx in range ((total_num - 1 ) // batch_size + 1 ):
104+ print ("--------------------- Attack batch [%s]------------------------------" % idx )
105+ iden = torch .tensor (
106+ target_labels [idx * batch_size : min ((idx + 1 )* batch_size , total_num )], device = config .device , dtype = torch .long
107+ )
108+
109+ update_dict = self .attack_step (iden )
110+
111+ bs = len (iden )
112+
113+ accumulator .add (update_dict )
114+
115+ for key , val in accumulator .avg ().items ():
116+ print (f'average { key } : { val :.6f} ' )
117+
118+ def evaluation (self , batch_size , transform = None , knn = True , feature_distance = True , fid = False ):
119+ eval_metrics = []
120+
121+ if knn :
122+ eval_metrics .append (KnnDistanceMetric (self .folder_manager , device = self .config .device , model = self .E ))
123+
124+ if feature_distance :
125+ eval_metrics .append (FeatureDistanceMetric (self .folder_manager , self .config .device , model = self .E ))
126+
127+ if fid :
128+ eval_metrics .append (FIDMetric (self .folder_manager , device = self .config .device , model = None ))
129+
130+ for metric in eval_metrics :
131+ metric : BaseMetric
132+ print (f'calculate { metric .get_metric_name ()} ' )
133+ metric .evaluation (self .config .dataset_name , batch_size , transform )
134+
135+ class BaseSingleLabelAttacker (BaseAttacker ):
136+
137+ def __init__ (self , config : BaseAttackConfig ) -> None :
138+ super ().__init__ (config )
139+
140+ def attack_step (self , target ) -> dict :
141+ return super ().attack_step (target )
142+
143+ def attack (self , batch_size : int , target_labels : list ):
144+
145+ self .batch_size = batch_size
146+ self .target_labels = target_labels
147+
148+ self .prepare_attack ()
149+
150+ config = self .config
151+
152+ print ("=> Begin attacking ..." )
153+
154+ total_num = len (target_labels )
155+
156+ accumulator = DictAccumulator ()
157+
158+ if total_num > 0 :
159+ for target in target_labels :
160+ print (f"--------------------- Attack label [{ target } ]------------------------------" )
161+
162+ update_dict = self .attack_step (target )
163+
164+
165+ accumulator .add (update_dict )
166+
167+ for key , val in accumulator .avg ().items ():
168+ print (f'average { key } : { val :.6f} ' )
0 commit comments