# This is a CamiTK python action
#
# Test frames and transformations
# This action will add, update and remove a transformation between two components.
# For image component you can choose to use the data or main frame.

import camitk
import numpy as np
from scipy.spatial.transform import Rotation

def init(self:camitk.Action):
    # open separately
    # amos_0336/amos_0336_mesh6.obj
    # amos_0336/amos_0336.mhd
    return

def process(self:camitk.Action):
    # get the first mesh and first image components if any opened
    mesh, image = getComponents()
    
    if mesh is not None and image is not None:
        print(f"found {mesh.getName()} and {image.getName()}")
                
        # move the mesh to the image data frame
        mesh.setFrame(image.getDataFrame())
        assert mesh.getFrame() == image.getDataFrame(), f"mesh frame should be equals to image data frame {image.getDataFrame().getName()} not {mesh.getFrame().getName()}"

        self.refreshApplication()
        
        # back to the original settings: mesh has its own independent frame
        new_frame = camitk.TransformationManager.addFrameOfReference("new mesh frame")
        mesh.setFrame(new_frame)
        
        self.refreshApplication()
        
        # add identity transformation from amos_0336_mesh6 to amos_0336 data
        tr = camitk.TransformationManager.addTransformation(mesh.getFrame(), image.getDataFrame())
        
        # update transformation using from/to
        camitk.TransformationManager.updateTransformation(mesh.getFrame(), image.getDataFrame(), random_homogeneous_transform())

        # update transformation using the tr object itself
        camitk.TransformationManager.updateTransformation(tr, random_homogeneous_transform())
        
        world = camitk.TransformationManager.getWorldFrame()
        print(f"world frame is called {world.getName()}")
        
        # update the transformation, the mesh should be displayed as if it was in the image data frame
        camitk.TransformationManager.updateTransformation(tr, np.eye(4))
    
        # remove transformation
        camitk.TransformationManager.removeTransformation(mesh.getFrame(), image.getDataFrame())        
        # or
        # camitk.TransformationManager.removeTransformation(tr)
        
        self.refreshApplication()
        return True
    else:
        return False

def targetDefined(self:camitk.Action):
    frames = camitk.TransformationManager.getFramesOfReference();
    frame_names = [ f.getName() for f in frames ]
    self.getProperty("From").setAttribute("enumNames", frame_names)
    self.getProperty("To").setAttribute("enumNames", frame_names)
    return

def parameterChanged(self:camitk.Action, name:str):
    return

def getComponents():
    components = camitk.Application.getTopLevelComponents()
    mesh = next((c for c in components if isinstance(c, camitk.MeshComponent)), None)
    image = next((c for c in components if isinstance(c, camitk.ImageComponent)), None)
    return mesh, image

def random_homogeneous_transform():
    # Random rotation using scipy
    r = Rotation.random()
    rotation_matrix = r.as_matrix()  # shape (3, 3)

    # Random translation vector
    translation = np.random.rand(3)  # values in [0, 1)

    # Build 4x4 transformation matrix
    transform = np.eye(4)
    transform[:3, :3] = rotation_matrix
    transform[:3, 3] = translation

    return transform