source: mds-and-trees/tree-genealogy.py @ 679

Last change on this file since 679 was 679, checked in by konrad, 7 years ago

compute_adepth no longer uses recursion (so it should work better overall with the same results)

File size: 32.8 KB
Line 
1import json
2import math
3import random
4import argparse
5import bisect
6import time as timelib
7from PIL import Image, ImageDraw, ImageFont
8from scipy import stats
9import numpy as np
10
11class LoadingError(Exception):
12    pass
13
14class Drawer:
15
16    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
17        self.design = design
18        self.width = w
19        self.height = h
20        self.w_margin = w_margin
21        self.h_margin = h_margin
22        self.w_no_margs = w - 2* w_margin
23        self.h_no_margs = h - 2* h_margin
24
25        self.colors = {
26            'black' :   {'r':0,     'g':0,      'b':0},
27            'red' :     {'r':100,   'g':0,      'b':0},
28            'green' :   {'r':0,     'g':100,    'b':0},
29            'blue' :    {'r':0,     'g':0,      'b':100},
30            'yellow' :  {'r':100,   'g':100,    'b':0},
31            'magenta' : {'r':100,   'g':0,      'b':100},
32            'cyan' :    {'r':0,     'g':100,    'b':100},
33            'orange':   {'r':100,   'g':50,     'b':0},
34            'purple':   {'r':50,    'g':0,      'b':100}
35        }
36
37        self.settings = {
38            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
39            'dots': {
40                'color': {
41                    'meaning': 'Lifespan',
42                    'start': 'red',
43                    'end': 'green',
44                    'bias': 1
45                    },
46                'size': {
47                    'meaning': 'EnergyEaten',
48                    'start': 1,
49                    'end': 6,
50                    'bias': 0.5
51                    },
52                'opacity': {
53                    'meaning': 'EnergyEaten',
54                    'start': 0.2,
55                    'end': 1,
56                    'bias': 1
57                    }
58            },
59            'lines': {
60                'color': {
61                    'meaning': 'adepth',
62                    'start': 'black',
63                    'end': 'red',
64                    'bias': 3
65                    },
66                'width': {
67                    'meaning': 'adepth',
68                    'start': 0.1,
69                    'end': 4,
70                    'bias': 3
71                    },
72                'opacity': {
73                    'meaning': 'adepth',
74                    'start': 0.1,
75                    'end': 0.8,
76                    'bias': 5
77                    }
78            }
79        }
80
81        def merge(source, destination):
82            for key, value in source.items():
83                if isinstance(value, dict):
84                    node = destination.setdefault(key, {})
85                    merge(value, node)
86                else:
87                    destination[key] = value
88
89            return destination
90
91        if config_file != "":
92            with open(config_file) as config:
93                c = json.load(config)
94            self.settings = merge(c, self.settings)
95            #print(json.dumps(self.settings, indent=4, sort_keys=True))
96
97    def draw_dots(self, file, min_width, max_width, max_height):
98        for i in range(len(self.design.positions)):
99            node = self.design.positions[i]
100            if 'x' not in node:
101                continue
102            dot_style = self.compute_dot_style(node=i)
103            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
104                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
105
106    def draw_lines(self, file, min_width, max_width, max_height):
107        for parent in range(len(self.design.positions)):
108            par_pos = self.design.positions[parent]
109            if not 'x' in par_pos:
110                continue
111            for child in self.design.tree.children[parent]:
112                chi_pos = self.design.positions[child]
113                if 'x' not in chi_pos:
114                    continue
115                line_style = self.compute_line_style(parent, child)
116                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
117                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
118                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
119                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
120
121    def draw_scale(self, file, filename):
122        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 5), "start")
123
124        start_text = ""
125        end_text = ""
126        if self.design.TIME == "BIRTHS":
127           start_text = "Birth #0"
128           end_text = "Birth #" + str(len(self.design.positions)-1)
129        if self.design.TIME == "REAL":
130           start_text = "Time " + str(min(self.design.tree.time))
131           end_text = "Time " + str(max(self.design.tree.time))
132        if self.design.TIME == "GENERATIONAL":
133           start_text = "Depth " + str(self.design.props['adepth_min'])
134           end_text = "Depth " + str(self.design.props['adepth_max'])
135
136        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
137        self.add_text(file, start_text, (self.width, self.h_margin), "end")
138        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
139        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")
140
141    def compute_property(self, part, prop, node):
142        start = self.settings[part][prop]['start']
143        end = self.settings[part][prop]['end']
144        value = (self.design.props[self.settings[part][prop]['meaning']][node]
145                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
146        bias = self.settings[part][prop]['bias']
147        if prop == "color":
148            return self.compute_color(start, end, value, bias)
149        else:
150            return self.compute_value(start, end, value, bias)
151
152    def compute_color(self, start, end, value, bias=1):
153        if isinstance(value, str):
154            value = int(value)
155            r = self.colors[self.settings['colors_of_kinds'][value]]['r']
156            g = self.colors[self.settings['colors_of_kinds'][value]]['g']
157            b = self.colors[self.settings['colors_of_kinds'][value]]['b']
158        else:
159            start_color = self.colors[start]
160            end_color = self.colors[end]
161            value = 1 - (1-value)**bias
162            r = start_color['r']*(1-value)+end_color['r']*value
163            g = start_color['g']*(1-value)+end_color['g']*value
164            b = start_color['b']*(1-value)+end_color['b']*value
165        return (r, g, b)
166
167    def compute_value(self, start, end, value, bias=1):
168        value = 1 - (1-value)**bias
169        return start*(1-value) + end*value
170
171class PngDrawer(Drawer):
172
173    def scale_up(self):
174        self.width *= self.multi
175        self.height *= self.multi
176        self.w_margin *= self.multi
177        self.h_margin *= self.multi
178        self.h_no_margs *= self.multi
179        self.w_no_margs *= self.multi
180
181    def scale_down(self):
182        self.width /= self.multi
183        self.height /= self.multi
184        self.w_margin /= self.multi
185        self.h_margin /= self.multi
186        self.h_no_margs /= self.multi
187        self.w_no_margs /= self.multi
188
189    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
190        print("Drawing...")
191
192        self.multi=multi
193        self.scale_up()
194
195        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
196
197        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
198        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
199        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
200
201        self.draw_lines(back, min_width, max_width, max_height)
202        self.draw_dots(back, min_width, max_width, max_height)
203
204        if scale == "SIMPLE":
205            self.draw_scale(back, input_filename)
206
207        #back.show()
208        self.scale_down()
209
210        back.thumbnail((self.width, self.height), Image.ANTIALIAS)
211
212        back.save(filename)
213
214    def add_dot(self, file, pos, style):
215        x, y = int(pos[0]), int(pos[1])
216        r = style['r']*self.multi
217        offset = (int(x - r), int(y - r))
218        size = (2*int(r), 2*int(r))
219
220        c = style['color']
221
222        img = Image.new('RGBA', size)
223        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
224                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
225        file.paste(img, offset, mask=img)
226
227    def add_line(self, file, from_pos, to_pos, style):
228        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
229        w = int(style['width'])*self.multi
230
231        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
232        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
233
234        c = style['color']
235
236        img = Image.new('RGBA', size)
237        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
238                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
239        file.paste(img, offset, mask=img)
240
241    def add_dashed_line(self, file, from_pos, to_pos):
242        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
243        sublines = 50
244        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
245        for i in range(sublines):
246            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/(2*sublines-1), 1),
247                            self.compute_value(from_pos[1], to_pos[1], 2*i/(2*sublines-1), 1))
248            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/(2*sublines-1), 1),
249                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/(2*sublines-1), 1))
250            self.add_line(file, from_pos_sub, to_pos_sub, style)
251
252    def add_text(self, file, text, pos, anchor, style=''):
253        font = ImageFont.truetype("Vera.ttf", 16*self.multi)
254
255        img = Image.new('RGBA', (self.width, self.height))
256        draw = ImageDraw.Draw(img)
257        txtsize = draw.textsize(text, font=font)
258        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
259        draw.text(pos, text, (0,0,0), font=font)
260        file.paste(img, (0,0), mask=img)
261
262    def compute_line_style(self, parent, child):
263        return {'color': self.compute_property('lines', 'color', child),
264                'width': self.compute_property('lines', 'width', child),
265                'opacity': self.compute_property('lines', 'opacity', child)}
266
267    def compute_dot_style(self, node):
268        return {'color': self.compute_property('dots', 'color', node),
269                'r': self.compute_property('dots', 'size', node),
270                'opacity': self.compute_property('dots', 'opacity', node)}
271
272class SvgDrawer(Drawer):
273    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
274        print("Drawing...")
275        file = open(filename, "w")
276
277        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
278        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
279        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
280
281        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
282                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
283                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
284
285        self.draw_lines(file, min_width, max_width, max_height)
286        self.draw_dots(file, min_width, max_width, max_height)
287
288        if scale == "SIMPLE":
289            self.draw_scale(file, input_filename)
290
291        file.write("</svg>")
292        file.close()
293
294    def add_text(self, file, text, pos, anchor, style=''):
295        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
296        # assuming font size 12, it should be taken from the style string!
297        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')
298
299    def add_dot(self, file, pos, style):
300        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
301
302    def add_line(self, file, from_pos, to_pos, style):
303        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
304                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
305
306    def add_dashed_line(self, file, from_pos, to_pos):
307        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
308        self.add_line(file, from_pos, to_pos, style)
309
310    def compute_line_style(self, parent, child):
311        return self.compute_stroke_color('lines', child) + ' ' \
312               + self.compute_stroke_width('lines', child) + ' ' \
313               + self.compute_stroke_opacity(child)
314
315    def compute_dot_style(self, node):
316        return self.compute_dot_size(node) + ' ' \
317               + self.compute_fill_opacity(node) + ' ' \
318               + self.compute_dot_fill(node)
319
320    def compute_stroke_color(self, part, node):
321        color = self.compute_property(part, 'color', node)
322        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
323
324    def compute_stroke_width(self, part, node):
325        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
326
327    def compute_stroke_opacity(self, node):
328        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
329
330    def compute_fill_opacity(self, node):
331        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
332
333    def compute_dot_size(self, node):
334        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
335
336    def compute_dot_fill(self, node):
337        color = self.compute_property('dots', 'color', node)
338        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
339
340class Designer:
341
342    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
343        self.props = {}
344
345        self.tree = tree
346
347        self.TIME = time
348        self.JITTER = jitter
349
350        if balance == "RANDOM":
351            self.xmin_crowd = self.xmin_crowd_random
352        elif balance == "MIN":
353            self.xmin_crowd = self.xmin_crowd_min
354        elif balance == "DENSITY":
355            self.xmin_crowd = self.xmin_crowd_density
356        else:
357            raise ValueError("Error, the value of BALANCE does not match any expected value.")
358
359    def calculate_measures(self):
360        print("Calculating measures...")
361        self.compute_depth()
362        self.compute_adepth()
363        self.compute_children()
364        self.compute_kind()
365        self.compute_time()
366        self.compute_progress()
367        self.compute_custom()
368
369    def xmin_crowd_random(self, x1, x2, y):
370        return (x1 if random.randrange(2) == 0 else x2)
371
372    def xmin_crowd_min(self, x1, x2, y):
373        x1_closest = 999999
374        x2_closest = 999999
375        miny = y-3
376        maxy = y+3
377        i = bisect.bisect_left(self.y_sorted, miny)
378        while True:
379            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
380                break
381            pos = self.positions_sorted[i]
382
383            x1_closest = min(x1_closest, abs(x1-pos['x']))
384            x2_closest = min(x2_closest, abs(x2-pos['x']))
385
386            i += 1
387        return (x1 if x1_closest > x2_closest else x2)
388
389    def xmin_crowd_density(self, x1, x2, y):
390        x1_dist = 0
391        x2_dist = 0
392        miny = y-20
393        maxy = y+20
394        i_left = bisect.bisect_left(self.y_sorted, miny)
395        i_right = bisect.bisect_right(self.y_sorted, maxy)
396        # print("i " + str(i) + " len " + str(len(self.positions)))
397        #
398        # i = bisect.bisect_left(self.y_sorted, y)
399        # i_left = max(0, i - 25)
400        # i_right = min(len(self.y_sorted), i + 25)
401
402        def include_pos(pos):
403            nonlocal x1_dist, x2_dist
404
405            dysq = (pos['y']-y)**2
406            dx1 = pos['x']-x1
407            dx2 = pos['x']-x2
408
409            x1_dist += math.sqrt(dysq + dx1**2)
410            x2_dist += math.sqrt(dysq + dx2**2)
411
412        # optimized to draw from all the nodes, if less than 10 nodes in the range
413        if len(self.positions_sorted) > i_left:
414            if i_right - i_left < 10:
415                for j in range(i_left, i_right):
416                    include_pos(self.positions_sorted[j])
417            else:
418                for j in range(10):
419                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
420                    include_pos(pos)
421
422        return (x1 if x1_dist > x2_dist else x2)
423        #print(x1_dist, x2_dist)
424        #x1_dist = x1_dist**2
425        #x2_dist = x2_dist**2
426        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
427        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
428
429    def calculate_node_positions(self, ignore_last=0):
430        print("Calculating positions...")
431
432        def add_node(node):
433            index = bisect.bisect_left(self.y_sorted, node['y'])
434            self.y_sorted.insert(index, node['y'])
435            self.positions_sorted.insert(index, node)
436            self.positions[node['id']] = node
437
438        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
439        self.y_sorted = [0]
440        self.positions = [{} for x in range(len(self.tree.parents))]
441        self.positions[0] = {'x':0, 'y':0, 'id':0}
442
443        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
444        visiting_order = [i for i in range(0, len(self.tree.parents))]
445        visiting_order = sorted(visiting_order, key=lambda q:
446                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))
447
448        node_counter = 0
449        start_time = timelib.time()
450
451        # for each child of the current node
452        for child in visiting_order:
453            node_counter += 1
454            # debug info - elapsed time
455            if node_counter%1000 == 0:
456               print(str(node_counter) + " " + str(timelib.time()-start_time))
457               start_time = timelib.time()
458
459            # using normalized adepth
460            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
461
462                ypos = 0
463                if self.TIME == "BIRTHS":
464                    ypos = child
465                elif self.TIME == "GENERATIONAL":
466                    # one more than its parent (what if more than one parent?)
467                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1
468                elif self.TIME == "REAL":
469                    ypos = self.tree.time[child]
470
471                if len(self.tree.parents[child]) == 1:
472                # if current_node is the only parent
473                    parent = [par for par, v in self.tree.parents[child].items()][0]
474
475                    if self.JITTER:
476                        dissimilarity = random.gauss(0, 0.5)
477                    else:
478                        dissimilarity = 1
479                    add_node({'id':child, 'y':ypos, 'x':
480                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
481                              self.positions[parent]['x']+dissimilarity, ypos)})
482                else:
483                    # position weighted by the degree of inheritence from each parent
484                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
485                    xpos = sum([self.positions[k]['x']*v/total_inheretance
486                               for k, v in self.tree.parents[child].items()])
487                    if self.JITTER:
488                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
489                    else:
490                        add_node({'id':child, 'y':ypos, 'x':xpos})
491
492
493    def compute_custom(self):
494        for prop in self.tree.props:
495            self.props[prop] = [None for x in range(len(self.tree.children))]
496
497            for i in range(len(self.props[prop])):
498                self.props[prop][i] = self.tree.props[prop][i]
499
500            self.normalize_prop(prop)
501
502    def compute_time(self):
503        # simple rewrite from the tree
504        self.props["time"] = [0 for x in range(len(self.tree.children))]
505
506        for i in range(len(self.props['time'])):
507            self.props['time'][i] = self.tree.time[i]
508
509        self.normalize_prop('time')
510
511    def compute_kind(self):
512        # simple rewrite from the tree
513        self.props["kind"] = [0 for x in range(len(self.tree.children))]
514
515        for i in range (len(self.props['kind'])):
516            self.props['kind'][i] = str(self.tree.kind[i])
517
518    def compute_depth(self):
519        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
520
521        nodes_to_visit = [0]
522        self.props["depth"][0] = 0
523        while True:
524            for child in self.tree.children[nodes_to_visit[0]]:
525                nodes_to_visit.append(child)
526                self.props["depth"][child] = min([self.props["depth"][d] for d in self.tree.parents[child]])+1
527            nodes_to_visit = nodes_to_visit[1:]
528            if len(nodes_to_visit) == 0:
529                break
530
531        self.normalize_prop('depth')
532
533    def compute_adepth(self):
534        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
535
536        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
537        visiting_order = [i for i in range(0, len(self.tree.parents))]
538        visiting_order = sorted(visiting_order, key=lambda q:
539                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))[::-1]
540
541        for node in visiting_order:
542            children = self.tree.children[node]
543            if len(children) != 0:
544                # 0 by default
545                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
546        self.normalize_prop('adepth')
547
548    def compute_children(self):
549        self.props["children"] = [0 for x in range(len(self.tree.children))]
550        for i in range (len(self.props['children'])):
551            self.props['children'][i] = len(self.tree.children[i])
552
553        self.normalize_prop('children')
554
555    def compute_progress(self):
556        self.props["progress"] = [0 for x in range(len(self.tree.children))]
557        for i in range(len(self.props['children'])):
558            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
559            if len(times) > 4:
560                times = [times[i+1] - times[i] for i in range(len(times)-1)]
561                #print(times)
562                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
563                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0
564
565        for i in range(0, 5):
566            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
567            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0
568
569        mini = min(self.props['progress'])
570        maxi = max(self.props['progress'])
571        for k in range(len(self.props['progress'])):
572            if self.props['progress'][k] == 0:
573                self.props['progress'][k] = mini
574
575        #for k in range(len(self.props['progress'])):
576        #        self.props['progress'][k] = 1-self.props['progress'][k]
577
578        self.normalize_prop('progress')
579
580    def normalize_prop(self, prop):
581        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
582        if len(noneless) > 0:
583            max_val = max(noneless)
584            min_val = min(noneless)
585            print(prop, max_val, min_val)
586            self.props[prop +'_max'] = max_val
587            self.props[prop +'_min'] = min_val
588            for i in range(len(self.props[prop])):
589                if self.props[prop][i] is not None:
590                    qqq = self.props[prop][i]
591                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)
592
593class TreeData:
594    simple_data = None
595
596    children = []
597    parents = []
598    time = []
599    kind = []
600
601    def __init__(self): #, simple_data=False):
602        #self.simple_data = simple_data
603        pass
604
605    def load(self, filename, max_nodes=0):
606        print("Loading...")
607
608        CLI_PREFIX = "Script.Message:"
609        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
610
611        self.ids = {}
612        def get_id(id, createOnError = True):
613            if createOnError:
614                if id not in self.ids:
615                    self.ids[id] = len(self.ids)
616            else:
617                if id not in self.ids:
618                    return None
619            return self.ids[id]
620
621        file = open(filename)
622
623        # counting the number of expected nodes
624        nodes = 0
625        for line in file:
626            line_arr = line.split(' ', 1)
627            if len(line_arr) == 2:
628                if line_arr[0] == CLI_PREFIX:
629                    line_arr = line_arr[1].split(' ', 1)
630                if line_arr[0] == "[OFFSPRING]":
631                    nodes += 1
632
633        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
634        self.parents = [{} for x in range(nodes)]
635        self.children = [[] for x in range(nodes)]
636        self.time = [0] * nodes
637        self.kind = [0] * nodes
638        self.life_lenght = [0] * nodes
639        self.props = {}
640
641        print(len(self.parents))
642
643        file.seek(0)
644        loaded_so_far = 0
645        lasttime = timelib.time()
646        for line in file:
647            line_arr = line.split(' ', 1)
648            if len(line_arr) == 2:
649                if line_arr[0] == CLI_PREFIX:
650                    line_arr = line_arr[1].split(' ', 1)
651                if line_arr[0] == "[OFFSPRING]":
652                    creature = json.loads(line_arr[1])
653                    if "FromIDs" in creature:
654
655                        # make sure that ID's of parents are lower than that of their children
656                        for i in range(0, len(creature["FromIDs"])):
657                            if creature["FromIDs"][i] not in self.ids:
658                                get_id("virtual_parent")
659
660                        creature_id = get_id(creature["ID"])
661
662                        # debug
663                        if loaded_so_far%1000 == 0:
664                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
665                            lasttime = timelib.time()
666
667                        # we assign to each parent its contribution to the genotype of the child
668                        for i in range(0, len(creature["FromIDs"])):
669                            if creature["FromIDs"][i] in self.ids:
670                                parent_id = get_id(creature["FromIDs"][i])
671                            else:
672                                parent_id = get_id("virtual_parent")
673                            inherited = 1 #(creature["Inherited"][i] if 'Inherited' in creature else 1) #ONLY FOR NOW
674                            self.parents[creature_id][parent_id] = inherited
675
676                        if "Time" in creature:
677                            self.time[creature_id] = creature["Time"]
678
679                        if "Kind" in creature:
680                            self.kind[creature_id] = creature["Kind"]
681
682                        for prop in creature:
683                            if prop not in default_props:
684                                if prop not in self.props:
685                                    self.props[prop] = [0 for i in range(nodes)]
686                                self.props[prop][creature_id] = creature[prop]
687
688                        loaded_so_far += 1
689                    else:
690                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
691                if line_arr[0] == "[DIED]":
692                    creature = json.loads(line_arr[1])
693                    creature_id = get_id(creature["ID"], False)
694                    if creature_id is not None:
695                        for prop in creature:
696                            if prop not in default_props:
697                                if prop not in self.props:
698                                    self.props[prop] = [0 for i in range(nodes)]
699                                self.props[prop][creature_id] = creature[prop]
700
701
702            if loaded_so_far >= max_nodes and max_nodes != 0:
703                break
704
705        for k in range(len(self.parents)):
706            v = self.parents[k]
707            for val in self.parents[k]:
708                self.children[val].append(k)
709
710depth = {}
711kind = {}
712
713def main():
714
715    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
716                                                 'information from a text file. Supports files generated by Framsticks experiments.')
717    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
718    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
719    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
720
721    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
722    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
723    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')
724
725    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
726                                                                      'BIRTHS: time measured as the number of births since the beginning; '
727                                                                      'GENERATIONAL: time measured as number of ancestors; '
728                                                                      'REAL: real time of the simulation')
729    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
730    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
731    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
732    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
733    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
734    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
735
736    parser.set_defaults(draw_tree=True)
737    parser.set_defaults(draw_skeleton=False)
738    parser.set_defaults(draw_spine=False)
739
740    parser.set_defaults(seed=-1)
741
742    args = parser.parse_args()
743
744    TIME = args.time.upper()
745    BALANCE = args.balance.upper()
746    SCALE = args.scale.upper()
747    JITTER = args.jitter
748    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
749        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
750        or not SCALE in ['NONE', 'SIMPLE']:
751        print("Incorrect value of one of the parameters! Closing the program.") #TODO don't be lazy, figure out which parameter is wrong...
752        return
753
754    dir = args.input
755    seed = args.seed
756    if seed == -1:
757        seed = random.randint(0, 10000)
758    random.seed(seed)
759    print("seed:", seed)
760
761    tree = TreeData()
762    tree.load(dir, max_nodes=args.max_nodes)
763
764    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
765    designer.calculate_measures()
766    designer.calculate_node_positions(ignore_last=args.skip)
767
768    if args.output.endswith(".svg"):
769        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
770    else:
771        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
772    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)
773
774
775main()
Note: See TracBrowser for help on using the repository browser.