diff --git a/src/cedalion/plots.py b/src/cedalion/plots.py new file mode 100644 index 0000000..b1f25f8 --- /dev/null +++ b/src/cedalion/plots.py @@ -0,0 +1,25 @@ +import matplotlib.pyplot as p +import xarray as xr + + +def plot_montage3D(amp: xr.DataArray, geo3d: xr.DataArray): + f = p.figure() + ax = f.add_subplot(projection="3d") + colors = ["r", "b", "gray"] + sizes = [20, 20, 2] + for i, (type, x) in enumerate(geo3d.groupby("type")): + ax.scatter(x[:, 0], x[:, 1], x[:, 2], c=colors[i], s=sizes[i]) + + for i in range(amp.sizes["channel"]): + src = geo3d.loc[amp.source[i], :] + det = geo3d.loc[amp.detector[i], :] + ax.plot([src[0], det[0]], [src[1], det[1]], [src[2], det[2]], c="k") + + # if available mark Nasion in yellow + if "Nz" in geo3d.label: + ax.scatter( + geo3d.loc["Nz", 0], geo3d.loc["Nz", 1], geo3d.loc["Nz", 2], c="y", s=25 + ) + + ax.view_init(elev=30, azim=145) + p.tight_layout()