Skip to content

Device

use_accel = not args.no_accel and torch.accelerator.is_available()

if use_accel:
    device = torch.accelerator.current_accelerator()
    accel_kwargs = {'num_workers'       : 1,
                    'persistent_workers': True,
                    'pin_memory'        : True,
                    'shuffle'           : True}
    train_kwargs.update(accel_kwargs)
    test_kwargs.update(accel_kwargs)
else:
    device = torch.device("cpu")