Skip to main content

Training Procedure in VIFY

The training procedure in the vify repository is implemented in the script scripts/train.py. This script is responsible for training a predefined or hyperparameter tuned model using configurations specified in the config/train_config.yaml file. Below is a detailed explanation of the training process, its components, and how to use it effectively.


Table of Contents


Overview

The training script is designed to:

  1. Load configurations using Hydra and OmegaConf.
  2. Prepare the training environment, including logging and output directories.
  3. Train a model using the specified dataset and hyperparameters.
  4. Optionally perform hyperparameter tuning.
  5. Save the trained model and relevant artifacts for evaluation and prediction.

Configuration

The training script relies on a configuration file located at config/train_config.yaml. This file defines all the parameters required for training. A detailed description of each parameter can be found in this config.


Training Workflow

The training process follows these steps:

  1. Load Configuration:

    • The configuration is loaded using Hydra and converted into a Python object to allow for typing help.
  2. Prepare Output Directory:

    • A unique output directory is created based on the current timestamp, project name, and task name.
  3. Set Up Logging:

    • If use_clearml is enabled, a ClearML logger is initialized for experiment tracking.
    • Otherwise, only a standard Python logging is used.
  4. Feature Extraction:

    • Features are extracted from the dataset using the extract_features_from_studies function if extract_features is set to true otherwise the features are loaded from disk.
  5. Hyperparameter Tuning (Optional):

    • If do_hp_tuning is enabled, the script performs hyperparameter tuning using the specified method.
  6. Model Evaluation:

    • The model architecture is evaluated using the evaluate function, which evaluates the model using a repeated cross validation.
  7. Model Saving:

    • The model is trained using the train_and_save_model function, which saves the trained model and its metadata.

Logging

The script supports two logging mechanisms:

  1. ClearML:

    • If use_clearml is enabled, the script initializes a ClearML task and logger using the get_task_and_logger function.
    • This allows for detailed experiment tracking, including metrics, hyperparameters, and artifacts.
  2. Standard Logging:

    • If ClearML is not used, the script sets up standard Python logging with the specified logging level.

Output Structure

The outputs of the training process are organized as follows:

<output_base_path>/<project_name>/<task_name>/<timestamp>/
├── <model_name>.pkl # Trained model file
├── ... plots and feature files

How to Run the Training Script

To run the training script, use the following command:

python scripts/train.py

Optional Arguments

  • You can override configuration parameters directly from the command line. For example:
python scripts/train.py train_final_model=True

Code Walkthrough

Below is a detailed explanation of the key components in scripts/train.py:

1. Configuration Loading

@hydra.main(
version_base=None,
config_path="../config",
config_name="train_config",
)
def run_training(omega_config: OmegaConf) -> None:
config = cast(BaseConfig, OmegaConf.to_object(omega_config))
  • The @hydra.main decorator loads the configuration file.
  • The configuration is converted into a Python object for easier manipulation.

2. Output Directory Creation

time = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_path = os.path.join(
config.output_base_path, config.project_name, config.task_name, time
)
os.makedirs(experiment_path, exist_ok=True)
  • A unique output directory is created based on the current timestamp, project name, and task name.

3. Logging Setup

if config.use_clearml:
_, clearml_logger = get_task_and_logger(
project_name=config.project_name,
task_name=config.task_name,
add_datetime=False,
logging_level=config.logging_level,
)
else:
clearml_logger = None
logging.basicConfig(level=config.logging_level)
  • If ClearML is enabled, a ClearML logger is initialized.
  • Otherwise, standard Python logging is configured.

4. Feature Extraction

if config.extract_features:
for image_type in config.image_types:
extract_features_from_studies(
studies_path=config.studies_path,
patient_csv=config.patient_csv,
output_path=experiment_path,
segmentation_method=config.lesion_segmentation_method.id,
queries=config.data_queries.kwargs[image_type],
image_type=image_type,
feature_extraction_methods=config.feature_extraction_methods,
lesion_cutoff=config.lesion_segmentation_method.softmax_cut_off,
dilation_factor=config.lesion_segmentation_method.dilation_factor,
)
  • Features are extracted using a either mirp or simple features as specified in feature_extraction_methods.

if config.hp_tuning_method == "grid":
parameter_space: dict[str, Any] = (
PARAM_SPACES[model_type]["grid"]["general"]
| PARAM_SPACES[model_type]["grid"][image_type]
)
results = run_grid_search(
parameter_space=parameter_space, obj_func=obj_func
)
else:
parameter_space: dict[str, Any] = (
PARAM_SPACES[model_type]["advanced"]["general"]
| PARAM_SPACES[model_type]["advanced"][image_type]
)
results = run_advanced_search(
parameter_space=parameter_space,
obj_func=obj_func,
max_eval=config.max_eval,
search_algo=config.hp_tuning_method,
model_type=model_type,
)
  • Grid, bayes and random search can be run over the parameter space.
  • Parameter Search spaces are defined in src/vify/_config/parameter_search_configs.py
  • After search the model is run with optimized parameters

6. Evaluation

if config.do_performance_eval:
# Evaluate the model
evaluate(
csv=features_df,
output_path=experiment_path,
image_type=image_type, # type: ignore
feature_extraction_method=feature_extraction_method,
model=model,
model_type=model_type,
clearml_logger=clearml_logger,
features_used=features_used,
use_individual_lesions=config.model_params.use_individual_lesions,
num_cv_repeats=config.num_cv_repeats,
exclude_num_first_columns=config.exclude_num_first_columns,
num_cv_folds=config.num_cv_folds,
num_workers=config.num_workers,
plot_dict=config.plots_params.kwargs,
)
  • Evaluates the model in a repeated CV fashion

7. Model Seaving

if config.train_final_model:
train_and_save_model(
csv=features_df,
model=model,
model_type=model_type,
features_extraction_method=feature_extraction_method,
image_type=image_type,
output_path=experiment_path,
features_used=features_used,
use_individual_lesions=config.model_params.use_individual_lesions,
exclude_num_first_columns=config.exclude_num_first_columns,
)
  • Saves Model

...

This documentation provides a comprehensive guide to the training procedure in the vify repository. For further details, refer to the source code in scripts/train.py and the configuration file at config/train_config.yaml.