galaxy.train

Attributes

Classes

Trainer

Trainer class for managing model training, validation, and testing.

Predictor

Predictor class for evaluating models on test data.

Module Contents

galaxy.train.IS_CLUSTER_THRESHOLD = 0.5
class galaxy.train.Trainer(model_name: str, model: torch.nn.Module, optimizer_name: str, optimizer: torch.optim.Optimizer, train_dataloader: torch.utils.data.DataLoader, val_dataloader: torch.utils.data.DataLoader, experiment: comet_ml.Experiment, criterion: torch.nn.Module | None = None, lr_scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, lr_scheduler_type: str = 'per_epoch', batch_size: int = 128)

Trainer class for managing model training, validation, and testing.

model_name
model
optimizer_name
optimizer
criterion = None
lr_scheduler = None
lr_scheduler_type = 'per_epoch'
train_dataloader
val_dataloader
experiment
batch_size = 128
train_table_data: list = []
val_table_data: list = []
history
device
global_step = 0
cache
post_train_batch() None

Post-processing after each training batch.

post_val_batch()
post_train_stage()
post_val_stage()
save_checkpoint()
log_metrics(loss: float, acc: float, mode: str, step: int | None = None, epoch: int | None = None) None

Logs metrics to Comet.ml.

Args:

loss (float): Loss value. acc (float): Accuracy value. mode (str): Mode (‘train’ or ‘val’). step (Optional[int]): Step number. Defaults to None. epoch (Optional[int]): Epoch number. Defaults to None.

train(num_epochs: int) Tuple[List[Tuple[int, float, float]], List[Tuple[int, float, float]]]

Trains the model for a specified number of epochs.

Args:

num_epochs (int): Number of epochs to train.

test(test_dataloader: torch.utils.data.DataLoader) Tuple[pandas.DataFrame, List[float], List[float]]

Evaluates the model on a test dataset.

Args:

test_dataloader (DataLoader): DataLoader containing the test dataset.

Returns:
Tuple[pd.DataFrame, List[float], List[float]]:
  • Predictions DataFrame with columns for true labels, predicted labels, probabilities, and metadata.

  • List of loss values for each batch.

  • List of accuracy values for each batch.

compute_all(batch: dict) tuple

Computes logits, loss, and accuracy for a batch.

Args:

batch (dict): Input batch containing images and labels.

Returns:

tuple: Logits, outputs, labels, loss, and accuracy.

cache_states() dict

Caches the current states of the model and optimizer.

Returns:

dict: Dictionary containing model and optimizer states.

rollback_states() None

Rolls back the model and optimizer to the cached states.

find_lr(min_lr: float = 1e-06, max_lr: float = 0.1, num_lrs: int = 20, smoothing_window: int = 30, smooth_beta: float = 0.8) float

Finds the optimal learning rate using a range test.

Args:

min_lr (float, optional): Minimum learning rate. Defaults to 1e-6. max_lr (float, optional): Maximum learning rate. Defaults to 1e-1. num_lrs (int, optional): Number of learning rates to test. Defaults to 20. smoothing_window (int, optional): Window size for smoothing. Defaults to 30. smooth_beta (float, optional): Beta value for loss smoothing. Defaults to 0.8.

Returns:

float: Optimal learning rate.

class galaxy.train.Predictor(model: torch.nn.Module, device)

Predictor class for evaluating models on test data.

model
device
predict(dataloader: torch.utils.data.DataLoader) pandas.DataFrame

Generates predictions for the given DataLoader.

Args:

dataloader (DataLoader): DataLoader containing test data.

Returns:

pd.DataFrame: Predictions with true labels and additional metadata.

compute_all(batch: dict) tuple

Computes logits and outputs for a batch.

Args:

batch (dict): Input batch.

Returns:

tuple: Logits and outputs.