Files
2026-05-22 10:02:10 +02:00

56 lines
1.4 KiB
Python

"""
Some handy functions for pytroch model training ...
"""
import torch
from torch.optim import Optimizer
# Checkpoints
def save_checkpoint(model, model_dir):
torch.save(model.state_dict(), model_dir)
def resume_checkpoint(model, model_dir, device_id):
device = f"cuda:{device_id}"
state_dict = torch.load(model_dir, map_location=device)
model.load_state_dict(state_dict)
# Hyper params
def use_cuda(enabled, device_id=0):
if enabled:
assert torch.cuda.is_available(), "CUDA is not available"
torch.cuda.set_device(device_id)
def use_optimizer(
optimizer_name: str,
network: torch.nn.Module,
learning_rate: float,
momentum: float = 0,
weight_decay: float = 0,
alpha: float = 0.99,
) -> Optimizer:
if optimizer_name == "sgd":
optimizer = torch.optim.SGD(
network.parameters(),
lr=learning_rate,
momentum=momentum,
weight_decay=weight_decay,
)
elif optimizer_name == "adam":
optimizer = torch.optim.Adam(
network.parameters(), lr=learning_rate, weight_decay=weight_decay
)
elif optimizer_name == "rmsprop":
optimizer = torch.optim.RMSprop(
network.parameters(), lr=learning_rate, alpha=alpha, momentum=momentum
)
else:
raise ValueError(f"Optimizer '{optimizer_name}' is not supported")
return optimizer