Build Your Own Backdoor Defense
This is a simple demonstration of how to build a backdoor defense using the our framework.
We take the default case of creating a mitigating backdoor defense as an example. (We inherit from the base class in ./defense/base.py and create your own defense)
Hyperparameter setting and basic configuration
First by copying from the ft class, we have basic training hyperaparameters in args. You can add more into the parser for your specific usages, you can add you own defense parameter and initial the defense module
parser = argparse.ArgumentParser(description=sys.argv[0]) defense_name.add_arguments(parser) args = parser.parse_args() defense_name_method = defense_name(args)
Backdoor attack result load and parepare your own logger
We first load the backdoor attack result and set your log/checkpoint/ dir in set_result() function
attack_file = 'record/' + result_file save_path = 'record/' + result_file + '/defense/ft/' if not (os.path.exists(save_path)): os.makedirs(save_path) self.args.save_path = save_path if self.args.checkpoint_save is None: self.args.checkpoint_save = save_path + 'checkpoint/' if not (os.path.exists(self.args.checkpoint_save)): os.makedirs(self.args.checkpoint_save) if self.args.log is None: self.args.log = save_path + 'log/' if not (os.path.exists(self.args.log)): os.makedirs(self.args.log) self.result = load_attack_result(attack_file + '/attack_result.pt')
set your logger with function set_logger
args = self.args logFormatter = logging.Formatter( fmt='%(asctime)s [%(levelname)-8s] [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d:%H:%M:%S', ) logger = logging.getLogger() fileHandler = logging.FileHandler(args.log + '/' + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + '.log') fileHandler.setFormatter(logFormatter) logger.addHandler(fileHandler) consoleHandler = logging.StreamHandler() consoleHandler.setFormatter(logFormatter) logger.addHandler(consoleHandler) logger.setLevel(logging.INFO) logging.info(pformat(args.__dict__)) try: logging.info(pformat(get_git_info())) except: logging.info('Getting git info fails.')
Parepare you training dataset, model and trainer
If you need some clean sample to mitigate backdoor, you can initial it by following code
train_tran = get_transform(self.args.dataset, *([self.args.input_height,self.args.input_width]) , train = True) clean_dataset = prepro_cls_DatasetBD_v2(self.result['clean_train'].wrapped_dataset) data_all_length = len(clean_dataset) ran_idx = choose_index(self.args, data_all_length) log_index = self.args.log + 'index.txt' np.savetxt(log_index, ran_idx, fmt='%d') clean_dataset.subset(ran_idx) data_set_without_tran = clean_dataset data_set_o = self.result['clean_train'] data_set_o.wrapped_dataset = data_set_without_tran data_set_o.wrap_img_transform = train_tran data_loader = torch.utils.data.DataLoader(data_set_o, batch_size=self.args.batch_size, num_workers=self.args.num_workers, shuffle=True, pin_memory=args.pin_memory) trainloader = data_loader test_tran = get_transform(self.args.dataset, *([self.args.input_height,self.args.input_width]) , train = False) data_bd_testset = self.result['bd_test'] data_bd_testset.wrap_img_transform = test_tran data_bd_loader = torch.utils.data.DataLoader(data_bd_testset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,drop_last=False, shuffle=True,pin_memory=args.pin_memory) data_clean_testset = self.result['clean_test'] data_clean_testset.wrap_img_transform = test_tran data_clean_loader = torch.utils.data.DataLoader(data_clean_testset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,drop_last=False, shuffle=True,pin_memory=args.pin_memory)
Here the pratio and p_num can be used to poison by fraction or poison by exact sample number
set your model and trainer
model = generate_cls_model(self.args.model,self.args.num_classes) model.load_state_dict(self.result['model']) model.to(self.args.device) optimizer, scheduler = argparser_opt_scheduler(model, self.args) self.set_trainer(model) criterion = argparser_criterion(args)
You can use PureCleanModelTrainer or your own trainer for mitigating in set_trainer() function
self.trainer = PureCleanModelTrainer( model, )
Training to mitigate backdoor
Mitigate your backdoor with training process
self.trainer.train_with_test_each_epoch_on_mix( trainloader, data_clean_loader, data_bd_loader, args.epochs, criterion=criterion, optimizer=optimizer, scheduler=scheduler, device=self.args.device, frequency_save=args.frequency_save, save_folder_path=args.save_path, save_prefix='ft', amp=args.amp, prefetch=args.prefetch, prefetch_transform_attr_name="ori_image_transform_in_loading", # since we use the preprocess_bd_dataset non_blocking=args.non_blocking, )
If you need to test several threshold for mitigating backdoor, you can refer to anp defense.
Saving your defense result
This is handled by save_defense_result, you should give it the basic setting information for further loading.
save_defense_result( model_name=args.model, num_classes=args.num_classes, model=model.cpu().state_dict(), save_path=args.save_path, )