Changeset 602
 Timestamp:
 08/28/16 17:00:05 (8 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

mdsandtrees/mds_plot.py
r600 r602 4 4 import sys 5 5 import numpy as np 6 from sklearn import manifold 6 #from sklearn import manifold #was needed for manifold MDS http://scikitlearn.org/stable/auto_examples/manifold/plot_compare_methods.html 7 7 8 8 #to make it work in console, http://stackoverflow.com/questions/2801882/generatingapngwithmatplotlibwhendisplayisundefined … … 70 70 def read_file(fname, separator): 71 71 distances = np.genfromtxt(fname, delimiter=separator) 72 if np.isnan(distances[0][len(distances[0])1]):#separator after the last element in row 73 distances = np.array([row[:1] for row in distances]) 74 return distances 72 if (distances.shape[0]!=distances.shape[1]): 73 print("Matrix is not square:",distances.shape) 74 minsize = min(distances.shape[0],distances.shape[1]) 75 distances = np.array([row[:minsize] for row in distances]) #this can only fix matrices with more columns than rows 76 print("Making it square:",distances.shape) 77 78 try: #maybe the file has more columns than rows, and the extra column has labels? 79 labels = np.genfromtxt(fname, delimiter=separator, usecols=distances.shape[0],dtype=[('label','S10')]) 80 labels = [label[0].decode("utf8") for label in labels] 81 except ValueError: 82 labels = None #no labels 83 84 return distances,labels 75 85 76 86 … … 89 99 90 100 91 def plot(coordinates, dimensions, jitter=0, outname=""):101 def plot(coordinates, labels, dimensions, jitter=0, outname=""): 92 102 fig = plt.figure() 93 103 … … 102 112 y_dim = len(coordinates) 103 113 104 ax.scatter(*[add_jitter(coordinates[:, i]) for i in range(x_dim)], alpha=0.5) 114 points = [add_jitter(coordinates[:, i]) for i in range(x_dim)] 115 116 if labels is not None and dimensions==2: 117 ax.scatter(*points, alpha=0.1) #barely visible points, because we will show labels anyway 118 labelconvert={'vel':'V','vpp':'P','vpa':'A'} #use this if you want to replace long names with short IDs 119 #for point in points: 120 # print(point) 121 for label, x, y in zip(labels, points[0], points[1]): 122 #if label not in knownlabels: 123 # knownlabels.append(label) 124 # colors.append('#ff0000') 125 for key in labelconvert: 126 if label.startswith(key): 127 label=labelconvert[key] 128 plt.annotate( 129 label, 130 xy = (x, y), xytext = (0, 0), 131 textcoords = 'offset points', ha = 'center', va = 'center', 132 #bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5), 133 #arrowprops = dict(arrowstyle = '>', connectionstyle = 'arc3,rad=0') 134 ) 135 else: 136 ax.scatter(*points, alpha=0.5) 137 105 138 106 139 plt.title('Phenotypes distances') … … 118 151 def main(filename, dimensions=3, outname="", jitter=0, separator='\t'): 119 152 dimensions = int(dimensions) 120 distances = read_file(filename, separator)153 distances,labels = read_file(filename, separator) 121 154 embed = compute_mds(distances, dimensions) 122 155 … … 124 157 embed = np.array([np.insert(e, 0, 0, axis=0) for e in embed]) 125 158 126 plot(embed, dimensions, jitter, outname)159 plot(embed, labels, dimensions, jitter, outname) 127 160 128 161
Note: See TracChangeset
for help on using the changeset viewer.