#!/usr/bin/python3
#
from __future__ import division, print_function
from vtkplotter import Plotter, ProgressBar, printc, humansort, __version__
import vtk
import sys, argparse, os

pr = argparse.ArgumentParser(description="version "+str(__version__)+""" -                              
                             check out home page https://github.com/marcomusy/vtkplotter""")
pr.add_argument('files', nargs='*',             help="Input filename(s)")
pr.add_argument("-c", "--color", type=str,      help="mesh color [integer or color name]", default=None, metavar='')
pr.add_argument("-a", "--alpha",    type=float, help="alpha value [0-1]", default=-1, metavar='')
pr.add_argument("-w", "--wireframe",            help="use wireframe representation", action="store_true")
pr.add_argument("-p", "--point-size", type=float, help="specify point size", default=-1, metavar='') 
pr.add_argument("-k", "--show-scalars",         help="use scalars as colors", action="store_true") 
pr.add_argument("-x", "--axes-type", type=int,  help="specify axes type [0-5]", default=4, metavar='')
pr.add_argument("-i", "--no-camera-share",      help="do not share camera in renderers", action="store_true")
pr.add_argument("-l", "--legend-off",           help="do not show legends", action="store_true")
pr.add_argument("-f", "--full-screen",          help="full screen mode", action="store_true")
pr.add_argument("-bg","--background", type=str, help="background color [integer or color name]", default='blackboard', metavar='')
pr.add_argument("-z", "--zoom", type=float,     help="zooming factor", default=1, metavar='')
pr.add_argument("-q", "--quiet",                help="quiet mode, less verbose", default=False, action="store_false")
pr.add_argument("-n", "--multirenderer-mode",   help="Multi renderer Mode: files go to separate renderers", action="store_true")
pr.add_argument("-s", "--scrolling-mode",       help="Scrolling Mode: use arrows to scroll files", action="store_true") 
pr.add_argument("-g", "--ray-cast-mode",        help="GPU Ray-casting Mode for 3D image files", action="store_true") 
pr.add_argument("-gz", "--z-spacing", type=float, help="Volume z-spacing factor [1]", default=None, metavar='') 
pr.add_argument("-gy", "--y-spacing", type=float, help="Volume y-spacing factor [1]", default=None, metavar='') 
pr.add_argument("--slicer",                     help="Slicer Mode for 3D image files", action="store_true")
args = pr.parse_args()

humansort(args.files)
nfiles = len(args.files)
if nfiles == 0:
    sys.exit()

wsize = "auto"
if args.full_screen:
    wsize = "full"

if args.ray_cast_mode or args.z_spacing or args.y_spacing:
    if args.background == "blackboard":
        args.background = "white"

if args.scrolling_mode and 3<args.axes_type<5: #types 4 and 5 are not good for scrolling
    args.axes_type = 8

N = None
if args.multirenderer_mode:
    if nfiles < 201:
        N = nfiles
    if nfiles > 200:
        printc("~lightning Warning: option '-n' allows a maximum of 200 files", c=1)
        printc("         you are trying to load ", nfiles, " files.\n", c=1)
        N = 200
    vp = Plotter(size=wsize, N=N, bg=args.background)
    if args.axes_type == 1:
        vp.axes = 0
else:
    N = nfiles
    vp = Plotter(size=wsize, bg=args.background)
    vp.axes = args.axes_type

vp.verbose = not args.quiet
vp.sharecam = not args.no_camera_share

alpha = 1
leg = True
wire = False
if args.legend_off or nfiles == 1:
    leg = False
if args.wireframe:
    wire = True
if args.scrolling_mode and args.multirenderer_mode:
    args.scrolling_mode = False

_alphaslider1, _alphaslider2 = 0.4, 0.9  # defaults

