Skip to main content

Model Training

Run Training

To train a new nnUnet you can use the ./scripts/train_runner.py script. This script takes as input a yaml config, and uses hydra to enable parameter overwriting and multiruns from the command line. To train the default anatomy segmentation network use anatomy_config.yaml and for the lesion segmentation model lesion_config.yaml:

pixi run -e cuda-dev python scripts/train_runner.py --config-name=anatomy_config.yaml
pixi run -e cuda-dev python scripts/train_runner.py --config-name=lesion_config.yaml

To update certain parameters you can either change them in the yaml file, or overwrite them from command line. Furthermore you can use hydra's full functionality, such as multiruns, to run your experiments. For example, to train multiple cross-validation folds and update the device you can use:

pixi run -e cuda python scripts/train_runner.py --config-name=anatomy_config.yaml -m fold=0,1 devices=[1]

!!! warning " Warning: Hydra initialises lazily" Hydra starts every run in the multirun with the code in the repository at that timepoint. This means that if you start a multirun, and then checkout to a different Git branch (or do any other changes), the second run of the multirun will be with the code from that branch. Thus, make sure to only use multirun if you don't do active development on the repo (i.e. over a weekend or night). Alternatively, you can look into Git worktrees that enable you to have different branches of the same repo in different folders.

Config management

In this repo we use Hydra in combination with structured configs. A structured config is essentially a dataclass that defines types and default values for all possible parameters. This enables strong type checking in the code base, and also provides a natural place for comments/config value explanations. More information on using hydra with structured configs can be found here, although I found those docs not very clear.

In order to make a dataclass available to hydra, the dataclass needs to be stored to a Hydra ConfigStore instance. For our BaseConfig dataclass, this can be done as follows:

from hydra.core.config_store import ConfigStore
from viseg.base_configs.base_config import BaseConfig
cs = ConfigStore.instance()
cs.store(name="base_config", node=BaseConfig)

This needs to be run before the @hydra.main() decorated main function. To now use define this dataclass to be used for our configuration, we can add the name of the config to the defaults list in the config.yaml as follows:

defaults:
- base_config

batch_size: 2
num_epochs: 1500
...

This will automatically compose a config from the BaseConfig, and overwrites all default values (in this case the batch_size and num_epochs) with the provided ones. All values can also be overwritten from the command line.

Nested Objects Some of the config values (such as the learning scheduler and model) are defined as class objects, without a default value. This means that these must be defined in the yaml or from the command line. In order to select a class in string format, there are two options:

  • YAML files: Using yaml files to define individual configurations. This can be done by defining a data class (this can be with or without default values), and creating .yaml files with the values for different configurations. For example for the learning rate scheduler, we have defined this class to be:

    @dataclass
    class LR_scheduler:
    type: str
    kwargs: dict
    monitor: str | None = None

    @dataclass
    class BaseConfig(OmegaConf):
    # these are being set in the experiment config.yaml
    lr_scheduler: LR_scheduler

    and we can now create individual yaml files that each describe a certain configuration. These need to be located at the same path as the training configs, nested inside a folder with the name of the variable (in this case lr_scheduler). For example, a configuration yaml file for the step_lr looks as follows:

    type: "StepLR"
    kwargs:
    step_size: 250
    gamma: 0.1

    We can now select this learning rate scheduler in our training config by their file name:

    defaults:
    - base_config
    - lr_scheduler: step_lr

    This option is in general easiest for small classes without any class nesting.

  • Python Classes: For more complex configurations (for example the model architecture) that require nested classes, it is most practical to implement the configurations as python classes. A simplified example for two model configs, one for 2D and one for 3D, could be defined as follows:

    # PlainConvUnetArgs is the parent class of the other two
    import PlainConvUnet2DArgs, PlainConvUnet3DArgs, PlainConvUnetArgs

    @dataclass
    class ModelConfig:
    model_cls: str
    model_args: PlainConvUnetArgs

    @dataclass
    class PlainConv3DConfig(ModelConfig):
    model_cls: 'PlainConvUNet'
    model_args: PlainConvUnet3DArgs = field(default_factory=PlainConvUnet3DArgs)

    @dataclass
    class PlainConv3DConfig(ModelConfig):
    model_cls: 'PlainConvUNet'
    model_args: PlainConvUnet2DArgs = field(default_factory=PlainConvUnet2DArgs)

    @dataclass
    class BaseConfig(OmegaConf):
    # these are being set in the experiment config.yaml
    model: ModelConfig

    In order to make these classes now available by a string name we need to also store them in Hydra's config store:

    cs.store(group="model", name="plain_conv_unet_3d", node=PlainConvUnet3DConfig)
    cs.store(group="model", name="plain_conv_unet_2d", node=PlainConvUnet2DConfig)

    Now these configurations can be selected from the yaml config or from the command line just as before:

    defaults:
    - base_config
    - model: plain_conv_unet_3d

Even though everything could be implemented as python classes (option 2) the yaml configurations have the big advantage that you do not have to change any source code to add configurations. Therefore, those are preferred in this repo, and we only use the python dataclass for the model configuration as these have class nesting which would be difficult (if not impossible) to do with yamls.

One last note is the __post_init__ function of the dataclass. This part is not being used by Hydra, as Hydra builds an OmegaConf object. However, we can convert the OmegaConf to the underlying object (our BaseConfig class), which then also runs the __post_init__ part of the BaseConfig class. This has the added advantage that we can now have stronger type checking.

