import json
import math
import random
import argparse
import bisect
import time as timelib
from PIL import Image, ImageDraw, ImageFont
from scipy import stats
from matplotlib import colors
import numpy as np

class LoadingError(Exception):
    pass

class Drawer:

    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
        self.design = design
        self.width = w
        self.height = h
        self.w_margin = w_margin
        self.h_margin = h_margin
        self.w_no_margs = w - 2* w_margin
        self.h_no_margs = h - 2* h_margin

        self.color_converter = colors.ColorConverter()

        self.settings = {
            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
            'dots': {
                'color': {
                    'meaning': 'Lifespan',
                    'start': 'red',
                    'end': 'green',
                    'bias': 1
                    },
                'size': {
                    'meaning': 'EnergyEaten',
                    'start': 1,
                    'end': 6,
                    'bias': 0.5
                    },
                'opacity': {
                    'meaning': 'EnergyEaten',
                    'start': 0.2,
                    'end': 1,
                    'bias': 1
                    }
            },
            'lines': {
                'color': {
                    'meaning': 'adepth',
                    'start': 'black',
                    'end': 'red',
                    'bias': 3
                    },
                'width': {
                    'meaning': 'adepth',
                    'start': 0.1,
                    'end': 4,
                    'bias': 3
                    },
                'opacity': {
                    'meaning': 'adepth',
                    'start': 0.1,
                    'end': 0.8,
                    'bias': 5
                    }
            }
        }

        def merge(source, destination):
            for key, value in source.items():
                if isinstance(value, dict):
                    node = destination.setdefault(key, {})
                    merge(value, node)
                else:
                    destination[key] = value

            return destination

        if config_file != "":
            with open(config_file) as config:
                c = json.load(config)
            self.settings = merge(c, self.settings)
            #print(json.dumps(self.settings, indent=4, sort_keys=True))

    def draw_dots(self, file, min_width, max_width, max_height):
        for i in range(len(self.design.positions)):
            node = self.design.positions[i]
            if 'x' not in node:
                continue
            dot_style = self.compute_dot_style(node=i)
            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)

    def draw_lines(self, file, min_width, max_width, max_height):
        for parent in range(len(self.design.positions)):
            par_pos = self.design.positions[parent]
            if not 'x' in par_pos:
                continue
            for child in self.design.tree.children[parent]:
                chi_pos = self.design.positions[child]
                if 'x' not in chi_pos:
                    continue
                line_style = self.compute_line_style(parent, child)
                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)

    def draw_scale(self, file, filename):
        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 5), "start")

        start_text = ""
        end_text = ""
        if self.design.TIME == "BIRTHS":
           start_text = "Birth #0"
           end_text = "Birth #" + str(len(self.design.positions)-1)
        if self.design.TIME == "REAL":
           start_text = "Time " + str(min(self.design.tree.time))
           end_text = "Time " + str(max(self.design.tree.time))
        if self.design.TIME == "GENERATIONAL":
           start_text = "Depth " + str(self.design.props['adepth_min'])
           end_text = "Depth " + str(self.design.props['adepth_max'])

        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
        self.add_text(file, start_text, (self.width, self.h_margin), "end")
        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")

    def compute_property(self, part, prop, node):
        start = self.settings[part][prop]['start']
        end = self.settings[part][prop]['end']
        value = (self.design.props[self.settings[part][prop]['meaning']][node]
                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
        bias = self.settings[part][prop]['bias']
        if prop == "color":
            return self.compute_color(start, end, value, bias)
        else:
            return self.compute_value(start, end, value, bias)

    def compute_color(self, start, end, value, bias=1):
        if isinstance(value, str):
            value = int(value)
            r, g, b = self.color_converter.to_rgb(self.settings['colors_of_kinds'][value])
        else:
            start_color = self.color_converter.to_rgb(start)
            end_color = self.color_converter.to_rgb(end)
            value = 1 - (1-value)**bias
            r = start_color[0]*(1-value)+end_color[0]*value
            g = start_color[1]*(1-value)+end_color[1]*value
            b = start_color[2]*(1-value)+end_color[2]*value
        return (100*r, 100*g, 100*b)

    def compute_value(self, start, end, value, bias=1):
        value = 1 - (1-value)**bias
        return start*(1-value) + end*value

class PngDrawer(Drawer):

    def scale_up(self):
        self.width *= self.multi
        self.height *= self.multi
        self.w_margin *= self.multi
        self.h_margin *= self.multi
        self.h_no_margs *= self.multi
        self.w_no_margs *= self.multi

    def scale_down(self):
        self.width /= self.multi
        self.height /= self.multi
        self.w_margin /= self.multi
        self.h_margin /= self.multi
        self.h_no_margs /= self.multi
        self.w_no_margs /= self.multi

    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
        print("Drawing...")

        self.multi=multi
        self.scale_up()

        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))

        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
        max_height = max([x['y'] for x in self.design.positions if 'y' in x])

        self.draw_lines(back, min_width, max_width, max_height)
        self.draw_dots(back, min_width, max_width, max_height)

        if scale == "SIMPLE":
            self.draw_scale(back, input_filename)

        #back.show()
        self.scale_down()

        back.thumbnail((self.width, self.height), Image.ANTIALIAS)

        back.save(filename)

    def add_dot(self, file, pos, style):
        x, y = int(pos[0]), int(pos[1])
        r = style['r']*self.multi
        offset = (int(x - r), int(y - r))
        size = (2*int(r), 2*int(r))

        c = style['color']

        img = Image.new('RGBA', size)
        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
        file.paste(img, offset, mask=img)

    def add_line(self, file, from_pos, to_pos, style):
        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
        w = int(style['width'])*self.multi

        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)

        c = style['color']

        img = Image.new('RGBA', size)
        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
        file.paste(img, offset, mask=img)

    def add_dashed_line(self, file, from_pos, to_pos):
        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
        sublines = 50
        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
        normdiv = 2*sublines-1
        for i in range(sublines):
            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/normdiv, 1),
                            self.compute_value(from_pos[1], to_pos[1], 2*i/normdiv, 1))
            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/normdiv, 1),
                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/normdiv, 1))
            self.add_line(file, from_pos_sub, to_pos_sub, style)

    def add_text(self, file, text, pos, anchor, style=''):
        font = ImageFont.truetype("Vera.ttf", 16*self.multi)

        img = Image.new('RGBA', (self.width, self.height))
        draw = ImageDraw.Draw(img)
        txtsize = draw.textsize(text, font=font)
        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
        draw.text(pos, text, (0,0,0), font=font)
        file.paste(img, (0,0), mask=img)

    def compute_line_style(self, parent, child):
        return {'color': self.compute_property('lines', 'color', child),
                'width': self.compute_property('lines', 'width', child),
                'opacity': self.compute_property('lines', 'opacity', child)}

    def compute_dot_style(self, node):
        return {'color': self.compute_property('dots', 'color', node),
                'r': self.compute_property('dots', 'size', node),
                'opacity': self.compute_property('dots', 'opacity', node)}

