Source code for dosma.models.oaiunet2d

"""
@author: Arjun Desai
        (C) Stanford University, 2019
"""

import os
from copy import deepcopy

import numpy as np

try:
    from keras.layers import BatchNormalization as BN
    from keras.layers import Concatenate, Conv2D, Conv2DTranspose, Dropout, Input, MaxPooling2D
    from keras.models import Model

    _SUPPORTS_KERAS = True
except ImportError:  # pragma: no-cover
    _SUPPORTS_KERAS = False  # pragma: no-cover

from dosma.core.med_volume import MedicalVolume
from dosma.core.orientation import SAGITTAL
from dosma.models.seg_model import KerasSegModel, whiten_volume

__all__ = ["OAIUnet2D", "IWOAIOAIUnet2D", "IWOAIOAIUnet2DNormalized"]


[docs]class OAIUnet2D(KerasSegModel): """Model trained in Chaudhari et al. IWOAI 2018 Original Github: https://github.com/akshaysc/msk_segmentation """ ALIASES = ["oai-unet2d", "oai_unet2d"] sigmoid_threshold = 0.5 def __load_keras_model__(self, input_shape, force_weights=False): """Generate Unet 2D model Args: input_shape: tuple of input size - format: (height, width, 1) Returns: A Keras model Raises: ValueError: If ``input_size`` is not tuple or dimensions or ``input_size`` does not match (height, width, 1) """ if not _SUPPORTS_KERAS: raise ImportError( "`oaiunet2d` segmentation models depend on tensorflow/keras backends. " "Install them with `pip install tensorflow; pip install keras`" ) if type(input_shape) is not tuple or len(input_shape) != 3 or input_shape[2] != 1: raise ValueError("input_size must be a tuple of size (height, width, 1)") nfeatures = [2 ** feat * 32 for feat in np.arange(6)] depth = len(nfeatures) conv_ptr = [] # input layer inputs = Input(input_shape) # step down convolutional layers pool = inputs for depth_cnt in range(depth): conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(pool) conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(conv) conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv) conv = Dropout(rate=0.0)(conv) conv_ptr.append(conv) # Only maxpool till penultimate depth if depth_cnt < depth - 1: # If size of input is odd, only do a 3x3 max pool xres = conv.shape.as_list()[1] if xres % 2 == 0: pooling_size = (2, 2) elif xres % 2 == 1: pooling_size = (3, 3) pool = MaxPooling2D(pool_size=pooling_size)(conv) # step up convolutional layers for depth_cnt in range(depth - 2, -1, -1): deconv_shape = conv_ptr[depth_cnt].shape.as_list() deconv_shape[0] = None # If size of input is odd, then do a 3x3 deconv if deconv_shape[1] % 2 == 0: unpooling_size = (2, 2) elif deconv_shape[1] % 2 == 1: unpooling_size = (3, 3) up = Concatenate(axis=3)( [ Conv2DTranspose( nfeatures[depth_cnt], (3, 3), padding="same", strides=unpooling_size )(conv), conv_ptr[depth_cnt], ] ) conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(up) conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(conv) conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv) conv = Dropout(rate=0.00)(conv) # combine features recon = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(conv) model = Model(inputs=[inputs], outputs=[recon]) return model def generate_mask(self, volume: MedicalVolume): vol_copy = deepcopy(volume) # reorient to the sagittal plane vol_copy.reformat(SAGITTAL, inplace=True) vol = vol_copy.volume vol = self.__preprocess_volume__(vol) # reshape volumes to be (slice, x, y, 1) v = np.transpose(vol, (2, 0, 1)) v = np.expand_dims(v, axis=-1) model = self.seg_model mask = model.predict(v, batch_size=self.batch_size, verbose=1) mask = (mask > self.sigmoid_threshold).astype(np.uint8) # reshape mask to be (x, y, slice) mask = np.transpose(np.squeeze(mask, axis=-1), (1, 2, 0)) vol_copy.volume = mask # reorient to match with original volume vol_copy.reformat(volume.orientation, inplace=True) return vol_copy def __preprocess_volume__(self, volume: np.ndarray): # TODO: Remove epsilon if difference in performance difference is not large. return whiten_volume(volume, eps=1e-8)
[docs]class IWOAIOAIUnet2D(OAIUnet2D): """ Model trained by Team 6 in the 2019 IWOAI Segmentation Challenge. References: Desai, et al., "The International Workshop on Osteoarthritis Imaging Knee MRI Segmentation Challenge: A Multi-Institute Evaluation and Analysis Framework on a Standardized Dataset." arXiv preprint arXiv:2004.14003 (2020). `[link] <https://arxiv.org/abs/2004.14003>`_ """ ALIASES = ["iwoai-2019-t6"] _WEIGHTS_FILE = "iwoai-2019-unet2d_fc-tc-pc-men_weights.h5"
[docs] def __init__(self, input_shape, weights_path, force_weights=False): if not force_weights and os.path.basename(weights_path) != self._WEIGHTS_FILE: raise ValueError(f"Weights {weights_path} not supported for {type(self)}") super().__init__(input_shape, weights_path)
def __load_keras_model__(self, input_shape): if type(input_shape) is not tuple or len(input_shape) != 3 or input_shape[2] != 1: raise ValueError("input_size must be a tuple of size (height, width, 1)") nfeatures = [2 ** feat * 32 for feat in np.arange(6)] depth = len(nfeatures) conv_ptr = [] # input layer inputs = Input(input_shape) # step down convolutional layers pool = inputs for depth_cnt in range(depth): conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(pool) conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(conv) conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv) conv = Dropout(rate=0.0)(conv) conv_ptr.append(conv) # Only maxpool till penultimate depth if depth_cnt < depth - 1: # If size of input is odd, only do a 3x3 max pool xres = conv.shape.as_list()[1] if xres % 2 == 0: pooling_size = (2, 2) elif xres % 2 == 1: pooling_size = (3, 3) pool = MaxPooling2D(pool_size=pooling_size)(conv) # step up convolutional layers for depth_cnt in range(depth - 2, -1, -1): deconv_shape = conv_ptr[depth_cnt].shape.as_list() deconv_shape[0] = None # If size of input is odd, then do a 3x3 deconv if deconv_shape[1] % 2 == 0: unpooling_size = (2, 2) elif deconv_shape[1] % 2 == 1: unpooling_size = (3, 3) up = Concatenate(axis=3)( [ Conv2DTranspose( nfeatures[depth_cnt], (3, 3), padding="same", strides=unpooling_size )(conv), conv_ptr[depth_cnt], ] ) conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(up) conv = Conv2D( nfeatures[depth_cnt], (3, 3), padding="same", activation="relu", kernel_initializer="he_normal", )(conv) conv = BN(axis=-1, momentum=0.95, epsilon=0.001)(conv) conv = Dropout(rate=0.00)(conv) # combine features recon = Conv2D(4, (1, 1), padding="same", activation="sigmoid")(conv) model = Model(inputs=[inputs], outputs=[recon]) return model def generate_mask(self, volume: MedicalVolume): vol_copy = deepcopy(volume) # reorient to the sagittal plane vol_copy.reformat(SAGITTAL, inplace=True) vol = vol_copy.volume vol = self.__preprocess_volume__(vol) # reshape volumes to be (slice, x, y, 1) v = np.transpose(vol, (2, 0, 1)) v = np.expand_dims(v, axis=-1) model = self.seg_model mask = model.predict(v, batch_size=self.batch_size, verbose=1) mask = (mask > self.sigmoid_threshold).astype(np.uint8) # reshape mask to be (x, y, slice, classes) mask = np.transpose(mask, (1, 2, 0, 3)) vols = {} for i, category in enumerate(["fc", "tc", "pc", "men"]): vol_cp = deepcopy(vol_copy) vol_cp.volume = mask[..., i] # reorient to match with original volume vol_cp.reformat(volume.orientation, inplace=True) vols[category] = vol_cp return vols def __preprocess_volume__(self, volume: np.ndarray): return volume
[docs]class IWOAIOAIUnet2DNormalized(IWOAIOAIUnet2D): """ Extension of model trained by Team 6 in the 2019 IWOAI Segmentation Challenge (with normalization). This model uses the same architecture as :class:`IWOAIOAIUnet2D`, but pre-processes the input data by zero-mean, unit-std normalization. References: Desai, et al., "The International Workshop on Osteoarthritis Imaging Knee MRI Segmentation Challenge: A Multi-Institute Evaluation and Analysis Framework on a Standardized Dataset." arXiv preprint arXiv:2004.14003 (2020). """ ALIASES = ("iwoai-2019-t6-normalized",) _WEIGHTS_FILE = "iwoai-2019-unet2d-normalized_fc-tc-pc-men_weights.h5" def __preprocess_volume__(self, volume: np.ndarray): return whiten_volume(volume)