source: mds-and-trees/tree-genealogy.py @ 694

Last change on this file since 694 was 694, checked in by konrad, 7 years ago

Some improvements in behavior for big trees

File size: 34.6 KB
Line 
1import json
2import math
3import random
4import argparse
5import bisect
6import time as timelib
7from PIL import Image, ImageDraw, ImageFont
8from scipy import stats
9import numpy as np
10
11class LoadingError(Exception):
12    pass
13
14class Drawer:
15
16    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
17        self.design = design
18        self.width = w
19        self.height = h
20        self.w_margin = w_margin
21        self.h_margin = h_margin
22        self.w_no_margs = w - 2* w_margin
23        self.h_no_margs = h - 2* h_margin
24
25        self.colors = {
26            'white' :   {'r':100,   'g':100,    'b':100},
27            'black' :   {'r':0,     'g':0,      'b':0},
28            'red' :     {'r':100,   'g':0,      'b':0},
29            'green' :   {'r':0,     'g':100,    'b':0},
30            'blue' :    {'r':0,     'g':0,      'b':100},
31            'yellow' :  {'r':100,   'g':100,    'b':0},
32            'magenta' : {'r':100,   'g':0,      'b':100},
33            'cyan' :    {'r':0,     'g':100,    'b':100},
34            'orange':   {'r':100,   'g':50,     'b':0},
35            'purple':   {'r':50,    'g':0,      'b':100}
36        }
37
38        self.settings = {
39            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
40            'dots': {
41                'color': {
42                    'meaning': 'Lifespan',
43                    'start': 'red',
44                    'end': 'green',
45                    'bias': 1
46                    },
47                'size': {
48                    'meaning': 'EnergyEaten',
49                    'start': 1,
50                    'end': 6,
51                    'bias': 0.5
52                    },
53                'opacity': {
54                    'meaning': 'EnergyEaten',
55                    'start': 0.2,
56                    'end': 1,
57                    'bias': 1
58                    }
59            },
60            'lines': {
61                'color': {
62                    'meaning': 'adepth',
63                    'start': 'black',
64                    'end': 'red',
65                    'bias': 3
66                    },
67                'width': {
68                    'meaning': 'adepth',
69                    'start': 0.1,
70                    'end': 4,
71                    'bias': 3
72                    },
73                'opacity': {
74                    'meaning': 'adepth',
75                    'start': 0.1,
76                    'end': 0.8,
77                    'bias': 5
78                    }
79            }
80        }
81
82        def merge(source, destination):
83            for key, value in source.items():
84                if isinstance(value, dict):
85                    node = destination.setdefault(key, {})
86                    merge(value, node)
87                else:
88                    destination[key] = value
89
90            return destination
91
92        if config_file != "":
93            with open(config_file) as config:
94                c = json.load(config)
95            self.settings = merge(c, self.settings)
96            #print(json.dumps(self.settings, indent=4, sort_keys=True))
97
98    def draw_dots(self, file, min_width, max_width, max_height):
99        for i in range(len(self.design.positions)):
100            node = self.design.positions[i]
101            if 'x' not in node:
102                continue
103            dot_style = self.compute_dot_style(node=i)
104            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
105                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
106
107    def draw_lines(self, file, min_width, max_width, max_height):
108        for parent in range(len(self.design.positions)):
109            par_pos = self.design.positions[parent]
110            if not 'x' in par_pos:
111                continue
112            for child in self.design.tree.children[parent]:
113                chi_pos = self.design.positions[child]
114                if 'x' not in chi_pos:
115                    continue
116                line_style = self.compute_line_style(parent, child)
117                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
118                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
119                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
120                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
121
122    def draw_scale(self, file, filename):
123        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 5), "start")
124
125        start_text = ""
126        end_text = ""
127        if self.design.TIME == "BIRTHS":
128           start_text = "Birth #0"
129           end_text = "Birth #" + str(len(self.design.positions)-1)
130        if self.design.TIME == "REAL":
131           start_text = "Time " + str(min(self.design.tree.time))
132           end_text = "Time " + str(max(self.design.tree.time))
133        if self.design.TIME == "GENERATIONAL":
134           start_text = "Depth " + str(self.design.props['adepth_min'])
135           end_text = "Depth " + str(self.design.props['adepth_max'])
136
137        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
138        self.add_text(file, start_text, (self.width, self.h_margin), "end")
139        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
140        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")
141
142    def compute_property(self, part, prop, node):
143        start = self.settings[part][prop]['start']
144        end = self.settings[part][prop]['end']
145        value = (self.design.props[self.settings[part][prop]['meaning']][node]
146                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
147        bias = self.settings[part][prop]['bias']
148        if prop == "color":
149            return self.compute_color(start, end, value, bias)
150        else:
151            return self.compute_value(start, end, value, bias)
152
153    def compute_color(self, start, end, value, bias=1):
154        if isinstance(value, str):
155            value = int(value)
156            r = self.colors[self.settings['colors_of_kinds'][value]]['r']
157            g = self.colors[self.settings['colors_of_kinds'][value]]['g']
158            b = self.colors[self.settings['colors_of_kinds'][value]]['b']
159        else:
160            start_color = self.colors[start]
161            end_color = self.colors[end]
162            value = 1 - (1-value)**bias
163            r = start_color['r']*(1-value)+end_color['r']*value
164            g = start_color['g']*(1-value)+end_color['g']*value
165            b = start_color['b']*(1-value)+end_color['b']*value
166        return (r, g, b)
167
168    def compute_value(self, start, end, value, bias=1):
169        value = 1 - (1-value)**bias
170        return start*(1-value) + end*value
171
172class PngDrawer(Drawer):
173
174    def scale_up(self):
175        self.width *= self.multi
176        self.height *= self.multi
177        self.w_margin *= self.multi
178        self.h_margin *= self.multi
179        self.h_no_margs *= self.multi
180        self.w_no_margs *= self.multi
181
182    def scale_down(self):
183        self.width /= self.multi
184        self.height /= self.multi
185        self.w_margin /= self.multi
186        self.h_margin /= self.multi
187        self.h_no_margs /= self.multi
188        self.w_no_margs /= self.multi
189
190    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
191        print("Drawing...")
192
193        self.multi=multi
194        self.scale_up()
195
196        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
197
198        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
199        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
200        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
201
202        self.draw_lines(back, min_width, max_width, max_height)
203        self.draw_dots(back, min_width, max_width, max_height)
204
205        if scale == "SIMPLE":
206            self.draw_scale(back, input_filename)
207
208        #back.show()
209        self.scale_down()
210
211        back.thumbnail((self.width, self.height), Image.ANTIALIAS)
212
213        back.save(filename)
214
215    def add_dot(self, file, pos, style):
216        x, y = int(pos[0]), int(pos[1])
217        r = style['r']*self.multi
218        offset = (int(x - r), int(y - r))
219        size = (2*int(r), 2*int(r))
220
221        c = style['color']
222
223        img = Image.new('RGBA', size)
224        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
225                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
226        file.paste(img, offset, mask=img)
227
228    def add_line(self, file, from_pos, to_pos, style):
229        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
230        w = int(style['width'])*self.multi
231
232        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
233        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
234
235        c = style['color']
236
237        img = Image.new('RGBA', size)
238        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
239                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
240        file.paste(img, offset, mask=img)
241
242    def add_dashed_line(self, file, from_pos, to_pos):
243        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
244        sublines = 50
245        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
246        normdiv = 2*sublines-1
247        for i in range(sublines):
248            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/normdiv, 1),
249                            self.compute_value(from_pos[1], to_pos[1], 2*i/normdiv, 1))
250            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/normdiv, 1),
251                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/normdiv, 1))
252            self.add_line(file, from_pos_sub, to_pos_sub, style)
253
254    def add_text(self, file, text, pos, anchor, style=''):
255        font = ImageFont.truetype("Vera.ttf", 16*self.multi)
256
257        img = Image.new('RGBA', (self.width, self.height))
258        draw = ImageDraw.Draw(img)
259        txtsize = draw.textsize(text, font=font)
260        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
261        draw.text(pos, text, (0,0,0), font=font)
262        file.paste(img, (0,0), mask=img)
263
264    def compute_line_style(self, parent, child):
265        return {'color': self.compute_property('lines', 'color', child),
266                'width': self.compute_property('lines', 'width', child),
267                'opacity': self.compute_property('lines', 'opacity', child)}
268
269    def compute_dot_style(self, node):
270        return {'color': self.compute_property('dots', 'color', node),
271                'r': self.compute_property('dots', 'size', node),
272                'opacity': self.compute_property('dots', 'opacity', node)}
273
274class SvgDrawer(Drawer):
275    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
276        print("Drawing...")
277        file = open(filename, "w")
278
279        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
280        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
281        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
282
283        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
284                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
285                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
286
287        self.draw_lines(file, min_width, max_width, max_height)
288        self.draw_dots(file, min_width, max_width, max_height)
289
290        if scale == "SIMPLE":
291            self.draw_scale(file, input_filename)
292
293        file.write("</svg>")
294        file.close()
295
296    def add_text(self, file, text, pos, anchor, style=''):
297        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
298        # assuming font size 12, it should be taken from the style string!
299        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')
300
301    def add_dot(self, file, pos, style):
302        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
303
304    def add_line(self, file, from_pos, to_pos, style):
305        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
306                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
307
308    def add_dashed_line(self, file, from_pos, to_pos):
309        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
310        self.add_line(file, from_pos, to_pos, style)
311
312    def compute_line_style(self, parent, child):
313        return self.compute_stroke_color('lines', child) + ' ' \
314               + self.compute_stroke_width('lines', child) + ' ' \
315               + self.compute_stroke_opacity(child)
316
317    def compute_dot_style(self, node):
318        return self.compute_dot_size(node) + ' ' \
319               + self.compute_fill_opacity(node) + ' ' \
320               + self.compute_dot_fill(node)
321
322    def compute_stroke_color(self, part, node):
323        color = self.compute_property(part, 'color', node)
324        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
325
326    def compute_stroke_width(self, part, node):
327        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
328
329    def compute_stroke_opacity(self, node):
330        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
331
332    def compute_fill_opacity(self, node):
333        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
334
335    def compute_dot_size(self, node):
336        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
337
338    def compute_dot_fill(self, node):
339        color = self.compute_property('dots', 'color', node)
340        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
341
342class Designer:
343
344    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
345        self.props = {}
346
347        self.tree = tree
348
349        self.TIME = time
350        self.JITTER = jitter
351
352        if balance == "RANDOM":
353            self.xmin_crowd = self.xmin_crowd_random
354        elif balance == "MIN":
355            self.xmin_crowd = self.xmin_crowd_min
356        elif balance == "DENSITY":
357            self.xmin_crowd = self.xmin_crowd_density
358        else:
359            raise ValueError("Error, the value of BALANCE does not match any expected value.")
360
361    def calculate_measures(self):
362        print("Calculating measures...")
363        self.compute_depth()
364        self.compute_adepth()
365        self.compute_children()
366        self.compute_kind()
367        self.compute_time()
368        self.compute_progress()
369        self.compute_custom()
370
371    def xmin_crowd_random(self, x1, x2, y):
372        return (x1 if random.randrange(2) == 0 else x2)
373
374    def xmin_crowd_min(self, x1, x2, y):
375        x1_closest = 999999
376        x2_closest = 999999
377        miny = y-3
378        maxy = y+3
379        i = bisect.bisect_left(self.y_sorted, miny)
380        while True:
381            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
382                break
383            pos = self.positions_sorted[i]
384
385            x1_closest = min(x1_closest, abs(x1-pos['x']))
386            x2_closest = min(x2_closest, abs(x2-pos['x']))
387
388            i += 1
389        return (x1 if x1_closest > x2_closest else x2)
390
391    def xmin_crowd_density(self, x1, x2, y):
392        # TODO experimental - requires further work to make it less 'jumpy' and more predictable
393        CONST_LOCAL_AREA_RADIUS = 5
394        CONST_GLOBAL_AREA_RADIUS = 10
395        CONST_WINDOW_SIZE = 20000 #TODO should depend on the maxY ?
396        x1_dist_loc = 0
397        x2_dist_loc = 0
398        count_loc = 1
399        x1_dist_glob = 0
400        x2_dist_glob = 0
401        count_glob = 1
402        miny = y-CONST_WINDOW_SIZE
403        maxy = y+CONST_WINDOW_SIZE
404        i_left = bisect.bisect_left(self.y_sorted, miny)
405        i_right = bisect.bisect_right(self.y_sorted, maxy)
406        #TODO test: maxy=y should give the same results, right?
407
408        def include_pos(pos):
409            nonlocal x1_dist_loc, x2_dist_loc, x1_dist_glob, x2_dist_glob, count_loc, count_glob
410
411            dysq = (pos['y']-y)**2 + 1 #+1 so 1/dysq is at most 1
412            dx1 = math.fabs(pos['x']-x1)
413            dx2 = math.fabs(pos['x']-x2)
414
415            d = math.fabs(pos['x'] - (x1+x2)/2)
416
417            if d < CONST_LOCAL_AREA_RADIUS:
418                x1_dist_loc += math.sqrt(dx1/dysq + dx1**2)
419                x2_dist_loc += math.sqrt(dx2/dysq + dx2**2)
420                count_loc += 1
421            elif d > CONST_GLOBAL_AREA_RADIUS:
422                x1_dist_glob += math.sqrt(dx1/dysq + dx1**2)
423                x2_dist_glob += math.sqrt(dx2/dysq + dx2**2)
424                count_glob += 1
425
426        # optimized to draw from all the nodes, if less than 10 nodes in the range
427        if len(self.positions_sorted) > i_left:
428            if i_right - i_left < 10:
429                for j in range(i_left, i_right):
430                    include_pos(self.positions_sorted[j])
431            else:
432                for j in range(10):
433                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
434                    include_pos(pos)
435
436        return (x1 if (x1_dist_loc-x2_dist_loc)/count_loc-(x1_dist_glob-x2_dist_glob)/count_glob > 0  else x2)
437        #return (x1 if x1_dist +random.gauss(0, 0.00001) > x2_dist +random.gauss(0, 0.00001)  else x2)
438        #print(x1_dist, x2_dist)
439        #x1_dist = x1_dist**2
440        #x2_dist = x2_dist**2
441        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
442        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
443
444    def calculate_node_positions(self, ignore_last=0):
445        print("Calculating positions...")
446
447        def add_node(node):
448            index = bisect.bisect_left(self.y_sorted, node['y'])
449            self.y_sorted.insert(index, node['y'])
450            self.positions_sorted.insert(index, node)
451            self.positions[node['id']] = node
452
453        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
454        self.y_sorted = [0]
455        self.positions = [{} for x in range(len(self.tree.parents))]
456        self.positions[0] = {'x':0, 'y':0, 'id':0}
457
458        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
459        visiting_order = [i for i in range(0, len(self.tree.parents))]
460        visiting_order = sorted(visiting_order, key=lambda q:
461                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))
462
463        start_time = timelib.time()
464
465        # for each child of the current node
466        for node_counter,child in enumerate(visiting_order, start=1):
467            # debug info - elapsed time
468            if node_counter % 100000 == 0:
469               print("%d%%\t%d\t%g" % (node_counter*100/len(self.tree.parents), node_counter, timelib.time()-start_time))
470               start_time = timelib.time()
471
472            # using normalized adepth
473            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
474
475                ypos = 0
476                if self.TIME == "BIRTHS":
477                    ypos = child
478                elif self.TIME == "GENERATIONAL":
479                    # one more than its parent (what if more than one parent?)
480                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1 \
481                        if self.tree.parents[child] else 0
482                elif self.TIME == "REAL":
483                    ypos = self.tree.time[child]
484
485                if len(self.tree.parents[child]) == 1:
486                # if current_node is the only parent
487                    parent, similarity = [(par, v) for par, v in self.tree.parents[child].items()][0]
488
489                    if self.JITTER:
490                        dissimilarity = (1-similarity) + random.gauss(0, 0.01) + 0.001
491                    else:
492                        dissimilarity = (1-similarity) + 0.001
493                    add_node({'id':child, 'y':ypos, 'x':
494                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
495                              self.positions[parent]['x']+dissimilarity, ypos)})
496                else:
497                    # position weighted by the degree of inheritence from each parent
498                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
499                    xpos = sum([self.positions[k]['x']*v/total_inheretance
500                               for k, v in self.tree.parents[child].items()])
501                    if self.JITTER:
502                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
503                    else:
504                        add_node({'id':child, 'y':ypos, 'x':xpos})
505
506
507    def compute_custom(self):
508        for prop in self.tree.props:
509            self.props[prop] = [None for x in range(len(self.tree.children))]
510
511            for i in range(len(self.props[prop])):
512                self.props[prop][i] = self.tree.props[prop][i]
513
514            self.normalize_prop(prop)
515
516    def compute_time(self):
517        # simple rewrite from the tree
518        self.props["time"] = [0 for x in range(len(self.tree.children))]
519
520        for i in range(len(self.props['time'])):
521            self.props['time'][i] = self.tree.time[i]
522
523        self.normalize_prop('time')
524
525    def compute_kind(self):
526        # simple rewrite from the tree
527        self.props["kind"] = [0 for x in range(len(self.tree.children))]
528
529        for i in range (len(self.props['kind'])):
530            self.props['kind'][i] = str(self.tree.kind[i])
531
532    def compute_depth(self):
533        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
534        visited = [0 for x in range(len(self.tree.children))]
535
536        nodes_to_visit = [0]
537        visited[0] = 1
538        self.props["depth"][0] = 0
539        while True:
540            current_node = nodes_to_visit[0]
541
542            for child in self.tree.children[current_node]:
543                if visited[child] == 0:
544                    visited[child] = 1
545                    nodes_to_visit.append(child)
546                    self.props["depth"][child] = self.props["depth"][current_node]+1
547            nodes_to_visit = nodes_to_visit[1:]
548            if len(nodes_to_visit) == 0:
549                break
550
551        self.normalize_prop('depth')
552
553    def compute_adepth(self):
554        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
555
556        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
557        visiting_order = [i for i in range(0, len(self.tree.parents))]
558        visiting_order = sorted(visiting_order, key=lambda q:
559                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))[::-1]
560
561        for node in visiting_order:
562            children = self.tree.children[node]
563            if len(children) != 0:
564                # 0 by default
565                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
566        self.normalize_prop('adepth')
567
568    def compute_children(self):
569        self.props["children"] = [0 for x in range(len(self.tree.children))]
570        for i in range (len(self.props['children'])):
571            self.props['children'][i] = len(self.tree.children[i])
572
573        self.normalize_prop('children')
574
575    def compute_progress(self):
576        self.props["progress"] = [0 for x in range(len(self.tree.children))]
577        for i in range(len(self.props['children'])):
578            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
579            if len(times) > 4:
580                times = [times[i+1] - times[i] for i in range(len(times)-1)]
581                #print(times)
582                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
583                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0
584
585        for i in range(0, 5):
586            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
587            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0
588
589        mini = min(self.props['progress'])
590        maxi = max(self.props['progress'])
591        for k in range(len(self.props['progress'])):
592            if self.props['progress'][k] == 0:
593                self.props['progress'][k] = mini
594
595        #for k in range(len(self.props['progress'])):
596        #        self.props['progress'][k] = 1-self.props['progress'][k]
597
598        self.normalize_prop('progress')
599
600    def normalize_prop(self, prop):
601        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
602        if len(noneless) > 0:
603            max_val = max(noneless)
604            min_val = min(noneless)
605            print(prop, max_val, min_val)
606            self.props[prop +'_max'] = max_val
607            self.props[prop +'_min'] = min_val
608            for i in range(len(self.props[prop])):
609                if self.props[prop][i] is not None:
610                    qqq = self.props[prop][i]
611                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)
612
613class TreeData:
614    simple_data = None
615
616    children = []
617    parents = []
618    time = []
619    kind = []
620
621    def __init__(self): #, simple_data=False):
622        #self.simple_data = simple_data
623        pass
624
625    def load(self, filename, max_nodes=0):
626        print("Loading...")
627
628        CLI_PREFIX = "Script.Message:"
629        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
630
631        self.ids = {}
632        def get_id(id, createOnError = True):
633            if createOnError:
634                if id not in self.ids:
635                    self.ids[id] = len(self.ids)
636            else:
637                if id not in self.ids:
638                    return None
639            return self.ids[id]
640
641        file = open(filename)
642
643        # counting the number of expected nodes
644        nodes = 0
645        for line in file:
646            line_arr = line.split(' ', 1)
647            if len(line_arr) == 2:
648                if line_arr[0] == CLI_PREFIX:
649                    line_arr = line_arr[1].split(' ', 1)
650                if line_arr[0] == "[OFFSPRING]":
651                    nodes += 1
652
653        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
654        self.parents = [{} for x in range(nodes)]
655        self.children = [[] for x in range(nodes)]
656        self.time = [0] * nodes
657        self.kind = [0] * nodes
658        self.life_lenght = [0] * nodes
659        self.props = {}
660
661        print("nodes: %d" % len(self.parents))
662
663        file.seek(0)
664        loaded_so_far = 0
665        lasttime = timelib.time()
666        for line in file:
667            line_arr = line.split(' ', 1)
668            if len(line_arr) == 2:
669                if line_arr[0] == CLI_PREFIX:
670                    line_arr = line_arr[1].split(' ', 1)
671                if line_arr[0] == "[OFFSPRING]":
672                    try:
673                        creature = json.loads(line_arr[1])
674                    except ValueError:
675                        print("Json format error - the line cannot be read. Breaking the loading loop.")
676                        # fixing arrays by removing the last element
677                        # ! assuming that only the last line is broken !
678                        self.parents.pop()
679                        self.children.pop()
680                        self.time.pop()
681                        self.kind.pop()
682                        self.life_lenght.pop()
683                        nodes -= 1
684                        break
685
686                    if "FromIDs" in creature:
687
688                        # make sure that ID's of parents are lower than that of their children
689                        for i in range(0, len(creature["FromIDs"])):
690                            if creature["FromIDs"][i] not in self.ids:
691                                get_id("virtual_parent")
692
693                        creature_id = get_id(creature["ID"])
694
695                        # debug
696                        if loaded_so_far%1000 == 0:
697                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
698                            lasttime = timelib.time()
699
700                        # we assign to each parent its contribution to the genotype of the child
701                        for i in range(0, len(creature["FromIDs"])):
702                            if creature["FromIDs"][i] in self.ids:
703                                parent_id = get_id(creature["FromIDs"][i])
704                            else:
705                                parent_id = get_id("virtual_parent")
706                            inherited = (creature["Inherited"][i] if 'Inherited' in creature else 1)
707                            self.parents[creature_id][parent_id] = inherited
708
709                        if "Time" in creature:
710                            self.time[creature_id] = creature["Time"]
711
712                        if "Kind" in creature:
713                            self.kind[creature_id] = creature["Kind"]
714
715                        for prop in creature:
716                            if prop not in default_props:
717                                if prop not in self.props:
718                                    self.props[prop] = [0 for i in range(nodes)]
719                                self.props[prop][creature_id] = creature[prop]
720
721                        loaded_so_far += 1
722                    else:
723                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
724                if line_arr[0] == "[DIED]":
725                    creature = json.loads(line_arr[1])
726                    creature_id = get_id(creature["ID"], False)
727                    if creature_id is not None:
728                        for prop in creature:
729                            if prop not in default_props:
730                                if prop not in self.props:
731                                    self.props[prop] = [0 for i in range(nodes)]
732                                self.props[prop][creature_id] = creature[prop]
733
734
735            if loaded_so_far >= max_nodes and max_nodes != 0:
736                break
737
738        for k in range(len(self.parents)):
739            v = self.parents[k]
740            for val in self.parents[k]:
741                self.children[val].append(k)
742
743depth = {}
744kind = {}
745
746def main():
747
748    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
749                                                 'information from a text file. Supports files generated by Framsticks experiments.')
750    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
751    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
752    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
753
754    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
755    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
756    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')
757
758    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
759                                                                      'BIRTHS: time measured as the number of births since the beginning; '
760                                                                      'GENERATIONAL: time measured as number of ancestors; '
761                                                                      'REAL: real time of the simulation')
762    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
763    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
764    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
765    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
766    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
767    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
768
769    parser.set_defaults(draw_tree=True)
770    parser.set_defaults(draw_skeleton=False)
771    parser.set_defaults(draw_spine=False)
772
773    parser.set_defaults(seed=-1)
774
775    args = parser.parse_args()
776
777    TIME = args.time.upper()
778    BALANCE = args.balance.upper()
779    SCALE = args.scale.upper()
780    JITTER = args.jitter
781    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
782        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
783        or not SCALE in ['NONE', 'SIMPLE']:
784        print("Incorrect value of one of the parameters! (time or balance or scale).") #user has to figure out which parameter is wrong...
785        return
786
787    dir = args.input
788    seed = args.seed
789    if seed == -1:
790        seed = random.randint(0, 10000)
791    random.seed(seed)
792    print("randomseed:", seed)
793
794    tree = TreeData()
795    tree.load(dir, max_nodes=args.max_nodes)
796
797
798    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
799    designer.calculate_measures()
800    designer.calculate_node_positions(ignore_last=args.skip)
801
802    if args.output.endswith(".svg"):
803        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
804    else:
805        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
806    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)
807
808
809main()
Note: See TracBrowser for help on using the repository browser.