source: mds-and-trees/tree-genealogy.py @ 633

Last change on this file since 633 was 633, checked in by konrad, 7 years ago

Added measure of progress (based on regression of time-distances between next children), some minor bug fixes

File size: 32.4 KB
Line 
1import json
2import math
3import random
4import argparse
5import bisect
6import time as timelib
7from PIL import Image, ImageDraw, ImageFont
8from scipy import stats
9import numpy as np
10
11class LoadingError(Exception):
12    pass
13
14class Drawer:
15
16    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
17        self.design = design
18        self.width = w
19        self.height = h
20        self.w_margin = w_margin
21        self.h_margin = h_margin
22        self.w_no_margs = w - 2* w_margin
23        self.h_no_margs = h - 2* h_margin
24
25        self.colors = {
26            'black' :   {'r':0,     'g':0,      'b':0},
27            'red' :     {'r':100,   'g':0,      'b':0},
28            'green' :   {'r':0,     'g':100,    'b':0},
29            'blue' :    {'r':0,     'g':0,      'b':100},
30            'yellow' :  {'r':100,   'g':100,    'b':0},
31            'magenta' : {'r':100,   'g':0,      'b':100},
32            'cyan' :    {'r':0,     'g':100,    'b':100},
33            'orange':   {'r':100,   'g':50,     'b':0},
34            'purple':   {'r':50,    'g':0,      'b':100}
35        }
36
37        self.settings = {
38            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
39            'dots': {
40                'color': {
41                    'meaning': 'Lifespan',
42                    'start': 'red',
43                    'end': 'green',
44                    'bias': 1
45                    },
46                'size': {
47                    'meaning': 'EnergyEaten',
48                    'start': 1,
49                    'end': 6,
50                    'bias': 0.5
51                    },
52                'opacity': {
53                    'meaning': 'EnergyEaten',
54                    'start': 0.2,
55                    'end': 1,
56                    'bias': 1
57                    }
58            },
59            'lines': {
60                'color': {
61                    'meaning': 'adepth',
62                    'start': 'black',
63                    'end': 'red',
64                    'bias': 3
65                    },
66                'width': {
67                    'meaning': 'adepth',
68                    'start': 0.1,
69                    'end': 4,
70                    'bias': 3
71                    },
72                'opacity': {
73                    'meaning': 'adepth',
74                    'start': 0.1,
75                    'end': 0.8,
76                    'bias': 5
77                    }
78            }
79        }
80
81        def merge(source, destination):
82            for key, value in source.items():
83                if isinstance(value, dict):
84                    node = destination.setdefault(key, {})
85                    merge(value, node)
86                else:
87                    destination[key] = value
88
89            return destination
90
91        if config_file != "":
92            with open(config_file) as config:
93                c = json.load(config)
94            self.settings = merge(c, self.settings)
95            #print(json.dumps(self.settings, indent=4, sort_keys=True))
96
97    def draw_dots(self, file, min_width, max_width, max_height):
98        for i in range(len(self.design.positions)):
99            node = self.design.positions[i]
100            if 'x' not in node:
101                continue
102            dot_style = self.compute_dot_style(node=i)
103            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
104                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
105
106    def draw_lines(self, file, min_width, max_width, max_height):
107        for parent in range(len(self.design.positions)):
108            par_pos = self.design.positions[parent]
109            if not 'x' in par_pos:
110                continue
111            for child in self.design.tree.children[parent]:
112                chi_pos = self.design.positions[child]
113                if 'x' not in chi_pos:
114                    continue
115                line_style = self.compute_line_style(parent, child)
116                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
117                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
118                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
119                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
120
121    def draw_scale(self, file, filename):
122        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 5), "start")
123
124        start_text = ""
125        end_text = ""
126        if self.design.TIME == "BIRTHS":
127           start_text = "Birth #0"
128           end_text = "Birth #" + str(len(self.design.positions)-1)
129        if self.design.TIME == "REAL":
130           start_text = "Time " + str(min(self.design.tree.time))
131           end_text = "Time " + str(max(self.design.tree.time))
132        if self.design.TIME == "GENERATIONAL":
133           start_text = "Depth " + str(self.design.props['adepth_min'])
134           end_text = "Depth " + str(self.design.props['adepth_max'])
135
136        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
137        self.add_text(file, start_text, (self.width, self.h_margin), "end")
138        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
139        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")
140
141    def compute_property(self, part, prop, node):
142        start = self.settings[part][prop]['start']
143        end = self.settings[part][prop]['end']
144        value = (self.design.props[self.settings[part][prop]['meaning']][node]
145                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
146        bias = self.settings[part][prop]['bias']
147        if prop == "color":
148            return self.compute_color(start, end, value, bias)
149        else:
150            return self.compute_value(start, end, value, bias)
151
152    def compute_color(self, start, end, value, bias=1):
153        if isinstance(value, str):
154            value = int(value)
155            r = self.colors[self.settings['colors_of_kinds'][value]]['r']
156            g = self.colors[self.settings['colors_of_kinds'][value]]['g']
157            b = self.colors[self.settings['colors_of_kinds'][value]]['b']
158        else:
159            start_color = self.colors[start]
160            end_color = self.colors[end]
161            value = 1 - (1-value)**bias
162            r = start_color['r']*(1-value)+end_color['r']*value
163            g = start_color['g']*(1-value)+end_color['g']*value
164            b = start_color['b']*(1-value)+end_color['b']*value
165        return (r, g, b)
166
167    def compute_value(self, start, end, value, bias=1):
168        value = 1 - (1-value)**bias
169        return start*(1-value) + end*value
170
171class PngDrawer(Drawer):
172
173    def scale_up(self):
174        self.width *= self.multi
175        self.height *= self.multi
176        self.w_margin *= self.multi
177        self.h_margin *= self.multi
178        self.h_no_margs *= self.multi
179        self.w_no_margs *= self.multi
180
181    def scale_down(self):
182        self.width /= self.multi
183        self.height /= self.multi
184        self.w_margin /= self.multi
185        self.h_margin /= self.multi
186        self.h_no_margs /= self.multi
187        self.w_no_margs /= self.multi
188
189    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
190        print("Drawing...")
191
192        self.multi=multi
193        self.scale_up()
194
195        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
196
197        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
198        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
199        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
200
201        self.draw_lines(back, min_width, max_width, max_height)
202        self.draw_dots(back, min_width, max_width, max_height)
203
204        if scale == "SIMPLE":
205            self.draw_scale(back, input_filename)
206
207        #back.show()
208        self.scale_down()
209
210        back.thumbnail((self.width, self.height), Image.ANTIALIAS)
211
212        back.save(filename)
213
214    def add_dot(self, file, pos, style):
215        x, y = int(pos[0]), int(pos[1])
216        r = style['r']*self.multi
217        offset = (int(x - r), int(y - r))
218        size = (2*int(r), 2*int(r))
219
220        c = style['color']
221
222        img = Image.new('RGBA', size)
223        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
224                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
225        file.paste(img, offset, mask=img)
226
227    def add_line(self, file, from_pos, to_pos, style):
228        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
229        w = int(style['width'])*self.multi
230
231        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
232        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
233
234        c = style['color']
235
236        img = Image.new('RGBA', size)
237        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
238                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
239        file.paste(img, offset, mask=img)
240
241    def add_dashed_line(self, file, from_pos, to_pos):
242        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
243        sublines = 50
244        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
245        for i in range(sublines):
246            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/(2*sublines-1), 1),
247                            self.compute_value(from_pos[1], to_pos[1], 2*i/(2*sublines-1), 1))
248            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/(2*sublines-1), 1),
249                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/(2*sublines-1), 1))
250            self.add_line(file, from_pos_sub, to_pos_sub, style)
251
252    def add_text(self, file, text, pos, anchor, style=''):
253        font = ImageFont.truetype("Vera.ttf", 16*self.multi)
254
255        img = Image.new('RGBA', (self.width, self.height))
256        draw = ImageDraw.Draw(img)
257        txtsize = draw.textsize(text, font=font)
258        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
259        draw.text(pos, text, (0,0,0), font=font)
260        file.paste(img, (0,0), mask=img)
261
262    def compute_line_style(self, parent, child):
263        return {'color': self.compute_property('lines', 'color', child),
264                'width': self.compute_property('lines', 'width', child),
265                'opacity': self.compute_property('lines', 'opacity', child)}
266
267    def compute_dot_style(self, node):
268        return {'color': self.compute_property('dots', 'color', node),
269                'r': self.compute_property('dots', 'size', node),
270                'opacity': self.compute_property('dots', 'opacity', node)}
271
272class SvgDrawer(Drawer):
273    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
274        print("Drawing...")
275        file = open(filename, "w")
276
277        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
278        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
279        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
280
281        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
282                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
283                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
284
285        self.draw_lines(file, min_width, max_width, max_height)
286        self.draw_dots(file, min_width, max_width, max_height)
287
288        if scale == "SIMPLE":
289            self.draw_scale(file, input_filename)
290
291        file.write("</svg>")
292        file.close()
293
294    def add_text(self, file, text, pos, anchor, style=''):
295        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
296        # assuming font size 12, it should be taken from the style string!
297        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')
298
299    def add_dot(self, file, pos, style):
300        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
301
302    def add_line(self, file, from_pos, to_pos, style):
303        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
304                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
305
306    def add_dashed_line(self, file, from_pos, to_pos):
307        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
308        self.add_line(file, from_pos, to_pos, style)
309
310    def compute_line_style(self, parent, child):
311        return self.compute_stroke_color('lines', child) + ' ' \
312               + self.compute_stroke_width('lines', child) + ' ' \
313               + self.compute_stroke_opacity(child)
314
315    def compute_dot_style(self, node):
316        return self.compute_dot_size(node) + ' ' \
317               + self.compute_fill_opacity(node) + ' ' \
318               + self.compute_dot_fill(node)
319
320    def compute_stroke_color(self, part, node):
321        color = self.compute_property(part, 'color', node)
322        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
323
324    def compute_stroke_width(self, part, node):
325        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
326
327    def compute_stroke_opacity(self, node):
328        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
329
330    def compute_fill_opacity(self, node):
331        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
332
333    def compute_dot_size(self, node):
334        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
335
336    def compute_dot_fill(self, node):
337        color = self.compute_property('dots', 'color', node)
338        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
339
340class Designer:
341
342    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
343        self.props = {}
344
345        self.tree = tree
346
347        self.TIME = time
348        self.JITTER = jitter
349
350        if balance == "RANDOM":
351            self.xmin_crowd = self.xmin_crowd_random
352        elif balance == "MIN":
353            self.xmin_crowd = self.xmin_crowd_min
354        elif balance == "DENSITY":
355            self.xmin_crowd = self.xmin_crowd_density
356        else:
357            raise ValueError("Error, the value of BALANCE does not match any expected value.")
358
359    def calculate_measures(self):
360        print("Calculating measures...")
361        self.compute_adepth()
362        self.compute_depth()
363        self.compute_children()
364        self.compute_kind()
365        self.compute_time()
366        self.compute_progress()
367        self.compute_custom()
368
369    def xmin_crowd_random(self, x1, x2, y):
370        return (x1 if random.randrange(2) == 0 else x2)
371
372    def xmin_crowd_min(self, x1, x2, y):
373        x1_closest = 999999
374        x2_closest = 999999
375        miny = y-3
376        maxy = y+3
377        i = bisect.bisect_left(self.y_sorted, miny)
378        while True:
379            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
380                break
381            pos = self.positions_sorted[i]
382
383            x1_closest = min(x1_closest, abs(x1-pos['x']))
384            x2_closest = min(x2_closest, abs(x2-pos['x']))
385
386            i += 1
387        return (x1 if x1_closest > x2_closest else x2)
388
389    def xmin_crowd_density(self, x1, x2, y):
390        x1_dist = 0
391        x2_dist = 0
392        miny = y-20
393        maxy = y+20
394        i_left = bisect.bisect_left(self.y_sorted, miny)
395        i_right = bisect.bisect_right(self.y_sorted, maxy)
396        # print("i " + str(i) + " len " + str(len(self.positions)))
397        #
398        # i = bisect.bisect_left(self.y_sorted, y)
399        # i_left = max(0, i - 25)
400        # i_right = min(len(self.y_sorted), i + 25)
401
402        def include_pos(pos):
403            nonlocal x1_dist, x2_dist
404
405            dysq = (pos['y']-y)**2
406            dx1 = pos['x']-x1
407            dx2 = pos['x']-x2
408
409            x1_dist += math.sqrt(dysq + dx1**2)
410            x2_dist += math.sqrt(dysq + dx2**2)
411
412        # optimized to draw from all the nodes, if less than 10 nodes in the range
413        if len(self.positions_sorted) > i_left:
414            if i_right - i_left < 10:
415                for j in range(i_left, i_right):
416                    include_pos(self.positions_sorted[j])
417            else:
418                for j in range(10):
419                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
420                    include_pos(pos)
421
422        return (x1 if x1_dist > x2_dist else x2)
423        #print(x1_dist, x2_dist)
424        #x1_dist = x1_dist**2
425        #x2_dist = x2_dist**2
426        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
427        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
428
429    def calculate_node_positions(self, ignore_last=0):
430        print("Calculating positions...")
431
432        current_node = 0
433
434        def add_node(node):
435            nonlocal current_node
436            index = bisect.bisect_left(self.y_sorted, node['y'])
437            self.y_sorted.insert(index, node['y'])
438            self.positions_sorted.insert(index, node)
439            self.positions[node['id']] = node
440
441        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
442        self.y_sorted = [0]
443        self.positions = [{} for x in range(len(self.tree.parents))]
444        self.positions[0] = {'x':0, 'y':0, 'id':0}
445
446        nodes_to_visit = [0]
447        visited = [False] * len(self.tree.parents)
448        visited[0] = True
449
450        node_counter = 0
451        start_time = timelib.time()
452
453        while True:
454
455            node_counter += 1
456            if node_counter%1000 == 0:
457                print(str(node_counter) + " " + str(timelib.time()-start_time))
458                start_time = timelib.time()
459
460            current_node = nodes_to_visit[0]
461
462            for child in self.tree.children[current_node]:
463                if not visited[child] and self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
464                    nodes_to_visit.append(child)
465                    visited[child] = True
466
467                    ypos = 0
468                    if self.TIME == "BIRTHS":
469                        ypos = child
470                    elif self.TIME == "GENERATIONAL":
471                        ypos = self.positions[current_node]['y']+1
472                    elif self.TIME == "REAL":
473                        ypos = self.tree.time[child]
474
475                    if len(self.tree.parents[child]) == 1:
476                    # if current_node is the only parent
477                        if self.JITTER:
478                            dissimilarity = random.gauss(0, 0.5)
479                        else:
480                            dissimilarity = 1
481                        add_node({'id':child, 'y':ypos, 'x':
482                                 self.xmin_crowd(self.positions[current_node]['x']-dissimilarity,
483                                  self.positions[current_node]['x']+dissimilarity, ypos)})
484                    else:
485                        total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
486                        xpos = sum([self.positions[k]['x']*v/total_inheretance
487                                   for k, v in self.tree.parents[child].items()])
488                        if self.JITTER:
489                            add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
490                        else:
491                            add_node({'id':child, 'y':ypos, 'x':xpos})
492
493            nodes_to_visit = nodes_to_visit[1:]
494            # if none left, we can stop
495            if len(nodes_to_visit) == 0:
496                print("done")
497                break
498
499    def compute_custom(self):
500        for prop in self.tree.props:
501            self.props[prop] = [None for x in range(len(self.tree.children))]
502
503            for i in range(len(self.props[prop])):
504                self.props[prop][i] = self.tree.props[prop][i]
505
506            self.normalize_prop(prop)
507
508    def compute_time(self):
509        # simple rewrite from the tree
510        self.props["time"] = [0 for x in range(len(self.tree.children))]
511
512        for i in range(len(self.props['time'])):
513            self.props['time'][i] = self.tree.time[i]
514
515        self.normalize_prop('time')
516
517    def compute_kind(self):
518        # simple rewrite from the tree
519        self.props["kind"] = [0 for x in range(len(self.tree.children))]
520
521        for i in range (len(self.props['kind'])):
522            self.props['kind'][i] = str(self.tree.kind[i])
523
524    def compute_depth(self):
525        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
526
527        nodes_to_visit = [0]
528        self.props["depth"][0] = 0
529        while True:
530            for child in self.tree.children[nodes_to_visit[0]]:
531                nodes_to_visit.append(child)
532                self.props["depth"][child] = min([self.props["depth"][d] for d in self.tree.parents[child]])+1
533            nodes_to_visit = nodes_to_visit[1:]
534            if len(nodes_to_visit) == 0:
535                break
536
537        self.normalize_prop('depth')
538
539    def compute_adepth(self):
540        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
541
542        def compute_local_adepth(node):
543            my_adepth = 0
544            for c in self.tree.children[node]:
545                my_adepth = max(my_adepth, compute_local_adepth(c)+1)
546            self.props["adepth"][node] = my_adepth
547            return my_adepth
548
549        compute_local_adepth(0)
550        self.normalize_prop('adepth')
551
552    def compute_children(self):
553        self.props["children"] = [0 for x in range(len(self.tree.children))]
554        for i in range (len(self.props['children'])):
555            self.props['children'][i] = len(self.tree.children[i])
556
557        self.normalize_prop('children')
558
559    def compute_progress(self):
560        self.props["progress"] = [0 for x in range(len(self.tree.children))]
561        for i in range(len(self.props['children'])):
562            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
563            if len(times) > 4:
564                times = [times[i+1] - times[i] for i in range(len(times)-1)]
565                #print(times)
566                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
567                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0
568
569        for i in range(0, 5):
570            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
571            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0
572
573        mini = min(self.props['progress'])
574        maxi = max(self.props['progress'])
575        for k in range(len(self.props['progress'])):
576            if self.props['progress'][k] == 0:
577                self.props['progress'][k] = mini
578
579        #for k in range(len(self.props['progress'])):
580        #        self.props['progress'][k] = 1-self.props['progress'][k]
581
582        self.normalize_prop('progress')
583
584    def normalize_prop(self, prop):
585        noneless = [v for v in self.props[prop] if type(v)!=str]
586        if len(noneless) > 0:
587            max_val = max(noneless)
588            min_val = min(noneless)
589            print(prop, max_val, min_val)
590            self.props[prop +'_max'] = max_val
591            self.props[prop +'_min'] = min_val
592            for i in range(len(self.props[prop])):
593                if self.props[prop][i] is not None:
594                    qqq = self.props[prop][i]
595                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)
596
597class TreeData:
598    simple_data = None
599
600    children = []
601    parents = []
602    time = []
603    kind = []
604
605    def __init__(self): #, simple_data=False):
606        #self.simple_data = simple_data
607        pass
608
609    def load(self, filename, max_nodes=0):
610        print("Loading...")
611
612        CLI_PREFIX = "Script.Message:"
613        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
614
615        self.ids = {}
616        def get_id(id, createOnError = True):
617            if createOnError:
618                if id not in self.ids:
619                    self.ids[id] = len(self.ids)
620            else:
621                if id not in self.ids:
622                    return None
623            return self.ids[id]
624
625        file = open(filename)
626
627        # counting the number of expected nodes
628        nodes = 0
629        for line in file:
630            line_arr = line.split(' ', 1)
631            if len(line_arr) == 2:
632                if line_arr[0] == CLI_PREFIX:
633                    line_arr = line_arr[1].split(' ', 1)
634                if line_arr[0] == "[OFFSPRING]":
635                    nodes += 1
636
637        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
638        self.parents = [{} for x in range(nodes)]
639        self.children = [[] for x in range(nodes)]
640        self.time = [0] * nodes
641        self.kind = [0] * nodes
642        self.life_lenght = [0] * nodes
643        self.props = {}
644
645        print(len(self.parents))
646
647        file.seek(0)
648        loaded_so_far = 0
649        lasttime = timelib.time()
650        for line in file:
651            line_arr = line.split(' ', 1)
652            if len(line_arr) == 2:
653                if line_arr[0] == CLI_PREFIX:
654                    line_arr = line_arr[1].split(' ', 1)
655                if line_arr[0] == "[OFFSPRING]":
656                    creature = json.loads(line_arr[1])
657                    if "FromIDs" in creature:
658
659                        # make sure that ID's of parents are lower than that of their children
660                        for i in range(0, len(creature["FromIDs"])):
661                            if creature["FromIDs"][i] not in self.ids:
662                                get_id("virtual_parent")
663
664                        creature_id = get_id(creature["ID"])
665
666                        # debug
667                        if loaded_so_far%1000 == 0:
668                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
669                            lasttime = timelib.time()
670
671                        # we assign to each parent its contribution to the genotype of the child
672                        for i in range(0, len(creature["FromIDs"])):
673                            if creature["FromIDs"][i] in self.ids:
674                                parent_id = get_id(creature["FromIDs"][i])
675                            else:
676                                parent_id = get_id("virtual_parent")
677                            inherited = 1 #(creature["Inherited"][i] if 'Inherited' in creature else 1) #ONLY FOR NOW
678                            self.parents[creature_id][parent_id] = inherited
679
680                        if "Time" in creature:
681                            self.time[creature_id] = creature["Time"]
682
683                        if "Kind" in creature:
684                            self.kind[creature_id] = creature["Kind"]
685
686                        for prop in creature:
687                            if prop not in default_props:
688                                if prop not in self.props:
689                                    self.props[prop] = [0 for i in range(nodes)]
690                                self.props[prop][creature_id] = creature[prop]
691
692                        loaded_so_far += 1
693                    else:
694                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
695                if line_arr[0] == "[DIED]":
696                    creature = json.loads(line_arr[1])
697                    creature_id = get_id(creature["ID"], False)
698                    if creature_id is not None:
699                        for prop in creature:
700                            if prop not in default_props:
701                                if prop not in self.props:
702                                    self.props[prop] = [0 for i in range(nodes)]
703                                self.props[prop][creature_id] = creature[prop]
704
705
706            if loaded_so_far >= max_nodes and max_nodes != 0:
707                break
708
709        for k in range(len(self.parents)):
710            v = self.parents[k]
711            for val in self.parents[k]:
712                self.children[val].append(k)
713
714depth = {}
715kind = {}
716
717def main():
718
719    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
720                                                 'information from a text file. Supports files generated by Framsticks experiments.')
721    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
722    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
723    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
724
725    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
726    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
727    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')
728
729    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
730                                                                      'BIRTHS: time measured as the number of births since the beginning; '
731                                                                      'GENERATIONAL: time measured as number of ancestors; '
732                                                                      'REAL: real time of the simulation')
733    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
734    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
735    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
736    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
737    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
738    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
739
740    parser.set_defaults(draw_tree=True)
741    parser.set_defaults(draw_skeleton=False)
742    parser.set_defaults(draw_spine=False)
743
744    parser.set_defaults(seed=-1)
745
746    args = parser.parse_args()
747
748    TIME = args.time.upper()
749    BALANCE = args.balance.upper()
750    SCALE = args.scale.upper()
751    JITTER = args.jitter
752    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
753        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
754        or not SCALE in ['NONE', 'SIMPLE']:
755        print("Incorrect value of one of the parameters! Closing the program.") #TODO don't be lazy, figure out which parameter is wrong...
756        return
757
758    dir = args.input
759    seed = args.seed
760    if seed == -1:
761        seed = random.randint(0, 10000)
762    random.seed(seed)
763    print("seed:", seed)
764
765    tree = TreeData()
766    tree.load(dir, max_nodes=args.max_nodes)
767
768    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
769    designer.calculate_measures()
770    designer.calculate_node_positions(ignore_last=args.skip)
771
772    if args.output.endswith(".svg"):
773        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
774    else:
775        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
776    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)
777
778
779main()
Note: See TracBrowser for help on using the repository browser.