galaxy.train ============ .. py:module:: galaxy.train Classes ------- .. autoapisummary:: galaxy.train.Trainer galaxy.train.Predictor Module Contents --------------- .. py:class:: Trainer(model_name: str, model: torch.nn.Module, optimizer_name: str, optimizer: torch.optim.Optimizer, train_dataloader: torch.utils.data.DataLoader, val_dataloader: torch.utils.data.DataLoader, experiment: comet_ml.Experiment, criterion: Optional[torch.nn.Module] = None, lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, lr_scheduler_type: str = 'per_epoch', batch_size: int = 128) Trainer class for managing model training, validation, and testing. .. py:attribute:: model_name .. py:attribute:: model .. py:attribute:: optimizer_name .. py:attribute:: optimizer .. py:attribute:: criterion .. py:attribute:: lr_scheduler .. py:attribute:: lr_scheduler_type .. py:attribute:: train_dataloader .. py:attribute:: val_dataloader .. py:attribute:: experiment .. py:attribute:: batch_size .. py:attribute:: train_table_data :type: list :value: [] .. py:attribute:: val_table_data :type: list :value: [] .. py:attribute:: history .. py:attribute:: device .. py:attribute:: global_step :value: 0 .. py:attribute:: cache .. py:method:: post_train_batch() -> None Post-processing after each training batch. .. py:method:: post_val_batch() .. py:method:: post_train_stage() .. py:method:: post_val_stage() .. py:method:: save_checkpoint() .. py:method:: log_metrics(loss: float, acc: float, mode: str, step: Optional[int] = None, epoch: Optional[int] = None) -> None Logs metrics to Comet.ml. Args: loss (float): Loss value. acc (float): Accuracy value. mode (str): Mode ('train' or 'val'). step (Optional[int]): Step number. Defaults to None. epoch (Optional[int]): Epoch number. Defaults to None. .. py:method:: train(num_epochs: int) -> None Trains the model for a specified number of epochs. Args: num_epochs (int): Number of epochs to train. .. py:method:: test(test_dataloader: torch.utils.data.DataLoader) -> Tuple[pandas.DataFrame, List[float], List[float]] Evaluates the model on a test dataset. Args: test_dataloader (DataLoader): DataLoader containing the test dataset. Returns: Tuple[pd.DataFrame, List[float], List[float]]: - Predictions DataFrame with columns for true labels, predicted labels, probabilities, and metadata. - List of loss values for each batch. - List of accuracy values for each batch. .. py:method:: compute_all(batch: dict) -> tuple Computes logits, loss, and accuracy for a batch. Args: batch (dict): Input batch containing images and labels. Returns: tuple: Logits, outputs, labels, loss, and accuracy. .. py:method:: cache_states() -> dict Caches the current states of the model and optimizer. Returns: dict: Dictionary containing model and optimizer states. .. py:method:: rollback_states() -> None Rolls back the model and optimizer to the cached states. .. py:method:: find_lr(min_lr: float = 1e-06, max_lr: float = 0.1, num_lrs: int = 20, smoothing_window: int = 30, smooth_beta: float = 0.8) -> float Finds the optimal learning rate using a range test. Args: min_lr (float, optional): Minimum learning rate. Defaults to 1e-6. max_lr (float, optional): Maximum learning rate. Defaults to 1e-1. num_lrs (int, optional): Number of learning rates to test. Defaults to 20. smoothing_window (int, optional): Window size for smoothing. Defaults to 30. smooth_beta (float, optional): Beta value for loss smoothing. Defaults to 0.8. Returns: float: Optimal learning rate. .. py:class:: Predictor(model: torch.nn.Module, device) Predictor class for evaluating models on test data. .. py:attribute:: model .. py:attribute:: device .. py:method:: predict(dataloader: torch.utils.data.DataLoader) -> pandas.DataFrame Generates predictions for the given DataLoader. Args: dataloader (DataLoader): DataLoader containing test data. Returns: pd.DataFrame: Predictions with true labels and additional metadata. .. py:method:: compute_all(batch: dict) -> tuple Computes logits and outputs for a batch. Args: batch (dict): Input batch. Returns: tuple: Logits and outputs.