class SvgDrawer(Drawer):
    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
        print("Drawing...")
        file = open(filename, "w")

        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
        max_height = max([x['y'] for x in self.design.positions if 'y' in x])

        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')

        self.draw_lines(file, min_width, max_width, max_height)
        self.draw_dots(file, min_width, max_width, max_height)

        if scale == "SIMPLE":
            self.draw_scale(file, input_filename)

        file.write("</svg>")
        file.close()

    def add_text(self, file, text, pos, anchor, style=''):
        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
        # assuming font size 12, it should be taken from the style string!
        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')

    def add_dot(self, file, pos, style):
        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')

    def add_line(self, file, from_pos, to_pos, style):
        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')

    def add_dashed_line(self, file, from_pos, to_pos):
        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
        self.add_line(file, from_pos, to_pos, style)

    def compute_line_style(self, parent, child):
        return self.compute_stroke_color('lines', child) + ' ' \
               + self.compute_stroke_width('lines', child) + ' ' \
               + self.compute_stroke_opacity(child)

    def compute_dot_style(self, node):
        return self.compute_dot_size(node) + ' ' \
               + self.compute_fill_opacity(node) + ' ' \
               + self.compute_dot_fill(node)

    def compute_stroke_color(self, part, node):
        color = self.compute_property(part, 'color', node)
        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'

    def compute_stroke_width(self, part, node):
        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'

    def compute_stroke_opacity(self, node):
        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'

    def compute_fill_opacity(self, node):
        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'

    def compute_dot_size(self, node):
        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'

    def compute_dot_fill(self, node):
        color = self.compute_property('dots', 'color', node)
        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'