########################################################################
def _showVoxelImage():

    import vtkplotter.colors as vc
    from vtkplotter import Volume
    import numpy as np

    printc("GPU Ray-casting Mode:", c="b", invert=1)
    vp.show(interactive=0)

    for filename in args.files:
        if not os.path.isfile(filename):
            printc("~times File not found:", filename, c=1)
            exit()
        printc("...loading", filename, c="b", bold=0)

        if ".tif" in filename.lower():
            reader = vtk.vtkTIFFReader()
        elif ".slc" in filename.lower():
            reader = vtk.vtkSLCReader()
            if not reader.CanReadFile(filename):
                printc("Bad SLC file:", filename, c=1)
                continue
        elif ".vti" in filename.lower():
            reader = vtk.vtkXMLImageDataReader()
        #            reader = vtk.vtkImageReader()
        else:
            printc("~noentry Use -g option only with VTI, SLC or TIFF files.", c=1)
            continue
        reader.SetFileName(filename)
        reader.Update()
        img = reader.GetOutput()

        if args.z_spacing:
            ispa = img.GetSpacing()
            img.SetSpacing(ispa[0], ispa[1], ispa[2] * args.z_spacing)
        if args.y_spacing:
            ispa = img.GetSpacing()
            img.SetSpacing(ispa[0], ispa[1] * args.y_spacing, ispa[2])

        volumeMapper = vtk.vtkGPUVolumeRayCastMapper()
        volumeMapper.SetBlendModeToMaximumIntensity()
        volumeMapper.UseJitteringOn()
        volumeMapper.SetInputConnection(reader.GetOutputPort())
        printc("scalar range is ", img.GetScalarRange(), c="b", bold=0)
        smin, smax = img.GetScalarRange()
        if smax > 1e10:
            print("Warning, high scalar range detected:", smax)
            smax = abs(10 * smin) + 0.1
            print("         reset to:", smax)
        printc("voxel spacing is", img.GetSpacing(), c="b", bold=0)

        # Create transfer mapping scalar value to color
        if args.color is None:
            r, g, b = (0, 0, 0.6)
        else:
            r, g, b = vc.getColor(args.color)
        colorTransferFunction = vtk.vtkColorTransferFunction()
        colorTransferFunction.AddRGBPoint(smin, 1, 0, 0)
        colorTransferFunction.AddRGBPoint((smin+smax)/3.0, 0.01, .33, 0)
        colorTransferFunction.AddRGBPoint(smax, 0, 0, 1)

        x1alpha = smin + (smax - smin) * 0.33
        x2alpha = smin + (smax - smin) * 0.66
        # Create transfer mapping scalar value to opacity
        opacityTransferFunction = vtk.vtkPiecewiseFunction()
        opacityTransferFunction.AddPoint(smin, 0.0)
        opacityTransferFunction.AddPoint(smin + (smax - smin) * 0.1, 0.0)
        opacityTransferFunction.AddPoint(x1alpha, _alphaslider1)
        opacityTransferFunction.AddPoint(x2alpha, _alphaslider2)
        opacityTransferFunction.AddPoint(smax, 1.0)

        # The property describes how the data will look
        volumeProperty = vtk.vtkVolumeProperty()
        volumeProperty.SetColor(colorTransferFunction)
        volumeProperty.SetScalarOpacity(opacityTransferFunction)
        volumeProperty.SetInterpolationTypeToLinear()
        
