source: mds-and-trees/tree-genealogy.py @ 1216

Last change on this file since 1216 was 713, checked in by Maciej Komosinski, 7 years ago

Made error message more specific

File size: 40.0 KB
Line 
1import json
2import math
3import random
4import argparse
5import bisect
6import copy
7import time as timelib
8from PIL import Image, ImageDraw, ImageFont
9from scipy import stats
10from matplotlib import colors
11import numpy as np
12
13class LoadingError(Exception):
14    pass
15
16class Drawer:
17
18    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
19        self.design = design
20        self.width = w
21        self.height = h
22        self.w_margin = w_margin
23        self.h_margin = h_margin
24        self.w_no_margs = w - 2* w_margin
25        self.h_no_margs = h - 2* h_margin
26
27        self.color_converter = colors.ColorConverter()
28
29        self.settings = {
30            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
31            'dots': {
32                'color': {
33                    'meaning': 'Lifespan',
34                    'normalize_cmap': False,
35                    'cmap': {},
36                    'start': 'red',
37                    'end': 'green',
38                    'bias': 1
39                    },
40                'size': {
41                    'meaning': 'EnergyEaten',
42                    'start': 1,
43                    'end': 6,
44                    'bias': 0.5
45                    },
46                'opacity': {
47                    'meaning': 'EnergyEaten',
48                    'start': 0.2,
49                    'end': 1,
50                    'bias': 1
51                    }
52            },
53            'lines': {
54                'color': {
55                    'meaning': 'adepth',
56                    'normalize_cmap': False,
57                    'cmap': {},
58                    'start': 'black',
59                    'end': 'red',
60                    'bias': 3
61                    },
62                'width': {
63                    'meaning': 'adepth',
64                    'start': 0.1,
65                    'end': 4,
66                    'bias': 3
67                    },
68                'opacity': {
69                    'meaning': 'adepth',
70                    'start': 0.1,
71                    'end': 0.8,
72                    'bias': 5
73                    }
74            }
75        }
76
77        def merge(source, destination):
78            for key, value in source.items():
79                if isinstance(value, dict):
80                    node = destination.setdefault(key, {})
81                    merge(value, node)
82                else:
83                    destination[key] = value
84            return destination
85
86        if config_file != "":
87            with open(config_file) as config:
88                c = json.load(config)
89            self.settings = merge(c, self.settings)
90            #print(json.dumps(self.settings, indent=4, sort_keys=True))
91
92        self.compile_cmaps()
93
94    def compile_cmaps(self):
95        def normalize_and_compile_cmap(cmap):
96            for key in cmap:
97                for arr in cmap[key]:
98                    arr[0] = (arr[0] - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
99            return colors.LinearSegmentedColormap('Custom', cmap)
100
101        for part in ['dots', 'lines']:
102            if self.settings[part]['color']['cmap']:
103                if self.settings[part]['color']['normalize_cmap']:
104                    cmap = self.settings[part]['color']['cmap']
105                    min = self.design.props[self.settings[part]['color']['meaning'] + "_min"]
106                    max = self.design.props[self.settings[part]['color']['meaning'] + "_max"]
107
108                    for key in cmap:
109                        if cmap[key][0][0] > min:
110                            cmap[key].insert(0, cmap[key][0][:])
111                            cmap[key][0][0] = min
112                        if cmap[key][-1][0] < max:
113                            cmap[key].append(cmap[key][-1][:])
114                            cmap[key][-1][0] = max
115
116                    og_cmap = normalize_and_compile_cmap(copy.deepcopy(cmap))
117
118                    col2key = {'red':0, 'green':1, 'blue':2}
119                    for key in cmap:
120                        # for color from (r/g/b) #n's should be the same for all keys!
121                        n_min = (min - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
122                        n_max = (max - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
123
124                        min_col = og_cmap(n_min)
125                        max_col = og_cmap(n_max)
126
127                        cmap[key][0] = [min, min_col[col2key[key]], min_col[col2key[key]]]
128                        cmap[key][-1] = [max, max_col[col2key[key]], max_col[col2key[key]]]
129                print(self.settings[part]['color']['cmap'])
130                self.settings[part]['color']['cmap'] = normalize_and_compile_cmap(self.settings[part]['color']['cmap'])
131
132    def draw_dots(self, file, min_width, max_width, max_height):
133        for i in range(len(self.design.positions)):
134            node = self.design.positions[i]
135            if 'x' not in node:
136                continue
137            dot_style = self.compute_dot_style(node=i)
138            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
139                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
140
141    def draw_lines(self, file, min_width, max_width, max_height):
142        for parent in range(len(self.design.positions)):
143            par_pos = self.design.positions[parent]
144            if not 'x' in par_pos:
145                continue
146            for child in self.design.tree.children[parent]:
147                chi_pos = self.design.positions[child]
148                if 'x' not in chi_pos:
149                    continue
150                line_style = self.compute_line_style(parent, child)
151                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
152                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
153                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
154                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
155
156    def draw_scale(self, file, filenames):
157        self.add_text(file, "Generated from " + filenames[0].split("\\")[-1]
158                      + (" and " + str(len(filenames)-1) + " more" if len(filenames) > 1 else ""), (5, 5), "start")
159
160        start_text = ""
161        end_text = ""
162        if self.design.TIME == "BIRTHS":
163           start_text = "Birth #0"
164           end_text = "Birth #" + str(len(self.design.positions)-1)
165        if self.design.TIME == "REAL":
166           start_text = "Time " + str(min(self.design.tree.time))
167           end_text = "Time " + str(max(self.design.tree.time))
168        if self.design.TIME == "GENERATIONAL":
169           start_text = "Depth " + str(self.design.props['adepth_min'])
170           end_text = "Depth " + str(self.design.props['adepth_max'])
171
172        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
173        self.add_text(file, start_text, (self.width, self.h_margin), "end")
174        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
175        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")
176
177    def compute_property(self, part, prop, node):
178        start = self.settings[part][prop]['start']
179        end = self.settings[part][prop]['end']
180        value = (self.design.props[self.settings[part][prop]['meaning']][node]
181                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
182        bias = self.settings[part][prop]['bias']
183        if prop == "color":
184            if not self.settings[part][prop]['cmap']:
185                return self.compute_color(start, end, value, bias)
186            else:
187                return self.compute_color_from_cmap(self.settings[part][prop]['cmap'], value, bias)
188        else:
189            return self.compute_value(start, end, value, bias)
190
191    def compute_color_from_cmap(self, cmap, value, bias=1):
192        value = 1 - (1-value)**bias
193        rgba = cmap(value)
194        return (100*rgba[0], 100*rgba[1], 100*rgba[2])
195
196
197    def compute_color(self, start, end, value, bias=1):
198        if isinstance(value, str):
199            value = int(value)
200            r, g, b = self.color_converter.to_rgb(self.settings['colors_of_kinds'][value])
201        else:
202            start_color = self.color_converter.to_rgb(start)
203            end_color = self.color_converter.to_rgb(end)
204            value = 1 - (1-value)**bias
205            r = start_color[0]*(1-value)+end_color[0]*value
206            g = start_color[1]*(1-value)+end_color[1]*value
207            b = start_color[2]*(1-value)+end_color[2]*value
208        return (100*r, 100*g, 100*b)
209
210    def compute_value(self, start, end, value, bias=1):
211        value = 1 - (1-value)**bias
212        return start*(1-value) + end*value
213
214class PngDrawer(Drawer):
215
216    def scale_up(self):
217        self.width *= self.multi
218        self.height *= self.multi
219        self.w_margin *= self.multi
220        self.h_margin *= self.multi
221        self.h_no_margs *= self.multi
222        self.w_no_margs *= self.multi
223
224    def scale_down(self):
225        self.width /= self.multi
226        self.height /= self.multi
227        self.w_margin /= self.multi
228        self.h_margin /= self.multi
229        self.h_no_margs /= self.multi
230        self.w_no_margs /= self.multi
231
232    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
233        print("Drawing...")
234
235        self.multi=multi
236        self.scale_up()
237
238        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
239
240        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
241        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
242        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
243
244        self.draw_lines(back, min_width, max_width, max_height)
245        self.draw_dots(back, min_width, max_width, max_height)
246
247        if scale == "SIMPLE":
248            self.draw_scale(back, input_filename)
249
250        #back.show()
251        self.scale_down()
252
253        back.thumbnail((self.width, self.height), Image.ANTIALIAS)
254
255        back.save(filename)
256
257    def add_dot(self, file, pos, style):
258        x, y = int(pos[0]), int(pos[1])
259        r = style['r']*self.multi
260        offset = (int(x - r), int(y - r))
261        size = (2*int(r), 2*int(r))
262
263        c = style['color']
264
265        img = Image.new('RGBA', size)
266        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
267                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
268        file.paste(img, offset, mask=img)
269
270    def add_line(self, file, from_pos, to_pos, style):
271        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
272        w = int(style['width'])*self.multi
273
274        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
275        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
276        if size[0] == 0 or size[1] == 0:
277            return
278
279        c = style['color']
280
281        img = Image.new('RGBA', size)
282        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
283                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
284        file.paste(img, offset, mask=img)
285
286    def add_dashed_line(self, file, from_pos, to_pos):
287        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
288        sublines = 50
289        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
290        normdiv = 2*sublines-1
291        for i in range(sublines):
292            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/normdiv, 1),
293                            self.compute_value(from_pos[1], to_pos[1], 2*i/normdiv, 1))
294            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/normdiv, 1),
295                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/normdiv, 1))
296            self.add_line(file, from_pos_sub, to_pos_sub, style)
297
298    def add_text(self, file, text, pos, anchor, style=''):
299        font = ImageFont.truetype("Vera.ttf", 16*self.multi)
300
301        img = Image.new('RGBA', (self.width, self.height))
302        draw = ImageDraw.Draw(img)
303        txtsize = draw.textsize(text, font=font)
304        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
305        draw.text(pos, text, (0,0,0), font=font)
306        file.paste(img, (0,0), mask=img)
307
308    def compute_line_style(self, parent, child):
309        return {'color': self.compute_property('lines', 'color', child),
310                'width': self.compute_property('lines', 'width', child),
311                'opacity': self.compute_property('lines', 'opacity', child)}
312
313    def compute_dot_style(self, node):
314        return {'color': self.compute_property('dots', 'color', node),
315                'r': self.compute_property('dots', 'size', node),
316                'opacity': self.compute_property('dots', 'opacity', node)}
317
318class SvgDrawer(Drawer):
319    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
320        print("Drawing...")
321        file = open(filename, "w")
322
323        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
324        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
325        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
326
327        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
328                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
329                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
330
331        self.draw_lines(file, min_width, max_width, max_height)
332        self.draw_dots(file, min_width, max_width, max_height)
333
334        if scale == "SIMPLE":
335            self.draw_scale(file, input_filename)
336
337        file.write("</svg>")
338        file.close()
339
340    def add_text(self, file, text, pos, anchor, style=''):
341        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
342        # assuming font size 12, it should be taken from the style string!
343        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')
344
345    def add_dot(self, file, pos, style):
346        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
347
348    def add_line(self, file, from_pos, to_pos, style):
349        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
350                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
351
352    def add_dashed_line(self, file, from_pos, to_pos):
353        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
354        self.add_line(file, from_pos, to_pos, style)
355
356    def compute_line_style(self, parent, child):
357        return self.compute_stroke_color('lines', child) + ' ' \
358               + self.compute_stroke_width('lines', child) + ' ' \
359               + self.compute_stroke_opacity(child)
360
361    def compute_dot_style(self, node):
362        return self.compute_dot_size(node) + ' ' \
363               + self.compute_fill_opacity(node) + ' ' \
364               + self.compute_dot_fill(node)
365
366    def compute_stroke_color(self, part, node):
367        color = self.compute_property(part, 'color', node)
368        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
369
370    def compute_stroke_width(self, part, node):
371        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
372
373    def compute_stroke_opacity(self, node):
374        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
375
376    def compute_fill_opacity(self, node):
377        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
378
379    def compute_dot_size(self, node):
380        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
381
382    def compute_dot_fill(self, node):
383        color = self.compute_property('dots', 'color', node)
384        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
385
386class Designer:
387
388    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
389        self.props = {}
390
391        self.tree = tree
392
393        self.TIME = time
394        self.JITTER = jitter
395
396        if balance == "RANDOM":
397            self.xmin_crowd = self.xmin_crowd_random
398        elif balance == "MIN":
399            self.xmin_crowd = self.xmin_crowd_min
400        elif balance == "DENSITY":
401            self.xmin_crowd = self.xmin_crowd_density
402        else:
403            raise ValueError("Error, the value of BALANCE does not match any expected value.")
404
405    def calculate_measures(self):
406        print("Calculating measures...")
407        self.compute_depth()
408        self.compute_maxdepth()
409        self.compute_adepth()
410        self.compute_children()
411        self.compute_kind()
412        self.compute_time()
413        self.compute_progress()
414        self.compute_custom()
415
416    def xmin_crowd_random(self, x1, x2, y):
417        return (x1 if random.randrange(2) == 0 else x2)
418
419    def xmin_crowd_min(self, x1, x2, y):
420        x1_closest = 999999
421        x2_closest = 999999
422        miny = y-3
423        maxy = y+3
424        i = bisect.bisect_left(self.y_sorted, miny)
425        while True:
426            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
427                break
428            pos = self.positions_sorted[i]
429
430            x1_closest = min(x1_closest, abs(x1-pos['x']))
431            x2_closest = min(x2_closest, abs(x2-pos['x']))
432
433            i += 1
434        return (x1 if x1_closest > x2_closest else x2)
435
436    def xmin_crowd_density(self, x1, x2, y):
437        # TODO experimental - requires further work to make it less 'jumpy' and more predictable
438        CONST_LOCAL_AREA_RADIUS = 5
439        CONST_GLOBAL_AREA_RADIUS = 10
440        CONST_WINDOW_SIZE = 20000 #TODO should depend on the maxY ?
441        x1_dist_loc = 0
442        x2_dist_loc = 0
443        count_loc = 1
444        x1_dist_glob = 0
445        x2_dist_glob = 0
446        count_glob = 1
447        miny = y-CONST_WINDOW_SIZE
448        maxy = y+CONST_WINDOW_SIZE
449        i_left = bisect.bisect_left(self.y_sorted, miny)
450        i_right = bisect.bisect_right(self.y_sorted, maxy)
451        #TODO test: maxy=y should give the same results, right?
452
453        def include_pos(pos):
454            nonlocal x1_dist_loc, x2_dist_loc, x1_dist_glob, x2_dist_glob, count_loc, count_glob
455
456            dysq = (pos['y']-y)**2 + 1 #+1 so 1/dysq is at most 1
457            dx1 = math.fabs(pos['x']-x1)
458            dx2 = math.fabs(pos['x']-x2)
459
460            d = math.fabs(pos['x'] - (x1+x2)/2)
461
462            if d < CONST_LOCAL_AREA_RADIUS:
463                x1_dist_loc += math.sqrt(dx1/dysq + dx1**2)
464                x2_dist_loc += math.sqrt(dx2/dysq + dx2**2)
465                count_loc += 1
466            elif d > CONST_GLOBAL_AREA_RADIUS:
467                x1_dist_glob += math.sqrt(dx1/dysq + dx1**2)
468                x2_dist_glob += math.sqrt(dx2/dysq + dx2**2)
469                count_glob += 1
470
471        # optimized to draw from all the nodes, if less than 10 nodes in the range
472        if len(self.positions_sorted) > i_left:
473            if i_right - i_left < 10:
474                for j in range(i_left, i_right):
475                    include_pos(self.positions_sorted[j])
476            else:
477                for j in range(10):
478                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
479                    include_pos(pos)
480
481        return (x1 if (x1_dist_loc-x2_dist_loc)/count_loc-(x1_dist_glob-x2_dist_glob)/count_glob > 0  else x2)
482        #return (x1 if x1_dist +random.gauss(0, 0.00001) > x2_dist +random.gauss(0, 0.00001)  else x2)
483        #print(x1_dist, x2_dist)
484        #x1_dist = x1_dist**2
485        #x2_dist = x2_dist**2
486        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
487        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
488
489    def calculate_node_positions(self, ignore_last=0):
490        print("Calculating positions...")
491
492        def add_node(node):
493            index = bisect.bisect_left(self.y_sorted, node['y'])
494            self.y_sorted.insert(index, node['y'])
495            self.positions_sorted.insert(index, node)
496            self.positions[node['id']] = node
497
498        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
499        self.y_sorted = [0]
500        self.positions = [{} for x in range(len(self.tree.parents))]
501        self.positions[0] = {'x':0, 'y':0, 'id':0}
502
503        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
504        visiting_order = [i for i in range(0, len(self.tree.parents))]
505        visiting_order = sorted(visiting_order, key=lambda q:\
506                            0 if q == 0 else self.props["maxdepth"][q])
507
508        start_time = timelib.time()
509
510        # for each child of the current node
511        for node_counter,child in enumerate(visiting_order, start=1):
512            # debug info - elapsed time
513            if node_counter % 100000 == 0:
514               print("%d%%\t%d\t%g" % (node_counter*100/len(self.tree.parents), node_counter, timelib.time()-start_time))
515               start_time = timelib.time()
516
517            # using normalized adepth
518            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
519
520                ypos = 0
521                if self.TIME == "BIRTHS":
522                    ypos = child
523                elif self.TIME == "GENERATIONAL":
524                    # one more than its parent (what if more than one parent?)
525                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1 \
526                        if self.tree.parents[child] else 0
527                elif self.TIME == "REAL":
528                    ypos = self.tree.time[child]
529
530                if len(self.tree.parents[child]) == 1:
531                # if current_node is the only parent
532                    parent, similarity = [(par, v) for par, v in self.tree.parents[child].items()][0]
533
534                    if self.JITTER:
535                        dissimilarity = (1-similarity) + random.gauss(0, 0.01) + 0.001
536                    else:
537                        dissimilarity = (1-similarity) + 0.001
538                    add_node({'id':child, 'y':ypos, 'x':
539                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
540                              self.positions[parent]['x']+dissimilarity, ypos)})
541                else:
542                    # position weighted by the degree of inheritence from each parent
543                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
544                    xpos = sum([self.positions[k]['x']*v/total_inheretance
545                               for k, v in self.tree.parents[child].items()])
546                    if self.JITTER:
547                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
548                    else:
549                        add_node({'id':child, 'y':ypos, 'x':xpos})
550
551
552    def compute_custom(self):
553        for prop in self.tree.props:
554            self.props[prop] = [None for x in range(len(self.tree.children))]
555
556            for i in range(len(self.props[prop])):
557                self.props[prop][i] = self.tree.props[prop][i]
558
559            self.normalize_prop(prop)
560
561    def compute_time(self):
562        # simple rewrite from the tree
563        self.props["time"] = [0 for x in range(len(self.tree.children))]
564
565        for i in range(len(self.props['time'])):
566            self.props['time'][i] = self.tree.time[i]
567
568        self.normalize_prop('time')
569
570    def compute_kind(self):
571        # simple rewrite from the tree
572        self.props["kind"] = [0 for x in range(len(self.tree.children))]
573
574        for i in range (len(self.props['kind'])):
575            self.props['kind'][i] = str(self.tree.kind[i])
576
577    def compute_depth(self):
578        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
579        visited = [0 for x in range(len(self.tree.children))]
580
581        nodes_to_visit = [0]
582        visited[0] = 1
583        self.props["depth"][0] = 0
584        while True:
585            current_node = nodes_to_visit[0]
586
587            for child in self.tree.children[current_node]:
588                if visited[child] == 0:
589                    visited[child] = 1
590                    nodes_to_visit.append(child)
591                    self.props["depth"][child] = self.props["depth"][current_node]+1
592            nodes_to_visit = nodes_to_visit[1:]
593            if len(nodes_to_visit) == 0:
594                break
595
596        self.normalize_prop('depth')
597
598    def compute_maxdepth(self):
599        self.props["maxdepth"] = [999999999 for x in range(len(self.tree.children))]
600        visited = [0 for x in range(len(self.tree.children))]
601
602        nodes_to_visit = [0]
603        visited[0] = 1
604        self.props["maxdepth"][0] = 0
605        while True:
606            current_node = nodes_to_visit[0]
607
608            for child in self.tree.children[current_node]:
609                if visited[child] == 0:
610                    visited[child] = 1
611                    nodes_to_visit.append(child)
612                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
613                elif self.props["maxdepth"][child] < self.props["maxdepth"][current_node]+1:
614                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
615                    if child not in  nodes_to_visit:
616                        nodes_to_visit.append(child)
617
618            nodes_to_visit = nodes_to_visit[1:]
619            if len(nodes_to_visit) == 0:
620                break
621
622        self.normalize_prop('maxdepth')
623
624    def compute_adepth(self):
625        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
626
627        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
628        visiting_order = [i for i in range(0, len(self.tree.parents))]
629        visiting_order = sorted(visiting_order, key=lambda q: self.props["maxdepth"][q])[::-1]
630
631        for node in visiting_order:
632            children = self.tree.children[node]
633            if len(children) != 0:
634                # 0 by default
635                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
636        self.normalize_prop('adepth')
637
638    def compute_children(self):
639        self.props["children"] = [0 for x in range(len(self.tree.children))]
640        for i in range (len(self.props['children'])):
641            self.props['children'][i] = len(self.tree.children[i])
642
643        self.normalize_prop('children')
644
645    def compute_progress(self):
646        self.props["progress"] = [0 for x in range(len(self.tree.children))]
647        for i in range(len(self.props['children'])):
648            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
649            if len(times) > 4:
650                times = [times[i+1] - times[i] for i in range(len(times)-1)]
651                #print(times)
652                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
653                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0
654
655        for i in range(0, 5):
656            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
657            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0
658
659        mini = min(self.props['progress'])
660        maxi = max(self.props['progress'])
661        for k in range(len(self.props['progress'])):
662            if self.props['progress'][k] == 0:
663                self.props['progress'][k] = mini
664
665        #for k in range(len(self.props['progress'])):
666        #        self.props['progress'][k] = 1-self.props['progress'][k]
667
668        self.normalize_prop('progress')
669
670    def normalize_prop(self, prop):
671        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
672        if len(noneless) > 0:
673            max_val = max(noneless)
674            min_val = min(noneless)
675            print("%s: [%g, %g]" % (prop, min_val, max_val))
676            self.props[prop +'_max'] = max_val
677            self.props[prop +'_min'] = min_val
678            for i in range(len(self.props[prop])):
679                if self.props[prop][i] is not None:
680                    qqq = self.props[prop][i]
681                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)
682
683class TreeData:
684    simple_data = None
685
686    children = []
687    parents = []
688    time = []
689    kind = []
690
691    def __init__(self): #, simple_data=False):
692        #self.simple_data = simple_data
693        pass
694
695    def load(self, filenames, max_nodes=0):
696        print("Loading...")
697
698        CLI_PREFIX = "Script.Message:"
699        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
700
701        merged_with_virtual_parent = [] #this list will contain individuals for which the parent could not be found
702
703        self.ids = {}
704        def get_id(id, createOnError = True):
705            if createOnError:
706                if id not in self.ids:
707                    self.ids[id] = len(self.ids)
708            else:
709                if id not in self.ids:
710                    return None
711
712            return self.ids[id]
713
714        def try_to_load(input):
715            creature = False
716            try:
717                creature = json.loads(input)
718            except ValueError:
719                print("Json format error: the line cannot be read. Breaking the loading loop.")
720                # fixing arrays by removing the last element
721                # ! assuming that only the last line is broken !
722                self.parents.pop()
723                self.children.pop()
724                self.time.pop()
725                self.kind.pop()
726                self.life_lenght.pop()
727            return creature
728
729        def load_creature_props(creature):
730            creature_id = get_id(creature["ID"])
731            for prop in creature:
732                if prop not in default_props:
733                    if prop not in self.props:
734                        self.props[prop] = [0 for i in range(nodes)]
735                    self.props[prop][creature_id] = creature[prop]
736
737        def load_born_props(creature):
738            nonlocal max_time
739            creature_id = get_id(creature["ID"])
740            if "Time" in creature:
741                self.time[creature_id] = creature["Time"] + time_offset
742                max_time = max(self.time[creature_id], max_time)
743
744        def load_offspring_props(creature):
745            creature_id = get_id(creature["ID"])#, False)
746            if "FromIDs" in creature:
747                # make sure that ID's of parents are lower than that of their children
748                for i in range(0, len(creature["FromIDs"])):
749                    if creature["FromIDs"][i] not in self.ids:
750                        get_id("virtual_parent")
751
752
753                # we assign to each parent its contribution to the genotype of the child
754                for i in range(0, len(creature["FromIDs"])):
755                    if creature["FromIDs"][i] in self.ids:
756                        parent_id = get_id(creature["FromIDs"][i])
757                    else:
758                        if creature["FromIDs"][i] not in merged_with_virtual_parent:
759                            merged_with_virtual_parent.append(creature["FromIDs"][i])
760                        parent_id = get_id("virtual_parent")
761                    inherited = (creature["Inherited"][i] if 'Inherited' in creature else 1)
762                    self.parents[creature_id][parent_id] = inherited
763
764                if "Kind" in creature:
765                    self.kind[creature_id] = creature["Kind"]
766            else:
767                raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
768
769        # counting the number of expected nodes
770        nodes_born, nodes_offspring = 0, 0
771        for filename in filenames:
772            file = open(filename)
773            for line in file:
774                line_arr = line.split(' ', 1)
775                if len(line_arr) == 2:
776                    if line_arr[0] == CLI_PREFIX:
777                        line_arr = line_arr[1].split(' ', 1)
778                    if line_arr[0] == "[BORN]":
779                        nodes_born += 1
780                    if line_arr[0] == "[OFFSPRING]":
781                        nodes_offspring += 1
782        # assuming that either BORN or OFFSPRING, or both, are present for each individual
783        nodes = max(nodes_born, nodes_offspring)
784        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
785
786        self.parents = [{} for x in range(nodes)]
787        self.children = [[] for x in range(nodes)]
788        self.time = [0] * nodes
789        self.kind = [0] * nodes
790        self.life_lenght = [0] * nodes
791        self.props = {}
792
793        print("nodes: %d" % len(self.parents))
794
795
796        get_id("virtual_parent")
797        loaded_so_far = 0
798        max_time = 0
799        # rewind the file
800
801        for filename in filenames:
802            file = open(filename)
803            time_offset = max_time
804            if max_time != 0:
805                print("NOTE: merging files, assuming cumulative time offset for '%s' to be %d" % (filename, time_offset))
806
807            lasttime = timelib.time()
808
809            for line in file:
810                line_arr = line.split(' ', 1)
811                if len(line_arr) == 2:
812                    if line_arr[0] == CLI_PREFIX:
813                        line_arr = line_arr[1].split(' ', 1)
814                    if line_arr[0] == "[BORN]":
815                        creature = try_to_load(line_arr[1])
816                        if not creature:
817                            nodes -= 1
818                            break
819
820                        if get_id(creature["ID"], False) is None:
821                            loaded_so_far += 1
822
823                        load_born_props(creature)
824                        load_creature_props(creature)
825
826                    if line_arr[0] == "[OFFSPRING]":
827                        creature = try_to_load(line_arr[1])
828                        if not creature:
829                            nodes -= 1
830                            break
831
832                        if get_id(creature["ID"], False) is None:
833                            loaded_so_far += 1
834                            # load time only if there was no [BORN] yet
835                            load_born_props(creature)
836
837                        load_offspring_props(creature)
838
839                    if line_arr[0] == "[DIED]":
840                        creature = try_to_load(line_arr[1])
841                        if not creature:
842                            nodes -= 1
843                            break
844                        if get_id(creature["ID"], False) is not None:
845                            load_creature_props(creature)
846                        else:
847                            print("NOTE: encountered [DIED] entry for individual '%s' before it was [BORN] or [OFFSPRING]" % creature["ID"])
848
849                # debug
850                if loaded_so_far%1000 == 0:
851                    #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
852                    lasttime = timelib.time()
853
854                # breaking both loops
855                if loaded_so_far >= max_nodes and max_nodes != 0:
856                    break
857            if loaded_so_far >= max_nodes and max_nodes != 0:
858                break
859
860        print("NOTE: all individuals with parent not provided or missing were connected to a single 'virtual parent' node: " + str(merged_with_virtual_parent))
861
862        for c_id in range(1, nodes):
863            if not self.parents[c_id]:
864                self.parents[c_id][get_id("virtual_parent")] = 1
865
866        for k in range(len(self.parents)):
867            v = self.parents[k]
868            for val in self.parents[k]:
869                self.children[val].append(k)
870
871depth = {}
872kind = {}
873
874def main():
875
876    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
877                                                 'information from a text file. Supports files generated by Framsticks experiments.')
878    parser.add_argument('-i', '--in', nargs='+', dest='input', required=True, help='input file name with stuctured evolutionary data (or a list of input files)')
879    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
880    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
881
882    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
883    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
884    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')
885
886    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
887                                                                      'BIRTHS: time measured as the number of births since the beginning; '
888                                                                      'GENERATIONAL: time measured as number of ancestors; '
889                                                                      'REAL: real time of the simulation')
890    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
891    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
892    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
893    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
894    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
895    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
896
897    parser.set_defaults(draw_tree=True)
898    parser.set_defaults(draw_skeleton=False)
899    parser.set_defaults(draw_spine=False)
900
901    parser.set_defaults(seed=-1)
902
903    args = parser.parse_args()
904
905    TIME = args.time.upper()
906    BALANCE = args.balance.upper()
907    SCALE = args.scale.upper()
908    JITTER = args.jitter
909    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
910        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
911        or not SCALE in ['NONE', 'SIMPLE']:
912        print("Incorrect value of one of the parameters! (time or balance or scale).") #user has to figure out which parameter is wrong...
913        return
914
915    dir = args.input
916
917    seed = args.seed
918    if seed == -1:
919        seed = random.randint(0, 10000)
920    random.seed(seed)
921    print("randomseed:", seed)
922
923    tree = TreeData()
924    tree.load(dir, max_nodes=args.max_nodes)
925
926
927    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
928    designer.calculate_measures()
929    designer.calculate_node_positions(ignore_last=args.skip)
930
931    if args.output.endswith(".svg"):
932        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
933    else:
934        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
935    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)
936
937
938main()
Note: See TracBrowser for help on using the repository browser.