Source code for dosma.utils.img_utils

import itertools

import numpy as np
import seaborn as sns

from dosma import defaults

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

__all__ = ["downsample_slice", "write_regions"]


[docs]def downsample_slice(img_array, ds_factor, is_mask=False): """ Takes in a 3D array and then downsamples in the z-direction by a user-specified downsampling factor. Args: img_array (np.ndarray): 3D numpy array for now (xres x yres x zres) ds_factor (int): Downsampling factor is_mask (:obj:`bool`, optional): If ``True``, ``img_array`` is a mask and will be binarized after downsampling. Defaults to `False`. Returns: np.ndarray: 3D numpy array of dimensions (xres x yres x zres//ds_factor) Examples: >>> input_image = numpy.random.rand(4,4,4) >>> input_mask = (a > 0.5) * 1.0 >>> output_image = downsample_slice(input_mask, ds_factor = 2, is_mask = False) >>> output_mask = downsample_slice(input_mask, ds_factor = 2, is_mask = True) """ img_array = np.transpose(img_array, (2, 0, 1)) L = list(img_array) def grouper(iterable, n): args = [iter(iterable)] * n return itertools.zip_longest(fillvalue=0, *args) final = np.array([sum(x) for x in grouper(L, ds_factor)]) final = np.transpose(final, (1, 2, 0)) # Binarize if it is a mask. if is_mask is True: final = (final >= 1) * 1 return final
[docs]def write_regions(file_path, arr, plt_dict=None): """Write 2D array to region image where colors correspond to the region. All finite values should be >= 1. nan/inf value are ignored - written as white. Args: file_path (str): File path to save image. arr (np.ndarray): The 2D numpy array to convert to region image. Unique non-zero values correspond to different regions. Values that are `0` or `np.nan` will be written as white pixels. plt_dict (:obj:`dict`, optional): Dictionary of values to use when plotting with ``matplotlib.pyplot``. Keys are strings like `xlabel`, `ylabel`, etc. Use Key `labels` to specify a mapping from unique non-zero values in the array to names for the legend. """ if len(arr.shape) != 2: raise ValueError("`arr` must be a 2D numpy array") unique_vals = np.unique(arr.flatten()) if 0 in unique_vals: raise ValueError("All finite values in `arr` must be >=1") unique_vals = unique_vals[np.isfinite(unique_vals)] num_unique_vals = len(unique_vals) plt_dict_int = {"xlabel": "", "ylabel": "", "title": "", "labels": None} if plt_dict: plt_dict_int.update(plt_dict) plt_dict = plt_dict_int labels = plt_dict["labels"] if labels is None: labels = list(unique_vals) if len(labels) != num_unique_vals: raise ValueError( "len(labels) != num_unique_vals - %d != %d" % (len(labels), num_unique_vals) ) cpal = sns.color_palette("pastel", num_unique_vals) arr_c = np.array(arr) arr_c = np.nan_to_num(arr_c) arr_c[arr_c > np.max(unique_vals)] = 0 arr_rgb = np.ones([arr_c.shape[0], arr_c.shape[1], 3]) plt.figure() plt.clf() custom_lines = [] for i in range(num_unique_vals): unique_val = unique_vals[i] i0, i1 = np.where(arr_c == unique_val) arr_rgb[i0, i1, ...] = np.asarray(cpal[i]) custom_lines.append( Line2D([], [], color=cpal[i], marker="o", linestyle="None", markersize=5) ) plt.xlabel(plt_dict["xlabel"]) plt.ylabel(plt_dict["ylabel"]) plt.title(plt_dict["title"]) lgd = plt.legend( custom_lines, labels, loc="upper center", bbox_to_anchor=(0.5, -defaults.DEFAULT_TEXT_SPACING), fancybox=True, shadow=True, ncol=3, ) plt.imshow(arr_rgb) plt.savefig(file_path, bbox_extra_artists=(lgd,), bbox_inches="tight")