#        volumeMapper.SetBlendModeToComposite()
#        volumeProperty.ShadeOn()
#        volumeProperty.SetAmbient(0.1)
#        volumeProperty.SetDiffuse(1)
#        volumeProperty.SetSpecular(1)
#        volumeProperty.SetSpecularPower(2.0)

        # volume holds the mapper and the property and can be used to position/orient it
        volume = Volume(img)
        volume.SetMapper(volumeMapper)
        volume.SetProperty(volumeProperty)

        vp.add(volume)

    ############################## color intensity slider
    def sliderColor(widget, event):
        value = widget.GetRepresentation().GetValue() 
        colorTransferFunction.RemoveAllPoints()
        x = (value-smin)/(smax+smin)-0.
        s = 40 # sharpness
        r = np.exp(-x*x*s)
        g = np.exp(-(x-.5)*(x-.5)*s)
        b = np.exp(-(x-1)*(x-1)*s)
        
        colorTransferFunction.AddRGBPoint(smin,  1, 0, 0) 
        colorTransferFunction.AddRGBPoint(value, r, g, b)
        colorTransferFunction.AddRGBPoint(smax,  0, 0, 1) 
        #print(x, r, g, b)
    w1 = vp.addSlider2D(sliderColor, smin, smax, value=(smin+smax)/3.0, 
                        pos=4, title='color brightness')
    
    ############################## alpha sliders    
    def sliderA1(widget, event):
        global _alphaslider1
        _alphaslider1 = widget.GetRepresentation().GetValue()
        opacityTransferFunction.RemoveAllPoints()
        opacityTransferFunction.AddPoint(smin, 0.0)
        opacityTransferFunction.AddPoint(smin + (smax - smin) * 0.1, 0.0)
        opacityTransferFunction.AddPoint(x1alpha, _alphaslider1)
        opacityTransferFunction.AddPoint(x2alpha, _alphaslider2)
        opacityTransferFunction.AddPoint(smax, 1.0)

    # _alphaslider1 /= smax
    w2 = vp.addSlider2D(
        sliderA1,
        0,
        1,
        value=_alphaslider1,
        pos=[(0.93, 0.1), (0.93, 0.44)],
        title="",
        showValue=0,
    )

    def sliderA2(widget, event):
        global _alphaslider2
        _alphaslider2 = widget.GetRepresentation().GetValue()
        opacityTransferFunction.RemoveAllPoints()
        opacityTransferFunction.AddPoint(smin, 0.0)
        opacityTransferFunction.AddPoint(smin + (smax - smin) * 0.1, 0.0)
        opacityTransferFunction.AddPoint(x1alpha, _alphaslider1)
        opacityTransferFunction.AddPoint(x2alpha, _alphaslider2)
        opacityTransferFunction.AddPoint(smax, 1.0)

    # _alphaslider2 /= smax
    w3 = vp.addSlider2D(
        sliderA2,
        0,
        1,
        value=_alphaslider2,
        pos=[(0.96, 0.1), (0.96, 0.44)],
        title="alpha thresholds",
        showValue=0,
    )

    def CheckAbort(obj, event):
        if obj.GetEventPending() != 0:
            obj.SetAbortRender(1)

    vp.window.AddObserver("AbortCheckEvent", CheckAbort)

    printc("Press r to reset camera", c="b")
    printc("      q to exit.", c="b")
    vp.show(zoom=1.2, interactive=1)
    w1.SetEnabled(0)
    w2.SetEnabled(0)
    w3.SetEnabled(0)


##########################################################
# special case of SLC/TIFF volumes with -g option
if args.ray_cast_mode or args.z_spacing or args.y_spacing:
    if args.axes_type in [1, 2, 3]:
        vp.axes = 4
    wsize = "auto"
    if args.full_screen:
        wsize = "full"
    _showVoxelImage()
    exit()

