Skip to main content

Model Training

Run Training

To train a new affine / elastic registation network, the scripts ./scripts/run_affine_training.py and ./scripts/run_elastic_training.py are available. These scripts takes as input a yaml config, and use hydra to enable parameter overwriting and multiruns from the command line. To train the default affine / elastic registration networks, use .config/training/affine_config.yaml and .config/training/elastic_config.yaml.

To update certain parameters you can either change them in the yaml file, or overwrite them from command line. Furthermore you can use hydra's full functionality, such as multiruns, to run your experiments. For example, to train multiple cross-validation folds and update the device you can use:

pixi run -e cuda python scripts/run_affube_training.py -m fold=0,1 devices=[1]

Warning "Hydra initialises lazily" Hydra starts every run in the multirun with the code in the repository at that timepoint. This means that if you start a multirun, and then checkout to a different Git branch (or do any other changes), the second run of the multirun will be with the code from that branch. Thus, make sure to only use multirun if you don't do active development on the repo (i.e. over a weekend or night).

Config management

In this repo we use Hydra in combination with structured configs. A structured config is essentially a dataclass that defines types and default values for all possible parameters. This enables strong type checking in the code base, and also provides a natural place for comments/config value explanations. More information on using hydra with structured configs can be found here, although I found those docs not very clear.

In order to make a dataclass available to hydra, the dataclass needs to be stored to a Hydra ConfigStore instance. For our BaseConfig dataclass, this can be done as follows:

from hydra.core.config_store import ConfigStore
from viseg.base_configs.base_config import BaseConfig
cs = ConfigStore.instance()
cs.store(name="base_config", node=BaseConfig)

This needs to be run before the @hydra.main() decorated main function. To now use define this dataclass to be used for our configuration, we can add the name of the config to the defaults list in the config.yaml as follows:

defaults:
- base_config

batch_size: 4
num_epochs: 100
...

This will automatically compose a config from the BaseConfig, and overwrites all default values (in this case the batch_size and num_epochs) with the provided ones. All values can also be overwritten from the command line.

Nested Objects Some of the config values (such as the learning scheduler and model) are defined as class objects, without a default value. This means that these must be defined in the yaml or from the command line. In order to select a class in string format, there are two options:

  • YAML files: Using yaml files to define individual configurations. This can be done by defining a data class (this can be with or without default values), and creating .yaml files with the values for different configurations. For example for the learning rate scheduler, we have defined this class to be:

    @dataclass
    class LR_scheduler:
    type: str
    kwargs: dict

    @dataclass
    class BaseConfig(OmegaConf):
    # these are being set in the experiment config.yaml
    lr_scheduler: LR_scheduler

    and we can now create individual yaml files that each describe a certain configuration. These need to be located at the same path as the training configs, nested inside a folder with the name of the variable (in this case lr_scheduler). For example, a configuration yaml file for the step_lr looks as follows:

    type: "StepLR"
    kwargs:
    step_size: 250
    gamma: 0.1

    We can now select this learning rate scheduler in our training config by their file name:

    defaults:
    - base_config
    - lr_scheduler: step_lr

    This option is in general easiest for small classes without any class nesting.

  • Python Classes: For more complex configurations (for example the model architecture) that require nested classes, it is most practical to implement the configurations as python classes. A simplified example for two model configs, one for 2D and one for 3D, could be defined as follows:

    # PlainConvUnetArgs is the parent class of the other two
    import PlainConvUnet2DArgs, PlainConvUnet3DArgs, PlainConvUnetArgs

    @dataclass
    class ModelConfig:
    model_cls: str
    model_args: RegModelArgs


    @dataclass
    class AffineParamsConfig(ModelConfig):
    model_cls: str = "RegistrationModel"
    model_args: AffineParamsConfigArgs = field(default_factory=AffineParamsConfigArgs)


    @dataclass
    class AffineMatrixConfig(ModelConfig):
    model_cls: str = "RegistrationModel"
    model_args: AffineMatrixConfigArgs = field(default_factory=AffineMatrixConfigArgs)


    @dataclass
    class ElasticDiffeomorphicConfig(ModelConfig):
    model_cls: str = "RegistrationModel"
    model_args: ElasticDiffeomorphicConfigArgs = field(
    default_factory=ElasticDiffeomorphicConfigArgs
    )

    In order to make these classes now available by a string name we need to also store them in Hydra's config store:

    cs.store(group="model_config", name="affine_params", node=AffineParamsConfig)
    cs.store(group="model_config", name="affine_matrix", node=AffineMatrixConfig)
    cs.store(group="model_config", name="elastic_diff", node=ElasticDiffeomorphicConfig)

    Now these configurations can be selected from the yaml config or from the command line just as before:

    defaults:
    - base_config
    - model: affine_params

