attack.NormalCase
- class NormalCase[source]
Bases:
object
Normal training case (Train a clean model with clean data)
basic structure:
config args, save_path, fix random seed
set the clean train data and clean test data
set the device, model, criterion, optimizer, training schedule.
train model
save the clean model
normal_train_process = NormalCase() normal_train_process.attack()
- Parameters:
num_workers (int) – num_workers used in dataloader
pin_memory (bool) – dataloader pin_memory or not
non_blocking (bool) – move to GPU use non_blocking or not
prefetch (bool) – use prefetch or not (Still in test, default is False, not recommend)
amp (bool) – use amp or not
device (str) – device used for training
lr_scheduler (str) – lr scheduler used for training
epochs (int) – epochs for training
dataset (str) – which dataset is used for training
dataset_path (str) – where the dataset is (we add the default path of data folder after, so only need to provide the folder name)
batch_size (int) – batch size for training
lr (float) – learning rate for training
steplr_stepsize (int) – step size for steplr if use StepLR scheduler
steplr_gamma (float) – gamma for steplr if use StepLR scheduler
sgd_momentum (float) – momentum for optimizer
wd (float) – weight decay for optimizer
steplr_milestones (list) – milestones for steplr if use StepLR scheduler
client_optimizer (int) – which optimizer is used for training
random_seed (int) – random seed for training
frequency_save (int) – frequency to save the model, 0 is never
model (str) – which model is used for training
attack_save_path (str) – where to save the model
yaml_path (str) – yaml path for training settings