Build Your Own Backdoor Defense

This is a simple demonstration of how to build a backdoor defense using the our framework.

We take the default case of creating a mitigating backdoor defense as an example. (We inherit from the base class in ./defense/base.py and create your own defense)

  • Hyperparameter setting and basic configuration

    • First by copying from the ft class, we have basic training hyperaparameters in args. You can add more into the parser for your specific usages, you can add you own defense parameter and initial the defense module

      parser = argparse.ArgumentParser(description=sys.argv[0])
      defense_name.add_arguments(parser)
      args = parser.parse_args()
      defense_name_method = defense_name(args)
      
  • Backdoor attack result load and parepare your own logger

    • We first load the backdoor attack result and set your log/checkpoint/ dir in set_result() function

      attack_file = 'record/' + result_file
      save_path = 'record/' + result_file + '/defense/ft/'
      if not (os.path.exists(save_path)):
          os.makedirs(save_path)
      self.args.save_path = save_path
      if self.args.checkpoint_save is None:
          self.args.checkpoint_save = save_path + 'checkpoint/'
          if not (os.path.exists(self.args.checkpoint_save)):
              os.makedirs(self.args.checkpoint_save)
      if self.args.log is None:
          self.args.log = save_path + 'log/'
          if not (os.path.exists(self.args.log)):
              os.makedirs(self.args.log)
      self.result = load_attack_result(attack_file + '/attack_result.pt')
      
    • set your logger with function set_logger

      args = self.args
      logFormatter = logging.Formatter(
          fmt='%(asctime)s [%(levelname)-8s] [%(filename)s:%(lineno)d] %(message)s',
          datefmt='%Y-%m-%d:%H:%M:%S',
      )
      logger = logging.getLogger()
      
      fileHandler = logging.FileHandler(args.log + '/' + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + '.log')
      fileHandler.setFormatter(logFormatter)
      logger.addHandler(fileHandler)
      
      consoleHandler = logging.StreamHandler()
      consoleHandler.setFormatter(logFormatter)
      logger.addHandler(consoleHandler)
      
      logger.setLevel(logging.INFO)
      logging.info(pformat(args.__dict__))
      
      try:
          logging.info(pformat(get_git_info()))
      except:
          logging.info('Getting git info fails.')
      
  • Parepare you training dataset, model and trainer

    • If you need some clean sample to mitigate backdoor, you can initial it by following code

      train_tran = get_transform(self.args.dataset, *([self.args.input_height,self.args.input_width]) , train = True)
      clean_dataset = prepro_cls_DatasetBD_v2(self.result['clean_train'].wrapped_dataset)
      data_all_length = len(clean_dataset)
      ran_idx = choose_index(self.args, data_all_length)
      log_index = self.args.log + 'index.txt'
      np.savetxt(log_index, ran_idx, fmt='%d')
      clean_dataset.subset(ran_idx)
      data_set_without_tran = clean_dataset
      data_set_o = self.result['clean_train']
      data_set_o.wrapped_dataset = data_set_without_tran
      data_set_o.wrap_img_transform = train_tran
      data_loader = torch.utils.data.DataLoader(data_set_o, batch_size=self.args.batch_size, num_workers=self.args.num_workers, shuffle=True, pin_memory=args.pin_memory)
      trainloader = data_loader
      
      test_tran = get_transform(self.args.dataset, *([self.args.input_height,self.args.input_width]) , train = False)
      data_bd_testset = self.result['bd_test']
      data_bd_testset.wrap_img_transform = test_tran
      data_bd_loader = torch.utils.data.DataLoader(data_bd_testset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,drop_last=False, shuffle=True,pin_memory=args.pin_memory)
      
      data_clean_testset = self.result['clean_test']
      data_clean_testset.wrap_img_transform = test_tran
      data_clean_loader = torch.utils.data.DataLoader(data_clean_testset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,drop_last=False, shuffle=True,pin_memory=args.pin_memory)
      
      • Here the pratio and p_num can be used to poison by fraction or poison by exact sample number

    • set your model and trainer

      model = generate_cls_model(self.args.model,self.args.num_classes)
      model.load_state_dict(self.result['model'])
      model.to(self.args.device)
      
      
      optimizer, scheduler = argparser_opt_scheduler(model, self.args)
      
      self.set_trainer(model)
      criterion = argparser_criterion(args)
      
      • You can use PureCleanModelTrainer or your own trainer for mitigating in set_trainer() function

        self.trainer = PureCleanModelTrainer(
            model,
        )
        
  • Training to mitigate backdoor

    • Mitigate your backdoor with training process

      self.trainer.train_with_test_each_epoch_on_mix(
        trainloader,
        data_clean_loader,
        data_bd_loader,
        args.epochs,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=self.args.device,
        frequency_save=args.frequency_save,
        save_folder_path=args.save_path,
        save_prefix='ft',
        amp=args.amp,
        prefetch=args.prefetch,
        prefetch_transform_attr_name="ori_image_transform_in_loading", # since we use the preprocess_bd_dataset
        non_blocking=args.non_blocking,
      )
      
    • If you need to test several threshold for mitigating backdoor, you can refer to anp defense.

  • Saving your defense result

    • This is handled by save_defense_result, you should give it the basic setting information for further loading.

      save_defense_result(
        model_name=args.model,
        num_classes=args.num_classes,
        model=model.cpu().state_dict(),
        save_path=args.save_path,
      )