Source code for mct.reconstruction

import logging
import timeit
from contextlib import contextmanager
import numpy as np
import collections

import time

import mdt
import mot
from mct.__version__ import VERSION
from mct.utils import UnzippedNiftis


__author__ = 'Robbert Harms'
__date__ = '2017-09-09'
__maintainer__ = 'Robbert Harms'
__email__ = 'robbert.harms@maastrichtuniversity.nl'
__licence__ = 'LGPL v3'


[docs]class ReconstructionMethod(object): """A reconstruction method reconstructs volume(s) from multiple channels.""" command_line_info = ''' No info defined for this method. '''
[docs] def reconstruct(self, output_directory, volumes=None): """Reconstruct the given channels and place the result in a subdirectory in the given directory. Args: output_directory (str): the location for the output files volumes (list of int): the indices of the volume we want to reconstruct (0-based). Returns: dict: the set of results from this reconstruction method """ raise NotImplementedError()
[docs]class SliceBySliceReconstructionMethod(ReconstructionMethod): def __init__(self, channels, **kwargs): """Create a basic reconstruction method initialized with the given data and settings. Args: channels (list): the list of input nifti files, one for each channel element. Every nifti file should be a 4d matrix with on the 4th dimension the time series. The length of this list equals the number of channels. slicing_axis (int): the (x,y,z) axis over which we will loop to reconstruct the volumes. 0=x, 1=y, 2=z. """ super().__init__() self._channels = channels self._logger = logging.getLogger(__name__) self._output_subdir = self.__class__.__name__ self._slicing_axis = kwargs.get('slicing_axis', 2)
[docs] def reconstruct(self, output_directory, volumes=None): output_subdir = output_directory + '/' + self._output_subdir niftis = UnzippedNiftis(self._channels, output_subdir) combined = self._reconstruct(niftis, volumes=volumes) if isinstance(combined, collections.Mapping): for name, data in combined.items(): mdt.write_nifti(data, output_subdir + '/{}.nii.gz'.format(name), niftis[0].get_header()) else: mdt.write_nifti(combined, output_subdir + '/reconstruction.nii.gz', niftis[0].get_header())
def _reconstruct(self, input_niftis, volumes=None): nifti_shape = input_niftis[0].shape slice_results = [] self._logger.info('Using MCT version {}'.format(VERSION)) self._logger.info('Using MOT version {}'.format(mot.__version__)) if volumes is None: volumes = list(range(nifti_shape[3])) self._logger.info('Reconstructing all volumes') elif volumes == 'odd': volumes = list(range(0, nifti_shape[3], 2)) self._logger.info('Reconstructing all odd volumes {}'.format(list(volumes))) elif volumes == 'even': volumes = list(range(1, nifti_shape[3], 2)) self._logger.info('Reconstructing all even volumes {}'.format(list(volumes))) else: self._logger.info('Reconstructing using the specified volumes {}'.format(list(volumes))) start_time = timeit.default_timer() logging_enabled = True for z_slice in range(nifti_shape[2]): self._logger.info(self._get_batch_start_message(start_time, z_slice, nifti_shape[2])) slice_data = self._get_slice_all_channels(input_niftis, z_slice, volumes) if logging_enabled: slice_results.append(self._reconstruct_slice(slice_data, z_slice)) logging_enabled = False else: with self._with_logging_to_debug(): slice_results.append(self._reconstruct_slice(slice_data, z_slice)) self._logger.info('Computed all slices, now assembling results.') if isinstance(slice_results[0], collections.Mapping): final_results = {} for key in slice_results[0]: final_results[key] = np.stack([el[key] for el in slice_results], axis=self._slicing_axis) return final_results return np.stack(slice_results, axis=self._slicing_axis) @contextmanager def _with_logging_to_debug(self): handlers = logging.getLogger('mot').handlers for handler in handlers: handler.setLevel(logging.WARNING) yield for handler in handlers: handler.setLevel(logging.INFO) def _get_batch_start_message(self, start_time, slice_index, total_nmr_slices): def calculate_run_days(runtime): if runtime > 24 * 60 * 60: return runtime // 24 * 60 * 60 return 0 run_time = timeit.default_timer() - start_time current_percentage = slice_index / float(total_nmr_slices) if current_percentage > 0: remaining_time = (run_time / current_percentage) - run_time else: remaining_time = None run_time_str = str(calculate_run_days(run_time)) + ':' + time.strftime('%H:%M:%S', time.gmtime(run_time)) remaining_time_str = (str(calculate_run_days(remaining_time)) + ':' + time.strftime('%H:%M:%S', time.gmtime(remaining_time))) if remaining_time else '?' return ('Processing slice {0} of {1}, we are at {2:.2%}, time spent: {3}, time left: {4} (d:h:m:s).'. format(slice_index, total_nmr_slices, current_percentage, run_time_str, remaining_time_str)) def _reconstruct_slice(self, slice_data, slice_index): """Reconstruct the given slice. Args: slice_data (ndarray): a 4d array with the first two dimensions the remaining voxel locations and third the volumes and finally the channels slice_index (int): the slice index Returns: ndarray: the reconstruction of the given volumes for the given slice. This should be an array of shape (dim1, dim2, nmr_volumes) where dim1 and dim2 should be the same dimensions as the input (typically x and y if we are looping over z) and nmr_volumes equals the number of volumes. """ raise NotImplementedError() def _get_slice_all_channels(self, input_niftis, slice_index, volumes): """Get the requested slice over each of the input niftis. Args: input_niftis (list of nifti): the list of nifti file objects slice_index (int): the slice index we want volumes (list of int): the volumes to use for the reconstruction Returns: ndarray: a 4d array with the first two dimensions the remaining voxel locations and then the timeseries and then the channels """ if self._slicing_axis == 0: slices = [nifti.dataobj[int(slice_index), :, :][..., volumes] for nifti in input_niftis] elif self._slicing_axis == 1: slices = [nifti.dataobj[:, int(slice_index), :][..., volumes] for nifti in input_niftis] else: slices = [nifti.dataobj[:, :, int(slice_index)][..., volumes] for nifti in input_niftis] return np.stack(slices, axis=-1)