Build Your Own Backdoor Dataset

This is a simple script to build your own backdoor dataset. Specifically, this file include two parts.

  1. How to add a new dataset into our framework.

  2. Further demonstration about how the class prepro_cls_DatasetBD_v2 is working.

  • How to add a new dataset into our framework

    • Step 1: Download the dataset and put it into the folder ./data/

    • Step 2: Add the dataset name into the file ./data/__init__.py

    • Step 3: Add the dataset class into the file ./data/{dataset_name}.py

      class YourDatasetName(Dataset):
      
        def __init__(self, ):
          ...
      
        def __len__(self):
          return ...
      
        def __getitem__(self, idx):
          return ...
      
      • Your custom dataset should follow the format of torch.utils.data.Dataset.

      • The __getitem__ function should return a tuple (img, label).

    • Step 4: Write how to call you dataset and get num_classes, input_shape, normalization, and other basic information in the backdoorbench/utils/aggregate_block/dataset_and_transform_generate.py. Take CIFAR-10 as an example.

      def dataset_and_transform_generate(args):
      
        ...
      
        elif args.dataset == 'cifar10':
          from torchvision.datasets import CIFAR10
          train_dataset_without_transform = CIFAR10(
            args.dataset_path,
            train=True,
            transform=None,
            download=True,
          )
          test_dataset_without_transform = CIFAR10(
            args.dataset_path,
            train=False,
            transform=None,
            download=True,
          )
      
        ...
        return train_dataset_without_transform, \
           train_img_transform, \
           train_label_transform, \
           test_dataset_without_transform, \
           test_img_transform, \
           test_label_transform, \
      
      • The train_dataset_without_transform and test_dataset_without_transform should be the dataset without any transform.

      • The train_img_transform and test_img_transform should be the transform for the input image.

      • The train_label_transform and test_label_transform should be the transform for the input label.

  • Further demonstration about how the class prepro_cls_DatasetBD_v2 is working.(File is at backdoorbench/utils/bd_dataset_v2.py)

    • The basic idea is that this dataset take a clean dataset without transformation. And We use poison_indicator to indicate whether the sample is poisoned or not. If the sample is poisoned, we will use the bd_image_pre_transform and bd_label_pre_transform to transform the sample. Otherwise, we will not change the file.

    • For space concern, we use poisonedCLSDataContainer to only save the poisoned samples on disk/RAM.

    • original_index_array is used to record the index of the original dataset. For example, if the original dataset is CIFAR-10, then the original_index_array will be a list of 0 to 49999. Notice that this array also work in subset functionality.

    • When we call this class

      • Step 1: We will first load all given information and set them.

        self.dataset = full_dataset_without_transform
        
        if poison_indicator is None:
          poison_indicator = np.zeros(len(full_dataset_without_transform))
        self.poison_indicator = poison_indicator
        
        assert len(full_dataset_without_transform) == len(poison_indicator)
        
        self.bd_image_pre_transform = bd_image_pre_transform
        self.bd_label_pre_transform = bd_label_pre_transform
        
        self.save_folder_path = save_folder_path # since when we want to save this dataset, this may cause problem
        
        self.original_index_array = np.arange(len(full_dataset_without_transform))
        
        self.bd_data_container = poisonedCLSDataContainer(self.save_folder_path, ".png")
        
      • Step 2: We check if any position in poison_indicator is 1, if exists, then we do trigger injection.

        if sum(self.poison_indicator) >= 1:
          self.prepro_backdoor()
        
        def prepro_backdoor(self):
          for selected_index in tqdm(self.original_index_array, desc="prepro_backdoor"):
            if self.poison_indicator[selected_index] == 1:
              img, label = self.dataset[selected_index]
              img = self.bd_image_pre_transform(img, target=label, image_serial_id=selected_index)
              bd_label = self.bd_label_pre_transform(label)
              self.set_one_bd_sample(
                selected_index, img, bd_label, label
              )
        
        def set_one_bd_sample(self, selected_index, img, bd_label, label):
        
          '''
          1. To pil image
          2. set the image to container
          3. change the poison_index.
        
          logic is that no matter by the prepro_backdoor or not, after we set the bd sample,
          This method will automatically change the poison index to 1.
        
          :param selected_index: The index of bd sample
          :param img: The converted img that want to put in the bd_container
          :param bd_label: The label bd_sample has
          :param label: The original label bd_sample has
        
          '''
        
          # we need to save the bd img, so we turn it into PIL
          if (not isinstance(img, Image.Image)) and self.save_folder_path is not None:
            if isinstance(img, np.ndarray):
              img = img.astype(np.uint8)
            img = ToPILImage()(img)
          self.bd_data_container.setitem(
            key=selected_index,
            value=(img, bd_label, label),
            relative_loc_to_save_folder_name=f"{label}",
          )
          self.poison_indicator[selected_index] = 1
        
        • prepro_backdoor is for more than one bd sample. set_one_bd_sample is for one bd sample.

      • Step 3: Since we often not only need image and label, but also other information like index and original label are needed. More information is set, and will be returned at __getitem__

        self.getitem_all = True
        self.getitem_all_switch = False
        
        def __getitem__(self, index):
        
          original_index = self.original_index_array[index]
          if self.poison_indicator[original_index] == 0:
            # clean
            img, label = self.dataset[original_index]
            original_target = label
            poison_or_not = 0
          else:
            # bd
            img, label, original_target = self.bd_data_container[original_index]
            poison_or_not = 1
        
          if not isinstance(img, Image.Image):
            img = ToPILImage()(img)
        
          if self.getitem_all:
            if self.getitem_all_switch:
              # this is for the case that you want original targets, but you do not want change your testing process
              return img, \
                   original_target, \
                   original_index, \
                   poison_or_not, \
                   label
        
            else: # here should corresponding to the order in the bd trainer
              return img, \
                   label, \
                   original_index, \
                   poison_or_not, \
                   original_target
          else:
            return img, label
        
      • Step 4 (Optional): We can also save the dataset for further usage. retrieve_state and set_state are useful for save and load all information in dataset.

        def retrieve_state(self):
          return {
            "bd_data_container" : self.bd_data_container.retrieve_state(),
            "getitem_all":self.getitem_all,
            "getitem_all_switch":self.getitem_all_switch,
            "original_index_array": self.original_index_array,
            "poison_indicator": self.poison_indicator,
            "save_folder_path": self.save_folder_path,
          }
        
        def set_state(self, state_file):
          self.bd_data_container = poisonedCLSDataContainer()
          self.bd_data_container.set_state(
            state_file['bd_data_container']
          )
          self.getitem_all = state_file['getitem_all']
          self.getitem_all_switch = state_file['getitem_all_switch']
          self.original_index_array = state_file["original_index_array"]
          self.poison_indicator = state_file["poison_indicator"]
          self.save_folder_path = state_file["save_folder_path"]
        
        • The save and load are also used in the save_attack_result and load_attack_result, you can take a look at them.