Source code for defense.mcr

# python defense/mcr/mcr.py --save_path /workspace/chenhongrui/bdzoo2/record/t_914_badnet
'''
This file is modified based on the following source:
link : https://github.com/IBM/model-sanitization.
The defense method is called mcr.

Since the model is different from original paper, we change the hyperparameter for preactresnet18 on cifar10 to align the performance.

'''

from defense.base import defense


# def plot_hessian_eigenvalues(
#         model_visual,
#         data_loader,  # only use one batch
#         device,
#         save_path_for_hessian=None,  # xx/xx/xx.png
# ):
#     # save_path_for_hessian =
#     # data_loader =
#     # device =
#     # model_visual =
#
#     model_visual = (model_visual)
#     data_loader = (data_loader)
#     model_visual.to(device)
#
#     # !!! Important to set eval mode !!!
#     model_visual.eval()
#
#     criterion = torch.nn.CrossEntropyLoss()
#
#     batch_x, batch_y, *others = next(iter(data_loader))
#     batch_x = batch_x.to(device)
#     batch_y = batch_y.to(device)
#
#     if torch.__version__ > '1.8.1':
#         logging.info('Use self-defined function as an alternative for torch.eig since your torch>=1.9')
#
#         def old_torcheig(A, eigenvectors):
#             '''A temporary function as an alternative for torch.eig (torch<1.9)'''
#             vals, vecs = torch.linalg.eig(A)
#             if torch.is_complex(vals) or torch.is_complex(vecs):
#                 logging.info(
#                     'Warning: Complex values founded in Eigenvalues/Eigenvectors. This is impossible for real symmetric matrix like Hessian. \n We only keep the real part.')
#
#             vals = torch.real(vals)
#             vecs = torch.real(vecs)
#
#             # vals is a nx2 matrix. see https://virtualgroup.cn/pytorch.org/docs/stable/generated/torch.eig.html
#             vals = vals.view(-1, 1) + torch.zeros(vals.size()[0], 2).to(vals.device)
#             if eigenvectors:
#                 return vals, vecs
#             else:
#                 return vals, torch.tensor([])
#
#         torch.eig = old_torcheig
#
#     # create the hessian computation module
#     hessian_comp = hessian(model_visual, criterion, data=(batch_x, batch_y), cuda=True)
#     # Now let's compute the top 2 eigenavlues and eigenvectors of the Hessian
#     top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues(top_n=2, maxIter=1000)
#     logging.info("The top two eigenvalues of this model are: %.4f %.4f" % (top_eigenvalues[0], top_eigenvalues[1]))
#
#     if save_path_for_hessian is not None:
#
#         density_eigen, density_weight = hessian_comp.density()
#
#         def get_esd_plot(eigenvalues, weights):
#             density, grids = density_generate(eigenvalues, weights)
#             plt.semilogy(grids, density + 1.0e-7)
#             plt.ylabel('Density (Log Scale)', fontsize=14, labelpad=10)
#             plt.xlabel('Eigenvlaue', fontsize=14, labelpad=10)
#             plt.xticks(fontsize=12)
#             plt.yticks(fontsize=12)
#             plt.axis([np.min(eigenvalues) - 1, np.max(eigenvalues) + 1, None, None])
#             return plt.gca()
#
#         def density_generate(eigenvalues,
#                              weights,
#                              num_bins=10000,
#                              sigma_squared=1e-5,
#                              overhead=0.01):
#             eigenvalues = np.array(eigenvalues)
#             weights = np.array(weights)
#
#             lambda_max = np.mean(np.max(eigenvalues, axis=1), axis=0) + overhead
#             lambda_min = np.mean(np.min(eigenvalues, axis=1), axis=0) - overhead
#
#             grids = np.linspace(lambda_min, lambda_max, num=num_bins)
#             sigma = sigma_squared * max(1, (lambda_max - lambda_min))
#
#             num_runs = eigenvalues.shape[0]
#             density_output = np.zeros((num_runs, num_bins))
#
#             for i in range(num_runs):
#                 for j in range(num_bins):
#                     x = grids[j]
#                     tmp_result = gaussian(eigenvalues[i, :], x, sigma)
#                     density_output[i, j] = np.sum(tmp_result * weights[i, :])
#             density = np.mean(density_output, axis=0)
#             normalization = np.sum(density) * (grids[1] - grids[0])
#             density = density / normalization
#             return density, grids
#
#         def gaussian(x, x0, sigma_squared):
#             return np.exp(-(x0 - x) ** 2 /
#                           (2.0 * sigma_squared)) / np.sqrt(2 * np.pi * sigma_squared)
#
#         ax = get_esd_plot(density_eigen, density_weight)
#
#         ax.set_title(f'Max Eigen Value: {top_eigenvalues[0]:.2f}')
#
#         plt.tight_layout()
#         plt.savefig(save_path_for_hessian)
#         plt.close()
#
#         logging.info(f'Save to {save_path_for_hessian}')
#
#     return top_eigenvalues


[docs]class mcr(defense): r'''Bridging mode connectivity in loss landscapes and adversarial robustness basic structure: 1. config args, save_path, fix random seed 2. load the backdoor attack data and backdoor test data 3. mcr a. use poisoned model and clean(finetuned from poison) model to form a curve in parameter space b. train curve with given subset of data, test with given t 4. test the result and get ASR, ACC, RC .. code-block:: python mcr = mcr() parser = argparse.ArgumentParser(description=sys.argv[0]) parser = mcr.set_args(parser) args = parser.parse_args() mcr.add_yaml_to_args(args) args = mcr.process_args(args) mcr.prepare(args) mcr.defense() .. Note:: @inproceedings{zhao2020bridging, title={BRIDGING MODE CONNECTIVITY IN LOSS LANDSCAPES AND ADVERSARIAL ROBUSTNESS}, author={Zhao, Pu and Chen, Pin-Yu and Das, Payel and Ramamurthy, Karthikeyan Natesan and Lin, Xue}, booktitle={International Conference on Learning Representations (ICLR 2020)}, year={2020}} Args: baisc args: in the base class train_curve_epochs(int): how many epochs to train the curve num_bends(int): number of bends in curve test_t(float): t of tested model on the curve (which points on the curve is used in test) curve(str): which curve is used ft_epochs(int): finetune epochs ft_lr_scheduler(str): finetune lr_scheduler ratio(float): the ratio of clean data loader acc_ratio(float): the tolerance ration of the clean accuracy test_curve_every(int): frequency of testing the models on curve load_other_model_path(str): instead of finetune the given poisoned model, we load other model from this part use_clean_subset(bool): use bd poison dataset as data poison for path training and BN update; or, use clean subset instead '''