1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
| import numpy as np
from matplotlib.widgets import PolygonSelector from matplotlib.path import Path
class SelectFromCollection(object): """Select indices from a matplotlib collection using `PolygonSelector`.
Selected indices are saved in the `ind` attribute. This tool fades out the points that are not part of the selection (i.e., reduces their alpha values). If your collection has alpha < 1, this tool will permanently alter the alpha values.
Note that this tool selects collection objects based on their *origins* (i.e., `offsets`).
Parameters ---------- ax : :class:`~matplotlib.axes.Axes` Axes to interact with.
collection : :class:`matplotlib.collections.Collection` subclass Collection you want to select from.
alpha_other : 0 <= float <= 1 To highlight a selection, this tool sets all selected points to an alpha value of 1 and non-selected points to `alpha_other`. """
def __init__(self, ax, collection, alpha_other=0.3): self.canvas = ax.figure.canvas self.collection = collection self.alpha_other = alpha_other
self.xys = collection.get_offsets() self.Npts = len(self.xys)
self.fc = collection.get_facecolors() if len(self.fc) == 0: raise ValueError('Collection must have a facecolor') elif len(self.fc) == 1: self.fc = np.tile(self.fc, (self.Npts, 1))
self.poly = PolygonSelector(ax, self.onselect) self.ind = []
def onselect(self, verts): path = Path(verts) self.ind = np.nonzero(path.contains_points(self.xys))[0] self.fc[:, -1] = self.alpha_other self.fc[self.ind, -1] = 1 self.collection.set_facecolors(self.fc) self.canvas.draw_idle()
def disconnect(self): self.poly.disconnect_events() self.fc[:, -1] = 1 self.collection.set_facecolors(self.fc) self.canvas.draw_idle()
if __name__ == '__main__': import matplotlib.pyplot as plt
fig, ax = plt.subplots() grid_size = 5 grid_x = np.tile(np.arange(grid_size), grid_size) grid_y = np.repeat(np.arange(grid_size), grid_size) pts = ax.scatter(grid_x, grid_y)
selector = SelectFromCollection(ax, pts)
print("Select points in the figure by enclosing them within a polygon.") print("Press the 'esc' key to start a new polygon.") print("Try holding the 'shift' key to move all of the vertices.") print("Try holding the 'ctrl' key to move a single vertex.")
plt.show()
selector.disconnect()
print('\nSelected points:') print(selector.xys[selector.ind])
|