galaxy.train
Attributes
Classes
Module Contents
- galaxy.train.IS_CLUSTER_THRESHOLD = 0.5
- class galaxy.train.Trainer(model_name: str, model: torch.nn.Module, optimizer_name: str, optimizer: torch.optim.Optimizer, train_dataloader: torch.utils.data.DataLoader, val_dataloader: torch.utils.data.DataLoader, experiment: comet_ml.Experiment, criterion: torch.nn.Module | None = None, lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, lr_scheduler_type: str = 'per_epoch', batch_size: int = 128)
Trainer class for managing model training, validation, and testing.
- model_name
- model
- optimizer_name
- optimizer
- criterion = None
- lr_scheduler = None
- lr_scheduler_type = 'per_epoch'
- train_dataloader
- val_dataloader
- experiment
- batch_size = 128
- train_table_data: list = []
- val_table_data: list = []
- history
- device
- global_step = 0
- cache
- post_train_batch() None
Post-processing after each training batch.
- post_val_batch()
- post_train_stage()
- post_val_stage()
- save_checkpoint()
- log_metrics(loss: float, acc: float, mode: str, step: int | None = None, epoch: int | None = None) None
Logs metrics to Comet.ml.
- Args:
loss (float): Loss value. acc (float): Accuracy value. mode (str): Mode (‘train’ or ‘val’). step (Optional[int]): Step number. Defaults to None. epoch (Optional[int]): Epoch number. Defaults to None.
- train(num_epochs: int) Tuple[List[Tuple[int, float, float]], List[Tuple[int, float, float]]]
Trains the model for a specified number of epochs.
- Args:
num_epochs (int): Number of epochs to train.
- test(test_dataloader: torch.utils.data.DataLoader) Tuple[pandas.DataFrame, List[float], List[float]]
Evaluates the model on a test dataset.
- Args:
test_dataloader (DataLoader): DataLoader containing the test dataset.
- Returns:
- Tuple[pd.DataFrame, List[float], List[float]]:
Predictions DataFrame with columns for true labels, predicted labels, probabilities, and metadata.
List of loss values for each batch.
List of accuracy values for each batch.
- compute_all(batch: dict) tuple
Computes logits, loss, and accuracy for a batch.
- Args:
batch (dict): Input batch containing images and labels.
- Returns:
tuple: Logits, outputs, labels, loss, and accuracy.
- cache_states() dict
Caches the current states of the model and optimizer.
- Returns:
dict: Dictionary containing model and optimizer states.
- rollback_states() None
Rolls back the model and optimizer to the cached states.
- find_lr(min_lr: float = 1e-06, max_lr: float = 0.1, num_lrs: int = 20, smoothing_window: int = 30, smooth_beta: float = 0.8) float
Finds the optimal learning rate using a range test.
- Args:
min_lr (float, optional): Minimum learning rate. Defaults to 1e-6. max_lr (float, optional): Maximum learning rate. Defaults to 1e-1. num_lrs (int, optional): Number of learning rates to test. Defaults to 20. smoothing_window (int, optional): Window size for smoothing. Defaults to 30. smooth_beta (float, optional): Beta value for loss smoothing. Defaults to 0.8.
- Returns:
float: Optimal learning rate.
- class galaxy.train.Predictor(model: torch.nn.Module, device)
Predictor class for evaluating models on test data.
- model
- device
- predict(dataloader: torch.utils.data.DataLoader) pandas.DataFrame
Generates predictions for the given DataLoader.
- Args:
dataloader (DataLoader): DataLoader containing test data.
- Returns:
pd.DataFrame: Predictions with true labels and additional metadata.
- compute_all(batch: dict) tuple
Computes logits and outputs for a batch.
- Args:
batch (dict): Input batch.
- Returns:
tuple: Logits and outputs.