##########################################################
# special case of SLC/TIFF volumes with --slicer option
elif args.slicer:

    filename = args.files[0]
    if not os.path.isfile(filename):
        printc("~times File not found:", filename, c=1)
        exit()
    printc("...loading", filename, c="m", bold=0)

    if ".tif" in filename.lower():
        reader = vtk.vtkTIFFReader()

    elif ".slc" in filename.lower():
        reader = vtk.vtkSLCReader()
        if not reader.CanReadFile(filename):
            printc("~bomb Bad SLC file:", filename, c=1)

    elif ".vti" in filename.lower():
        reader = vtk.vtkXMLImageDataReader()

    else:
        printc("~sad Use --slicer option only with voxel data files [slc,tif,vti].", c=1)
    reader.SetFileName(filename)
    reader.Update()
    img = reader.GetOutput()

    ren1 = vtk.vtkRenderer()
    renWin = vtk.vtkRenderWindow()
    renWin.AddRenderer(ren1)
    iren = vtk.vtkRenderWindowInteractor()
    iren.SetRenderWindow(renWin)

    im = vtk.vtkImageResliceMapper()
    im.SetInputData(img)
    im.SliceFacesCameraOn()
    im.SliceAtFocalPointOn()
    im.BorderOn()

    ip = vtk.vtkImageProperty()
    ip.SetInterpolationTypeToLinear()

    ia = vtk.vtkImageSlice()
    ia.SetMapper(im)
    ia.SetProperty(ip)

    ren1.AddViewProp(ia)
    ren1.SetBackground(0.6, 0.6, 0.7)
    renWin.SetSize(900, 900)

    iren = vtk.vtkRenderWindowInteractor()
    style = vtk.vtkInteractorStyleImage()
    style.SetInteractionModeToImage3D()
    iren.SetInteractorStyle(style)
    renWin.SetInteractor(iren)

    renWin.Render()
    cam1 = ren1.GetActiveCamera()
    cam1.ParallelProjectionOn()
    ren1.ResetCameraClippingRange()
    cam1.Zoom(1.3)
    renWin.Render()

    printc("Slicer Mode:", invert=1, c="m")
    printc(
        """Press  SHIFT+Left mouse    to rotate the camera for oblique slicing
       SHIFT+Middle mouse  to slice perpendicularly through the image
       Left mouse and Drag to modify luminosity and contrast
       X                   to Reset to sagittal view
       Y                   to Reset to coronal view
       Z                   to Reset to axial view
       R                   to Reset the Window/Levels
       Q                   to Quit.""",
        c="m",
    )

    iren.Start()
    exit()


########################################################################
# normal mode for single VOXEL file with Isosurface Slider
elif nfiles == 1 and (
    ".slc" in args.files[0] or ".vti" in args.files[0] or ".tif" in args.files[0]
):
    from vtkplotter.vtkio import loadImageData
    from vtkplotter import Actor

    image = loadImageData(args.files[0])
    scrange = image.GetScalarRange()
    cf = vtk.vtkContourFilter()
    cf.SetInputData(image)
    cf.UseScalarTreeOn()
    cf.ComputeScalarsOff()
    ic = "gold"
    if args.color is not None:
        if args.color.isdigit():
            ic = int(args.color)
        else:
            ic = args.color
    if args.show_scalars:
        ic = None

    threshold = (2 * scrange[0] + scrange[1]) / 3.0
    cf.SetValue(0, threshold)
    cf.Update()
    act = Actor(cf.GetOutput(), ic, alpha=abs(args.alpha), wire=args.wireframe)
    act.phong()

    ############################## threshold slider
    def sliderThres(widget, event):
        cf.SetValue(0, widget.GetRepresentation().GetValue())
        cf.Update()
        poly = cf.GetOutput()
        a = Actor(poly, ic, alpha=act.alpha(), wire=args.wireframe)
        a.phong()
        vp.actors = []
        vp.renderer.RemoveActor(vp.getActors()[0])
        vp.renderer.AddActor(a)
        vp.renderer.Render()

    dr = scrange[1] - scrange[0]
    vp.addSlider2D(
        sliderThres,
        scrange[0] + 0.025 * dr,
        scrange[1] - 0.025 * dr,
        value=threshold,
        title="isosurface threshold",
    )

    def CheckAbort(obj, event):
        if obj.GetEventPending() != 0:
            obj.SetAbortRender(1)

    vp.window.AddObserver("AbortCheckEvent", CheckAbort)

    act.legend(leg)
    vp.show(act, zoom=args.zoom)


########################################################################
# NORMAL mode for single or multiple files, or multiren mode 
elif (not args.scrolling_mode) or nfiles == 1:

    actors = []
    for i in range(N):
        f = args.files[i]

        if ".neutral" in f.lower() or ".xml" in f.lower() or ".gmsh" in f.lower():
            alpha = 0.05
            wire = True
        else:
            wire = False

        if 0 < args.alpha <= 1:
            alpha = args.alpha

        if args.wireframe:
            wire = True

        ic = i  # default here
        if args.color is not None:
            if args.color.isdigit():
                ic = int(args.color)
            else:
                ic = args.color

        if args.show_scalars:
            ic = None

        actor = vp.load(f, c=ic, alpha=alpha, wire=wire)

        if leg:
            actor.legend(os.path.basename(f))

        actors.append(actor)

        if args.point_size > 0:
            try:
                ps = actor.GetProperty().GetPointSize()
                actor.GetProperty().SetPointSize(args.point_size)
                actor.GetProperty().SetRepresentationToPoints()
            except AttributeError:
                print("cannot set point size for", f)

        if args.multirenderer_mode:
            actor._legend = None
            vp.show(actor, at=i, interactive=False, zoom=args.zoom)
            vp.actors = actors

    if args.multirenderer_mode:
        vp.interactor.Start()
    else:
        vp.show(interactive=True, zoom=args.zoom)

