
class NormalCase[source]

Bases: object

Normal training case (Train a clean model with clean data)

basic structure:

  1. config args, save_path, fix random seed

  2. set the clean train data and clean test data

  3. set the device, model, criterion, optimizer, training schedule.

  4. train model

  5. save the clean model

normal_train_process = NormalCase()
  • num_workers (int) – num_workers used in dataloader

  • pin_memory (bool) – dataloader pin_memory or not

  • non_blocking (bool) – move to GPU use non_blocking or not

  • prefetch (bool) – use prefetch or not (Still in test, default is False, not recommend)

  • amp (bool) – use amp or not

  • device (str) – device used for training

  • lr_scheduler (str) – lr scheduler used for training

  • epochs (int) – epochs for training

  • dataset (str) – which dataset is used for training

  • dataset_path (str) – where the dataset is (we add the default path of data folder after, so only need to provide the folder name)

  • batch_size (int) – batch size for training

  • lr (float) – learning rate for training

  • steplr_stepsize (int) – step size for steplr if use StepLR scheduler

  • steplr_gamma (float) – gamma for steplr if use StepLR scheduler

  • sgd_momentum (float) – momentum for optimizer

  • wd (float) – weight decay for optimizer

  • steplr_milestones (list) – milestones for steplr if use StepLR scheduler

  • client_optimizer (int) – which optimizer is used for training

  • random_seed (int) – random seed for training

  • frequency_save (int) – frequency to save the model, 0 is never

  • model (str) – which model is used for training

  • attack_save_path (str) – where to save the model

  • yaml_path (str) – yaml path for training settings