source: mds-and-trees/tree-genealogy.py @ 594

Last change on this file since 594 was 594, checked in by konrad, 8 years ago

Fixed drawing the spine and the skeleton for the tree-genealogy.py

File size: 19.7 KB
RevLine 
[562]1# Draws a genealogical tree (generates a SVG file) based on parent-child relationship information.
2
3import json
4import random
5import math
6import argparse
7
[571]8TIME = "" # BIRTHS / GENERATIONAL / REAL
9BALANCE = "" # MIN / DENSITY
[562]10
[571]11DOT_STYLE = "" # NONE / NORMAL / CLEAR
12
13JITTER = "" #
14
[562]15# ------SVG---------
16svg_file = 0
17
[577]18svg_line_style = 'stroke="rgb(90%,10%,16%)" stroke-width="1" stroke-opacity="0.7"'
[585]19svg_mutation_line_style = 'stroke-width="1"'
20svg_crossover_line_style = 'stroke-width="1"'
[577]21svg_spine_line_style = 'stroke="rgb(0%,90%,40%)" stroke-width="2" stroke-opacity="1"'
22svg_scale_line_style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
23
[562]24svg_dot_style = 'r="2" stroke="black" stroke-width="0.2" fill="red"'
[571]25svg_clear_dot_style = 'r="2" stroke="black" stroke-width="0.4" fill="none"'
[562]26svg_spine_dot_style = 'r="1" stroke="black" stroke-width="0.2" fill="rgb(50%,50%,100%)"'
27
[576]28svg_scale_text_style = 'style="font-family: Arial; font-size: 12; fill: #000000;"'
29
[585]30def hex_to_style(hex):
[586]31    default_style = ' stroke="black" stroke-opacity="0.5" '
32
[585]33    if hex[0] == "#":
34        hex = hex[1:]
35
36    if len(hex) == 6 or len(hex) == 8:
37        try:
38            int(hex, 16)
39        except:
40            print("Wrong characters in the color's hex #" + hex + "! Assuming black.")
[586]41            return default_style
[585]42        red = 100*int(hex[0:2], 16)/255
43        green = 100*int(hex[2:4], 16)/255
44        blue = 100*int(hex[4:6], 16)/255
45        opacity = 0.5
46        if len(hex) == 8:
47            opacity = int(hex[6:8], 16)/255
48        return ' stroke="rgb(' +str(red)+ '%,' +str(green)+ '%,' +str(blue)+ '%)" stroke-opacity="' +str(opacity)+ '" '
49    else:
50        print("Wrong number of digits in the color's hex #" + hex + "! Assuming black.")
[586]51        return default_style
[585]52
[562]53def svg_add_line(from_pos, to_pos, style=svg_line_style):
[589]54    svg_file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
55                   '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
[562]56
[576]57def svg_add_text(text, pos, anchor, style=svg_scale_text_style):
58    svg_file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]) + '" >' + text + '</text>')
59
[562]60def svg_add_dot(pos, style=svg_dot_style):
61    svg_file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
62
63def svg_generate_line_style(percent):
[564]64    # hotdog
[562]65    from_col = [100, 70, 0]
[564]66    to_col = [60, 0, 0]
[571]67    # lava
68    # from_col = [100, 80, 0]
69    # to_col = [100, 0, 0]
[564]70    # neon
71    # from_col = [30, 200, 255]
72    # to_col = [240, 0, 220]
[562]73
[564]74    from_opa = 0.2
75    to_opa = 1.0
76    from_stroke = 1
77    to_stroke = 3
[562]78
[564]79    opa = from_opa*(1-percent) + to_opa*percent
80    stroke = from_stroke*(1-percent) + to_stroke*percent
81
82    percent = 1 - ((1-percent)**20)
83
[562]84    return 'stroke="rgb(' + str(from_col[0]*(1-percent) + to_col[0]*percent) + '%,' \
85           + str(from_col[1]*(1-percent) + to_col[1]*percent) + '%,' \
[564]86           + str(from_col[2]*(1-percent) + to_col[2]*percent) + '%)" stroke-width="' + str(stroke) + '" stroke-opacity="' + str(opa) + '"'
[562]87
[577]88def svg_generate_dot_style(kind):
89    kinds = ["red", "lawngreen", "royalblue", "magenta", "yellow", "cyan", "white", "black"]
[562]90
[577]91    r = min(2500/len(nodes), 10)
[562]92
[577]93    return 'fill="' + kinds[kind] + '" r="' + str(r) + '" stroke="black" stroke-width="' + str(r/10) + '" fill-opacity="1.0" ' \
94           'stroke-opacity="1.0"'
[564]95
[562]96# -------------------
97
98def load_data(dir):
[571]99    global firstnode, nodes, inv_nodes, time
[562]100    f = open(dir)
101    for line in f:
[571]102        sline = line.split(' ', 1)
103        if len(sline) == 2:
104            if sline[0] == "[OFFSPRING]":
105                creature = json.loads(sline[1])
[562]106                #print("B" +str(creature))
[563]107                if "FromIDs" in creature:
[572]108                    if not creature["ID"] in nodes:
109                        nodes[creature["ID"]] = {}
110                        # we assign to each parent its contribution to the genotype of the child
111                        for i in range(0, len(creature["FromIDs"])):
112                            inherited = 1 #(creature["Inherited"][i] if 'Inherited' in creature else 1) #ONLY FOR NOW
113                            nodes[creature["ID"]][creature["FromIDs"][i]] = inherited
114                    else:
115                        print("Doubled entry for " + creature["ID"])
116                        quit()
117
[563]118                    if not creature["FromIDs"][0] in nodes:
119                        firstnode = creature["FromIDs"][0]
[572]120
[566]121                if "Time" in creature:
122                    time[creature["ID"]] = creature["Time"]
[562]123
[577]124                if "Kind" in creature:
125                    kind[creature["ID"]] = creature["Kind"]
126
[562]127    for k, v in sorted(nodes.items()):
[572]128        for val in sorted(v):
129            inv_nodes[val] = inv_nodes.get(val, [])
130            inv_nodes[val].append(k)
[562]131
132
133def load_simple_data(dir):
134    global firstnode, nodes, inv_nodes
135    f = open(dir)
136    for line in f:
137        sline = line.split()
138        if len(sline) > 1:
139            #if int(sline[0]) > 15000:
140            #    break
141            if sline[0] == firstnode:
142                continue
143            nodes[sline[0]] = str(max(int(sline[1]), int(firstnode)))
144        else:
145            firstnode = sline[0]
146
147    for k, v in sorted(nodes.items()):
148        inv_nodes[v] = inv_nodes.get(v, [])
149        inv_nodes[v].append(k)
150
151    #print(str(inv_nodes))
152    #quit()
153
154def compute_depth(node):
155    my_depth = 0
156    if node in inv_nodes:
157        for c in inv_nodes[node]:
158            my_depth = max(my_depth, compute_depth(c)+1)
159    depth[node] = my_depth
160    return my_depth
161
162# ------------------------------------
163
164def xmin_crowd(x1, x2, y):
165    if BALANCE == "RANDOM":
166        return (x1 if random.randrange(2) == 0 else x2)
167    elif BALANCE == "MIN":
168        x1_closest = 999999
169        x2_closest = 999999
170        for pos in positions:
171            pos = positions[pos]
172            if pos[1] == y:
173                x1_closest = min(x1_closest, abs(x1-pos[0]))
174                x2_closest = min(x2_closest, abs(x2-pos[0]))
175        return (x1 if x1_closest > x2_closest else x2)
176    elif BALANCE == "DENSITY":
177        x1_dist = 0
178        x2_dist = 0
179        for pos in positions:
180            pos = positions[pos]
181            if pos[1] > y-10 or pos[1] < y+10:
182                dy = pos[1]-y
183                dx1 = pos[0]-x1
184                dx2 = pos[0]-x2
185
186                x1_dist += math.sqrt(dy**2 + dx1**2)
187                x2_dist += math.sqrt(dy**2 + dx2**2)
188        return (x1 if x1_dist > x2_dist else x2)
189
190# ------------------------------------
191
192def prepos_children_reccurent(node):
[572]193    global visited
[562]194    for c in inv_nodes[node]:
[572]195
196        # we want to visit the node just once, after all of its parents
197        if not all_parents_visited(c):
198            continue
[571]199        else:
[572]200            visited[c] = True
[571]201
[572]202        cy = 0
[566]203        if TIME == "BIRTHS":
[562]204            if c[0] == "c":
[572]205                cy = int(c[1:])
[562]206            else:
[572]207                cy = int(c)
[562]208        elif TIME == "GENERATIONAL":
[572]209            cy = positions[node][1]+1
[566]210        elif TIME == "REAL":
[572]211            cy = time[c]
[562]212
[572]213        if len(nodes[c]) == 1:
214            dissimilarity = 0
215            if JITTER == True:
216                dissimilarity = random.gauss(0,1)
217            else:
218                dissimilarity = 1
219            positions[c] = [xmin_crowd(positions[node][0]-dissimilarity, positions[node][0]+dissimilarity, cy), cy]
220        else:
221            vsum = sum([v for k, v in nodes[c].items()])
222            cx = sum([positions[k][0]*v/vsum for k, v in nodes[c].items()])
223
224            if JITTER == True:
225                positions[c] = [cx + random.gauss(0, 0.1), cy]
226            else:
227                positions[c] = [cx, cy]
228
229
[562]230        if c in inv_nodes:
231            prepos_children_reccurent(c)
232
233def prepos_children():
[572]234    global max_height, max_width, min_width, visited
[562]235
[566]236    if not bool(time):
237        print("REAL time requested, but no real time data provided. Assuming BIRTHS time instead.")
238        TIME = "BIRTHS"
239
[562]240    positions[firstnode] = [0, 0]
241
[572]242    visited = {}
243    visited[firstnode] = True
[562]244    prepos_children_reccurent(firstnode)
245
246    for pos in positions:
247        max_height = max(max_height, positions[pos][1])
248        max_width = max(max_width, positions[pos][0])
249        min_width = min(min_width, positions[pos][0])
250
251# ------------------------------------
252
[572]253def all_parents_visited(node):
254    apv = True
255    for k, v in sorted(nodes[node].items()):
256        if not k in visited:
257            apv = False
258            break
259    return apv
260# ------------------------------------
261
[562]262def draw_children_recurrent(node, max_depth):
[572]263    global visited
264
[562]265    for c in inv_nodes[node]:
[572]266
267        # we want to draw the node just once
268        if not all_parents_visited(c):
269            continue
270        else:
271            visited[c] = True
272
[562]273        if c in inv_nodes:
274            draw_children_recurrent(c, max_depth)
[564]275
[577]276        line_style = ""
277        if COLORING == "NONE":
278            line_style = svg_line_style
279        elif COLORING == "TYPE":
280            line_style = (svg_mutation_line_style if len(nodes[c]) == 1 else svg_crossover_line_style)
281        else: # IMPORTANCE, default
282            line_style = svg_generate_line_style(depth[c]/max_depth)
283
[572]284        for k, v in sorted(nodes[c].items()):
285            svg_add_line( (w_margin+w_no_margs*(positions[k][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[k][1]/max_height),
286                (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), line_style)
[571]287
288        if DOT_STYLE == "NONE":
289            continue
[585]290        elif DOT_STYLE == "TYPE":
[577]291            dot_style = svg_generate_dot_style(kind[c] if c in kind else 0) #type
[571]292        else: # NORMAL, default
[577]293            dot_style = svg_clear_dot_style #svg_generate_dot_style(depth[c]/max_depth)
[564]294        svg_add_dot( (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), dot_style)
[594]295        #svg_add_text( str(depth[c]), (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), "end")
[562]296def draw_children():
[572]297    global visited
298    visited = {}
299    visited[firstnode] = True
300
[562]301    max_depth = 0
302    for k, v in depth.items():
303            max_depth = max(max_depth, v)
304    draw_children_recurrent(firstnode, max_depth)
[571]305
306    if DOT_STYLE == "NONE":
307        return
[585]308    elif DOT_STYLE == "TYPE":
[577]309        dot_style = svg_generate_dot_style(kind[firstnode] if firstnode in kind else 0)
[571]310    else: # NORMAL, default
[577]311        dot_style = svg_clear_dot_style #svg_generate_dot_style(depth[c]/max_depth)
[564]312    svg_add_dot( (w_margin+w_no_margs*(positions[firstnode][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[firstnode][1]/max_height), dot_style)
[562]313
314def draw_spine_recurrent(node):
[594]315    global visited
[562]316    for c in inv_nodes[node]:
[594]317
318        # we want to draw the node just once
319        if all_parents_visited(c):
320            visited[c] = True
321
322            if depth[c] == depth[node] - 1:
323                if c in inv_nodes:
324                    draw_spine_recurrent(c)
325
[562]326        if depth[c] == depth[node] - 1:
[564]327            line_style = svg_spine_line_style
[562]328            svg_add_line( (w_margin+w_no_margs*(positions[node][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[node][1]/max_height),
[564]329                (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), line_style)
[594]330        #svg_add_dot( (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), svg_spine_dot_style)
[562]331def draw_spine():
[594]332    global visited
333    visited = {}
334    visited[firstnode] = True
335
[562]336    draw_spine_recurrent(firstnode)
337    #svg_add_dot( (w_margin+w_no_margs*(positions[firstnode][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[firstnode][1]/max_height), svg_spine_dot_style)
338
[594]339def draw_skeleton_reccurent(node):
340    global visited
[562]341    for c in inv_nodes[node]:
[564]342
[594]343        if all_parents_visited(c):
344            visited[c] = True
345
346            if depth[c] >= min_skeleton_depth: # or depth[c] == max([depth[q] for q in inv_nodes[node]]):
347                if c in inv_nodes:
348                    draw_skeleton_reccurent(c)
349
350        if depth[c] >= min_skeleton_depth: # or depth[c] == max([depth[q] for q in inv_nodes[node]]):
351            #print([depth[q] for q in inv_nodes[node]])
[564]352            line_style = svg_spine_line_style
[562]353            svg_add_line( (w_margin+w_no_margs*(positions[node][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[node][1]/max_height),
[564]354                (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height), line_style)
[562]355            #svg_add_dot( (w_margin+w_no_margs*(positions[c][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[c][1]/max_height),
356            #             svg_spine_dot_style)
357def draw_skeleton():
[594]358    global visited
359    visited = {}
360    visited[firstnode] = True
[562]361
[594]362    draw_skeleton_reccurent(firstnode)
[562]363    #svg_add_dot( (w_margin+w_no_margs*(positions[firstnode][0]-min_width)/(max_width-min_width), h_margin+h_no_margs*positions[firstnode][1]/max_height),
364    #             svg_spine_dot_style)
365
[576]366# ------------------------------------
[562]367
[576]368def draw_scale(filename ,type):
[562]369
[576]370    svg_add_text( "Generated from " + filename.split("\\")[-1], (5, 15), "start")
371
372    svg_add_line( (w*0.7, h_margin), (w, h_margin), svg_scale_line_style)
373    start_text = ""
374    if TIME == "BIRTHS":
375       start_text = "Birth #" + str(min([int(k[1:]) for k, v in nodes.items()]))
376    if TIME == "REAL":
377       start_text = "Time " + str(min([v for k, v in time.items()]))
378    if TIME == "GENERATIONAL":
379       start_text = "Depth " + str(min([v for k, v in depth.items()]))
380    svg_add_text( start_text, (w, h_margin + 15), "end")
381
382    svg_add_line( (w*0.7, h-h_margin), (w, h-h_margin), svg_scale_line_style)
383    end_text = ""
384    if TIME == "BIRTHS":
385       end_text = "Birth #" + str(max([int(k[1:]) for k, v in nodes.items()]))
386    if TIME == "REAL":
387       end_text = "Time " + str(max([v for k, v in time.items()]))
388    if TIME == "GENERATIONAL":
389       end_text = "Depth " + str(max([v for k, v in depth.items()]))
[577]390    svg_add_text( end_text, (w, h-h_margin + 15), "end")
[576]391
392
[562]393##################################################### main #####################################################
394
395args = 0
396
397h = 800
398w = 600
[576]399h_margin = 20
[562]400w_margin = 10
401h_no_margs = h - 2* h_margin
402w_no_margs = w - 2* w_margin
403
404max_height = 0
405max_width = 0
406min_width = 9999999999
407
408min_skeleton_depth = 0
409
410firstnode = ""
411nodes = {}
412inv_nodes = {}
413positions = {}
[572]414visited= {}
[562]415depth = {}
[566]416time = {}
[577]417kind = {}
[562]418
419def main():
[585]420    global svg_file, min_skeleton_depth, args, \
421        TIME, BALANCE, DOT_STYLE, COLORING, JITTER, \
422        svg_mutation_line_style, svg_crossover_line_style
[562]423
424    parser = argparse.ArgumentParser(description='Process some integers.')
[576]425    parser.add_argument('-i', '--in', dest='input', required=True, help='input file with stuctured evolutionary data')
426    parser.add_argument('-o', '--out', dest='output', required=True, help='output file for the evolutionary tree')
[562]427    draw_tree_parser = parser.add_mutually_exclusive_group(required=False)
428    draw_tree_parser.add_argument('--draw-tree', dest='draw_tree', action='store_true', help='whether drawing the full tree should be skipped')
429    draw_tree_parser.add_argument('--no-draw-tree', dest='draw_tree', action='store_false')
430
431    draw_skeleton_parser = parser.add_mutually_exclusive_group(required=False)
432    draw_skeleton_parser.add_argument('--draw-skeleton', dest='draw_skeleton', action='store_true', help='whether the skeleton of the tree should be drawn')
433    draw_skeleton_parser.add_argument('--no-draw-skeleton', dest='draw_skeleton', action='store_false')
434
435    draw_spine_parser = parser.add_mutually_exclusive_group(required=False)
436    draw_spine_parser.add_argument('--draw-spine', dest='draw_spine', action='store_true', help='whether the spine of the tree should be drawn')
437    draw_spine_parser.add_argument('--no-draw-spine', dest='draw_spine', action='store_false')
438
439    #TODO: better names for those parameters
[585]440    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL/REAL); '
[571]441                                                                      'BIRTHS: time measured as the number of births since the beggining; '
442                                                                      'GENERATIONAL: time measured as number of ancestors; '
443                                                                      'REAL: real time of the simulation')
[585]444    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing node in the tree (RANDOM/MIN/DENSITY)')
[577]445    parser.add_argument('-s', '--scale', default='NONE', dest='scale', help='type of timescale added to the tree (NONE/SIMPLE)')
446    parser.add_argument('-c', '--coloring', default='IMPORTANCE', dest="coloring", help='method of coloring the tree (NONE/IMPORTANCE/TYPE)')
[585]447    parser.add_argument('-d', '--dots', default='TYPE', dest='dots', help='method of drawing dots (individuals) (NONE/NORMAL/TYPE)')
[571]448    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
449
[585]450    parser.add_argument('--color-mut', default="#000000", dest="color_mut", help='color of clone/mutation lines in rgba (e.g. #FF60B240) for TYPE coloring')
451    parser.add_argument('--color-cross', default="#660198", dest="color_cross", help='color of crossover lines in rgba (e.g. #FF60B240) for TYPE coloring')
452
[562]453    parser.add_argument('--min-skeleton-depth', type=int, default=2, dest='min_skeleton_depth', help='minimal distance from the leafs for the nodes in the skeleton')
454    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
455
456    parser.add_argument('--simple-data', type=bool, dest='simple_data', help='input data are given in a simple format (#child #parent)')
457
458    parser.set_defaults(draw_tree=True)
459    parser.set_defaults(draw_skeleton=False)
460    parser.set_defaults(draw_spine=False)
461
462    parser.set_defaults(seed=-1)
463
464    args = parser.parse_args()
465
466    TIME = args.time
467    BALANCE = args.balance
[571]468    DOT_STYLE = args.dots
[577]469    COLORING = args.coloring
[571]470    JITTER = args.jitter
[562]471
[585]472    svg_mutation_line_style += hex_to_style(args.color_mut)
473    svg_crossover_line_style += hex_to_style(args.color_cross)
474
[562]475    dir = args.input
476    min_skeleton_depth = args.min_skeleton_depth
477    seed = args.seed
478    if seed == -1:
479        seed = random.randint(0, 10000)
480    random.seed(seed)
481    print("seed:", seed)
482
483    if args.simple_data:
484        load_simple_data(dir)
485    else:
486        load_data(dir)
487
488    compute_depth(firstnode)
489
490    svg_file = open(args.output, "w")
491    svg_file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
492                   'width="' + str(w) + '" height="' + str(h) + '">')
493
494    prepos_children()
495
496    if args.draw_tree:
497        draw_children()
498    if args.draw_skeleton:
499        draw_skeleton()
500    if args.draw_spine:
501        draw_spine()
502
[576]503    draw_scale(dir, args.scale)
504
[562]505    svg_file.write("</svg>")
506    svg_file.close()
507
508main()
509
Note: See TracBrowser for help on using the repository browser.