56 lines
1.4 KiB
Python
56 lines
1.4 KiB
Python
"""
|
|
Some handy functions for pytroch model training ...
|
|
"""
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
# Checkpoints
|
|
def save_checkpoint(model, model_dir):
|
|
torch.save(model.state_dict(), model_dir)
|
|
|
|
|
|
def resume_checkpoint(model, model_dir, device_id):
|
|
device = f"cuda:{device_id}"
|
|
state_dict = torch.load(model_dir, map_location=device)
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
|
# Hyper params
|
|
def use_cuda(enabled, device_id=0):
|
|
if enabled:
|
|
assert torch.cuda.is_available(), "CUDA is not available"
|
|
torch.cuda.set_device(device_id)
|
|
|
|
|
|
def use_optimizer(
|
|
optimizer_name: str,
|
|
network: torch.nn.Module,
|
|
learning_rate: float,
|
|
momentum: float = 0,
|
|
weight_decay: float = 0,
|
|
alpha: float = 0.99,
|
|
) -> Optimizer:
|
|
if optimizer_name == "sgd":
|
|
optimizer = torch.optim.SGD(
|
|
network.parameters(),
|
|
lr=learning_rate,
|
|
momentum=momentum,
|
|
weight_decay=weight_decay,
|
|
)
|
|
|
|
elif optimizer_name == "adam":
|
|
optimizer = torch.optim.Adam(
|
|
network.parameters(), lr=learning_rate, weight_decay=weight_decay
|
|
)
|
|
|
|
elif optimizer_name == "rmsprop":
|
|
optimizer = torch.optim.RMSprop(
|
|
network.parameters(), lr=learning_rate, alpha=alpha, momentum=momentum
|
|
)
|
|
else:
|
|
raise ValueError(f"Optimizer '{optimizer_name}' is not supported")
|
|
|
|
return optimizer
|