source: mds-and-trees/mds_plot.py @ 568

Last change on this file since 568 was 565, checked in by oriona, 8 years ago

Script for plotting results of the MDS added.

File size: 2.7 KB
Line 
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3import sys
4import numpy as np
5from sklearn import manifold
6import matplotlib.pyplot as plt
7from mpl_toolkits.mplot3d import Axes3D
8from matplotlib import cm
9import argparse
10
11def rand_jitter(arr):
12        stdev = arr.max() / 100.
13        return arr + np.random.randn(len(arr)) * stdev * 2
14
15
16def read_file(fname, separator):
17        distances = np.genfromtxt(fname, delimiter=separator)
18        if np.isnan(distances[0][len(distances[0])-1]):#separator after the last element in row
19                distances = np.array([row[:-1] for row in distances])
20        return distances
21
22
23def compute_mds(distance_matrix, dim):
24        seed = np.random.RandomState(seed=3)
25        mds = manifold.MDS(n_components=int(dim), metric=True, max_iter=3000, eps=1e-9, random_state=seed, dissimilarity="precomputed")
26        embed = mds.fit(distance_matrix).embedding_
27        return embed
28
29
30def compute_variances(embed):
31        variances = []
32        for i in range(len(embed[0])):
33                variances.append(np.var(embed[:,i]))
34        percent_variances = [sum(variances[:i+1])/sum(variances) for i in range(len(variances))]
35        return percent_variances
36
37
38def plot(coordinates, dimensions, jitter=0, outname=""):
39        fig = plt.figure()
40
41        if dimensions < 3:
42                ax = fig.add_subplot(111)
43        else:
44                ax = fig.add_subplot(111, projection='3d')
45
46        add_jitter = lambda tab : rand_jitter(tab) if jitter==1 else tab
47
48        x_dim = len(coordinates[0])
49        y_dim = len(coordinates)
50
51        ax.scatter(*[add_jitter(coordinates[:, i]) for i in range(x_dim)], alpha=0.5)
52
53        plt.title('Phenotypes distances')
54        plt.tight_layout()
55        plt.axis('tight')
56
57        if outname == "":
58                plt.show()
59
60        else:
61                plt.savefig(outname+".pdf")
62
63
64def main(filename,dimensions=3, outname="", jitter=0, separator='\t'):
65        distances = read_file(filename, separator)
66        embed = compute_mds(distances, dimensions)
67
68        variances_perc = compute_variances(embed)
69        for i,vc in enumerate(variances_perc):
70                print(i+1,"dimension:",vc)
71
72        dimensions = int(dimensions)
73        if dimensions == 1:
74                embed = np.array([np.insert(e, 0, 0, axis=0) for e in embed])
75       
76        plot(embed, dimensions)
77
78
79if __name__ == '__main__':
80        parser = argparse.ArgumentParser()
81        parser.add_argument('--in', dest='input', required=True, help='input file with dissimilarity matrix')
82        parser.add_argument('--out', dest='output', required=False, help='output file name without extension')
83        parser.add_argument('--dim', required=False, help='number of dimensions of the new space')
84        parser.add_argument('--sep', required=False, help='separator of the source file')
85        parser.add_argument('--j', required=False, help='for j=1 random jitter is added to the plot')
86
87        args = parser.parse_args()
88        set_value = lambda value, default : default if value == None else value
89        main(args.input, set_value(args.dim, 3), set_value(args.output, ""), set_value(args.j, 0), set_value(args.sep, "\t"))
Note: See TracBrowser for help on using the repository browser.