Changeset 624


Ignore:
Timestamp:
10/21/16 19:45:07 (8 years ago)
Author:
konrad
Message:

New, faster, more flexible and refactored version of the tree-genealogy script. Supports drawing trees as both raster graphics (bmp/png/jpg) and vector graphics (svg).

Location:
mds-and-trees
Files:
1 added
1 edited

Legend:

Unmodified
Added
Removed
  • mds-and-trees/tree-genealogy.py

    r623 r624  
    1 # Draws a genealogical tree (generates a SVG file) based on parent-child relationship information.
    2 # Supports files generated by Framsticks experiments.
    3 
    41import json
     2import math
    53import random
    6 import math
    74import argparse
    8 import time as ttime
    9 
    10 TIME = "" # BIRTHS / GENERATIONAL / REAL
    11 BALANCE = "" # MIN / DENSITY
    12 
    13 DOT_STYLE = "" # NONE / NORMAL / CLEAR
    14 
    15 JITTER = "" #
    16 
    17 # ------SVG---------
    18 svg_file = 0
    19 
    20 svg_line_style = 'stroke="rgb(90%,10%,16%)" stroke-width="1" stroke-opacity="0.7"'
    21 svg_mutation_line_style = 'stroke-width="1"'
    22 svg_crossover_line_style = 'stroke-width="1"'
    23 svg_spine_line_style = 'stroke="rgb(0%,90%,40%)" stroke-width="2" stroke-opacity="1"'
    24 svg_scale_line_style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
    25 
    26 svg_dot_style = 'r="2" stroke="black" stroke-width="0.2" fill="red"'
    27 svg_clear_dot_style = 'r="2" stroke="black" stroke-width="0.4" fill="none"'
    28 svg_spine_dot_style = 'r="1" stroke="black" stroke-width="0.2" fill="rgb(50%,50%,100%)"'
    29 
    30 svg_scale_text_style = 'style="font-family: Arial; font-size: 12; fill: #000000;"'
    31 
    32 def hex_to_style(hex):
    33     default_style = ' stroke="black" stroke-opacity="0.5" '
    34 
    35     if hex[0] == "#":
    36         hex = hex[1:]
    37 
    38     if len(hex) == 6 or len(hex) == 8:
    39         try:
    40             int(hex, 16)
    41         except:
    42             print("Invalid characters in the color's hex #" + hex + "! Assuming black.")
    43             return default_style
    44         red = 100*int(hex[0:2], 16)/255
    45         green = 100*int(hex[2:4], 16)/255
    46         blue = 100*int(hex[4:6], 16)/255
    47         opacity = 0.5
    48         if len(hex) == 8:
    49             opacity = int(hex[6:8], 16)/255
    50         return ' stroke="rgb(' +str(red)+ '%,' +str(green)+ '%,' +str(blue)+ '%)" stroke-opacity="' +str(opacity)+ '" '
    51     else:
    52         print("Invalid number of digits in the color's hex #" + hex + "! Assuming black.")
    53         return default_style
    54 
    55 def svg_add_line(from_pos, to_pos, style=svg_line_style):
    56     svg_file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
    57                    '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
    58 
    59 def svg_add_text(text, pos, anchor, style=svg_scale_text_style):
    60     svg_file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]) + '" >' + text + '</text>')
    61 
    62 def svg_add_dot(pos, style=svg_dot_style):
    63     svg_file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
    64 
    65 def svg_generate_line_style(percent):
    66     # hotdog
    67     from_col = [100, 70, 0]
    68     to_col = [60, 0, 0]
    69     # lava
    70     # from_col = [100, 80, 0]
    71     # to_col = [100, 0, 0]
    72     # neon
    73     # from_col = [30, 200, 255]
    74     # to_col = [240, 0, 220]
    75 
    76     from_opa = 0.2
    77     to_opa = 1.0
    78     from_stroke = 1
    79     to_stroke = 3
    80 
    81     opa = from_opa*(1-percent) + to_opa*percent
    82     stroke = from_stroke*(1-percent) + to_stroke*percent
    83 
    84     percent = 1 - ((1-percent)**20)
    85 
    86     return 'stroke="rgb(' + str(from_col[0]*(1-percent) + to_col[0]*percent) + '%,' \
    87            + str(from_col[1]*(1-percent) + to_col[1]*percent) + '%,' \
    88            + str(from_col[2]*(1-percent) + to_col[2]*percent) + '%)" stroke-width="' + str(stroke) + '" stroke-opacity="' + str(opa) + '"'
    89 
    90 def svg_generate_dot_style(kind):
    91     kinds = ["red", "lawngreen", "royalblue", "magenta", "yellow", "cyan", "white", "black"]
    92 
    93     r = min(2500/len(nodes), 10)
    94 
    95     return 'fill="' + kinds[kind] + '" r="' + str(r) + '" stroke="black" stroke-width="' + str(r/10) + '" fill-opacity="1.0" ' \
    96            'stroke-opacity="1.0"'
    97 
    98 # -------------------
    99 
    100 def load_data(dir):
    101     global firstnode, nodes, inv_nodes, time
    102     f = open(dir)
    103     loaded = 0
    104 
    105     for line in f:
    106         sline = line.split(' ', 1)
    107         if len(sline) == 2:
    108             if sline[0] == "[OFFSPRING]":
    109                 creature = json.loads(sline[1])
    110                 #print("B" +str(creature))
    111                 if "FromIDs" in creature:
    112                     if not creature["ID"] in nodes:
    113                         nodes[creature["ID"]] = {}
     5import bisect
     6import time as timelib
     7from PIL import Image, ImageDraw, ImageFont
     8
     9class LoadingError(Exception):
     10    pass
     11
     12class Drawer:
     13
     14    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
     15        self.design = design
     16        self.width = w
     17        self.height = h
     18        self.w_margin = w_margin
     19        self.h_margin = h_margin
     20        self.w_no_margs = w - 2* w_margin
     21        self.h_no_margs = h - 2* h_margin
     22
     23        self.colors = {
     24            'black' :   {'r':0,     'g':0,      'b':0},
     25            'red' :     {'r':100,   'g':0,      'b':0},
     26            'green' :   {'r':0,     'g':100,    'b':0},
     27            'blue' :    {'r':0,     'g':0,      'b':100},
     28            'yellow' :  {'r':100,   'g':100,    'b':0},
     29            'magenta' : {'r':100,   'g':0,      'b':100},
     30            'cyan' :    {'r':0,     'g':100,    'b':100},
     31            'orange':   {'r':100,   'g':50,     'b':0},
     32            'purple':   {'r':50,    'g':0,      'b':100}
     33        }
     34
     35        self.settings = {
     36            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
     37            'dots': {
     38                'color': {
     39                    'meaning': 'depth',
     40                    'start': 'purple',
     41                    'end': 'green',
     42                    'bias': 1
     43                    },
     44                'size': {
     45                    'meaning': 'children',
     46                    'start': 1,
     47                    'end': 5,
     48                    'bias': 0.5
     49                    },
     50                'opacity': {
     51                    'meaning': 'children',
     52                    'start': 0.3,
     53                    'end': 0.8,
     54                    'bias': 1
     55                    }
     56            },
     57            'lines': {
     58                'color': {
     59                    'meaning': 'adepth',
     60                    'start': 'black',
     61                    'end': 'red',
     62                    'bias': 3
     63                    },
     64                'width': {
     65                    'meaning': 'adepth',
     66                    'start': 1,
     67                    'end': 4,
     68                    'bias': 3
     69                    },
     70                'opacity': {
     71                    'meaning': 'adepth',
     72                    'start': 0.1,
     73                    'end': 0.8,
     74                    'bias': 5
     75                    }
     76            }
     77        }
     78
     79        def merge(source, destination):
     80            for key, value in source.items():
     81                if isinstance(value, dict):
     82                    node = destination.setdefault(key, {})
     83                    merge(value, node)
     84                else:
     85                    destination[key] = value
     86
     87            return destination
     88
     89        if config_file != "":
     90            with open(config_file) as config:
     91                c = json.load(config)
     92            self.settings = merge(c, self.settings)
     93            #print(json.dumps(self.settings, indent=4, sort_keys=True))
     94
     95    def draw_dots(self, file, min_width, max_width, max_height):
     96        for i in range(len(self.design.positions)):
     97            node = self.design.positions[i]
     98            if 'x' not in node:
     99                continue
     100            dot_style = self.compute_dot_style(node=i)
     101            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
     102                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
     103
     104    def draw_lines(self, file, min_width, max_width, max_height):
     105        for parent in range(len(self.design.positions)):
     106            par_pos = self.design.positions[parent]
     107            if not 'x' in par_pos:
     108                continue
     109            for child in self.design.tree.children[parent]:
     110                chi_pos = self.design.positions[child]
     111                if 'x' not in chi_pos:
     112                    continue
     113                line_style = self.compute_line_style(parent, child)
     114                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
     115                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
     116                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
     117                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
     118
     119    def draw_scale(self, file, filename):
     120        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 15), "start")
     121        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
     122
     123        start_text = ""
     124        end_text = ""
     125        if self.design.TIME == "BIRTHS":
     126           start_text = "Birth #0"
     127           end_text = "Birth #" + str(len(self.design.positions)-1)
     128        if self.design.TIME == "REAL":
     129           start_text = "Time " + str(min(self.design.tree.time))
     130           end_text = "Time " + str(max(self.design.tree.time))
     131        if self.design.TIME == "GENERATIONAL":
     132           start_text = "Depth " + str(self.design.props['adepth']['min'])
     133           end_text = "Depth " + str(self.design.props['adepth']['max'])
     134
     135        self.add_text(file, start_text, (self.width, self.h_margin + 15), "end")
     136        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
     137        self.add_text(file, end_text, (self.width, self.height-self.h_margin + 15), "end")
     138
     139    def compute_property(self, part, prop, node):
     140        start = self.settings[part][prop]['start']
     141        end = self.settings[part][prop]['end']
     142        value = (self.design.props[self.settings[part][prop]['meaning']][node]
     143                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
     144        bias = self.settings[part][prop]['bias']
     145        if prop == "color":
     146            return self.compute_color(start, end, value, bias)
     147        else:
     148            return self.compute_value(start, end, value, bias)
     149
     150    def compute_color(self, start, end, value, bias=1):
     151        if isinstance(value, str):
     152            value = int(value)
     153            r = self.colors[self.settings['colors_of_kinds'][value]]['r']
     154            g = self.colors[self.settings['colors_of_kinds'][value]]['g']
     155            b = self.colors[self.settings['colors_of_kinds'][value]]['b']
     156        else:
     157            start_color = self.colors[start]
     158            end_color = self.colors[end]
     159            value = 1 - (1-value)**bias
     160            r = start_color['r']*(1-value)+end_color['r']*value
     161            g = start_color['g']*(1-value)+end_color['g']*value
     162            b = start_color['b']*(1-value)+end_color['b']*value
     163        return (r, g, b)
     164
     165    def compute_value(self, start, end, value, bias=1):
     166        value = 1 - (1-value)**bias
     167        return start*(1-value) + end*value
     168
     169class PngDrawer(Drawer):
     170    def draw_design(self, filename, input_filename, scale="SIMPLE"):
     171        print("Drawing...")
     172
     173        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
     174
     175        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
     176        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
     177        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
     178
     179        self.draw_lines(back, min_width, max_width, max_height)
     180        self.draw_dots(back, min_width, max_width, max_height)
     181
     182        if scale == "SIMPLE":
     183            self.draw_scale(back, input_filename)
     184
     185        back.show()
     186        back.save(filename)
     187
     188    def add_dot(self, file, pos, style):
     189        x, y = int(pos[0]), int(pos[1])
     190        r = style['r']
     191        offset = (int(x - r), int(y - r))
     192        size = (2*int(r), 2*int(r))
     193
     194        c = style['color']
     195
     196        img = Image.new('RGBA', size)
     197        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
     198                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
     199        file.paste(img, offset, mask=img)
     200
     201    def add_line(self, file, from_pos, to_pos, style):
     202        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
     203        w = int(style['width'])
     204
     205        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
     206        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
     207
     208        c = style['color']
     209
     210        img = Image.new('RGBA', size)
     211        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
     212                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), int(style['width']))
     213        file.paste(img, offset, mask=img)
     214
     215    def add_dashed_line(self, file, from_pos, to_pos):
     216        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
     217        sublines = 50
     218        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
     219        for i in range(sublines):
     220            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/(2*sublines-1), 1),
     221                            self.compute_value(from_pos[1], to_pos[1], 2*i/(2*sublines-1), 1))
     222            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/(2*sublines-1), 1),
     223                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/(2*sublines-1), 1))
     224            self.add_line(file, from_pos_sub, to_pos_sub, style)
     225
     226    def add_text(self, file, text, pos, anchor, style=''):
     227        font = ImageFont.truetype("Vera.ttf", 16)
     228
     229        img = Image.new('RGBA', (self.width, self.height))
     230        draw = ImageDraw.Draw(img)
     231        txtsize = draw.textsize(text, font=font)
     232        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1]-txtsize[1])
     233        draw.text(pos, text, (0,0,0), font=font)
     234        file.paste(img, (0,0), mask=img)
     235
     236    def compute_line_style(self, parent, child):
     237        return {'color': self.compute_property('lines', 'color', child),
     238                'width': self.compute_property('lines', 'width', child),
     239                'opacity': self.compute_property('lines', 'opacity', child)}
     240
     241    def compute_dot_style(self, node):
     242        return {'color': self.compute_property('dots', 'color', node),
     243                'r': self.compute_property('dots', 'size', node),
     244                'opacity': self.compute_property('dots', 'opacity', node)}
     245
     246class SvgDrawer(Drawer):
     247    def draw_design(self, filename, input_filename, scale="SIMPLE"):
     248        print("Drawing...")
     249        file = open(filename, "w")
     250
     251        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
     252        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
     253        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
     254
     255        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
     256                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
     257                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
     258
     259        self.draw_lines(file, min_width, max_width, max_height)
     260        self.draw_dots(file, min_width, max_width, max_height)
     261
     262        if scale == "SIMPLE":
     263            self.draw_scale(file, input_filename)
     264
     265        file.write("</svg>")
     266        file.close()
     267
     268    def add_text(self, file, text, pos, anchor, style=''):
     269        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
     270        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]) + '" >' + text + '</text>')
     271
     272    def add_dot(self, file, pos, style):
     273        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
     274
     275    def add_line(self, file, from_pos, to_pos, style):
     276        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
     277                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
     278
     279    def add_dashed_line(self, file, from_pos, to_pos):
     280        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
     281        self.add_line(file, from_pos, to_pos, style)
     282
     283    def compute_line_style(self, parent, child):
     284        return self.compute_stroke_color('lines', child) + ' ' \
     285               + self.compute_stroke_width('lines', child) + ' ' \
     286               + self.compute_stroke_opacity(child)
     287
     288    def compute_dot_style(self, node):
     289        return self.compute_dot_size(node) + ' ' \
     290               + self.compute_fill_opacity(node) + ' ' \
     291               + self.compute_dot_fill(node)
     292
     293    def compute_stroke_color(self, part, node):
     294        color = self.compute_property(part, 'color', node)
     295        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
     296
     297    def compute_stroke_width(self, part, node):
     298        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
     299
     300    def compute_stroke_opacity(self, node):
     301        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
     302
     303    def compute_fill_opacity(self, node):
     304        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
     305
     306    def compute_dot_size(self, node):
     307        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
     308
     309    def compute_dot_fill(self, node):
     310        color = self.compute_property('dots', 'color', node)
     311        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
     312
     313class Designer:
     314
     315    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
     316        self.props = {}
     317
     318        self.tree = tree
     319
     320        self.TIME = time
     321        self.JITTER = jitter
     322
     323        if balance == "RANDOM":
     324            self.xmin_crowd = self.xmin_crowd_random
     325        elif balance == "MIN":
     326            self.xmin_crowd = self.xmin_crowd_min
     327        elif balance == "DENSITY":
     328            self.xmin_crowd = self.xmin_crowd_density
     329        else:
     330            raise ValueError("Error, the value of BALANCE does not match any expected value.")
     331
     332    def calculate_measures(self):
     333        print("Calculating measures...")
     334        self.compute_adepth()
     335        self.compute_depth()
     336        self.compute_children()
     337        self.compute_kind()
     338        self.compute_time()
     339        self.compute_custom()
     340
     341    def xmin_crowd_random(self, x1, x2, y):
     342        return (x1 if random.randrange(2) == 0 else x2)
     343
     344    def xmin_crowd_min(self, x1, x2, y):
     345        x1_closest = 999999
     346        x2_closest = 999999
     347        miny = y-3
     348        maxy = y+3
     349        i = bisect.bisect_left(self.y_sorted, miny)
     350        while True:
     351            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
     352                break
     353            pos = self.positions_sorted[i]
     354
     355            x1_closest = min(x1_closest, abs(x1-pos['x']))
     356            x2_closest = min(x2_closest, abs(x2-pos['x']))
     357
     358            i += 1
     359        return (x1 if x1_closest > x2_closest else x2)
     360
     361    def xmin_crowd_density(self, x1, x2, y):
     362        x1_dist = 0
     363        x2_dist = 0
     364        miny = y-500
     365        maxy = y+500
     366        i_left = bisect.bisect_left(self.y_sorted, miny)
     367        i_right = bisect.bisect_right(self.y_sorted, maxy)
     368        # print("i " + str(i) + " len " + str(len(self.positions)))
     369        #
     370        # i = bisect.bisect_left(self.y_sorted, y)
     371        # i_left = max(0, i - 25)
     372        # i_right = min(len(self.y_sorted), i + 25)
     373
     374        def include_pos(pos):
     375            nonlocal x1_dist, x2_dist
     376
     377            dysq = (pos['y']-y)**2
     378            dx1 = pos['x']-x1
     379            dx2 = pos['x']-x2
     380
     381            x1_dist += math.sqrt(dysq + dx1**2)
     382            x2_dist += math.sqrt(dysq + dx2**2)
     383
     384        # optimized to draw from all the nodes, if less than 10 nodes in the range
     385        if len(self.positions_sorted) > i_left:
     386            if i_right - i_left < 10:
     387                for j in range(i_left, i_right):
     388                    include_pos(self.positions_sorted[j])
     389            else:
     390                for j in range(10):
     391                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
     392                    include_pos(pos)
     393
     394        return (x1 if x1_dist > x2_dist else x2)
     395        #print(x1_dist, x2_dist)
     396        #x1_dist = x1_dist**2
     397        #x2_dist = x2_dist**2
     398        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
     399        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
     400
     401    def calculate_node_positions(self, ignore_last=0):
     402        print("Calculating positions...")
     403
     404        current_node = 0
     405
     406        def add_node(node):
     407            nonlocal current_node
     408            index = bisect.bisect_left(self.y_sorted, node['y'])
     409            self.y_sorted.insert(index, node['y'])
     410            self.positions_sorted.insert(index, node)
     411            self.positions[node['id']] = node
     412
     413        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
     414        self.y_sorted = [0]
     415        self.positions = [{} for x in range(len(self.tree.parents))]
     416        self.positions[0] = {'x':0, 'y':0, 'id':0}
     417
     418        nodes_to_visit = [0]
     419        visited = [False] * len(self.tree.parents)
     420        visited[0] = True
     421
     422        node_counter = 0
     423        start_time = timelib.time()
     424
     425        while True:
     426
     427            node_counter += 1
     428            if node_counter%1000 == 0:
     429                print(str(node_counter) + " " + str(timelib.time()-start_time))
     430                start_time = timelib.time()
     431
     432            current_node = nodes_to_visit[0]
     433
     434            for child in self.tree.children[current_node]:
     435                if not visited[child] and self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
     436                    nodes_to_visit.append(child)
     437                    visited[child] = True
     438
     439                    ypos = 0
     440                    if self.TIME == "BIRTHS":
     441                        ypos = child
     442                    elif self.TIME == "GENERATIONAL":
     443                        ypos = self.positions[current_node]['y']+1
     444                    elif self.TIME == "REAL":
     445                        ypos = self.tree.time[child]
     446
     447                    if len(self.tree.parents[child]) == 1:
     448                    # if current_node is the only parent
     449                        if self.JITTER:
     450                            dissimilarity = random.gauss(0, 0.5)
     451                        else:
     452                            dissimilarity = 1
     453                        add_node({'id':child, 'y':ypos, 'x':
     454                                 self.xmin_crowd(self.positions[current_node]['x']-dissimilarity,
     455                                  self.positions[current_node]['x']+dissimilarity, ypos)})
     456                    else:
     457                        total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
     458                        xpos = sum([self.positions[k]['x']*v/total_inheretance
     459                                   for k, v in self.tree.parents[child].items()])
     460                        if self.JITTER:
     461                            add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
     462                        else:
     463                            add_node({'id':child, 'y':ypos, 'x':xpos})
     464
     465            nodes_to_visit = nodes_to_visit[1:]
     466            # if none left, we can stop
     467            if len(nodes_to_visit) == 0:
     468                print("done")
     469                break
     470
     471    def compute_custom(self):
     472        for prop in self.tree.props:
     473            self.props[prop] = [None for x in range(len(self.tree.children))]
     474
     475            for i in range(len(self.props[prop])):
     476                self.props[prop][i] = self.tree.props[prop][i]
     477
     478            self.normalize_prop(prop)
     479
     480    def compute_time(self):
     481        # simple rewrite from the tree
     482        self.props["time"] = [0 for x in range(len(self.tree.children))]
     483
     484        for i in range(len(self.props['time'])):
     485            self.props['time'][i] = self.tree.time[i]
     486
     487        self.normalize_prop('time')
     488
     489    def compute_kind(self):
     490        # simple rewrite from the tree
     491        self.props["kind"] = [0 for x in range(len(self.tree.children))]
     492
     493        for i in range (len(self.props['kind'])):
     494            self.props['kind'][i] = str(self.tree.kind[i])
     495
     496    def compute_depth(self):
     497        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
     498
     499        nodes_to_visit = [0]
     500        self.props["depth"][0] = 0
     501        while True:
     502            for child in self.tree.children[nodes_to_visit[0]]:
     503                nodes_to_visit.append(child)
     504                self.props["depth"][child] = min([self.props["depth"][d] for d in self.tree.parents[child]])+1
     505            nodes_to_visit = nodes_to_visit[1:]
     506            if len(nodes_to_visit) == 0:
     507                break
     508
     509        self.normalize_prop('depth')
     510
     511    def compute_adepth(self):
     512        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
     513
     514        def compute_local_adepth(node):
     515            my_adepth = 0
     516            for c in self.tree.children[node]:
     517                my_adepth = max(my_adepth, compute_local_adepth(c)+1)
     518            self.props["adepth"][node] = my_adepth
     519            return my_adepth
     520
     521        compute_local_adepth(0)
     522        self.normalize_prop('adepth')
     523
     524    def compute_children(self):
     525        self.props["children"] = [0 for x in range(len(self.tree.children))]
     526        for i in range (len(self.props['children'])):
     527            self.props['children'][i] = len(self.tree.children[i])
     528
     529        self.normalize_prop('children')
     530
     531    def normalize_prop(self, prop):
     532        noneless = [v for v in self.props[prop] if type(v)==int or type(v)==float]
     533        if len(noneless) > 0:
     534            max_val = max(noneless)
     535            min_val = min(noneless)
     536            self.props[prop +'_max'] = max_val
     537            self.props[prop +'_min'] = min_val
     538            for i in range(len(self.props[prop])):
     539                if self.props[prop][i] is not None:
     540                    self.props[prop][i] = (self.props[prop][i] - min_val) / max_val
     541
     542
     543class TreeData:
     544    simple_data = None
     545
     546    children = []
     547    parents = []
     548    time = []
     549    kind = []
     550
     551    def __init__(self): #, simple_data=False):
     552        #self.simple_data = simple_data
     553        pass
     554
     555    def load(self, filename, max_nodes=0):
     556        print("Loading...")
     557
     558        CLI_PREFIX = "Script.Message:"
     559        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
     560
     561        ids = {}
     562        def get_id(id):
     563            if not id in ids:
     564                ids[id] = len(ids)
     565            return ids[id]
     566
     567        file = open(filename)
     568
     569        # counting the number of expected nodes
     570        nodes = 0
     571        for line in file:
     572            line_arr = line.split(' ', 1)
     573            if len(line_arr) == 2:
     574                if line_arr[0] == CLI_PREFIX:
     575                    line_arr = line_arr[1].split(' ', 1)
     576                if line_arr[0] == "[OFFSPRING]":
     577                    nodes += 1
     578
     579        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
     580        self.parents = [{} for x in range(nodes)]
     581        self.children = [[] for x in range(nodes)]
     582        self.time = [0] * nodes
     583        self.kind = [0] * nodes
     584        self.props = {}
     585
     586        print(len(self.parents))
     587
     588        file.seek(0)
     589        loaded_so_far = 0
     590        lasttime = timelib.time()
     591        for line in file:
     592            line_arr = line.split(' ', 1)
     593            if len(line_arr) == 2:
     594                if line_arr[0] == CLI_PREFIX:
     595                    line_arr = line_arr[1].split(' ', 1)
     596                if line_arr[0] == "[OFFSPRING]":
     597                    creature = json.loads(line_arr[1])
     598                    if "FromIDs" in creature:
     599
     600                        # make sure that ID's of parents are lower than that of their children
     601                        for i in range(0, len(creature["FromIDs"])):
     602                            get_id(creature["FromIDs"][i])
     603
     604                        creature_id = get_id(creature["ID"])
     605
     606                        # debug
     607                        if loaded_so_far%1000 == 0:
     608                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
     609                            lasttime = timelib.time()
     610
    114611                        # we assign to each parent its contribution to the genotype of the child
    115612                        for i in range(0, len(creature["FromIDs"])):
     613                            parent_id = get_id(creature["FromIDs"][i])
    116614                            inherited = 1 #(creature["Inherited"][i] if 'Inherited' in creature else 1) #ONLY FOR NOW
    117                             nodes[creature["ID"]][creature["FromIDs"][i]] = inherited
     615                            self.parents[creature_id][parent_id] = inherited
     616
     617                        if "Time" in creature:
     618                            self.time[creature_id] = creature["Time"]
     619
     620                        if "Kind" in creature:
     621                            self.kind[creature_id] = creature["Kind"]
     622
     623                        for prop in creature:
     624                            if prop not in default_props:
     625                                if prop not in self.props:
     626                                    self.props[prop] = [None for i in range(nodes)]
     627                                self.props[prop][creature_id] = creature[prop]
     628
     629                        loaded_so_far += 1
    118630                    else:
    119                         print("Duplicated entry for " + creature["ID"])
    120                         quit()
    121 
    122                     if not creature["FromIDs"][0] in nodes and firstnode == None:
    123                         firstnode = creature["FromIDs"][0]
    124 
    125                 if "Time" in creature:
    126                     time[creature["ID"]] = creature["Time"]
    127 
    128                 if "Kind" in creature:
    129                     kind[creature["ID"]] = creature["Kind"]
    130 
    131                 loaded += 1
    132         if loaded == max_nodes and max_nodes != 0:
    133             break
    134 
    135     for k, v in sorted(nodes.items()):
    136         for val in sorted(v):
    137             inv_nodes[val] = inv_nodes.get(val, [])
    138             inv_nodes[val].append(k)
    139 
    140     print(len(nodes))
    141 
    142 
    143 def load_simple_data(dir):
    144     global firstnode, nodes, inv_nodes
    145     f = open(dir)
    146     loaded = 0
    147 
    148     for line in f:
    149         sline = line.split()
    150         if len(sline) > 1:
    151             #if int(sline[0]) > 15000:
    152             #    break
    153             if sline[0] == firstnode:
    154                 continue
    155             nodes[sline[0]] = str(max(int(sline[1]), int(firstnode)))
    156         else:
    157             firstnode = sline[0]
    158 
    159         loaded += 1
    160         if loaded == max_nodes and max_nodes != 0:
    161             break
    162 
    163     for k, v in sorted(nodes.items()):
    164         inv_nodes[v] = inv_nodes.get(v, [])
    165         inv_nodes[v].append(k)
    166 
    167     #print(str(inv_nodes))
    168     #quit()
    169 
    170 def compute_depth(node):
    171     my_depth = 0
    172     if node in inv_nodes:
    173         for c in inv_nodes[node]:
    174             my_depth = max(my_depth, compute_depth(c)+1)
    175     depth[node] = my_depth
    176     return my_depth
    177 
    178 # ------------------------------------
    179 
    180 
    181 def xmin_crowd_random(x1, x2, y):
    182     return (x1 if random.randrange(2) == 0 else x2)
    183 
    184 def xmin_crowd_min(x1, x2, y):
    185     x1_closest = 999999
    186     x2_closest = 999999
    187     for pos in positions:
    188         pos = positions[pos]
    189         if pos[1] == y:
    190             x1_closest = min(x1_closest, abs(x1-pos[0]))
    191             x2_closest = min(x2_closest, abs(x2-pos[0]))
    192     return (x1 if x1_closest > x2_closest else x2)
    193 def xmin_crowd_density(x1, x2, y):
    194     x1_dist = 0
    195     x2_dist = 0
    196     ymin = y-10
    197     ymax = y+10
    198     for pos in positions:
    199         pos = positions[pos]
    200         if pos[1] > ymin or pos[1] < ymax:
    201             dysq = (pos[1]-y)**2
    202             dx1 = pos[0]-x1
    203             dx2 = pos[0]-x2
    204 
    205 
    206             x1_dist += math.sqrt(dysq + dx1**2)
    207             x2_dist += math.sqrt(dysq + dx2**2)
    208     return (x1 if x1_dist > x2_dist else x2)
    209 
    210 # ------------------------------------
    211 
    212 def prepos_children():
    213     global max_height, max_width, min_width, visited, TIME
    214 
    215     print("firstnode " + firstnode)
    216 
    217     if not bool(time):
    218         print("REAL time requested, but no real time data provided. Assuming BIRTHS time instead.")
    219         TIME = "BIRTHS"
    220 
    221     positions[firstnode] = [0, 0]
    222 
    223     xmin_crowd = None
    224     if BALANCE == "RANDOM":
    225         xmin_crowd =xmin_crowd_random
    226     elif BALANCE == "MIN":
    227         xmin_crowd = xmin_crowd_min
    228     elif BALANCE == "DENSITY":
    229         xmin_crowd = xmin_crowd_density
    230     else:
    231         raise ValueError("Error, the value of BALANCE does not match any expected value.")
    232 
    233     nodes_to_visit = [firstnode]
    234 
    235     node_counter = 0
    236     start_time = ttime.time()
    237 
    238     while True:
    239 
    240         node_counter += 1
    241         if node_counter%1000 == 0 :
    242             print(str(node_counter) + " "  + str(ttime.time()-start_time))
    243             start_time = ttime.time()
    244 
    245         current_node = nodes_to_visit[0]
    246 
    247         if current_node in inv_nodes:
    248             for c in inv_nodes[current_node]:
    249                 # we want to visit the node just once, after all of its parents
    250                 if c not in nodes_to_visit:
    251                     nodes_to_visit.append(c)
    252 
    253                     cy = 0
    254                     if TIME == "BIRTHS":
    255                         if c[0] == "c":
    256                             cy = int(c[1:])
    257                         else:
    258                             cy = int(c)
    259                     elif TIME == "GENERATIONAL":
    260                         cy = positions[current_node][1]+1
    261                     elif TIME == "REAL":
    262                         cy = time[c]
    263 
    264                     if len(nodes[c]) == 1:
    265                         dissimilarity = 0
    266                         if JITTER == True:
    267                             dissimilarity = random.gauss(0,1)
    268                         else:
    269                             dissimilarity = 1
    270                         positions[c] = [xmin_crowd(positions[current_node][0]-dissimilarity, positions[current_node][0]+dissimilarity, cy), cy]
    271                     else:
    272                         vsum = sum([v for k, v in nodes[c].items()])
    273                         cx = sum([positions[k][0]*v/vsum for k, v in nodes[c].items()])
    274 
    275                         if JITTER == True:
    276                             positions[c] = [cx + random.gauss(0, 0.1), cy]
    277                         else:
    278                             positions[c] = [cx, cy]
    279 
    280         nodes_to_visit = nodes_to_visit[1:]
    281         # if none left, we can stop
    282         if len(nodes_to_visit) == 0:
    283             break
    284 
    285 
    286    # prepos_children_reccurent(firstnode)
    287 
    288     for pos in positions:
    289         max_height = max(max_height, positions[pos][1])
    290         max_width = max(max_width, positions[pos][0])
    291         min_width = min(min_width, positions[pos][0])
    292 
    293 # ------------------------------------
    294 
    295 def all_parents_visited(node):
    296     apv = True
    297     for k, v in sorted(nodes[node].items()):
    298         if not k in visited:
    299             apv = False
    300             break
    301     return apv
    302 # ------------------------------------
    303 
    304 def draw_children():
    305     max_depth = 0
    306     for k, v in depth.items():
    307             max_depth = max(max_depth, v)
    308 
    309     nodes_to_visit = [firstnode]
    310     while True:
    311         current_node = nodes_to_visit[0]
    312 
    313         if current_node in inv_nodes:
    314             for c in inv_nodes[current_node]: # inv_node => p->c
    315 
    316                 if not c in nodes_to_visit:
    317                     nodes_to_visit.append(c)
    318 
    319                 line_style = ""
    320                 if COLORING == "NONE":
    321                     line_style = svg_line_style
    322                 elif COLORING == "TYPE":
    323                     line_style = (svg_mutation_line_style if len(nodes[c]) == 1 else svg_crossover_line_style)
    324                 else: # IMPORTANCE, default
    325                     line_style = svg_generate_line_style(depth[c]/max_depth)
    326 
    327                 svg_add_line( (w_margin+w_no_margs*(positions[current_node][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[current_node][1]/max_height),
    328                         (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), line_style)
    329 
    330         # we want to draw the node just once
    331         if DOT_STYLE == "NONE":
    332             continue
    333         elif DOT_STYLE == "TYPE":
    334             dot_style = svg_generate_dot_style(kind[current_node] if current_node in kind else 0) #type
    335         else: # NORMAL, default
    336             dot_style = svg_clear_dot_style #svg_generate_dot_style(depth[c]/max_depth)
    337         svg_add_dot( (w_margin+w_no_margs*(positions[current_node][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[current_node][1]/max_height), dot_style)
    338         #svg_add_text( str(depth[current_node]), (w_margin+w_no_margs*(positions[current_node][0]-min_width)/(max_width-min_width),
    339         # h_margin+h_no_margs*positions[current_node][1]/max_height), "end")
    340 
    341         # we remove the current node from the list
    342         nodes_to_visit = nodes_to_visit[1:]
    343         # if none left, we can stop
    344         if len(nodes_to_visit) == 0:
    345             break
    346 
    347 def draw_spine():
    348     nodes_to_visit = [firstnode]
    349     while True:
    350         current_node = nodes_to_visit[0]
    351 
    352         if current_node in inv_nodes:
    353             for c in inv_nodes[current_node]: # inv_node => p->c
    354                 if depth[c] == depth[current_node] - 1:
    355                     if not c in nodes_to_visit:
    356                         nodes_to_visit.append(c)
    357                     line_style = svg_spine_line_style
    358                     svg_add_line( (w_margin+w_no_margs*(positions[current_node][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[current_node][1]/max_height),
    359                         (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), line_style)
    360 
    361         # we remove the current node from the list
    362         nodes_to_visit = nodes_to_visit[1:]
    363         # if none left, we can stop
    364         if len(nodes_to_visit) == 0:
    365             break
    366 
    367 def draw_skeleton():
    368     nodes_to_visit = [firstnode]
    369     while True:
    370         current_node = nodes_to_visit[0]
    371 
    372         if current_node in inv_nodes:
    373             for c in inv_nodes[current_node]: # inv_node => p->c
    374                 if depth[c] >= min_skeleton_depth:
    375                     if not c in nodes_to_visit:
    376                         nodes_to_visit.append(c)
    377                     line_style = svg_spine_line_style
    378                     svg_add_line( (w_margin+w_no_margs*(positions[current_node][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[current_node][1]/max_height),
    379                         (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), line_style)
    380 
    381         # we remove the current node from the list
    382         nodes_to_visit = nodes_to_visit[1:]
    383         # if none left, we can stop
    384         if len(nodes_to_visit) == 0:
    385             break
    386 
    387 # ------------------------------------
    388 
    389 def draw_scale(filename ,type):
    390 
    391     svg_add_text("Generated from " + filename.split("\\")[-1], (5, 15), "start")
    392 
    393     svg_add_line( (w*0.7, h_margin), (w, h_margin), svg_scale_line_style)
    394     start_text = ""
    395     if TIME == "BIRTHS":
    396        start_text = "Birth #" + str(min([int(k[1:]) for k, v in nodes.items()]))
    397     if TIME == "REAL":
    398        start_text = "Time " + str(min([v for k, v in time.items()]))
    399     if TIME == "GENERATIONAL":
    400        start_text = "Depth " + str(min([v for k, v in depth.items()]))
    401     svg_add_text( start_text, (w, h_margin + 15), "end")
    402 
    403     svg_add_line( (w*0.7, h-h_margin), (w, h-h_margin), svg_scale_line_style)
    404     end_text = ""
    405     if TIME == "BIRTHS":
    406        end_text = "Birth #" + str(max([int(k[1:]) for k, v in nodes.items()]))
    407     if TIME == "REAL":
    408        end_text = "Time " + str(max([v for k, v in time.items()]))
    409     if TIME == "GENERATIONAL":
    410        end_text = "Depth " + str(max([v for k, v in depth.items()]))
    411     svg_add_text( end_text, (w, h-h_margin + 15), "end")
    412 
    413 
    414 ##################################################### main #####################################################
    415 
    416 args = 0
    417 
    418 h = 800
    419 w = 600
    420 h_margin = 20
    421 w_margin = 10
    422 h_no_margs = h - 2* h_margin
    423 w_no_margs = w - 2* w_margin
    424 
    425 max_height = 0
    426 max_width = 0
    427 min_width = 9999999999
    428 
    429 min_skeleton_depth = 0
    430 max_nodes = 0
    431 
    432 firstnode = None
    433 nodes = {}
    434 inv_nodes = {}
    435 positions = {}
    436 visited= {}
     631                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
     632
     633            if loaded_so_far >= max_nodes and max_nodes != 0:
     634                break
     635
     636        for k in range(len(self.parents)):
     637            v = self.parents[k]
     638            for val in self.parents[k]:
     639                self.children[val].append(k)
     640
    437641depth = {}
    438 time = {}
    439642kind = {}
    440643
    441644def main():
    442     global svg_file, min_skeleton_depth, max_nodes, args, \
    443         TIME, BALANCE, DOT_STYLE, COLORING, JITTER, \
    444         svg_mutation_line_style, svg_crossover_line_style
    445 
    446     parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship information from a text file. Supports files generated by Framsticks experiments.')
     645
     646    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
     647                                                 'information from a text file. Supports files generated by Framsticks experiments.')
    447648    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
    448     parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG format)')
    449     draw_tree_parser = parser.add_mutually_exclusive_group(required=False)
    450     draw_tree_parser.add_argument('--draw-tree', dest='draw_tree', action='store_true', help='whether drawing the full tree should be skipped')
    451     draw_tree_parser.add_argument('--no-draw-tree', dest='draw_tree', action='store_false')
    452 
    453     draw_skeleton_parser = parser.add_mutually_exclusive_group(required=False)
    454     draw_skeleton_parser.add_argument('--draw-skeleton', dest='draw_skeleton', action='store_true', help='whether the skeleton of the tree should be drawn')
    455     draw_skeleton_parser.add_argument('--no-draw-skeleton', dest='draw_skeleton', action='store_false')
    456 
    457     draw_spine_parser = parser.add_mutually_exclusive_group(required=False)
    458     draw_spine_parser.add_argument('--draw-spine', dest='draw_spine', action='store_true', help='whether the spine of the tree should be drawn')
    459     draw_spine_parser.add_argument('--no-draw-spine', dest='draw_spine', action='store_false')
     649    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
     650    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
    460651
    461652    #TODO: better names for those parameters
     653    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
     654    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='heigt of the output image (800 by default)')
     655
    462656    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
    463657                                                                      'BIRTHS: time measured as the number of births since the beginning; '
     
    465659                                                                      'REAL: real time of the simulation')
    466660    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
    467     parser.add_argument('-s', '--scale', default='NONE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
    468     parser.add_argument('-c', '--coloring', default='IMPORTANCE', dest="coloring", help='method of coloring the tree (NONE/IMPORTANCE(d)/TYPE)')
    469     parser.add_argument('-d', '--dots', default='TYPE', dest='dots', help='method of drawing dots (individuals) (NONE/NORMAL/TYPE(d))')
     661    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
    470662    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
    471 
    472     parser.add_argument('--color-mut', default="#000000", dest="color_mut", help='color of clone/mutation lines in rgba (e.g. #FF60B240) for TYPE coloring')
    473     parser.add_argument('--color-cross', default="#660198", dest="color_cross", help='color of crossover lines in rgba (e.g. #FF60B240) for TYPE coloring')
    474 
    475     parser.add_argument('--min-skeleton-depth', type=int, default=2, dest='min_skeleton_depth', help='minimal distance from the leafs for the nodes in the skeleton')
     663    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
     664    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
    476665    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
    477 
    478     parser.add_argument('--simple-data', type=bool, dest='simple_data', help='input data are given in a simple format (#child #parent)')
    479 
    480 
    481     parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
    482666
    483667    parser.set_defaults(draw_tree=True)
     
    491675    TIME = args.time.upper()
    492676    BALANCE = args.balance.upper()
    493     DOT_STYLE = args.dots.upper()
    494     COLORING = args.coloring.upper()
    495677    SCALE = args.scale.upper()
    496678    JITTER = args.jitter
    497679    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
    498680        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
    499         or not DOT_STYLE in ['NONE', 'NORMAL', 'TYPE']\
    500         or not COLORING in ['NONE', 'IMPORTANCE', 'TYPE']\
    501681        or not SCALE in ['NONE', 'SIMPLE']:
    502682        print("Incorrect value of one of the parameters! Closing the program.") #TODO don't be lazy, figure out which parameter is wrong...
    503683        return
    504684
    505 
    506     svg_mutation_line_style += hex_to_style(args.color_mut)
    507     svg_crossover_line_style += hex_to_style(args.color_cross)
    508 
    509685    dir = args.input
    510     min_skeleton_depth = args.min_skeleton_depth
    511     max_nodes = args.max_nodes
    512686    seed = args.seed
    513687    if seed == -1:
     
    516690    print("seed:", seed)
    517691
    518     if args.simple_data:
    519         load_simple_data(dir)
     692    tree = TreeData()
     693    tree.load(dir, max_nodes=args.max_nodes)
     694
     695    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
     696    designer.calculate_measures()
     697    designer.calculate_node_positions(ignore_last=args.skip)
     698
     699    if args.output.endswith(".svg"):
     700        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
    520701    else:
    521         load_data(dir)
    522 
    523     compute_depth(firstnode)
    524 
    525     svg_file = open(args.output, "w")
    526     svg_file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
    527                    'width="' + str(w) + '" height="' + str(h) + '">')
    528 
    529     prepos_children()
    530 
    531     if args.draw_tree:
    532         draw_children()
    533     if args.draw_skeleton:
    534         draw_skeleton()
    535     if args.draw_spine:
    536         draw_spine()
    537 
    538     draw_scale(dir, SCALE)
    539 
    540     svg_file.write("</svg>")
    541     svg_file.close()
     702        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
     703    drawer.draw_design(args.output, args.input, scale=SCALE)
     704
    542705
    543706main()
Note: See TracChangeset for help on using the changeset viewer.