class Designer:

    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
        self.props = {}

        self.tree = tree

        self.TIME = time
        self.JITTER = jitter

        if balance == "RANDOM":
            self.xmin_crowd = self.xmin_crowd_random
        elif balance == "MIN":
            self.xmin_crowd = self.xmin_crowd_min
        elif balance == "DENSITY":
            self.xmin_crowd = self.xmin_crowd_density
        else:
            raise ValueError("Error, the value of BALANCE does not match any expected value.")

    def calculate_measures(self):
        print("Calculating measures...")
        self.compute_depth()
        self.compute_adepth()
        self.compute_children()
        self.compute_kind()
        self.compute_time()
        self.compute_progress()
        self.compute_custom()

    def xmin_crowd_random(self, x1, x2, y):
        return (x1 if random.randrange(2) == 0 else x2)

    def xmin_crowd_min(self, x1, x2, y):
        x1_closest = 999999
        x2_closest = 999999
        miny = y-3
        maxy = y+3
        i = bisect.bisect_left(self.y_sorted, miny)
        while True:
            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
                break
            pos = self.positions_sorted[i]

            x1_closest = min(x1_closest, abs(x1-pos['x']))
            x2_closest = min(x2_closest, abs(x2-pos['x']))

            i += 1
        return (x1 if x1_closest > x2_closest else x2)

    def xmin_crowd_density(self, x1, x2, y):
        # TODO experimental - requires further work to make it less 'jumpy' and more predictable
        CONST_LOCAL_AREA_RADIUS = 5
        CONST_GLOBAL_AREA_RADIUS = 10
        CONST_WINDOW_SIZE = 20000 #TODO should depend on the maxY ?
        x1_dist_loc = 0
        x2_dist_loc = 0
        count_loc = 1
        x1_dist_glob = 0
        x2_dist_glob = 0
        count_glob = 1
        miny = y-CONST_WINDOW_SIZE
        maxy = y+CONST_WINDOW_SIZE
        i_left = bisect.bisect_left(self.y_sorted, miny)
        i_right = bisect.bisect_right(self.y_sorted, maxy)
        #TODO test: maxy=y should give the same results, right?

        def include_pos(pos):
            nonlocal x1_dist_loc, x2_dist_loc, x1_dist_glob, x2_dist_glob, count_loc, count_glob

            dysq = (pos['y']-y)**2 + 1 #+1 so 1/dysq is at most 1
            dx1 = math.fabs(pos['x']-x1)
            dx2 = math.fabs(pos['x']-x2)

            d = math.fabs(pos['x'] - (x1+x2)/2)

            if d < CONST_LOCAL_AREA_RADIUS:
                x1_dist_loc += math.sqrt(dx1/dysq + dx1**2)
                x2_dist_loc += math.sqrt(dx2/dysq + dx2**2)
                count_loc += 1
            elif d > CONST_GLOBAL_AREA_RADIUS:
                x1_dist_glob += math.sqrt(dx1/dysq + dx1**2)
                x2_dist_glob += math.sqrt(dx2/dysq + dx2**2)
                count_glob += 1

        # optimized to draw from all the nodes, if less than 10 nodes in the range
        if len(self.positions_sorted) > i_left:
            if i_right - i_left < 10:
                for j in range(i_left, i_right):
                    include_pos(self.positions_sorted[j])
            else:
                for j in range(10):
                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
                    include_pos(pos)

        return (x1 if (x1_dist_loc-x2_dist_loc)/count_loc-(x1_dist_glob-x2_dist_glob)/count_glob > 0  else x2)
        #return (x1 if x1_dist +random.gauss(0, 0.00001) > x2_dist +random.gauss(0, 0.00001)  else x2)
        #print(x1_dist, x2_dist)
        #x1_dist = x1_dist**2
        #x2_dist = x2_dist**2
        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)

    def calculate_node_positions(self, ignore_last=0):
        print("Calculating positions...")

        def add_node(node):
            index = bisect.bisect_left(self.y_sorted, node['y'])
            self.y_sorted.insert(index, node['y'])
            self.positions_sorted.insert(index, node)
            self.positions[node['id']] = node

        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
        self.y_sorted = [0]
        self.positions = [{} for x in range(len(self.tree.parents))]
        self.positions[0] = {'x':0, 'y':0, 'id':0}

        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
        visiting_order = [i for i in range(0, len(self.tree.parents))]
        visiting_order = sorted(visiting_order, key=lambda q:
                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))

        start_time = timelib.time()

        # for each child of the current node
        for node_counter,child in enumerate(visiting_order, start=1):
            # debug info - elapsed time
            if node_counter % 100000 == 0:
               print("%d%%\t%d\t%g" % (node_counter*100/len(self.tree.parents), node_counter, timelib.time()-start_time))
               start_time = timelib.time()

            # using normalized adepth
            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:

                ypos = 0
                if self.TIME == "BIRTHS":
                    ypos = child
                elif self.TIME == "GENERATIONAL":
                    # one more than its parent (what if more than one parent?)
                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1 \
                        if self.tree.parents[child] else 0
                elif self.TIME == "REAL":
                    ypos = self.tree.time[child]

                if len(self.tree.parents[child]) == 1:
                # if current_node is the only parent
                    parent, similarity = [(par, v) for par, v in self.tree.parents[child].items()][0]

                    if self.JITTER:
                        dissimilarity = (1-similarity) + random.gauss(0, 0.01) + 0.001
                    else:
                        dissimilarity = (1-similarity) + 0.001
                    add_node({'id':child, 'y':ypos, 'x':
                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
                              self.positions[parent]['x']+dissimilarity, ypos)})
                else:
                    # position weighted by the degree of inheritence from each parent
                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
                    xpos = sum([self.positions[k]['x']*v/total_inheretance
                               for k, v in self.tree.parents[child].items()])
                    if self.JITTER:
                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
                    else:
                        add_node({'id':child, 'y':ypos, 'x':xpos})


    def compute_custom(self):
        for prop in self.tree.props:
            self.props[prop] = [None for x in range(len(self.tree.children))]

            for i in range(len(self.props[prop])):
                self.props[prop][i] = self.tree.props[prop][i]

            self.normalize_prop(prop)

    def compute_time(self):
        # simple rewrite from the tree
        self.props["time"] = [0 for x in range(len(self.tree.children))]

        for i in range(len(self.props['time'])):
            self.props['time'][i] = self.tree.time[i]

        self.normalize_prop('time')

    def compute_kind(self):
        # simple rewrite from the tree
        self.props["kind"] = [0 for x in range(len(self.tree.children))]

        for i in range (len(self.props['kind'])):
            self.props['kind'][i] = str(self.tree.kind[i])

    def compute_depth(self):
        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
        visited = [0 for x in range(len(self.tree.children))]

        nodes_to_visit = [0]
        visited[0] = 1
        self.props["depth"][0] = 0
        while True:
            current_node = nodes_to_visit[0]

            for child in self.tree.children[current_node]:
                if visited[child] == 0:
                    visited[child] = 1
                    nodes_to_visit.append(child)
                    self.props["depth"][child] = self.props["depth"][current_node]+1
            nodes_to_visit = nodes_to_visit[1:]
            if len(nodes_to_visit) == 0:
                break

        self.normalize_prop('depth')

    def compute_adepth(self):
        self.props["adepth"] = [0 for x in range(len(self.tree.children))]

        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
        visiting_order = [i for i in range(0, len(self.tree.parents))]
        visiting_order = sorted(visiting_order, key=lambda q:
                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))[::-1]

        for node in visiting_order:
            children = self.tree.children[node]
            if len(children) != 0:
                # 0 by default
                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
        self.normalize_prop('adepth')

    def compute_children(self):
        self.props["children"] = [0 for x in range(len(self.tree.children))]
        for i in range (len(self.props['children'])):
            self.props['children'][i] = len(self.tree.children[i])

        self.normalize_prop('children')

    def compute_progress(self):
        self.props["progress"] = [0 for x in range(len(self.tree.children))]
        for i in range(len(self.props['children'])):
            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
            if len(times) > 4:
                times = [times[i+1] - times[i] for i in range(len(times)-1)]
                #print(times)
                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0

        for i in range(0, 5):
            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0

        mini = min(self.props['progress'])
        maxi = max(self.props['progress'])
        for k in range(len(self.props['progress'])):
            if self.props['progress'][k] == 0:
                self.props['progress'][k] = mini

        #for k in range(len(self.props['progress'])):
        #        self.props['progress'][k] = 1-self.props['progress'][k]

        self.normalize_prop('progress')

    def normalize_prop(self, prop):
        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
        if len(noneless) > 0:
            max_val = max(noneless)
            min_val = min(noneless)
            print("%s: [%g, %g]" % (prop, min_val, max_val))
            self.props[prop +'_max'] = max_val
            self.props[prop +'_min'] = min_val
            for i in range(len(self.props[prop])):
                if self.props[prop][i] is not None:
                    qqq = self.props[prop][i]
                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)

