[TOC]
checkpoints 导出
1 | torch.save({'model_state_dict': model.state_dict()}, '/path/to/checkpoints') |
checkpoints 导入
1 | state = torch.load('/path/to/checkpoints') |
model.state_dict() 描述
1 | (method) def state_dict( |
model.state_dict().keys()
1 | print(model) |
1 | Sequential( |
1 | print(model.state_dict().keys()) |
1 | odict_keys(['0.weight', '0.bias', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked', '4.weight', '4.bias', '5.weight', '5.bias', '5.running_mean', '5.running_var', '5.num_batches_tracked', '8.weight', '8.bias', '9.weight', '9.bias', '9.running_mean', '9.running_var', '9.num_batches_tracked', '11.weight', '11.bias', '12.weight', '12.bias', '12.running_mean', '12.running_var', '12.num_batches_tracked', '16.weight', '16.bias', '18.weight', '18.bias']) |