########################################################################
# scrolling mode  -s
else:
    import numpy

    if 0 < args.alpha <= 1:
        alpha = args.alpha

    n = len(args.files)
    pb = ProgressBar(0, n)

    actors = []
    for i, f in enumerate(args.files):
        pb.print("..loading")
        ic = "gold"
        if args.color is not None:
            if args.color.isdigit():
                ic = int(args.color)
            else:
                ic = args.color
        if args.show_scalars:
            ic = None

        actor = vp.load(f, c=ic, alpha=alpha, wire=wire)
        actor.legend(leg)
        actors.append(actor)
        if args.point_size > 0:
            try:
                ps = actor.GetProperty().GetPointSize()
                actor.GetProperty().SetPointSize(args.point_size)
                actor.GetProperty().SetRepresentationToPoints()
            except AttributeError:
                print("cannot set point size for", f)

    # calculate max actors bounds
    bns = []
    for a in actors:
        if a and a.GetPickable():
            b = a.GetBounds()
            if b:
                bns.append(b)
    if len(bns):
        max_bns = numpy.max(bns, axis=0)
        min_bns = numpy.min(bns, axis=0)
        vbb = (min_bns[0], max_bns[1], min_bns[2], max_bns[3], min_bns[4], max_bns[5])

    def _scroll(iren, event):  # observer
        global _kact
        actors[_kact].off()

        key = iren.GetKeySym()

        n = len(actors)
        if key == "Right" and _kact < n - 1:
            _kact += 1
        elif key == "Left" and _kact > 0:
            _kact -= 1

        step = int(n / 5)
        if key == "Up":
            _kact += step
            if _kact > n - 1:
                _kact = n - 1
        elif key == "Down":
            _kact -= step
            if _kact < 0:
                _kact = 0
        sliderRep.SetValue(_kact)
        if 0 <= _kact < n:
            vp.clickedActor = actors[_kact]

        actors[_kact].on()
        fn = args.files[_kact].split("/")[-1]
        printc("showing file #", _kact, "\t", fn, "\r", c="y", bold=0, end="")

    vp.interactor.AddObserver("KeyPressEvent", _scroll)

    vp.show(actors[0], interactive=False, zoom=args.zoom)

    if isinstance(vp.axes_exist[0], vtk.vtkCubeAxesActor):
        vp.axes_exist[0].SetBounds(vbb)

    cb = (1, 1, 1)
    if numpy.sum(vp.renderer.GetBackground()) > 1.5:
        cb = (0.1, 0.1, 0.1)

    def sliderf(widget, event):
        global _kact
        actors[_kact].off()
        _kact = int(widget.GetRepresentation().GetValue())
        actors[_kact].on()
        fn = args.files[_kact].split("/")[-1]
        printc("showing file #", _kact, "\t", fn, "\r", c="y", bold=0, end="")

    wid = vp.addSlider2D(sliderf, 0, n - 1, pos=4, c=cb, showValue=False)
    wid.SetAnimationModeToJump()
    sliderRep = wid.GetRepresentation()

    _kact = 0
    for a in actors:
        a.off()
    actors[0].on()

    printc("Scrolling Mode", c="y", invert=1, end="")
    printc(", press:", c="y")
    printc("- Right and Left keys to move through files", c="y")
    printc("- Up and Down keys to fast forward and backward", c="y")
    vp.show(actors, interactive=True, zoom=args.zoom)
    print()
