#!/usr/bin/env python3
'''
This script generates features from
a NIfTI file.
'''

import nibabel, argparse, tabulate, matplotlib, sys
import numpy as np
import matplotlib.pyplot as plt
from skimage import measure
from mpl_toolkits.mplot3d import Axes3D
from pylab import get_cmap

def create_features(name, precontrast, arterial, liver_mask_file, tumor_mask_file, format_output, header_output, seg, slice_index):
    pre = nibabel.load(precontrast)
    art = nibabel.load(arterial)

    # Canoncalize the orientation.
    pre = nibabel.as_closest_canonical(pre)
    art = nibabel.as_closest_canonical(art)

    pre_pixdim = pre.header['pixdim']
    art_pixdim = art.header['pixdim']

    pre_units = pre.header['xyzt_units']
    art_units = art.header['xyzt_units']

    #for l in art.header:
    #    print('%s\t%s' % (l, art.header[l]))
    #print(nibabel.aff2axcodes(art.affine))

    pre_dimx, pre_dimy, pre_dimz = pre_pixdim[1:4]
    art_dimx, art_dimy, art_dimz = art_pixdim[1:4]

    pre_data = np.asarray(pre.dataobj)
    art_data = np.asarray(art.dataobj)

    # We need to work with int32 data.
    pre_data = pre_data.astype(dtype='float64')
    art_data = art_data.astype(dtype='float64')

    pre_data = pre_data[::-1,:,:]
    art_data = art_data[::-1,:,:]

    diff = art_data - pre_data
    diff[diff < 0] = 0.0 # The pre-contrast should never be more than the arterial.
    max_diff = np.amax(diff)
    diff =  (255 * (diff / max_diff)).astype(dtype='uint8') # Scale to 0-255 integer range.

    with open(liver_mask_file, 'rb') as f:
        liver = f.read()
        liver = np.fromstring(liver, dtype='uint8')
        liver = np.reshape(liver, diff.shape, order='F')
        liver = liver[:,::-1,:]
        liver[liver > 0] = 1

    with open(tumor_mask_file, 'rb') as f:
        tumor = f.read()
        tumor = np.fromstring(tumor, dtype='uint8')
        tumor = np.reshape(tumor, diff.shape, order='F')
        tumor = tumor[:,::-1,:]
        tumor[tumor > 0] = 1

    x, y, z = diff.shape

    diff_without_tumor = np.copy(diff)
    diff_without_tumor[tumor > 0] = 0

    just_tumor = np.copy(diff)
    just_tumor[tumor <= 0] = 0

    just_liver = np.copy(diff)
    just_liver[liver <= 0] = 0

    cnorm = matplotlib.colors.Normalize(vmin=0, vmax=np.amax(diff))
    def draw_fig(image_array, kind):
        
        image_array = image_array[:,:,slice_index]
        aspect = float(image_array.shape[1]) / image_array.shape[0]
        w = 20
        h = int(aspect * w)

        image_array = np.rot90(image_array)

        fig = plt.figure(frameon=False)
        fig.set_size_inches(w,h)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(image_array, interpolation='bilinear', norm=cnorm, cmap=plt.cm.gray, aspect='auto')

        segment_image_filename = name + '_' + str(slice_index) + '_' + kind + '.png'

        plt.savefig(segment_image_filename)
        print('Segmentation saved as %s' % segment_image_filename)
        sys.exit(0)

    if seg == 'whole':
        draw_fig(diff, 'whole')

    if seg == 'liver':
        draw_fig(just_liver, 'liver')

    if seg == 'tumor':
        draw_fig(just_tumor, 'tumor')

    # Features.
    liver_volume = np.sum(liver) * pre_dimx * pre_dimy * pre_dimz
    is_in_mm = pre_units == 2 or liver_volume > 10000
    if is_in_mm: # 2 is the code for millimeter, we convert to cubic centimeters.
        liver_volume /= 1000

    tumor_volume = np.sum(tumor) * pre_dimx * pre_dimy * pre_dimz

    if is_in_mm: # 2 is the code for millimeter, we convert to cubic centimeters.
        tumor_volume /= 1000

    mean_tumor_intensity = np.mean(diff[tumor > 0])
    mean_liver_intensity = np.mean(diff[liver > 0])

    std_tumor_intensity = np.std(diff[tumor > 0])
    std_liver_intensity = np.std(diff[liver > 0])

    enhancing_tumor_volume = np.sum(tumor[just_tumor > mean_liver_intensity]) * pre_dimx * pre_dimy * pre_dimz
    if is_in_mm: # 2 is the code for millimeter, we convert to cubic centimeters.
        enhancing_tumor_volume /= 1000

    tumor_contour_counts = []
    max_tumor_contour_diameter = 0
    for z in range(diff.shape[2]):
        if np.sum(tumor[:,:,z]) < 1:
            continue

        contours = measure.find_contours(tumor[:,:,z], 0.8)
        if len(contours) > 0:
            pass#print('%s\t%s' % (len(contours), z))
        tumor_contour_counts.append(len(contours))
        for contour in contours:
            for i, p1 in enumerate(contour):
                for p2 in contour[i+1:]:
                    contour_diameter = np.linalg.norm(p1-p2)
                    if contour_diameter > max_tumor_contour_diameter:
                        max_tumor_contour_diameter = contour_diameter

    
    mean_tumor_contour_count = np.median(tumor_contour_counts)
    
    headers = ['Name', 'Liver volume', 'Mean liver intensity', 'STD liver intensity', 'Tumor volume', 'Mean tumor intensity', 'STD tumor intensity', 'Enhancing tumor volume', 'Median tumor contour count', 'Max tumor contour diameter']
    features = [name, liver_volume, mean_liver_intensity, std_liver_intensity, tumor_volume, mean_tumor_intensity, std_tumor_intensity, enhancing_tumor_volume, mean_tumor_contour_count, max_tumor_contour_diameter]
    if format_output == 'tabulate':
        if header_output:
            results = tabulate.tabulate([features], headers=headers)
        else:
            results = tabulate.tabulate([features])
    else:
        features_row = '\t'.join(str(feature) for feature in features)
        if header_output:
            results = '\t'.join(headers) + '\n' + features_row
        else:
            results = features_row

    print(results)

