'''
This is the official implementation of the paper "Shared Adversarial Unlearning: Backdoor Mitigation by Unlearning Shared Adversarial Examples" (https://arxiv.org/pdf/2307.10562.pdf) in PyTorch.
Implementation by: Shaokui Wei (the first author of the paper)
basic sturcture for defense method:
1. basic setting: args
2. attack result(model, train data, test data)
3. sau defense:
a. get some clean data
b. SAU:
b.1 generate the shared adversarial examples
b.2 unlearn the backdoor model by the pertubation
4. test the result and get ASR, ACC, RC
'''
import os
import sys
sys.path.append('../')
sys.path.append(os.getcwd())
from defense.base import defense
# from utils.aggregate_block.train_settings_generate import argparser_opt_scheduler
# from utils.trainer_cls import Metric_Aggregator, PureCleanModelTrainer, general_plot_for_epoch, given_dataloader_test
# from utils.choose_index import choose_index
# from utils.aggregate_block.fix_random import fix_random
# from utils.aggregate_block.model_trainer_generate import generate_cls_model
# from utils.log_assist import get_git_info
# from utils.aggregate_block.dataset_and_transform_generate import get_input_shape, get_num_classes, get_transform, get_dataset_normalization, get_dataset_denormalization
# from utils.save_load_attack import load_attack_result, save_defense_result
# from utils.bd_dataset_v2 import prepro_cls_DatasetBD_v2
# class Shared_PGD():
# def __init__(self, model, model_ref, beta_1 = 0.01, beta_2 = 1, norm_bound = 0.2, norm_type = 'L_inf', step_size = 0.2, num_steps = 5, init_type = 'max', loss_func = torch.nn.CrossEntropyLoss(), pert_func = None, verbose = False):
# '''
# PGD attack for generating shared adversarial examples.
# See "Shared Adversarial Unlearning: Backdoor Mitigation by Unlearning Shared Adversarial Examples" (https://arxiv.org/pdf/2307.10562.pdf) for more details.
# Implemented by Shaokui Wei (the first author of the paper) in PyTorch.
# The code is originally implemented as a part of BackdoorBench but is not dependent on BackdoorBench, and can be used independently.
#
# args:
# model: the model to be attacked
# model_ref: the reference model to be attacked
# beta_1: the weight of adversarial loss, e.g. 0.01
# beta_2: the weight of shared loss, e.g. 1
# norm_bound: the bound of the norm of perturbation, e.g. 0.2
# norm_type: the type of norm, choose from ['L_inf', 'L1', 'L2', 'Reg']
# step_size: the step size of PGD, e.g. 0.2
# num_steps: the number of steps of PGD, e.g. 5
# init_type: the type of initialization of perturbation, choose from ['zero', 'random', 'max', 'min']
# loss_func: the loss function, e.g. nn.CrossEntropyLoss()
# pert_func: the function to process the perturbation and image, e.g. add the perturbation to image
# verbose: whether to print the information of the attack
# '''
#
# self.model = model
# self.model_ref = model_ref
# self.beta_1 = beta_1
# self.beta_2 = beta_2
# self.norm_bound = norm_bound
# self.norm_type = norm_type
# self.step_size = step_size
# self.num_steps = num_steps
# self.init_type = init_type
# self.loss_func = loss_func
# self.verbose = verbose
#
# if pert_func is None:
# # simply add x to perturbation
# self.pert_func = lambda x, pert: x + pert
# else:
# self.pert_func = pert_func
#
# def projection(self, pert):
# if self.norm_type == 'L_inf':
# pert.data = torch.clamp(pert.data, -self.norm_bound , self.norm_bound)
# elif self.norm_type == 'L1':
# norm = torch.sum(torch.abs(pert), dim=(1, 2, 3), keepdim=True)
# for i in range(pert.shape[0]):
# if norm[i] > self.norm_bound:
# pert.data[i] = pert.data[i] * self.norm_bound / norm[i].item()
# elif self.norm_type == 'L2':
# norm = torch.sum(pert ** 2, dim=(1, 2, 3), keepdim=True) ** 0.5
# for i in range(pert.shape[0]):
# if norm[i] > self.norm_bound:
# pert.data[i] = pert.data[i] * self.norm_bound / norm[i].item()
# elif self.norm_type == 'Reg':
# pass
# else:
# raise NotImplementedError
# return pert
#
# def init_pert(self, batch_pert):
# if self.init_type=='zero':
# batch_pert.data = batch_pert.data*0
# elif self.init_type=='random':
# batch_pert.data = torch.rand_like(batch_pert.data)
# elif self.init_type=='max':
# batch_pert.data = batch_pert.data + self.norm_bound
# elif self.init_type=='min':
# batch_pert.data = batch_pert.data - self.norm_bound
# else:
# raise NotImplementedError
#
# return self.projection(batch_pert)
#
# def attack(self, images, labels, max_eps = 1, min_eps = 0):
# # Set max_eps and min_eps to valid range
#
# model = self.model
# model_ref = self.model_ref
#
# batch_pert = torch.zeros_like(images, requires_grad=True)
# batch_pert = self.init_pert(batch_pert)
#
# for _ in range(self.num_steps):
# pert_image = self.pert_func(images, batch_pert)
# ori_lab = torch.argmax(model.forward(images),axis = 1).long()
# ori_lab_ref = torch.argmax(model_ref.forward(images),axis = 1).long()
#
# per_logits = model.forward(pert_image)
# per_logits_ref = model_ref.forward(pert_image)
#
# pert_label = torch.argmax(per_logits, dim=1)
# pert_label_ref = torch.argmax(per_logits_ref, dim=1)
#
# success_attack = pert_label != ori_lab
# success_attack_ref = pert_label_ref != ori_lab_ref
# common_attack = torch.logical_and(success_attack, success_attack_ref)
# shared_attack = torch.logical_and(common_attack, pert_label == pert_label_ref)
#
# # Adversarial loss
# # use early stop or loss clamp to avoid very large loss
# loss_adv = torch.tensor(0.0).to(images.device)
# if torch.logical_not(success_attack).sum()!=0:
# loss_adv += F.cross_entropy(per_logits, labels, reduction='none')[torch.logical_not(success_attack)].sum()
# if torch.logical_not(success_attack_ref).sum()!=0:
# loss_adv += F.cross_entropy(per_logits_ref, labels, reduction='none')[torch.logical_not(success_attack_ref)].sum()
# loss_adv = - loss_adv/2/images.shape[0]
#
# # Shared loss
# # JS divergence version (https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence)
# p_model = F.softmax(per_logits, dim=1).clamp(min=1e-8)
# p_ref = F.softmax(per_logits_ref, dim=1).clamp(min=1e-8)
# mix_p = 0.5*(p_model+p_ref)
# loss_js = 0.5*(p_model*p_model.log() + p_ref*p_ref.log()) - 0.5*(p_model*mix_p.log() + p_ref*mix_p.log())
# loss_cross = loss_js[torch.logical_not(shared_attack)].sum(dim=1).sum()/images.shape[0]
#
# # Update pert
# batch_pert.grad = None
# loss_ae = self.beta_1 * loss_adv + self.beta_2 * loss_cross
# loss_ae.backward()
#
# batch_pert.data = batch_pert.data - self.step_size * batch_pert.grad.sign()
#
# # Projection
# batch_pert = self.projection(batch_pert)
#
# # Optimal: projection to S and clip to [min_eps, max_eps] to ensure the perturbation is valid. It is not necessary for backdoor defense as done in i-BAU.
# # Mannually set the min_eps and max_eps to match the dataset normalization
# # batch_pert.data = torch.clamp(batch_pert.data, min_eps, max_eps)
#
# if torch.logical_not(shared_attack).sum()==0:
# break
# if self.verbose:
# print(f'Maximization End: \n Adv h: {success_attack.sum().item()}, Adv h_0: {success_attack_ref.sum().item()}, Adv Common: {common_attack.sum().item()}, Adv Share: {shared_attack.sum().item()}.\n Loss adv {loss_adv.item():.4f}, Loss share {loss_cross.item():.4f}, Loss total {loss_ae.item():.4f}.\n L1 norm: {torch.sum(batch_pert[0].abs().sum()):.4f}, L2 norm: {torch.norm(batch_pert[0]):.4f}, Linf norm: {torch.max(batch_pert[0].abs()):.4f}')
#
# return batch_pert.detach()
[docs]class sau(defense):
r'''Shared adversarial unlearning: Backdoor mitigation by unlearning shared adversarial examples
basic sturcture for defense method:
1. basic setting: args
2. attack result(model, train data, test data)
3. sau defense:
a. get some clean data
b. SAU:
1. generate the shared adversarial examples
2. unlearn the backdoor model by the pertubation
4. test the result and get ASR, ACC, RC
.. code-block:: python
parser = argparse.ArgumentParser(description=sys.argv[0])
sau.add_arguments(parser)
args = parser.parse_args()
sau_method = sau(args)
if "result_file" not in args.__dict__:
args.result_file = 'defense_test_badnet'
elif args.result_file is None:
args.result_file = 'defense_test_badnet'
result = sau_method.defense(args.result_file)
.. Note::
@inproceedings{wei2023shared,
title={Shared Adversarial Unlearning: Backdoor Mitigation by Unlearning Shared Adversarial Examples},
author={Wei, Shaokui and Zhang, Mingda and Zha, Hongyuan and Wu, Baoyuan},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}}
Args:
baisc args: in the base class
n_rounds(str): type of outer loop optimizer utilized
outer_steps(int): the maximum number of unelarning rounds
lmd_1(int): steps for outer loop, the number of unlearning rounds
lmd_2(float): clean acc, L_cl
lmd_3(float): AT acc. By default, lmd_2 = 0 and AT is not used.
beta_1(float): shared adv risk, L_sar
beta_2(float): L_adv
trigger_norm(float): L_share
pgd_init(float): threshold for PGD. Larger may not be good.
norm_type(str): init type for pgd. zero|random|max|min
adv_lr(str): type of norm used for generating perturbation. L1|L2|L_inf|Reg
adv_steps(float): lr for pgd
train_mode(bool): number of steps for pgd
'''
def __init__(self, args):
pass
# with open(args.yaml_path, 'r') as f:
# defaults = yaml.safe_load(f)
#
# defaults.update({k:v for k,v in args.__dict__.items() if v is not None})
#
# args.__dict__ = defaults
#
# args.terminal_info = sys.argv
#
# args.num_classes = get_num_classes(args.dataset)
# args.input_height, args.input_width, args.input_channel = get_input_shape(args.dataset)
# args.img_size = (args.input_height, args.input_width, args.input_channel)
# args.dataset_path = f"{args.dataset_path}/{args.dataset}"
#
# self.args = args
#
# if 'result_file' in args.__dict__ :
# if args.result_file is not None:
# self.set_result(args.result_file)
#
# def add_arguments(parser):
# parser.add_argument('--device', type=str, help='cuda, cpu')
# parser.add_argument("-pm","--pin_memory", type=lambda x: str(x) in ['True', 'true', '1'], help = "dataloader pin_memory")
# parser.add_argument("-nb","--non_blocking", type=lambda x: str(x) in ['True', 'true', '1'], help = ".to(), set the non_blocking = ?")
# parser.add_argument("-pf", '--prefetch', type=lambda x: str(x) in ['True', 'true', '1'], help='use prefetch')
# parser.add_argument('--amp', type=lambda x: str(x) in ['True','true','1'])
#
# parser.add_argument('--checkpoint_load', type=str, help='the location of load model')
# parser.add_argument('--checkpoint_save', type=str, help='the location of checkpoint where model is saved')
# parser.add_argument('--log', type=str, help='the location of log')
# parser.add_argument("--dataset_path", type=str, help='the location of data')
# parser.add_argument('--dataset', type=str, help='mnist, cifar10, cifar100, gtrsb, tiny')
# parser.add_argument('--result_file', type=str, help='the location of result')
#
# parser.add_argument('--epochs', type=int)
# parser.add_argument('--batch_size', type=int)
# parser.add_argument("--num_workers", type=float)
# parser.add_argument('--lr', type=float)
# parser.add_argument('--lr_scheduler', type=str, help='the scheduler of lr')
# parser.add_argument('--steplr_stepsize', type=int)
# parser.add_argument('--steplr_gamma', type=float)
# parser.add_argument('--steplr_milestones', type=list)
# parser.add_argument('--model', type=str, help='resnet18')
#
# parser.add_argument('--client_optimizer', type=int)
# parser.add_argument('--sgd_momentum', type=float)
# parser.add_argument('--wd', type=float, help='weight decay of sgd')
# parser.add_argument('--frequency_save', type=int,
# help=' frequency_save, 0 is never')
#
# parser.add_argument('--random_seed', type=int, help='random seed')
# parser.add_argument('--yaml_path', type=str, default="./config/defense/sau/config.yaml", help='the path of yaml')
#
#
# ###### sau defense parameter ######
# # defense setting
# parser.add_argument('--ratio', type=float, help='the ratio of clean data loader')
# parser.add_argument('--index', type=str, help='index of clean data')
# # hyper params
# parser.add_argument('--optim', type=str, default='Adam', help='type of outer loop optimizer utilized')
# parser.add_argument('--n_rounds', type=int, help='the maximum number of unelarning rounds')
# ## Minimization part
# parser.add_argument('--outer_steps', default=1, type=int,help='steps for outer loop, the number of unlearning rounds')
# parser.add_argument('--lmd_1', type=float,help='clean acc, L_cl')
# parser.add_argument('--lmd_2', type=float,help='AT acc. By default, lmd_2 = 0 and AT is not used.')
# parser.add_argument('--lmd_3', type=float,help=' shared adv risk, L_sar')
# ## Maximization part
# parser.add_argument('--beta_1', type=float,help='L_adv')
# parser.add_argument('--beta_2', type=float,help='L_share')
# ### PGD setting
# parser.add_argument('--trigger_norm', type=float,help='threshold for PGD. Larger may not be good.')
# parser.add_argument('--pgd_init', type=str, help='init type for pgd. zero|random|max|min')
# parser.add_argument('--norm_type', type=str,help='type of norm used for generating perturbation. L1|L2|L_inf|Reg')
# parser.add_argument('--adv_lr', type=float,help='lr for pgd')
# parser.add_argument('--adv_steps', type=int,help='number of steps for pgd')
# ## optimization setting
# parser.add_argument('--train_mode', action='store_true',default=False, help='Fix BN parameters or not. Fixing BN leads to higher ACC but also higher ASR.')
#
#
# def set_result(self, result_file):
# attack_file = 'record/' + result_file
# save_path = 'record/' + result_file + f'/defense/sau/'
# if not (os.path.exists(save_path)):
# os.makedirs(save_path)
# # assert(os.path.exists(save_path))
# self.args.save_path = save_path
# if self.args.checkpoint_save is None:
# self.args.checkpoint_save = save_path + 'checkpoint/'
# if not (os.path.exists(self.args.checkpoint_save)):
# os.makedirs(self.args.checkpoint_save)
# if self.args.log is None:
# self.args.log = save_path + 'log/'
# if not (os.path.exists(self.args.log)):
# os.makedirs(self.args.log)
# self.result = load_attack_result(attack_file + '/attack_result.pt')
#
#
# def set_logger(self):
# args = self.args
# logFormatter = logging.Formatter(
# fmt='%(asctime)s [%(levelname)-8s] [%(filename)s:%(lineno)d] %(message)s',
# datefmt='%Y-%m-%d:%H:%M:%S',
# )
# logger = logging.getLogger()
#
# fileHandler = logging.FileHandler(args.log + '/' + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + '.log')
# fileHandler.setFormatter(logFormatter)
# logger.addHandler(fileHandler)
#
# consoleHandler = logging.StreamHandler()
# consoleHandler.setFormatter(logFormatter)
# logger.addHandler(consoleHandler)
#
# logger.setLevel(logging.INFO)
# logging.info(pformat(args.__dict__))
#
# try:
# logging.info(pformat(get_git_info()))
# except:
# logging.info('Getting git info fails.')
#
#
# def mitigation(self):
# fix_random(self.args.random_seed)
#
# # initialize models
# model = generate_cls_model(self.args.model,self.args.num_classes)
# model.load_state_dict(self.result['model'])
#
# model_ref = generate_cls_model(self.args.model,self.args.num_classes)
# model_ref.load_state_dict(self.result['model'])
#
#
# if "," in self.args.device:
# model = torch.nn.DataParallel(model, device_ids=[int(i) for i in self.args.device[5:].split(",")])
# self.args.device = f'cuda:{model.device_ids[0]}'
# model.to(self.args.device)
#
# model_ref = torch.nn.DataParallel(model_ref, device_ids=[int(i) for i in self.args.device[5:].split(",")])
# self.args.device = f'cuda:{model_ref.device_ids[0]}'
# model_ref.to(self.args.device)
# else:
# model.to(self.args.device)
# model_ref.to(self.args.device)
#
# outer_opt, scheduler = argparser_opt_scheduler(model, self.args)
#
# # a. get some clean data
# logging.info("Fetch some samples from clean train dataset.")
#
# train_tran = get_transform(self.args.dataset, *([self.args.input_height,self.args.input_width]) , train = False)
#
# clean_dataset = prepro_cls_DatasetBD_v2(self.result['clean_train'].wrapped_dataset)
# data_all_length = len(clean_dataset)
# ran_idx = choose_index(self.args, data_all_length)
# log_index = self.args.log + 'index.txt'
# np.savetxt(log_index, ran_idx, fmt='%d')
#
# clean_dataset.subset(ran_idx)
#
# data_set_without_tran = clean_dataset
# data_set_o = self.result['clean_train']
# data_set_o.wrapped_dataset = data_set_without_tran
# data_set_o.wrap_img_transform = train_tran
#
# data_loader = torch.utils.data.DataLoader(data_set_o, batch_size=self.args.batch_size, num_workers=self.args.num_workers, shuffle=True, pin_memory=args.pin_memory)
# trainloader = data_loader
#
# ## set testing dataset
# test_tran = get_transform(self.args.dataset, *([self.args.input_height,self.args.input_width]) , train = False)
# data_bd_testset = self.result['bd_test']
# data_bd_testset.wrap_img_transform = test_tran
# data_bd_loader = torch.utils.data.DataLoader(data_bd_testset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,drop_last=False, shuffle=True,pin_memory=args.pin_memory)
#
# data_clean_testset = self.result['clean_test']
# data_clean_testset.wrap_img_transform = test_tran
# data_clean_loader = torch.utils.data.DataLoader(data_clean_testset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,drop_last=False, shuffle=True,pin_memory=args.pin_memory)
#
# clean_test_loss_list = []
# bd_test_loss_list = []
# ra_test_loss_list = []
# test_acc_list = []
# test_asr_list = []
# test_ra_list = []
#
# # b. unlearn the backdoor model by the pertubation
# logging.info("=> Conducting Defence..")
# model.eval()
# model_ref.eval()
#
# clean_test_loss_avg_over_batch, \
# bd_test_loss_avg_over_batch, \
# ra_test_loss_avg_over_batch, \
# test_acc, \
# test_asr, \
# test_ra = self.eval_step(
# model,
# data_clean_loader,
# data_bd_loader,
# args,
# )
#
# logging.info('Initial State: clean test loss: {:.4f}, bd test loss: {:.4f}, ra test loss: {:.4f}, test acc: {:.4f}, test asr: {:.4f}, test ra: {:.4f}'.format(clean_test_loss_avg_over_batch, bd_test_loss_avg_over_batch, ra_test_loss_avg_over_batch, test_acc, test_asr, test_ra))
#
#
# normalization = get_dataset_normalization(args.dataset)
# denormalization = get_dataset_denormalization(normalization)
#
#
#
# def get_perturbed_image(images, pert, train = True):
# images_wo_trans = denormalization(images) + pert
# images_with_trans = normalization(images_wo_trans)
# return images_with_trans
#
# Shared_PGD_Attacker = Shared_PGD(model = model,
# model_ref = model_ref,
# beta_1 = args.beta_1,
# beta_2 = args.beta_2,
# norm_bound = args.trigger_norm,
# norm_type = args.norm_type,
# step_size = args.adv_lr,
# num_steps = args.adv_steps,
# init_type = args.pgd_init,
# loss_func = torch.nn.CrossEntropyLoss(),
# pert_func = get_perturbed_image,
# verbose = True)
#
# agg = Metric_Aggregator()
# for round in range(args.n_rounds):
#
# for images, labels, original_index, poison_indicator, original_targets in trainloader:
# images = images.to(args.device)
# labels = labels.to(args.device)
#
# max_eps = 1 - denormalization(images)
# min_eps = -denormalization(images)
#
# batch_pert = Shared_PGD_Attacker.attack(images, labels, max_eps, min_eps)
#
# for _ in range(args.outer_steps):
# pert_image = get_perturbed_image(images, batch_pert.detach())
#
# if args.train_mode:
# model.train()
#
# concat_images = torch.cat([images, pert_image], dim=0)
# concat_logits = model.forward(concat_images)
# logits, per_logits = torch.split(concat_logits, images.shape[0], dim=0)
# model.eval()
#
# logits_ref = model_ref(images)
# per_logits_ref = model_ref.forward(pert_image)
#
# # Get prediction
# ori_lab = torch.argmax(logits,axis = 1).long()
# ori_lab_ref = torch.argmax(logits_ref,axis = 1).long()
#
# pert_label = torch.argmax(per_logits, dim=1)
# pert_label_ref = torch.argmax(per_logits_ref, dim=1)
#
# success_attack = pert_label != labels
# success_attack_ref = pert_label_ref != labels
# success_attack_ref = success_attack_ref & (pert_label_ref != ori_lab_ref)
# common_attack = torch.logical_and(success_attack, success_attack_ref)
# shared_attack = torch.logical_and(common_attack, pert_label == pert_label_ref)
#
# # Clean loss
# loss_cl = F.cross_entropy(logits, labels, reduction='mean')
#
# # AT loss
# loss_at = F.cross_entropy(per_logits, labels, reduction='mean')
#
#
# # Shared loss
# potential_poison = success_attack_ref
#
# if potential_poison.sum() == 0:
# loss_shared = torch.tensor(0.0).to(args.device)
# else:
# one_hot = F.one_hot(pert_label_ref, num_classes=args.num_classes)
#
# neg_one_hot = 1 - one_hot
# neg_p = (F.softmax(per_logits, dim = 1)*neg_one_hot).sum(dim = 1)[potential_poison]
# pos_p = (F.softmax(per_logits, dim = 1)*one_hot).sum(dim = 1)[potential_poison]
#
# # clamp the too small values to avoid nan and discard samples with p<1% to be shared
# # Note: The below equation combine two identical terms in math. Although they are the same in math, they are different in implementation due to the numerical issue.
# # Combining them can reduce the numerical issue.
#
# loss_shared = (-torch.sum(torch.log(1e-6 + neg_p.clamp(max = 0.999))) - torch.sum(torch.log(1 + 1e-6 - pos_p.clamp(min = 0.001))))/2
# loss_shared = loss_shared/images.shape[0]
#
# # Shared loss
#
# outer_opt.zero_grad()
#
# loss = args.lmd_1*loss_cl + args.lmd_2* loss_at + args.lmd_3*loss_shared
#
# loss.backward()
# outer_opt.step()
# model.eval()
#
# # delete the useless variable to save memory
# del logits, logits_ref, per_logits, per_logits_ref, loss_cl, loss_at, loss_shared, loss
#
# clean_test_loss_avg_over_batch, \
# bd_test_loss_avg_over_batch, \
# ra_test_loss_avg_over_batch, \
# test_acc, \
# test_asr, \
# test_ra = self.eval_step(
# model,
# data_clean_loader,
# data_bd_loader,
# args,
# )
#
# agg({
# "epoch": round,
#
# "clean_test_loss_avg_over_batch": clean_test_loss_avg_over_batch,
# "bd_test_loss_avg_over_batch": bd_test_loss_avg_over_batch,
# "ra_test_loss_avg_over_batch": ra_test_loss_avg_over_batch,
# "test_acc": test_acc,
# "test_asr": test_asr,
# "test_ra": test_ra,
# })
#
#
# clean_test_loss_list.append(clean_test_loss_avg_over_batch)
# bd_test_loss_list.append(bd_test_loss_avg_over_batch)
# ra_test_loss_list.append(ra_test_loss_avg_over_batch)
# test_acc_list.append(test_acc)
# test_asr_list.append(test_asr)
# test_ra_list.append(test_ra)
#
# general_plot_for_epoch(
# {
# "Test C-Acc": test_acc_list,
# "Test ASR": test_asr_list,
# "Test RA": test_ra_list,
# },
# save_path=f"{args.save_path}sau_acc_like_metric_plots.png",
# ylabel="percentage",
# )
#
# general_plot_for_epoch(
# {
# "Test Clean Loss": clean_test_loss_list,
# "Test Backdoor Loss": bd_test_loss_list,
# "Test RA Loss": ra_test_loss_list,
# },
# save_path=f"{args.save_path}sau_loss_metric_plots.png",
# ylabel="percentage",
# )
#
# agg.to_dataframe().to_csv(f"{args.save_path}sau_df.csv")
# agg.summary().to_csv(f"{args.save_path}sau_df_summary.csv")
#
# result = {}
# result['model'] = model
# save_defense_result(
# model_name=args.model,
# num_classes=args.num_classes,
# model=model.cpu().state_dict(),
# save_path=args.save_path,
# )
# return result
#
# def eval_step(
# self,
# netC,
# clean_test_dataloader,
# bd_test_dataloader,
# args,
# ):
# clean_metrics, clean_epoch_predict_list, clean_epoch_label_list = given_dataloader_test(
# netC,
# clean_test_dataloader,
# criterion=torch.nn.CrossEntropyLoss(),
# non_blocking=args.non_blocking,
# device=self.args.device,
# verbose=0,
# )
# clean_test_loss_avg_over_batch = clean_metrics['test_loss_avg_over_batch']
# test_acc = clean_metrics['test_acc']
# bd_metrics, bd_epoch_predict_list, bd_epoch_label_list = given_dataloader_test(
# netC,
# bd_test_dataloader,
# criterion=torch.nn.CrossEntropyLoss(),
# non_blocking=args.non_blocking,
# device=self.args.device,
# verbose=0,
# )
# bd_test_loss_avg_over_batch = bd_metrics['test_loss_avg_over_batch']
# test_asr = bd_metrics['test_acc']
#
# bd_test_dataloader.dataset.wrapped_dataset.getitem_all_switch = True # change to return the original label instead
# ra_metrics, ra_epoch_predict_list, ra_epoch_label_list = given_dataloader_test(
# netC,
# bd_test_dataloader,
# criterion=torch.nn.CrossEntropyLoss(),
# non_blocking=args.non_blocking,
# device=self.args.device,
# verbose=0,
# )
# ra_test_loss_avg_over_batch = ra_metrics['test_loss_avg_over_batch']
# test_ra = ra_metrics['test_acc']
# bd_test_dataloader.dataset.wrapped_dataset.getitem_all_switch = False # switch back
#
# return clean_test_loss_avg_over_batch, \
# bd_test_loss_avg_over_batch, \
# ra_test_loss_avg_over_batch, \
# test_acc, \
# test_asr, \
# test_ra
#
#
# def defense(self,result_file):
# self.set_result(result_file)
# self.set_logger()
# result = self.mitigation()
# return result
#
#
# def eval_attack(self, netC, net_ref, clean_test_dataloader, pert, args = None):
# total_success = 0
# total_success_ref = 0
# total_success_common = 0
# total_success_shared = 0
#
# total_samples = 0
# for images, labels, *other_info in clean_test_dataloader:
# images = images.to(self.args.device)
# labels = labels.to(self.args.device)
# pert_image = self.get_perturbed_image(images=images, pert=pert)
# outputs = netC(pert_image)
# outputs_ref = net_ref(pert_image)
# _, predicted = torch.max(outputs.data, 1)
# _, predicted_ref = torch.max(outputs_ref.data, 1)
# total_success += (predicted != labels).sum().item()
# total_success_ref += (predicted_ref != labels).sum().item()
# total_success_common += (torch.logical_and(predicted != labels, predicted_ref != labels)).sum().item()
# total_success_shared += (torch.logical_and(predicted != labels, predicted_ref == predicted)).sum().item()
# total_samples += labels.size(0)
#
# return total_success/total_samples, total_success_ref/total_samples, total_success_common/total_samples, total_success_shared/total_samples
#
# def eval_binary(self, netC, net_ref, bd_test_dataloader, pert, args = None):
# total_success = 0
# total_success_ref = 0
# total_success_common = 0
# total_success_shared = 0
#
# total_samples = 0
# for images, labels, *other_info in bd_test_dataloader:
# images = images.to(self.args.device)
# labels = labels.to(self.args.device)
# pert_image = self.get_perturbed_image(images=images, pert=pert)
# outputs = netC(pert_image)
# outputs_ref = net_ref(pert_image)
# _, predicted = torch.max(outputs.data, 1)
# _, predicted_ref = torch.max(outputs_ref.data, 1)
# total_success += (predicted != labels).sum().item()
# total_success_ref += (predicted_ref != labels).sum().item()
# total_success_common += (torch.logical_and(predicted != labels, predicted_ref != labels)).sum().item()
# total_success_shared += (torch.logical_and(predicted != labels, predicted_ref == predicted)).sum().item()
# total_samples += labels.size(0)
#
# return total_success/total_samples, total_success_ref/total_samples, total_success_common/total_samples, total_success_shared/total_samples
#
# def defense(self,result_file):
# self.set_result(result_file)
# self.set_logger()
# result = self.mitigation()
# return result
#
# if __name__ == '__main__':
# parser = argparse.ArgumentParser(description=sys.argv[0])
# sau.add_arguments(parser)
# args = parser.parse_args()
# sau_method = sau(args)
# if "result_file" not in args.__dict__:
# args.result_file = 'defense_test_badnet'
# elif args.result_file is None:
# args.result_file = 'defense_test_badnet'
# result = sau_method.defense(args.result_file)