Model Training
Run Training
To train a new nnUnet you can use the ./scripts/train_runner.py script. This script takes as input a yaml config, and uses hydra to enable parameter overwriting and multiruns from the command line. To train the default anatomy segmentation network use anatomy_config.yaml and for the lesion segmentation model lesion_config.yaml:
pixi run -e cuda-dev python scripts/train_runner.py --config-name=anatomy_config.yaml
pixi run -e cuda-dev python scripts/train_runner.py --config-name=lesion_config.yaml
To update certain parameters you can either change them in the yaml file, or overwrite them from command line. Furthermore you can use hydra's full functionality, such as multiruns, to run your experiments. For example, to train multiple cross-validation folds and update the device you can use:
pixi run -e cuda python scripts/train_runner.py --config-name=anatomy_config.yaml -m fold=0,1 devices=[1]
!!! warning " Warning: Hydra initialises lazily" Hydra starts every run in the multirun with the code in the repository at that timepoint. This means that if you start a multirun, and then checkout to a different Git branch (or do any other changes), the second run of the multirun will be with the code from that branch. Thus, make sure to only use multirun if you don't do active development on the repo (i.e. over a weekend or night). Alternatively, you can look into Git worktrees that enable you to have different branches of the same repo in different folders.
Config management
In this repo we use Hydra in combination with structured configs. A structured config is essentially a dataclass that defines types and default values for all possible parameters. This enables strong type checking in the code base, and also provides a natural place for comments/config value explanations. More information on using hydra with structured configs can be found here, although I found those docs not very clear.
In order to make a dataclass available to hydra, the dataclass needs to be stored to a Hydra ConfigStore instance. For our BaseConfig dataclass, this can be done as follows:
from hydra.core.config_store import ConfigStore
from viseg.base_configs.base_config import BaseConfig
cs = ConfigStore.instance()
cs.store(name="base_config", node=BaseConfig)
This needs to be run before the @hydra.main() decorated main function. To now use define this dataclass to be used for our configuration, we can add the name of the config to the defaults list in the config.yaml as follows:
defaults:
- base_config
batch_size: 2
num_epochs: 1500
...
This will automatically compose a config from the BaseConfig, and overwrites all default values (in this case the batch_size and num_epochs) with the provided ones. All values can also be overwritten from the command line.
Nested Objects Some of the config values (such as the learning scheduler and model) are defined as class objects, without a default value. This means that these must be defined in the yaml or from the command line. In order to select a class in string format, there are two options:
-
YAML files: Using yaml files to define individual configurations. This can be done by defining a data class (this can be with or without default values), and creating
.yamlfiles with the values for different configurations. For example for the learning rate scheduler, we have defined this class to be:@dataclass
class LR_scheduler:
type: str
kwargs: dict
monitor: str | None = None
@dataclass
class BaseConfig(OmegaConf):
# these are being set in the experiment config.yaml
lr_scheduler: LR_schedulerand we can now create individual yaml files that each describe a certain configuration. These need to be located at the same path as the training configs, nested inside a folder with the name of the variable (in this case
lr_scheduler). For example, a configuration yaml file for the step_lr looks as follows:type: "StepLR"
kwargs:
step_size: 250
gamma: 0.1We can now select this learning rate scheduler in our training config by their file name:
defaults:
- base_config
- lr_scheduler: step_lrThis option is in general easiest for small classes without any class nesting.
-
Python Classes: For more complex configurations (for example the model architecture) that require nested classes, it is most practical to implement the configurations as python classes. A simplified example for two model configs, one for 2D and one for 3D, could be defined as follows:
# PlainConvUnetArgs is the parent class of the other two
import PlainConvUnet2DArgs, PlainConvUnet3DArgs, PlainConvUnetArgs
@dataclass
class ModelConfig:
model_cls: str
model_args: PlainConvUnetArgs
@dataclass
class PlainConv3DConfig(ModelConfig):
model_cls: 'PlainConvUNet'
model_args: PlainConvUnet3DArgs = field(default_factory=PlainConvUnet3DArgs)
@dataclass
class PlainConv3DConfig(ModelConfig):
model_cls: 'PlainConvUNet'
model_args: PlainConvUnet2DArgs = field(default_factory=PlainConvUnet2DArgs)
@dataclass
class BaseConfig(OmegaConf):
# these are being set in the experiment config.yaml
model: ModelConfigIn order to make these classes now available by a string name we need to also store them in Hydra's config store:
cs.store(group="model", name="plain_conv_unet_3d", node=PlainConvUnet3DConfig)
cs.store(group="model", name="plain_conv_unet_2d", node=PlainConvUnet2DConfig)Now these configurations can be selected from the yaml config or from the command line just as before:
defaults:
- base_config
- model: plain_conv_unet_3d
Even though everything could be implemented as python classes (option 2) the yaml configurations have the big advantage that you do not have to change any source code to add configurations. Therefore, those are preferred in this repo, and we only use the python dataclass for the model configuration as these have class nesting which would be difficult (if not impossible) to do with yamls.
One last note is the __post_init__ function of the dataclass. This part is not being used by Hydra, as Hydra builds an OmegaConf object. However, we can convert the OmegaConf to the underlying object (our BaseConfig class), which then also runs the __post_init__ part of the BaseConfig class. This has the added advantage that we can now have stronger type checking.
@hydra.main(
version_base=None,
config_path="../config/training_configs",
config_name="anatomy_config",
)
def run_training(omega_config: OmegaConf) -> None:
# Convert the OmegaConf to our BaseConfig object
config = cast(BaseConfig, OmegaConf.to_object(omega_config))
Dataset Structure
The expected dataset structure is currently loosely based on the nnUNet structure but can be easily adapated to any other formats by writing a new CustomPathFinder class in the pathfinders.py file. This new pathfinder class can inherit from the parent PathFinder and needs to implement the _get_all_paths() and _extract_channelnames_from_folder. The current DefaultNNUnetFormatPathFinder() pathfinder expects the data to be in the following format:
basepath
├── img_foldername
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── _.nii.gz
│ ├── ...
├── seg_foldername
│ ├── .nii.gz
│ ├── .nii.gz
│ ├── .nii.gz
│ ├── ...
The basepath always needs to be provided in the config yaml / cmd, whereas the img_foldername and seg_foldername are by default set to images_tr and labelsTr respectively, but can also be overwritten from the config.
A pathfinder class can be selected by the pathfinder config value. This string value maps to a class using a dictionary also located in the pathfinders.py file. Any newly implemented pathfinder will also have to be added to this dictionary.
Dataset Fingerprinting
Similarly to nnUNet we also extract information from the dataset to determine the network architecture. Some of these parameters can also be overwritten from the command line / yaml config. An overview of these parameters is shown in the table:
| Variable Name | Description | Can be overwritten from cmd/yaml |
|---|---|---|
foreground_labels | the foreground class labels that we want to predict | ✓ |
patch_size | the shape of the network input images | ✓* |
spacing | the spacing to resample the images to | ✓ |
batch_size | the batch size | ✓ |
transpose_forward | the transpose operation applied to image volumes (to get largest spacing first) | ✓ (updated together with transpose_backward) |
transpose_backward | the inverse transpose | ✓ (updated together with transpose_forward) |
num_channels | the number of input channels | ✗ |
num_classes | the number of classes (= len(foreground_labels)+1) | ✗ |
The patch_size can be overwritten from command line, but this is treated as an initial patch size, and then used do determine the closest valid patch size. Based on the patch_size and the spacing the network parameters are then also extracted. None of these can be overwritten from command line because they are all interdependent. These are in the config all prepended by model.model_args (i.e. model.model_args.n_stages) the following:
| Variable Name | Description | Can be overwritten from cmd/yaml |
|---|---|---|
model.model_args.n_stages | the number of stages/blocks in the UNet architecture | ✗ |
model.model_args.n_conv_per_stage | number of convolutions per block in the encoder | ✗ |
model.model_args.n_conv_per_stage_decoder | number of convolutions per block in the decoder | ✗ |
model.model_args.features_per_stage | number of features (in the conv layers) per block | ✗ |
model.model_args.strides | strides of the max pool blocks in each block | ✗ |
model.model_args.kernel_sizes | convolutional kernel sizes in each of the blocks | ✗ |
Continue Training
In order to continue training a model run that has been aborted for some reason, the train_runner script has been set up so that it can as input also take a config containing a continue_task_id. If this key is in the config, it will retrieve this ClearML task, and continue training from the latest checkpoint. Furthermore, any other keys in the config can be used to overwrite parameters. For example, we could change the device the model is trained on, or increase the number of epochs.
??? warning " Warning: Increasing the number of epochs" For some learning rate schedulers that determine their learning rate reduction based on the total number of epochs, increasing the number of epochs to continue training will result in different results than having a higher number of epochs from the start as the learning rate scheduler will not be reinitialised. For example, for the polynomial scheduler, the learning rate is 0 at the last original epoch, and will stay for any additional epochs resulting in no more learning.
In order to set-up continued training, we have to make sure to save the full config to the task. The default logged OmegaConfig by ClearML is not sufficient, as this does not contain any of the dataset-dependent parameters. Therefore, after updating the config with the dataset parameters, we save it to the task as follows (see train_runner.py for the full code):
config = update_config_from_dataset(config, datadict)
task.set_configuration_object(
config_text=OmegaConf.to_yaml(config),
name="full_config",
config_type="OmegaConf YAML",
)
This enables us to load everything from a ClearML task to continue training in the same way. We then use the following code to continue training:
def run_training(omega_config: OmegaConf) -> None:
if "continue_task_id" in omega_config:
task, logger = get_task_and_logger(
add_datetime=True,
reuse_last_task_id=omega_config.continue_task_id,
)
params = task.get_configuration_object("full_config")
previous_run_cfg = OmegaConf.create(params)
# update the config with the new parameters
# this will only have an effect for any callback or trainer parameters
# not for lightning module parameters as these are instantiated from ckpt
del omega_config.continue_task_id
previous_run_cfg.update(omega_config)
# Create a base class instance from the previous run config
schema = OmegaConf.structured(BaseConfig)
new_cfg = OmegaConf.merge(schema, previous_run_cfg)
config = cast(BaseConfig, OmegaConf.to_object(new_cfg))
# Extract the last checkpoint
last_ckpt_id = task.output_models_id["last"]
ckpt_path = InputModel(last_ckpt_id).get_local_copy()
# No sanity val steps as these are logged which leads to peaks in our loss curves
num_sanity_val_steps = 0
A config yaml to continue training can look as simple as:
continue_task_id: "afd42b6c212b4fc4806f63dcec49e071"
devices: [1]