Even though everything could be implemented as python classes (option 2) the yaml configurations have the big advantage that you do not have to change any source code to add configurations. Therefore, those are preferred in this repo, and we only use the python dataclass for the model configuration as these have class nesting which would be difficult (if not impossible) to do with yamls.

One last note is the __post_init__ function of the dataclass. This part is not being used by Hydra, as Hydra builds an OmegaConf object. However, we can convert the OmegaConf to the underlying object (our BaseConfig class), which then also runs the __post_init__ part of the BaseConfig class. This has the added advantage that we can now have stronger type checking.

@hydra.main(
version_base=None,
config_path="../config/training_configs",
config_name="anatomy_config",
)
def run_training(omega_config: OmegaConf) -> None:
# Convert the OmegaConf to our BaseConfig object
config = cast(BaseConfig, OmegaConf.to_object(omega_config))

Dataset Structure

For training of the deep learning based affine networks, two datasets were used: the ProstateX and PRUS dataset. All T2 and DWI images were filtered out of the datasets and combined into a single ClearML dataset. Then viseg (task no. 9) was used to segment all T2 images, and these masks were also added to the dataset. Finally, the combined dataset was split into train (80%) and test (20%) based on patient ID.

The complete processing pipeline can be found in the Github clearml folder. To redo all processing, the create_datasets.sh can be run. In the ClearML datasets, you can also find a tree of the data generation process, where each intermediate dataset contains a task that details how it was generated:

drawing

Volumes are expected to be in a subfolder images, and the correspondig masks (if available) to be stored in a subfolder masks. Furthermore, mask and volume should share the same name, the first havinf "seg_" as prefix.

Training affine and elastic networks: deep learning method description

It is possible to perform affine, elastic, and joint registration usind deep learning networks standalone or as initial parameter estimators (HyReg). The deep learning networks for affine and elastic registration are briefly described below:

Affine Registration

For affine registration, two approaches are possible: prediction of affine parameters, or prediction of an affine matrix. In both cases, the type of affine transformation (rigid, rigid + isotropic scaling, rigid + anisotropic scaling, affine) can be specified.

Network details: Each network is trained using synthetically generated image pairs, in which each image pair is created from a single prostate MR image. This image is twice heavily augmented with intensity augmentations, generating two versions of the same image. This is performed on-the-fly, creating a range of different image pairs for training. Subsequently a set of parameters is sampled, and the inverse of these are used to transform one image of the pair into the 'moving' image. The network is then trained to predict the parameters / matrix going from the moving into the fixed space.

Affine parameter estimation

Currently, for the network predicting the affine parameters, a ResNet-18 is used as backbone architecture, but this can be further optimized. To pass the fixed and moving image through the network, they are either concatenated and given as 2-channel input to the network, or passed through the network individually and concatenated after feature extraction (i.e. concatenating the feature vectors before prediction). The network is trained with a L2 loss which is applied to the parameters, with the rotation in radians, scaling as deviation from 1 and translation expressed as a ratio of the half image width (PyTorch convention).

Another feature that has been implemented is cascading the image multiple times through the network, where the predicted image of the first pass is given as input on the second pass etc. This is based on link.

Once the parameters are estimated, an affine matrix is constructed which can be applied to the moving image to create the registered image.

Affine matrix estimation

For the network predictin the affine matrix directly, either a UNet or a ResNet-18 can be used as backbone. The identified features on both fixed and moving images are weighted and used to solve a least squares problem, resulting in a full affine matrix.

drawing

As mentioned before, the affine registration network can be used standalone or to predict an initial estimate to be iteratively optimized at test time. This framework is refered to as HyReg.

Elastic Registration

For elastic registration, a network based on SynthMorph is used.