CSC400 3DVis translate.py

From CSclasswiki
Jump to: navigation, search
# translate.py
# Kelly Dwan
# Version 1.0; 5.6

import numpy as N
from nifti import *
import re,math,sys

class points():
    def __init__(self,filename):
        #initialize dictionary of lists of references [used later in scale()]
        self.refs = {'sagittal':[],'coronal':[],'horizontal':[]}
        #these are the hardcoded widths of scansets from the Human Brain Atlas [see note in wiki]
        self.n = {'sagittal':2984.,'coronal':4170.,'horizontal':2340.}
        #The top left point in the bounding box- see wiki sketch for why.
        self.p0 = {'sagittal':[],'coronal':[],'horizontal':[]}
        #This program generates multiple nifti images, this helps name them.
        self.k = 0
        #point line regex - recognizes points of the form [x,y]
        self.pl = r'\[([A-Z0-9]*), ([A-Z0-9]*)\]'
        #name line regex - same as filename regex from start.py
        self.nl = r'[scan_]?(horizontal|sagittal|coronal)_([A-Z0-9]*)[_]?(C)?.(jpg|png|gif)'
        file = open(filename, 'r')
        lines = file.readlines()
        self.lines = lines
        self.getRefs()        #using the images marked 'C', find the bounding box. This info is stored in self.refs
        for i in range(len(lines)):
            namematch = re.search(self.nl, lines[i], re.IGNORECASE)
            if (namematch != None):        #this means the first line is a filename
                j = 0
                points=[]
                # Each orientation has a different point order to align it with MRICron -see wiki for more.
                if (namematch.group(1) == 'sagittal'):
                    while i+1 < len(lines) and re.search(self.pl, lines[i+1], re.IGNORECASE)!=None:
                        n = -1*(int(namematch.group(2))-1492)
                        i += 1
                        pointmatch = re.search(self.pl, lines[i], re.IGNORECASE)
                        x1 = int(pointmatch.group(1))
                        y1 = int(pointmatch.group(2))
                        points.append((x1,y1))
                        #the first four points are not actually input-points, they're reference points. skip them.
                        if j > 3:
                            # see wiki sketch for why we do this
                            x = self.refs['sagittal'][0] - (int(pointmatch.group(1))-self.p0['sagittal'][0])
                            y = self.refs['sagittal'][1] - (int(pointmatch.group(2))-self.p0['sagittal'][1])
                            point = (n,x,y,1)
                            self.trans('ch2better.nii', self.scale(point,self.refs['sagittal']))
                        j += 1
                elif (namematch.group(1) == "coronal"):
                    while i < len(lines)-1 and re.search(self.pl, lines[i+1], re.IGNORECASE)!=None:
                        n = -1*(int(namematch.group(2))-4170)
                        i += 1
                        pointmatch = re.search(self.pl, lines[i], re.IGNORECASE)
                        x1 = int(pointmatch.group(1))
                        y1 = int(pointmatch.group(2))
                        points.append((x1,y1))
                        #skip first four
                        if j > 3:
                            # see wiki sketch for why we do this
                            x = self.refs['coronal'][0] - (int(pointmatch.group(1))-self.p0['coronal'][0])
                            y = self.refs['coronal'][1] - (int(pointmatch.group(2))-self.p0['coronal'][1])
                            point = (x,n,y,1)
                            self.trans('ch2better.nii', self.scale(point,self.refs['coronal']))
                        j+=1
                elif (namematch.group(1) == "horizontal"):
                    while i < len(lines)-1 and re.search(self.pl, lines[i+1], re.IGNORECASE)!=None:
                        n = -1*(int(namematch.group(2))-2340)
                        i += 1
                        pointmatch = re.search(self.pl, lines[i], re.IGNORECASE)
                        x1 = int(pointmatch.group(1))
                        y1 = int(pointmatch.group(2))
                        points.append((x1,y1))
                        #skip first four
                        if j > 3:
                            # see wiki sketch for why we do this
                            x = self.refs['horizontal'][0] - (int(pointmatch.group(1))-self.p0['horizontal'][0])
                            y = self.refs['horizontal'][1] - (int(pointmatch.group(2))-self.p0['horizontal'][1])
                            point = (x,y,n,1)
                            self.trans('ch2better.nii', self.scale(point,self.refs['horizontal']))
                        j+=1

    # # #
    # Scale each point to the model coordinate system
    def scale(self,point,ref):
        x = (150./ref[0])
        y = (185./ref[1])
        z = (158./ref[2])
        #print x,y,z
        point = N.array(point)
        scale = N.array([x,y,z,1])
        offset = N.array([-75,-107,-70,0.])
        return scale*point + offset

    # # #
    # Translate a cube to the given 3D point
    def trans(self,model,point):
        base = NiftiImage(model)
        cube = N.ones((30,30,30))
        nim = NiftiImage(cube,base.header)
        matrix = base.header['sform']
        matrix[:,-1] = point
        #print matrix
        h = nim.header
        h['sform'] = matrix
        nim.header = h
        nim.save('cube'+str(k)+'.nii')
        self.k += 1
        return

    # # #
    # Define references using the filename
    def getRefs(self):
        lines = self.lines
        points = []
        for i in range(len(lines)):
            namematch = re.search(self.nl, lines[i], re.IGNORECASE)
            if namematch != None:
                name = namematch.group(1)
                points = []
                if namematch.group(3) == 'C':
                    i += 1
                    pointmatch = re.search(self.pl, lines[i], re.IGNORECASE)
                    while pointmatch != None and i < len(lines)-1:
                        x = int(pointmatch.group(1))
                        y = int(pointmatch.group(2))
                        points.append((x,y))
                        i += 1
                        pointmatch = re.search(self.pl, lines[i], re.IGNORECASE)
                    x = math.fabs(points[2][0]-points[0][0])
                    y = math.fabs(points[3][1]-points[1][1])
                    if name == 'sagittal':
                        self.p0[name] = (points[0][0],points[1][1])
                        self.refs[name] = [self.n[name],x,y,1]
                    elif name == 'coronal':
                        self.p0[name] = (points[0][0],points[1][1])
                        self.refs[name] = [x,self.n[name],y,1]
                    elif name == 'horizontal':
                        self.p0[name] = (points[0][0],points[1][1])
                        self.refs[name] = [x,y,self.n[name],1]
        return

if __name__ == "__main__":
    import translate
    if len(sys.argv) != 2:
        print("Insufficient arguments. Proper format is translate.py <filename.txt> \n")
        sys.exit()
    translate.points(sys.argv[1])