Source code for dosma.models.stanford_qdess

"""
Models trained on the 2021 Stanford quantitative double echo
steady state (qDESS) knee dataset.
"""


from copy import deepcopy

import numpy as np

try:
    from keras.layers import BatchNormalization as BN
    from keras.layers import Concatenate, Conv2D, Conv2DTranspose, Dropout, Input, MaxPooling2D
    from keras.models import Model

    _SUPPORTS_KERAS = True
except ImportError:  # pragma: no-cover
    _SUPPORTS_KERAS = False  # pragma: no-cover

from dosma.core.med_volume import MedicalVolume
from dosma.core.orientation import SAGITTAL
from dosma.models.seg_model import KerasSegModel, whiten_volume

__all__ = ["StanfordQDessUNet2D"]


[docs]class StanfordQDessUNet2D(KerasSegModel): """ Template for 2D U-Net models trained on the SKM-TEA dataset (previously *2021 Stanford qDESS Knee* dataset). This model segments patellar cartilage ("pc"), femoral cartilage ("fc"), tibial cartilage ("tc"), and the meniscus ("men") from quantitative double echo steady state (qDESS) knee scans. The segmentation is computed on the root-sum-of-squares (RSS) of the two echoes. There are a few weights files that are associated with this model. We provide a short description of each below: * ``qDESS_2021_v1-rms-unet2d-pc_fc_tc_men_weights.h5``: This is the baseline model trained on the SKM-TEA dataset (v1.0.0). * ``qDESS_2021_v0_0_1-rms-pc_fc_tc_men_weights.h5``: This model is trained on the RSS 2021 Stanford qDESS knee dataset (v0.0.1). * ``qDESS_2021_v0_0_1-traintest-rms-pc_fc_tc_men_weights.h5``: This model is trained on both the train and test set of the 2021 Stanford qDESS knee dataset (v0.0.1). Examples: >>> # Create model based on the volume's shape (SI, AP, 1). >>> model = StanfordQDessUNet2D((256, 256, 1), "/path/to/weights") >>> # Generate mask from root-sum-of-squares (rss) volume. >>> model.generate_mask(rss) >>> # Generate mask from dual-echo volume `de_vol` - shape: (SI, AP, LR, 2) >>> model.generate_mask(de_vol) """ ALIASES = ("stanford-qdess-2021-unet2d", "skm-tea-unet2d") sigmoid_threshold = 0.5 def __load_keras_model__(self, input_shape): if type(input_shape) is not tuple or len(input_shape) != 3 or input_shape[2] != 1: raise ValueError("input_size must be a tuple of size (height, width, 1)") nfeatures = [2 ** feat * 32 for feat in np.arange(6)] depth = len(nfeatures) conv_ptr = [] # input layer inputs = Input(input_shape) # step down convolutional layers pool = inputs for depth_cnt in range(depth): conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(pool) conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(conv) conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv) conv = Dropout(rate=0.0)(conv) conv_ptr.append(conv) # Only maxpool till penultimate depth if depth_cnt < depth - 1: # If size of input is odd, only do a 3x3 max pool xres = conv.shape.as_list()[1] if xres % 2 == 0: pooling_size = (2, 2) elif xres % 2 == 1: pooling_size = (3, 3) pool = MaxPooling2D(pool_size=pooling_size)(conv) # step up convolutional layers for depth_cnt in range(depth - 2, -1, -1): deconv_shape = conv_ptr[depth_cnt].shape.as_list() deconv_shape[0] = None # If size of input is odd, then do a 3x3 deconv if deconv_shape[1] % 2 == 0: unpooling_size = (2, 2) elif deconv_shape[1] % 2 == 1: unpooling_size = (3, 3) up = Concatenate(axis=3)( [ Conv2DTranspose( nfeatures[depth_cnt], (3, 3), padding="same", strides=unpooling_size )(conv), conv_ptr[depth_cnt], ] ) conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(up) conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(conv) conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv) conv = Dropout(rate=0.00)(conv) # combine features recon = Conv2D(4, (1, 1), padding="same", activation="sigmoid")(conv) model = Model(inputs=[inputs], outputs=[recon]) return model def generate_mask(self, volume: MedicalVolume): """Segment tissues. Args: volume (MedicalVolume): The volume to segment. Either 3D or 4D. If the volume is 3D, it is assumed to be the root-sum-of-squares (RSS) of the two qDESS echoes. If 4D, volume must be of the shape ``(..., 2)``, where the last dimension corresponds to echo 1 and 2, respectively. """ ndim = volume.ndim if ndim not in (3, 4): raise ValueError("`volume` must either be 3D or 4D") vol_copy = deepcopy(volume) if ndim == 4: vol_copy = np.sqrt(np.sum(vol_copy ** 2, axis=-1)) # reorient to the sagittal plane vol_copy.reformat(SAGITTAL, inplace=True) vol = vol_copy.volume vol = self.__preprocess_volume__(vol) # reshape volumes to be (slice, x, y, 1) v = np.transpose(vol, (2, 0, 1)) v = np.expand_dims(v, axis=-1) model = self.seg_model mask = model.predict(v, batch_size=self.batch_size, verbose=1) mask = (mask > self.sigmoid_threshold).astype(np.uint8) # reshape mask to be (x, y, slice, classes) mask = np.transpose(mask, (1, 2, 0, 3)) vols = {} for i, category in enumerate(["pc", "fc", "tc", "men"]): vol_cp = deepcopy(vol_copy) vol_cp.volume = mask[..., i] # reorient to match with original volume vol_cp.reformat(volume.orientation, inplace=True) vols[category] = vol_cp return vols def __preprocess_volume__(self, volume: np.ndarray): # TODO: Remove epsilon if difference in performance difference is not large. return whiten_volume(volume, eps=1e-8)