pytorch-checkpoints相关

[TOC]

checkpoints 导出

1
torch.save({'model_state_dict': model.state_dict()}, '/path/to/checkpoints')

checkpoints 导入

1
2
state = torch.load('/path/to/checkpoints')
model.load_state_dict(state['model_state_dict'])

model.state_dict() 描述

1
2
3
4
5
(method) def state_dict(
*,
prefix: str = ...,
keep_vars: bool = ...
) -> Dict[str, Any]

model.state_dict().keys()

1
print(model)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Sequential(
(0): Conv2d(3, 16, kernel_size=(11, 11), stride=(3, 3))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
(5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): ReLU(inplace=True)
(11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
(12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(13): ReLU(inplace=True)
(14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(15): Flatten(start_dim=1, end_dim=-1)
(16): Linear(in_features=3136, out_features=2048, bias=True)
(17): ReLU(inplace=True)
(18): Linear(in_features=2048, out_features=1, bias=True)
(19): Sigmoid()
)
1
print(model.state_dict().keys())
1
odict_keys(['0.weight', '0.bias', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked', '4.weight', '4.bias', '5.weight', '5.bias', '5.running_mean', '5.running_var', '5.num_batches_tracked', '8.weight', '8.bias', '9.weight', '9.bias', '9.running_mean', '9.running_var', '9.num_batches_tracked', '11.weight', '11.bias', '12.weight', '12.bias', '12.running_mean', '12.running_var', '12.num_batches_tracked', '16.weight', '16.bias', '18.weight', '18.bias'])