class TreeData:
    simple_data = None

    children = []
    parents = []
    time = []
    kind = []

    def __init__(self): #, simple_data=False):
        #self.simple_data = simple_data
        pass

    def load(self, filename, max_nodes=0):
        print("Loading...")

        CLI_PREFIX = "Script.Message:"
        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]

        self.ids = {}
        def get_id(id, createOnError = True):
            if createOnError:
                if id not in self.ids:
                    self.ids[id] = len(self.ids)
            else:
                if id not in self.ids:
                    return None
            return self.ids[id]

        file = open(filename)

        # counting the number of expected nodes
        nodes = 0
        for line in file:
            line_arr = line.split(' ', 1)
            if len(line_arr) == 2:
                if line_arr[0] == CLI_PREFIX:
                    line_arr = line_arr[1].split(' ', 1)
                if line_arr[0] == "[OFFSPRING]":
                    nodes += 1

        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
        self.parents = [{} for x in range(nodes)]
        self.children = [[] for x in range(nodes)]
        self.time = [0] * nodes
        self.kind = [0] * nodes
        self.life_lenght = [0] * nodes
        self.props = {}

        print("nodes: %d" % len(self.parents))

        file.seek(0)
        loaded_so_far = 0
        lasttime = timelib.time()
        for line in file:
            line_arr = line.split(' ', 1)
            if len(line_arr) == 2:
                if line_arr[0] == CLI_PREFIX:
                    line_arr = line_arr[1].split(' ', 1)
                if line_arr[0] == "[OFFSPRING]":
                    try:
                        creature = json.loads(line_arr[1])
                    except ValueError:
                        print("Json format error - the line cannot be read. Breaking the loading loop.")
                        # fixing arrays by removing the last element
                        # ! assuming that only the last line is broken !
                        self.parents.pop()
                        self.children.pop()
                        self.time.pop()
                        self.kind.pop()
                        self.life_lenght.pop()
                        nodes -= 1
                        break

                    if "FromIDs" in creature:

                        # make sure that ID's of parents are lower than that of their children
                        for i in range(0, len(creature["FromIDs"])):
                            if creature["FromIDs"][i] not in self.ids:
                                get_id("virtual_parent")

                        creature_id = get_id(creature["ID"])

                        # debug
                        if loaded_so_far%1000 == 0:
                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
                            lasttime = timelib.time()

                        # we assign to each parent its contribution to the genotype of the child
                        for i in range(0, len(creature["FromIDs"])):
                            if creature["FromIDs"][i] in self.ids:
                                parent_id = get_id(creature["FromIDs"][i])
                            else:
                                parent_id = get_id("virtual_parent")
                            inherited = (creature["Inherited"][i] if 'Inherited' in creature else 1)
                            self.parents[creature_id][parent_id] = inherited

                        if "Time" in creature:
                            self.time[creature_id] = creature["Time"]

                        if "Kind" in creature:
                            self.kind[creature_id] = creature["Kind"]

                        for prop in creature:
                            if prop not in default_props:
                                if prop not in self.props:
                                    self.props[prop] = [0 for i in range(nodes)]
                                self.props[prop][creature_id] = creature[prop]

                        loaded_so_far += 1
                    else:
                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
                if line_arr[0] == "[DIED]":
                    creature = json.loads(line_arr[1])
                    creature_id = get_id(creature["ID"], False)
                    if creature_id is not None:
                        for prop in creature:
                            if prop not in default_props:
                                if prop not in self.props:
                                    self.props[prop] = [0 for i in range(nodes)]
                                self.props[prop][creature_id] = creature[prop]


            if loaded_so_far >= max_nodes and max_nodes != 0:
                break

        for k in range(len(self.parents)):
            v = self.parents[k]
            for val in self.parents[k]:
                self.children[val].append(k)