if __name__ == "__main__":
    argument_parser = argparse.ArgumentParser()
    argument_parser.add_argument('-n', '--name', type=str, required=True, help='The patient name to prepend to the features row.')
    argument_parser.add_argument('-d', '--header_output', action='store_true', required=False, help='Whether or not to output the header row.')
    argument_parser.add_argument('-p', '--precontrast', type=str, required=True, help='The input precontrast image in Nifti (.nii or .nii.gz) format.')
    argument_parser.add_argument('-a', '--arterial', type=str, required=True, help='The input arterial_phase image in Nifti (.nii or .nii.gz) format.')
    argument_parser.add_argument('-l', '--liver_mask_file', type=str, required=True, help='The liver mask file.')
    argument_parser.add_argument('-t', '--tumor_mask_file', type=str, required=True, help='The tumor mask file.')
    argument_parser.add_argument('-f', '--format_output', type=str, choices=['tab', 'tabulate'], required=True, help='The format of the output.')
    argument_parser.add_argument('-seg', '--seg', type=str, choices=['whole', 'liver', 'tumor'], required=False, help='Do not generate features. Segment liver or tumor instead.')
    argument_parser.add_argument('-slice', '--slice', type=int, required=False, help='The slice to segment')
    arguments = argument_parser.parse_args()

    if arguments.seg and not arguments.slice:
        print('If the -seg/--seg flag is given, the -slice/--slice parameter must be provided.')
        sys.exit(0)
    
    create_features(arguments.name, arguments.precontrast, arguments.arterial, arguments.liver_mask_file, arguments.tumor_mask_file, arguments.format_output, arguments.header_output, arguments.seg, arguments.slice)