@hydra.main(
version_base=None,
config_path="../config/training_configs",
config_name="anatomy_config",
)
def run_training(omega_config: OmegaConf) -> None:
# Convert the OmegaConf to our BaseConfig object
config = cast(BaseConfig, OmegaConf.to_object(omega_config))

Dataset Structure

The expected dataset structure is currently loosely based on the nnUNet structure but can be easily adapated to any other formats by writing a new CustomPathFinder class in the pathfinders.py file. This new pathfinder class can inherit from the parent PathFinder and needs to implement the _get_all_paths() and _extract_channelnames_from_folder. The current DefaultNNUnetFormatPathFinder() pathfinder expects the data to be in the following format:

basepath
├── img_foldername
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── ...
├── seg_foldername
│ ├── .nii.gz
│ ├── .nii.gz
│ ├── .nii.gz
│ ├── ...

The basepath always needs to be provided in the config yaml / cmd, whereas the img_foldername and seg_foldername are by default set to images_tr and labelsTr respectively, but can also be overwritten from the config.

A pathfinder class can be selected by the pathfinder config value. This string value maps to a class using a dictionary also located in the pathfinders.py file. Any newly implemented pathfinder will also have to be added to this dictionary.

Dataset Fingerprinting

Similarly to nnUNet we also extract information from the dataset to determine the network architecture. Some of these parameters can also be overwritten from the command line / yaml config. An overview of these parameters is shown in the table:

Variable NameDescriptionCan be overwritten from cmd/yaml
foreground_labelsthe foreground class labels that we want to predict
patch_sizethe shape of the network input images✓*
spacingthe spacing to resample the images to
batch_sizethe batch size
transpose_forwardthe transpose operation applied to image volumes (to get largest spacing first)✓ (updated together with transpose_backward)
transpose_backwardthe inverse transpose✓ (updated together with transpose_forward)
num_channelsthe number of input channels
num_classesthe number of classes (= len(foreground_labels)+1)

The patch_size can be overwritten from command line, but this is treated as an initial patch size, and then used do determine the closest valid patch size. Based on the patch_size and the spacing the network parameters are then also extracted. None of these can be overwritten from command line because they are all interdependent. These are in the config all prepended by model.model_args (i.e. model.model_args.n_stages) the following:

Variable NameDescriptionCan be overwritten from cmd/yaml
model.model_args.n_stagesthe number of stages/blocks in the UNet architecture
model.model_args.n_conv_per_stagenumber of convolutions per block in the encoder
model.model_args.n_conv_per_stage_decodernumber of convolutions per block in the decoder
model.model_args.features_per_stagenumber of features (in the conv layers) per block
model.model_args.stridesstrides of the max pool blocks in each block
model.model_args.kernel_sizesconvolutional kernel sizes in each of the blocks

Continue Training

In order to continue training a model run that has been aborted for some reason, the train_runner script has been set up so that it can as input also take a config containing a continue_task_id. If this key is in the config, it will retrieve this ClearML task, and continue training from the latest checkpoint. Furthermore, any other keys in the config can be used to overwrite parameters. For example, we could change the device the model is trained on, or increase the number of epochs.

??? warning " Warning: Increasing the number of epochs" For some learning rate schedulers that determine their learning rate reduction based on the total number of epochs, increasing the number of epochs to continue training will result in different results than having a higher number of epochs from the start as the learning rate scheduler will not be reinitialised. For example, for the polynomial scheduler, the learning rate is 0 at the last original epoch, and will stay for any additional epochs resulting in no more learning.

In order to set-up continued training, we have to make sure to save the full config to the task. The default logged OmegaConfig by ClearML is not sufficient, as this does not contain any of the dataset-dependent parameters. Therefore, after updating the config with the dataset parameters, we save it to the task as follows (see train_runner.py for the full code):

config = update_config_from_dataset(config, datadict)
task.set_configuration_object(
config_text=OmegaConf.to_yaml(config),
name="full_config",
config_type="OmegaConf YAML",
)

This enables us to load everything from a ClearML task to continue training in the same way. We then use the following code to continue training:


def run_training(omega_config: OmegaConf) -> None:
if "continue_task_id" in omega_config:
task, logger = get_task_and_logger(
add_datetime=True,
reuse_last_task_id=omega_config.continue_task_id,
)
params = task.get_configuration_object("full_config")
previous_run_cfg = OmegaConf.create(params)

# update the config with the new parameters
# this will only have an effect for any callback or trainer parameters
# not for lightning module parameters as these are instantiated from ckpt
del omega_config.continue_task_id
previous_run_cfg.update(omega_config)

# Create a base class instance from the previous run config
schema = OmegaConf.structured(BaseConfig)
new_cfg = OmegaConf.merge(schema, previous_run_cfg)
config = cast(BaseConfig, OmegaConf.to_object(new_cfg))

# Extract the last checkpoint
last_ckpt_id = task.output_models_id["last"]
ckpt_path = InputModel(last_ckpt_id).get_local_copy()

# No sanity val steps as these are logged which leads to peaks in our loss curves
num_sanity_val_steps = 0

A config yaml to continue training can look as simple as:

continue_task_id: "afd42b6c212b4fc4806f63dcec49e071"
devices: [1]