Device
use_accel = not args.no_accel and torch.accelerator.is_available()
if use_accel:
device = torch.accelerator.current_accelerator()
accel_kwargs = {'num_workers' : 1,
'persistent_workers': True,
'pin_memory' : True,
'shuffle' : True}
train_kwargs.update(accel_kwargs)
test_kwargs.update(accel_kwargs)
else:
device = torch.device("cpu")