source: mds-and-trees/tree-genealogy.py @ 612

Last change on this file since 612 was 594, checked in by konrad, 8 years ago

Fixed drawing the spine and the skeleton for the tree-genealogy.py

File size: 19.7 KB
Line 
1# Draws a genealogical tree (generates a SVG file) based on parent-child relationship information.
2
3import json
4import random
5import math
6import argparse
7
8TIME = "" # BIRTHS / GENERATIONAL / REAL
9BALANCE = "" # MIN / DENSITY
10
11DOT_STYLE = "" # NONE / NORMAL / CLEAR
12
13JITTER = "" #
14
15# ------SVG---------
16svg_file = 0
17
18svg_line_style = 'stroke="rgb(90%,10%,16%)" stroke-width="1" stroke-opacity="0.7"'
19svg_mutation_line_style = 'stroke-width="1"'
20svg_crossover_line_style = 'stroke-width="1"'
21svg_spine_line_style = 'stroke="rgb(0%,90%,40%)" stroke-width="2" stroke-opacity="1"'
22svg_scale_line_style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
23
24svg_dot_style = 'r="2" stroke="black" stroke-width="0.2" fill="red"'
25svg_clear_dot_style = 'r="2" stroke="black" stroke-width="0.4" fill="none"'
26svg_spine_dot_style = 'r="1" stroke="black" stroke-width="0.2" fill="rgb(50%,50%,100%)"'
27
28svg_scale_text_style = 'style="font-family: Arial; font-size: 12; fill: #000000;"'
29
30def hex_to_style(hex):
31    default_style = ' stroke="black" stroke-opacity="0.5" '
32
33    if hex[0] == "#":
34        hex = hex[1:]
35
36    if len(hex) == 6 or len(hex) == 8:
37        try:
38            int(hex, 16)
39        except:
40            print("Wrong characters in the color's hex #" + hex + "! Assuming black.")
41            return default_style
42        red = 100*int(hex[0:2], 16)/255
43        green = 100*int(hex[2:4], 16)/255
44        blue = 100*int(hex[4:6], 16)/255
45        opacity = 0.5
46        if len(hex) == 8:
47            opacity = int(hex[6:8], 16)/255
48        return ' stroke="rgb(' +str(red)+ '%,' +str(green)+ '%,' +str(blue)+ '%)" stroke-opacity="' +str(opacity)+ '" '
49    else:
50        print("Wrong number of digits in the color's hex #" + hex + "! Assuming black.")
51        return default_style
52
53def svg_add_line(from_pos, to_pos, style=svg_line_style):
54    svg_file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
55                   '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
56
57def svg_add_text(text, pos, anchor, style=svg_scale_text_style):
58    svg_file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]) + '" >' + text + '</text>')
59
60def svg_add_dot(pos, style=svg_dot_style):
61    svg_file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
62
63def svg_generate_line_style(percent):
64    # hotdog
65    from_col = [100, 70, 0]
66    to_col = [60, 0, 0]
67    # lava
68    # from_col = [100, 80, 0]
69    # to_col = [100, 0, 0]
70    # neon
71    # from_col = [30, 200, 255]
72    # to_col = [240, 0, 220]
73
74    from_opa = 0.2
75    to_opa = 1.0
76    from_stroke = 1
77    to_stroke = 3
78
79    opa = from_opa*(1-percent) + to_opa*percent
80    stroke = from_stroke*(1-percent) + to_stroke*percent
81
82    percent = 1 - ((1-percent)**20)
83
84    return 'stroke="rgb(' + str(from_col[0]*(1-percent) + to_col[0]*percent) + '%,' \
85           + str(from_col[1]*(1-percent) + to_col[1]*percent) + '%,' \
86           + str(from_col[2]*(1-percent) + to_col[2]*percent) + '%)" stroke-width="' + str(stroke) + '" stroke-opacity="' + str(opa) + '"'
87
88def svg_generate_dot_style(kind):
89    kinds = ["red", "lawngreen", "royalblue", "magenta", "yellow", "cyan", "white", "black"]
90
91    r = min(2500/len(nodes), 10)
92
93    return 'fill="' + kinds[kind] + '" r="' + str(r) + '" stroke="black" stroke-width="' + str(r/10) + '" fill-opacity="1.0" ' \
94           'stroke-opacity="1.0"'
95
96# -------------------
97
98def load_data(dir):
99    global firstnode, nodes, inv_nodes, time
100    f = open(dir)
101    for line in f:
102        sline = line.split(' ', 1)
103        if len(sline) == 2:
104            if sline[0] == "[OFFSPRING]":
105                creature = json.loads(sline[1])
106                #print("B" +str(creature))
107                if "FromIDs" in creature:
108                    if not creature["ID"] in nodes:
109                        nodes[creature["ID"]] = {}
110                        # we assign to each parent its contribution to the genotype of the child
111                        for i in range(0, len(creature["FromIDs"])):
112                            inherited = 1 #(creature["Inherited"][i] if 'Inherited' in creature else 1) #ONLY FOR NOW
113                            nodes[creature["ID"]][creature["FromIDs"][i]] = inherited
114                    else:
115                        print("Doubled entry for " + creature["ID"])
116                        quit()
117
118                    if not creature["FromIDs"][0] in nodes:
119                        firstnode = creature["FromIDs"][0]
120
121                if "Time" in creature:
122                    time[creature["ID"]] = creature["Time"]
123
124                if "Kind" in creature:
125                    kind[creature["ID"]] = creature["Kind"]
126
127    for k, v in sorted(nodes.items()):
128        for val in sorted(v):
129            inv_nodes[val] = inv_nodes.get(val, [])
130            inv_nodes[val].append(k)
131
132
133def load_simple_data(dir):
134    global firstnode, nodes, inv_nodes
135    f = open(dir)
136    for line in f:
137        sline = line.split()
138        if len(sline) > 1:
139            #if int(sline[0]) > 15000:
140            #    break
141            if sline[0] == firstnode:
142                continue
143            nodes[sline[0]] = str(max(int(sline[1]), int(firstnode)))
144        else:
145            firstnode = sline[0]
146
147    for k, v in sorted(nodes.items()):
148        inv_nodes[v] = inv_nodes.get(v, [])
149        inv_nodes[v].append(k)
150
151    #print(str(inv_nodes))
152    #quit()
153
154def compute_depth(node):
155    my_depth = 0
156    if node in inv_nodes:
157        for c in inv_nodes[node]:
158            my_depth = max(my_depth, compute_depth(c)+1)
159    depth[node] = my_depth
160    return my_depth
161
162# ------------------------------------
163
164def xmin_crowd(x1, x2, y):
165    if BALANCE == "RANDOM":
166        return (x1 if random.randrange(2) == 0 else x2)
167    elif BALANCE == "MIN":
168        x1_closest = 999999
169        x2_closest = 999999
170        for pos in positions:
171            pos = positions[pos]
172            if pos[1] == y:
173                x1_closest = min(x1_closest, abs(x1-pos[0]))
174                x2_closest = min(x2_closest, abs(x2-pos[0]))
175        return (x1 if x1_closest > x2_closest else x2)
176    elif BALANCE == "DENSITY":
177        x1_dist = 0
178        x2_dist = 0
179        for pos in positions:
180            pos = positions[pos]
181            if pos[1] > y-10 or pos[1] < y+10:
182                dy = pos[1]-y
183                dx1 = pos[0]-x1
184                dx2 = pos[0]-x2
185
186                x1_dist += math.sqrt(dy**2 + dx1**2)
187                x2_dist += math.sqrt(dy**2 + dx2**2)
188        return (x1 if x1_dist > x2_dist else x2)
189
190# ------------------------------------
191
192def prepos_children_reccurent(node):
193    global visited
194    for c in inv_nodes[node]:
195
196        # we want to visit the node just once, after all of its parents
197        if not all_parents_visited(c):
198            continue
199        else:
200            visited[c] = True
201
202        cy = 0
203        if TIME == "BIRTHS":
204            if c[0] == "c":
205                cy = int(c[1:])
206            else:
207                cy = int(c)
208        elif TIME == "GENERATIONAL":
209            cy = positions[node][1]+1
210        elif TIME == "REAL":
211            cy = time[c]
212
213        if len(nodes[c]) == 1:
214            dissimilarity = 0
215            if JITTER == True:
216                dissimilarity = random.gauss(0,1)
217            else:
218                dissimilarity = 1
219            positions[c] = [xmin_crowd(positions[node][0]-dissimilarity, positions[node][0]+dissimilarity, cy), cy]
220        else:
221            vsum = sum([v for k, v in nodes[c].items()])
222            cx = sum([positions[k][0]*v/vsum for k, v in nodes[c].items()])
223
224            if JITTER == True:
225                positions[c] = [cx + random.gauss(0, 0.1), cy]
226            else:
227                positions[c] = [cx, cy]
228
229
230        if c in inv_nodes:
231            prepos_children_reccurent(c)
232
233def prepos_children():
234    global max_height, max_width, min_width, visited
235
236    if not bool(time):
237        print("REAL time requested, but no real time data provided. Assuming BIRTHS time instead.")
238        TIME = "BIRTHS"
239
240    positions[firstnode] = [0, 0]
241
242    visited = {}
243    visited[firstnode] = True
244    prepos_children_reccurent(firstnode)
245
246    for pos in positions:
247        max_height = max(max_height, positions[pos][1])
248        max_width = max(max_width, positions[pos][0])
249        min_width = min(min_width, positions[pos][0])
250
251# ------------------------------------
252
253def all_parents_visited(node):
254    apv = True
255    for k, v in sorted(nodes[node].items()):
256        if not k in visited:
257            apv = False
258            break
259    return apv
260# ------------------------------------
261
262def draw_children_recurrent(node, max_depth):
263    global visited
264
265    for c in inv_nodes[node]:
266
267        # we want to draw the node just once
268        if not all_parents_visited(c):
269            continue
270        else:
271            visited[c] = True
272
273        if c in inv_nodes:
274            draw_children_recurrent(c, max_depth)
275
276        line_style = ""
277        if COLORING == "NONE":
278            line_style = svg_line_style
279        elif COLORING == "TYPE":
280            line_style = (svg_mutation_line_style if len(nodes[c]) == 1 else svg_crossover_line_style)
281        else: # IMPORTANCE, default
282            line_style = svg_generate_line_style(depth[c]/max_depth)
283
284        for k, v in sorted(nodes[c].items()):
285            svg_add_line( (w_margin+w_no_margs*(positions[k][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[k][1]/max_height),
286                (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), line_style)
287
288        if DOT_STYLE == "NONE":
289            continue
290        elif DOT_STYLE == "TYPE":
291            dot_style = svg_generate_dot_style(kind[c] if c in kind else 0) #type
292        else: # NORMAL, default
293            dot_style = svg_clear_dot_style #svg_generate_dot_style(depth[c]/max_depth)
294        svg_add_dot( (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), dot_style)
295        #svg_add_text( str(depth[c]), (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), "end")
296def draw_children():
297    global visited
298    visited = {}
299    visited[firstnode] = True
300
301    max_depth = 0
302    for k, v in depth.items():
303            max_depth = max(max_depth, v)
304    draw_children_recurrent(firstnode, max_depth)
305
306    if DOT_STYLE == "NONE":
307        return
308    elif DOT_STYLE == "TYPE":
309        dot_style = svg_generate_dot_style(kind[firstnode] if firstnode in kind else 0)
310    else: # NORMAL, default
311        dot_style = svg_clear_dot_style #svg_generate_dot_style(depth[c]/max_depth)
312    svg_add_dot( (w_margin+w_no_margs*(positions[firstnode][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[firstnode][1]/max_height), dot_style)
313
314def draw_spine_recurrent(node):
315    global visited
316    for c in inv_nodes[node]:
317
318        # we want to draw the node just once
319        if all_parents_visited(c):
320            visited[c] = True
321
322            if depth[c] == depth[node] - 1:
323                if c in inv_nodes:
324                    draw_spine_recurrent(c)
325
326        if depth[c] == depth[node] - 1:
327            line_style = svg_spine_line_style
328            svg_add_line( (w_margin+w_no_margs*(positions[node][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[node][1]/max_height),
329                (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), line_style)
330        #svg_add_dot( (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), svg_spine_dot_style)
331def draw_spine():
332    global visited
333    visited = {}
334    visited[firstnode] = True
335
336    draw_spine_recurrent(firstnode)
337    #svg_add_dot( (w_margin+w_no_margs*(positions[firstnode][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[firstnode][1]/max_height), svg_spine_dot_style)
338
339def draw_skeleton_reccurent(node):
340    global visited
341    for c in inv_nodes[node]:
342
343        if all_parents_visited(c):
344            visited[c] = True
345
346            if depth[c] >= min_skeleton_depth: # or depth[c] == max([depth[q] for q in inv_nodes[node]]):
347                if c in inv_nodes:
348                    draw_skeleton_reccurent(c)
349
350        if depth[c] >= min_skeleton_depth: # or depth[c] == max([depth[q] for q in inv_nodes[node]]):
351            #print([depth[q] for q in inv_nodes[node]])
352            line_style = svg_spine_line_style
353            svg_add_line( (w_margin+w_no_margs*(positions[node][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[node][1]/max_height),
354                (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), line_style)
355            #svg_add_dot( (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height),
356            #             svg_spine_dot_style)
357def draw_skeleton():
358    global visited
359    visited = {}
360    visited[firstnode] = True
361
362    draw_skeleton_reccurent(firstnode)
363    #svg_add_dot( (w_margin+w_no_margs*(positions[firstnode][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[firstnode][1]/max_height),
364    #             svg_spine_dot_style)
365
366# ------------------------------------
367
368def draw_scale(filename ,type):
369
370    svg_add_text( "Generated from " + filename.split("\\")[-1], (5, 15), "start")
371
372    svg_add_line( (w*0.7, h_margin), (w, h_margin), svg_scale_line_style)
373    start_text = ""
374    if TIME == "BIRTHS":
375       start_text = "Birth #" + str(min([int(k[1:]) for k, v in nodes.items()]))
376    if TIME == "REAL":
377       start_text = "Time " + str(min([v for k, v in time.items()]))
378    if TIME == "GENERATIONAL":
379       start_text = "Depth " + str(min([v for k, v in depth.items()]))
380    svg_add_text( start_text, (w, h_margin + 15), "end")
381
382    svg_add_line( (w*0.7, h-h_margin), (w, h-h_margin), svg_scale_line_style)
383    end_text = ""
384    if TIME == "BIRTHS":
385       end_text = "Birth #" + str(max([int(k[1:]) for k, v in nodes.items()]))
386    if TIME == "REAL":
387       end_text = "Time " + str(max([v for k, v in time.items()]))
388    if TIME == "GENERATIONAL":
389       end_text = "Depth " + str(max([v for k, v in depth.items()]))
390    svg_add_text( end_text, (w, h-h_margin + 15), "end")
391
392
393##################################################### main #####################################################
394
395args = 0
396
397h = 800
398w = 600
399h_margin = 20
400w_margin = 10
401h_no_margs = h - 2* h_margin
402w_no_margs = w - 2* w_margin
403
404max_height = 0
405max_width = 0
406min_width = 9999999999
407
408min_skeleton_depth = 0
409
410firstnode = ""
411nodes = {}
412inv_nodes = {}
413positions = {}
414visited= {}
415depth = {}
416time = {}
417kind = {}
418
419def main():
420    global svg_file, min_skeleton_depth, args, \
421        TIME, BALANCE, DOT_STYLE, COLORING, JITTER, \
422        svg_mutation_line_style, svg_crossover_line_style
423
424    parser = argparse.ArgumentParser(description='Process some integers.')
425    parser.add_argument('-i', '--in', dest='input', required=True, help='input file with stuctured evolutionary data')
426    parser.add_argument('-o', '--out', dest='output', required=True, help='output file for the evolutionary tree')
427    draw_tree_parser = parser.add_mutually_exclusive_group(required=False)
428    draw_tree_parser.add_argument('--draw-tree', dest='draw_tree', action='store_true', help='whether drawing the full tree should be skipped')
429    draw_tree_parser.add_argument('--no-draw-tree', dest='draw_tree', action='store_false')
430
431    draw_skeleton_parser = parser.add_mutually_exclusive_group(required=False)
432    draw_skeleton_parser.add_argument('--draw-skeleton', dest='draw_skeleton', action='store_true', help='whether the skeleton of the tree should be drawn')
433    draw_skeleton_parser.add_argument('--no-draw-skeleton', dest='draw_skeleton', action='store_false')
434
435    draw_spine_parser = parser.add_mutually_exclusive_group(required=False)
436    draw_spine_parser.add_argument('--draw-spine', dest='draw_spine', action='store_true', help='whether the spine of the tree should be drawn')
437    draw_spine_parser.add_argument('--no-draw-spine', dest='draw_spine', action='store_false')
438
439    #TODO: better names for those parameters
440    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL/REAL); '
441                                                                      'BIRTHS: time measured as the number of births since the beggining; '
442                                                                      'GENERATIONAL: time measured as number of ancestors; '
443                                                                      'REAL: real time of the simulation')
444    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing node in the tree (RANDOM/MIN/DENSITY)')
445    parser.add_argument('-s', '--scale', default='NONE', dest='scale', help='type of timescale added to the tree (NONE/SIMPLE)')
446    parser.add_argument('-c', '--coloring', default='IMPORTANCE', dest="coloring", help='method of coloring the tree (NONE/IMPORTANCE/TYPE)')
447    parser.add_argument('-d', '--dots', default='TYPE', dest='dots', help='method of drawing dots (individuals) (NONE/NORMAL/TYPE)')
448    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
449
450    parser.add_argument('--color-mut', default="#000000", dest="color_mut", help='color of clone/mutation lines in rgba (e.g. #FF60B240) for TYPE coloring')
451    parser.add_argument('--color-cross', default="#660198", dest="color_cross", help='color of crossover lines in rgba (e.g. #FF60B240) for TYPE coloring')
452
453    parser.add_argument('--min-skeleton-depth', type=int, default=2, dest='min_skeleton_depth', help='minimal distance from the leafs for the nodes in the skeleton')
454    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
455
456    parser.add_argument('--simple-data', type=bool, dest='simple_data', help='input data are given in a simple format (#child #parent)')
457
458    parser.set_defaults(draw_tree=True)
459    parser.set_defaults(draw_skeleton=False)
460    parser.set_defaults(draw_spine=False)
461
462    parser.set_defaults(seed=-1)
463
464    args = parser.parse_args()
465
466    TIME = args.time
467    BALANCE = args.balance
468    DOT_STYLE = args.dots
469    COLORING = args.coloring
470    JITTER = args.jitter
471
472    svg_mutation_line_style += hex_to_style(args.color_mut)
473    svg_crossover_line_style += hex_to_style(args.color_cross)
474
475    dir = args.input
476    min_skeleton_depth = args.min_skeleton_depth
477    seed = args.seed
478    if seed == -1:
479        seed = random.randint(0, 10000)
480    random.seed(seed)
481    print("seed:", seed)
482
483    if args.simple_data:
484        load_simple_data(dir)
485    else:
486        load_data(dir)
487
488    compute_depth(firstnode)
489
490    svg_file = open(args.output, "w")
491    svg_file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
492                   'width="' + str(w) + '" height="' + str(h) + '">')
493
494    prepos_children()
495
496    if args.draw_tree:
497        draw_children()
498    if args.draw_skeleton:
499        draw_skeleton()
500    if args.draw_spine:
501        draw_spine()
502
503    draw_scale(dir, args.scale)
504
505    svg_file.write("</svg>")
506    svg_file.close()
507
508main()
509
Note: See TracBrowser for help on using the repository browser.