depth = {}
kind = {}

def main():

    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
                                                 'information from a text file. Supports files generated by Framsticks experiments.')
    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')

    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')

    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
                                                                      'BIRTHS: time measured as the number of births since the beginning; '
                                                                      'GENERATIONAL: time measured as number of ancestors; '
                                                                      'REAL: real time of the simulation')
    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')

    parser.set_defaults(draw_tree=True)
    parser.set_defaults(draw_skeleton=False)
    parser.set_defaults(draw_spine=False)

    parser.set_defaults(seed=-1)

    args = parser.parse_args()

    TIME = args.time.upper()
    BALANCE = args.balance.upper()
    SCALE = args.scale.upper()
    JITTER = args.jitter
    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
        or not SCALE in ['NONE', 'SIMPLE']:
        print("Incorrect value of one of the parameters! (time or balance or scale).") #user has to figure out which parameter is wrong...
        return

    dir = args.input
    seed = args.seed
    if seed == -1:
        seed = random.randint(0, 10000)
    random.seed(seed)
    print("randomseed:", seed)

    tree = TreeData()
    tree.load(dir, max_nodes=args.max_nodes)


    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
    designer.calculate_measures()
    designer.calculate_node_positions(ignore_last=args.skip)

    if args.output.endswith(".svg"):
        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
    else:
        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)


main()
