Training Procedure in VIFY
The training procedure in the vify repository is implemented in the script scripts/train.py. This script is responsible for training a predefined or hyperparameter tuned model using configurations specified in the config/train_config.yaml file. Below is a detailed explanation of the training process, its components, and how to use it effectively.
Table of Contents
- Training Procedure in VIFY
Overview
The training script is designed to:
- Load configurations using Hydra and OmegaConf.
- Prepare the training environment, including logging and output directories.
- Train a model using the specified dataset and hyperparameters.
- Optionally perform hyperparameter tuning.
- Save the trained model and relevant artifacts for evaluation and prediction.
Configuration
The training script relies on a configuration file located at config/train_config.yaml. This file defines all the parameters required for training. A detailed description of each parameter can be found in this config.
Training Workflow
The training process follows these steps:
-
Load Configuration:
- The configuration is loaded using Hydra and converted into a Python object to allow for typing help.
-
Prepare Output Directory:
- A unique output directory is created based on the current timestamp, project name, and task name.
-
Set Up Logging:
- If
use_clearmlis enabled, a ClearML logger is initialized for experiment tracking. - Otherwise, only a standard Python logging is used.
- If
-
Feature Extraction:
- Features are extracted from the dataset using the
extract_features_from_studiesfunction ifextract_featuresis set to true otherwise the features are loaded from disk.
- Features are extracted from the dataset using the
-
Hyperparameter Tuning (Optional):
- If
do_hp_tuningis enabled, the script performs hyperparameter tuning using the specified method.
- If
-
Model Evaluation:
- The model architecture is evaluated using the
evaluatefunction, which evaluates the model using a repeated cross validation.
- The model architecture is evaluated using the
-
Model Saving:
- The model is trained using the
train_and_save_modelfunction, which saves the trained model and its metadata.
- The model is trained using the
Logging
The script supports two logging mechanisms:
-
ClearML:
- If
use_clearmlis enabled, the script initializes a ClearML task and logger using theget_task_and_loggerfunction. - This allows for detailed experiment tracking, including metrics, hyperparameters, and artifacts.
- If
-
Standard Logging:
- If ClearML is not used, the script sets up standard Python logging with the specified logging level.
Output Structure
The outputs of the training process are organized as follows:
<output_base_path>/<project_name>/<task_name>/<timestamp>/
├── <model_name>.pkl # Trained model file
├── ... plots and feature files
How to Run the Training Script
To run the training script, use the following command:
python scripts/train.py
Optional Arguments
- You can override configuration parameters directly from the command line. For example:
python scripts/train.py train_final_model=True
Code Walkthrough
Below is a detailed explanation of the key components in scripts/train.py:
1. Configuration Loading
@hydra.main(
version_base=None,
config_path="../config",
config_name="train_config",
)
def run_training(omega_config: OmegaConf) -> None:
config = cast(BaseConfig, OmegaConf.to_object(omega_config))
- The
@hydra.maindecorator loads the configuration file. - The configuration is converted into a Python object for easier manipulation.
2. Output Directory Creation
time = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_path = os.path.join(
config.output_base_path, config.project_name, config.task_name, time
)
os.makedirs(experiment_path, exist_ok=True)
- A unique output directory is created based on the current timestamp, project name, and task name.
3. Logging Setup
if config.use_clearml:
_, clearml_logger = get_task_and_logger(
project_name=config.project_name,
task_name=config.task_name,
add_datetime=False,
logging_level=config.logging_level,
)
else:
clearml_logger = None
logging.basicConfig(level=config.logging_level)
- If ClearML is enabled, a ClearML logger is initialized.
- Otherwise, standard Python logging is configured.
4. Feature Extraction
if config.extract_features:
for image_type in config.image_types:
extract_features_from_studies(
studies_path=config.studies_path,
patient_csv=config.patient_csv,
output_path=experiment_path,
segmentation_method=config.lesion_segmentation_method.id,
queries=config.data_queries.kwargs[image_type],
image_type=image_type,
feature_extraction_methods=config.feature_extraction_methods,
lesion_cutoff=config.lesion_segmentation_method.softmax_cut_off,
dilation_factor=config.lesion_segmentation_method.dilation_factor,
)
- Features are extracted using a either mirp or simple features as specified in
feature_extraction_methods.
5. Hyperparameter Search
if config.hp_tuning_method == "grid":
parameter_space: dict[str, Any] = (
PARAM_SPACES[model_type]["grid"]["general"]
| PARAM_SPACES[model_type]["grid"][image_type]
)
results = run_grid_search(
parameter_space=parameter_space, obj_func=obj_func
)
else:
parameter_space: dict[str, Any] = (
PARAM_SPACES[model_type]["advanced"]["general"]
| PARAM_SPACES[model_type]["advanced"][image_type]
)
results = run_advanced_search(
parameter_space=parameter_space,
obj_func=obj_func,
max_eval=config.max_eval,
search_algo=config.hp_tuning_method,
model_type=model_type,
)
- Grid, bayes and random search can be run over the parameter space.
- Parameter Search spaces are defined in
src/vify/_config/parameter_search_configs.py - After search the model is run with optimized parameters
6. Evaluation
if config.do_performance_eval:
# Evaluate the model
evaluate(
csv=features_df,
output_path=experiment_path,
image_type=image_type, # type: ignore
feature_extraction_method=feature_extraction_method,
model=model,
model_type=model_type,
clearml_logger=clearml_logger,
features_used=features_used,
use_individual_lesions=config.model_params.use_individual_lesions,
num_cv_repeats=config.num_cv_repeats,
exclude_num_first_columns=config.exclude_num_first_columns,
num_cv_folds=config.num_cv_folds,
num_workers=config.num_workers,
plot_dict=config.plots_params.kwargs,
)
- Evaluates the model in a repeated CV fashion
7. Model Seaving
if config.train_final_model:
train_and_save_model(
csv=features_df,
model=model,
model_type=model_type,
features_extraction_method=feature_extraction_method,
image_type=image_type,
output_path=experiment_path,
features_used=features_used,
use_individual_lesions=config.model_params.use_individual_lesions,
exclude_num_first_columns=config.exclude_num_first_columns,
)
- Saves Model
...
This documentation provides a comprehensive guide to the training procedure in the vify repository. For further details, refer to the source code in scripts/train.py and the configuration file at config/train_config.yaml.