Source code for spacekit.analyzer.explore

# STANDARD libraries
import os
import numpy as np
import pandas as pd
from scipy.stats import iqr
from spacekit.preprocessor.transform import PowerX
from spacekit.generator.augment import augment_image
from spacekit.logger.log import Logger

try:
    from keras.preprocessing.image import array_to_img
except ImportError:
    from tensorflow.keras.utils import array_to_img

try:
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    font_dict = {"family": "monospace", "size": 16}
    mpl.rc("font", **font_dict)
    styles = ["seaborn-bright", "seaborn-v0_8-bright"]
    valid_styles = [s for s in styles if s in plt.style.available]
    if len(valid_styles) > 0:
        try:
            plt.style.use(valid_styles[0])
        except OSError:
            pass
except ImportError:
    mpl = None
    plt = None

try:
    import plotly.graph_objects as go
    from plotly import subplots
    import plotly.offline as pyo
    import plotly.figure_factory as ff
    import plotly.express as px
except ImportError:
    go = None
    subplots = None
    pyo = None
    ff = None
    px = None


def check_viz_imports():
    return go is not None


[docs]class ImagePreviews: """Base parent class for rendering and displaying images as plots""" def __init__(self, X, labels, name="ImagePreviews", **log_kws): self.__name__ = name self.log = Logger(self.__name__, **log_kws).spacekit_logger() self.X = X self.y = labels if not check_viz_imports(): self.log.error("plotly and/or matplotlib not installed.") raise ImportError( "You must install plotly (`pip install plotly`) " "and matplotlib<4 (`pip install matplotlib<4`) " "for the compute module to work." "\n\nInstall extra deps via `pip install spacekit[x]`" )
[docs]class SVMPreviews(ImagePreviews): """ImagePreviews subclass for previewing SVM images. Primarily can be used to compare original with augmented versions. Parameters ---------- ImagePlots : class spacekit.analyzer.explore.ImagePreviews parent class """ def __init__( self, X, labels=None, names=None, ndims=3, channels=3, w=128, h=128, figsize=(10, 10), **log_kws, ): """Instantiates an SVMPreviews class object. Parameters ---------- X : ndarray ndimensional array of image pixel values labels : ndarray, optional target class labels for each image ndims : int, optional number of dimensions (frames) per image, by default 3 channels : int, optional channels per image frame (rgb color is 3, gray/bw is 1), by default 3 w : int, optional width of images, by default 128 h : int, optional height of images, by default 128 """ super().__init__(X, labels, name="SVMPreviews", **log_kws) self.names = names self.n_images = len(X) self.ndims = ndims self.channels = channels self.w = w self.h = h self.figsize = figsize
[docs] def select_image_from_array(self, i=None): if i is None: return self.X else: return self.X[i]
[docs] def check_dimensions(self, Xi): if Xi.shape != (self.ndims, self.w, self.h, self.channels): try: Xi = Xi.reshape(self.ndims, self.w, self.h, self.channels) return Xi except Exception as e: print(e)
[docs] def preview_image(self, Xi, dim=3, aug=False, show=False): if aug is True: # reshape handled by augment if needed Xi = augment_image(Xi) title = "Augmented" else: Xi = self.check_dimensions(Xi) title = "Original" frames = ["orig", "pt-seg", "gaia"] fig = px.imshow( Xi, facet_col=0, binary_string=True, labels={"facet_col": "frame"}, facet_col_wrap=3, ) for i, frame in enumerate(frames): fig.layout.annotations[i]["text"] = "%s" % frame fig.update_layout( title_text=f"{title} Image Slices", margin=dict(t=100), width=990, height=500, showlegend=False, paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={ "color": "#ffffff", }, ) fig.update_xaxes(visible=False) fig.update_yaxes(visible=False) if show is True: fig.show() else: return fig
[docs] def preview_image_mpl(self, Xi, dim=3, aug=False, show=False): if aug is True: # reshape handled by augment if needed Xi = augment_image(Xi) else: Xi = self.check_dimensions(Xi) fig = plt.figure(figsize=self.figsize) for n in range(dim): xi = array_to_img(Xi[n]) # xi = image.array_to_img(Xi[n]) ax = plt.subplot(dim, dim, n + 1) ax.imshow(xi) plt.axis("off") if show is True: plt.show() else: plt.close() return fig
[docs] def get_synthetic_image(self, img_name, show=False, dim=3, aug=False): pairs = [i for i in self.names if img_name in i] if len(pairs) > 1: synth_name = pairs[np.argmax([len(p.split("_")) for p in pairs])] synth_num = np.where(self.names == synth_name) synth_img = self.select_image_from_array(synth_num) if show is True: self.preview_image(synth_img, dim=dim, aug=aug) return synth_name, synth_num, synth_img else: print("Synthetic version not found for the selected image") return None
[docs] def preview_og_aug_pair(self, i=None, dim=3): """Plot frames of both original and augmented versions of n-dimensional images Parameters ---------- i : int, optional index of image selected from array X, by default None dim : int, optional dimensions (number of frames per image), by default 3 """ Xi = self.select_image_from_array(i=i) self.preview_image(Xi, dim=dim, aug=False) self.preview_image(Xi, dim=dim, aug=True)
[docs] def preview_og_syn_pair(self, img_name): pairs = [i for i in self.X if img_name in i] self.preview_image(pairs[0]) self.preview_image(pairs[1])
# def preview_corrupted_pairs(self): # """Finds the matching positive class images from both image sets and displays them in a grid.""" # posA = self.X[-self.X_prime.shape[0] :][self.y[-self.X_prime.shape[0] :] == 1] # posB = self.X_prime[self.y_prime == 1] # plt.figure(figsize=(10, 10)) # for n in range(5): # x = image.array_to_img(posA[n][0]) # ax = plt.subplot(5, 5, n + 1) # ax.imshow(x) # plt.axis("off") # plt.show() # plt.figure(figsize=(10, 10)) # for n in range(5): # x = image.array_to_img(posB[n][0]) # ax = plt.subplot(5, 5, n + 1) # ax.imshow(x) # plt.axis("off") # plt.show()
[docs]class DataPlots: """Parent class for drawing exploratory data analysis plots from a dataframe.""" def __init__( self, df, width=1300, height=700, show=False, save_html=None, name="DataPlots", **log_kws, ): self.__name__ = name self.log = Logger(self.__name__, **log_kws).spacekit_logger() self.df = df self.width = width self.height = height self.show = show self.save_html = save_html self.target = None # target (y) name e.g. "label", "memory", "wallclock" self.labels = None # self.classes = None # target classes e.g. [0,1] or [0,1,2,3] self.n_classes = None self.group = None # e.g. "detector" or "instr" self.gkeys = None self.categories = None self.cmap = ["dodgerblue", "gold", "fuchsia", "lime"] self.continuous = None self.categorical = None self.feature_list = None self.telescope = None self.figures = None self.scatter = None self.bar = None self.groupedbar = None self.kde = None if not check_viz_imports(): self.log.error("plotly and/or matplotlib not installed.") raise ImportError( "You must install plotly (`pip install plotly`) " "and matplotlib<4 (`pip install matplotlib<4`) " "for the compute module to work." "\n\nInstall extra deps via `pip install spacekit[x]`" )
[docs] def group_keys(self): if self.group in ["instr", "instrument"]: keys = ["acs", "cos", "stis", "wfc3"] elif self.group in ["det", "detector"]: uniq = list(self.df[self.group].unique()) if len(uniq) == 2: keys = ["wfc-uvis", "other"] else: keys = ["hrc", "ir", "sbc", "uvis", "wfc"] # TODO: target classification / "category" elif self.group in ["cat", "category"]: keys = [ "calibration", "galaxy", "galaxy_cluster", "ISM", "star", "stellar_cluster", "unidentified", ] # TODO: filters group_keys = dict(enumerate(keys)) return group_keys
[docs] def map_data(self): """Instantiates grouped dataframes for each detector Returns ------- dict data_map dictionary of grouped data frames and color map """ if self.cmap is None: cmap = ["#119dff", "salmon", "#66c2a5", "fuchsia", "#f4d365"] else: cmap = self.cmap self.data_map = {} for key, name in self.gkeys.items(): data = self.categories[name] self.data_map[name] = dict(data=data, color=cmap[key]) return self.data_map
[docs] def feature_subset(self): """Create a set of groups from a categorical feature (dataframe column). Used for plotting multiple traces on a figure Returns ------- dictionary self.categories attribute containing key-value pairs: groups of observations (values) for each category (keys) """ self.categories = {} feature_groups = self.df.groupby(self.group) for i in list(range(len(feature_groups))): dx = feature_groups.get_group(i) k = self.gkeys[i] self.categories[k] = dx return self.categories
[docs] def feature_stats_by_target(self, feature): """Calculates statistical info (mean and standard deviation) for a feature within each target class. Parameters ---------- feature : str dataframe column to get statistical calculations on Returns ------- nested lists list of means and list of standard deviations for a feature, subdivided for each target class. """ means, errs = [], [] for c in self.classes: mu, ste = [], [] for k in list(self.gkeys.keys()): data = self.df[ (self.df[self.target] == c) & (self.df[self.group] == k) ][feature] mu.append(np.mean(data)) ste.append(np.std(data) / np.sqrt(len(data))) means.append(mu) errs.append(ste) return means, errs
[docs] def make_subplots(self, figtype, xtitle, ytitle, data1, data2, name1, name2): fig = subplots.make_subplots( rows=1, cols=2, subplot_titles=(name1, name2), shared_yaxes=False, x_title=xtitle, y_title=ytitle, ) fig.add_trace(data1.data[0], 1, 1) fig.add_trace(data1.data[1], 1, 1) fig.add_trace(data2.data[0], 1, 2) fig.add_trace(data2.data[1], 1, 2) fig.update_layout( title_text=f"{name1} vs {name2}", margin=dict(t=50, l=80), width=self.width, height=self.height, paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={ "color": "#ffffff", }, ) if self.show: fig.show() if self.save_html: if not os.path.exists(self.save_html): os.makedirs(self.save_html, exist_ok=True) pyo.plot( fig, filename=f"{self.save_html}/{figtype}_{self.name1}_vs_{self.name2}" ) return fig
[docs] def make_scatter_figs( self, xaxis_name, yaxis_name, marker_size=15, cmap=["cyan", "fuchsia"], categories=None, target=None, ): if categories is None: categories = {"all": self.df} if target is None: target = self.target scatter_figs = [] for key, data in categories.items(): target_groups = data.groupby(target) traces = [] for i in list(range(len(target_groups))): dx = target_groups.get_group(i) trace = go.Scatter( x=dx[xaxis_name], y=dx[yaxis_name], text=dx.index, mode="markers", opacity=0.7, marker={"size": marker_size, "color": cmap[i]}, name=self.labels[i], # "aligned", ) traces.append(trace) layout = go.Layout( xaxis={"title": xaxis_name}, yaxis={"title": yaxis_name}, title=key, # margin={'l': 40, 'b': 40, 't': 10, 'r': 0}, hovermode="closest", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, width=700, height=500, ) fig = go.Figure(data=traces, layout=layout) if self.show: fig.show() if self.save_html: if not os.path.exists(self.save_html): os.makedirs(self.save_html, exist_ok=True) pyo.plot( fig, filename=f"{self.save_html}/{key}-{xaxis_name}-{yaxis_name}-{target}-scatter.html", ) scatter_figs.append(fig) return scatter_figs
[docs] def make_target_scatter(self, target=None): if target is None: target = self.target target_figs = {} for f in self.feature_list: target_figs[f] = self.make_scatter_figs(f, target) return target_figs
[docs] def bar_plots( self, X, Y, feature, y_err=[None, None], width=700, height=500, cmap=["dodgerblue", "fuchsia"], ): traces = [] for i in self.classes: i = int(i) trace = go.Bar( x=X, y=Y[i], error_y=dict(type="data", array=y_err[i], color="white", thickness=0.5), name=self.labels[i], text=sorted(list(self.group_keys().values())), marker=dict(color=cmap[i]), ) traces.append(trace) layout = go.Layout( title=f"{feature.upper()} average by {self.group.capitalize()}", xaxis={"title": self.group}, yaxis={"title": f"{feature} (mean)"}, paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, width=width, height=height, ) fig = go.Figure(data=traces, layout=layout) if self.save_html: pyo.plot(fig, filename=f"{self.save_html}/{feature}-barplot.html") if self.show: fig.show() else: return fig
[docs] def kde_plots( self, cols, norm=False, targets=False, hist=True, curve=True, binsize=0.2, # [0.3, 0.2, 0.1] width=700, height=500, cmap=["#F66095", "#2BCDC1"], ): if norm is True: df = PowerX(self.df, cols=cols, join_data=True).Xt cols = [c + "_scl" for c in cols] tag = "-norm" else: df = self.df tag = "" if targets is True: hist_data = [df.loc[df[self.target] == c][cols[0]] for c in self.classes] group_labels = self.labels # [f"{cols[0]}={i}" for i in self.labels] title = f"KDE {cols[0]} by target class ({self.target})" name = f"kde-targets-{cols[0]}{tag}.html" else: hist_data = [df[c] for c in cols] group_labels = cols title = f"KDE {group_labels[0]} vs {group_labels[1]}" name = f"kde-{group_labels[0]}-{group_labels[1]}{tag}.html" fig = ff.create_distplot( hist_data, group_labels, colors=cmap, bin_size=binsize, show_hist=hist, show_curve=curve, ) fig.update_layout( title_text=title, paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, width=width, height=height, ) if self.save_html: if not os.path.exists(self.save_html): os.makedirs(self.save_html, exist_ok=True) pyo.plot(fig, filename=f"{self.save_html}/{name}") if self.show: fig.show() return fig
[docs] def scatter3d(self, x, y, z, mask=None, target=None): if mask is None: df = self.df else: df = mask if target is None: target = self.target traces = [] for targ, group in df.groupby(target): trace = go.Scatter3d( x=group[x], y=group[y], z=group[z], name=targ, mode="markers", marker=dict(size=7, color=targ, colorscale="Plasma", opacity=0.8), ) traces.append(trace) layout = go.Layout( title=f"3D Scatterplot: {x} - {y} - {z}", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, legend_title_text=target, ) fig = go.Figure(data=traces, layout=layout) fig.update_layout(scene=dict(xaxis_title=x, yaxis_title=y, zaxis_title=z)) if self.save_html: pyo.plot(fig, filename=f"{self.save_html}/scatter3d.html") if self.show: fig.show() else: return fig
[docs] def remove_outliers(self, y_data): q = y_data.quantile([0.25, 0.75]).values q1, q3 = q[0], q[1] lower_fence = q1 - 1.5 * iqr(y_data) upper_fence = q3 + 1.5 * iqr(y_data) y = y_data.loc[(y_data > lower_fence) & (y_data < upper_fence)] return y
[docs] def box_plots(self, cols=None, outliers=True): box = {} title_sfx = "" if cols is None: features = self.continuous else: features = cols for f in features: traces = [] for i, name in enumerate(self.gkeys.values()): y_data = self.categories[name][f] if outliers is False: y_data = self.remove_outliers(y_data) title_sfx = "- no outliers" trace = go.Box(y=y_data, name=name, marker=dict(color=self.cmap[i])) traces.append(trace) layout = go.Layout( title=f"{f} by {self.group}{title_sfx}", hovermode="closest", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, ) fig = go.Figure(data=traces, layout=layout) box[f] = fig return box
[docs] def grouped_barplot(self, target="label", cmap=None, save=False): df = self.df if cmap is None: cmap = ["red", "orange", "yellow", "purple", "blue"] groups = df.groupby([self.group])[target] traces = [] for key, value in self.gkeys.items(): dx = groups.get_group(key).value_counts() trace = go.Bar( x=dx.index, y=dx, name=value.upper(), marker=dict(color=cmap[key]) ) traces.append(trace) layout = go.Layout(title=f"{target.title()} by {self.group.title()}") fig = go.Figure(data=traces, layout=layout) if self.save_html: pyo.plot(fig, filename=f"{self.save_html}/grouped-bar.html") if self.show: fig.show() else: return fig
[docs]class HstSvmPlots(DataPlots): """Instantiates an HstSvmPlots class Parameters ---------- DataPlots : class spacekit.analyzer.explore.DataPlots parent class """ def __init__( self, df, group="det", width=1300, height=700, show=False, save_html=None, **log_kws, ): super().__init__( df, width=width, height=height, show=show, save_html=save_html, name="HstSvmPlots", **log_kws, ) self.group = group self.telescope = "HST" self.target = "label" self.classes = list(set(df[self.target].values)) # [0, 1] self.labels = ["aligned", "misaligned"] self.n_classes = len(set(self.labels)) self.gkeys = super().group_keys() self.categories = self.feature_subset() self.continuous = ["rms_ra", "rms_dec", "gaia", "nmatches", "numexp"] self.categorical = ["det", "wcs", "cat"] self.feature_list = self.continuous + self.categorical self.cmap = ["#119dff", "salmon", "#66c2a5", "fuchsia", "#f4d365"] self.df_by_detector() self.bar = None self.scatter = None self.kde = None
[docs] def draw_plots(self): self.bar = self.alignment_bars() self.scatter = self.alignment_scatters() self.kde = self.alignment_kde()
[docs] def alignment_bars(self): self.bar = {} X = sorted(list(self.gkeys.keys())) for f in self.continuous: means, errs = self.feature_stats_by_target(f) bar = self.bar_plots(X, means, f, y_err=errs) self.bar[f] = bar return self.bar
[docs] def alignment_scatters(self): rms_scatter = self.make_scatter_figs( "rms_ra", "rms_dec", categories=self.categories ) source_scatter = self.make_scatter_figs( "point", "segment", categories=self.categories ) self.scatter = {"rms_ra_dec": rms_scatter, "point_segment": source_scatter} return self.scatter
[docs] def alignment_kde(self): cols = self.continuous self.kde = dict(rms=self.kde_plots(["rms_ra", "rms_dec"]), targ={}, norm={}) targ = [self.kde_plots([c], targets=True) for c in cols] norm = [self.kde_plots([c], norm=True, targets=True) for c in cols] for i, c in enumerate(cols): self.kde["targ"][c] = targ[i] self.kde["norm"][c] = norm[i] return self.kde
# def group_keys(self): # if self.group in ["det", "detector"]: # keys = ["hrc", "ir", "sbc", "uvis", "wfc"] # elif self.group in ["cat", "category"]: # keys = [ # "calibration", # "galaxy", # "galaxy_cluster", # "ISM", # "star", # "stellar_cluster", # "unidentified", # ] # group_keys = dict(enumerate(keys)) # return group_keys
[docs] def df_by_detector(self): """Instantiates grouped dataframes for each detector Returns ------- self """ try: self.hrc = self.df.groupby("det").get_group(0) self.ir = self.df.groupby("det").get_group(1) self.sbc = self.df.groupby("det").get_group(2) self.uvis = self.df.groupby("det").get_group(3) self.wfc = self.df.groupby("det").get_group(4) self.instr_dict = { "hrc": [self.hrc, "#119dff"], # lightblue "ir": [self.ir, "salmon"], "sbc": [self.sbc, "#66c2a5"], # lightgreen "uvis": [self.uvis, "fuchsia"], "wfc": [self.wfc, "#f4d365"], # softgold } except Exception as e: print(e) return self
[docs]class HstCalPlots(DataPlots): def __init__(self, df, group="instr", **log_kws): super().__init__(df, name="HstCalPlots", **log_kws) self.telescope = "HST" self.target = "mem_bin" self.classes = [0, 1, 2, 3] self.group = group self.labels = ["2g", "8g", "16g", "64g"] self.gkeys = self.group_keys() self.categories = self.feature_subset() self.acs = None self.cos = None self.stis = None self.wfc3 = None self.instr_dict = None self.instruments = list(self.df["instr_key"].unique()) self.continuous = ["n_files", "total_mb", "x_files", "x_size"] self.categorical = [ "drizcorr", "pctecorr", "crsplit", "subarray", "detector", "dtype", "instr", ] self.feature_list = self.continuous + self.categorical self.cmap = ["dodgerblue", "gold", "fuchsia", "lime"] self.data_map = None self.scatter = None self.box = None self.scatter3 = None
[docs] def df_by_instr(self): self.acs = self.df.groupby("instr").get_group(0) self.cos = self.df.groupby("instr").get_group(1) self.stis = self.df.groupby("instr").get_group(2) self.wfc3 = self.df.groupby("instr").get_group(3) self.instr_dict = { "acs": [self.acs, "#119dff"], "wfc3": [self.wfc3, "salmon"], "cos": [self.cos, "#66c2a5"], "stis": [self.stis, "fuchsia"], } return self
[docs] def draw_plots(self): self.scatter = self.make_cal_scatterplots() self.box = self.box_plots() box_target = self.box_plots(cols=["memory", "wallclock"]) box_fenced = self.box_plots(cols=["memory", "wallclock"], outliers=False) self.box["memory"] = box_target["memory"] self.box["wallclock"] = box_target["wallclock"] self.box["mem_fence"] = box_fenced["memory"] self.box["wall_fence"] = box_fenced["wallclock"]
# self.scatter3 = self.make_cal_scatter3d() # self.bar # self.kde
[docs] def make_cal_scatterplots(self): memory_figs, wallclock_figs = {}, {} for f in self.feature_list: memory_figs[f] = self.make_scatter_figs(f, "memory") wallclock_figs[f] = self.make_scatter_figs(f, "wallclock") self.scatter = dict(memory=memory_figs, wallclock=wallclock_figs) return self.scatter
[docs] def make_cal_scatter3d(self): x, y = "memory", "wallclock" self.scatter3 = {} for z in self.continuous: data = self.df[[x, y, z, "instr_key"]] scat3d = super().scatter3d( x, y, z, mask=data, target="instr_key", width=700, height=700 ) self.scatter3[z] = scat3d
[docs] def make_box_figs(self, vars): box_figs = [] for v in vars: data = [ go.Box(y=self.acs[v], name="acs"), go.Box(y=self.cos[v], name="cos"), go.Box(y=self.stis[v], name="stis"), go.Box(y=self.wfc3[v], name="wfc3"), ] layout = go.Layout( title=f"{v} by instrument", hovermode="closest", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, ) fig = go.Figure(data=data, layout=layout) box_figs.append(fig) return box_figs
[docs] def make_scatter_figs(self, xaxis_name, yaxis_name): if self.data_map is None: self.map_data() scatter_figs = [] for instr, datacolor in self.data_map.items(): data = datacolor["data"] color = datacolor["color"] trace = go.Scatter( x=data[xaxis_name], y=data[yaxis_name], text=data.index, mode="markers", opacity=0.7, marker={"size": 15, "color": color}, name=instr, ) layout = go.Layout( xaxis={"title": xaxis_name}, yaxis={"title": yaxis_name}, title=instr, hovermode="closest", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, ) fig = go.Figure(data=trace, layout=layout) scatter_figs.append(fig) return scatter_figs
class SignalPlots: @staticmethod def atomic_vector_plotter( signal, label_col=None, classes=None, class_names=None, figsize=(15, 5), y_units=None, x_units=None, ): """ Plots scatter and line plots of time series signal values. **ARGS signal: pandas series or numpy array label_col: name of the label column if using labeled pandas series -use default None for numpy array or unlabeled series. -this is simply for customizing plot Title to include classification classes: (optional- req labeled data) tuple if binary, array if multiclass class_names: tuple or array of strings denoting what the classes mean figsize: size of the figures (default = (15,5)) ****** Ex1: Labeled timeseries passing 1st row of pandas dataframe > first create the signal: signal = x_train.iloc[0, :] > then plot: atomic_vector_plotter(signal, label_col='LABEL',classes=[1,2], class_names=['No Planet', 'Planet']), figsize=(15,5)) Ex2: numpy array without any labels > first create the signal: signal = x_train.iloc[0, :] >then plot: atomic_vector_plotter(signal, figsize=(15,5)) """ import pandas as pd import numpy as np # pass None to label_col if unlabeled data, creates generic title if label_col is None: label = None title_scatter = "Scatterplot of Star Flux Signals" title_line = "Line Plot of Star Flux Signals" color = "black" # store target column as variable elif label_col is not None: label = signal[label_col] # for labeled timeseries if label == 1: cn = class_names[0] color = "red" elif label == 2: cn = class_names[1] color = "blue" # TITLES # create appropriate title acc to class_names title_scatter = f"Scatterplot for Star Flux Signal: {cn}" title_line = f"Line Plot for Star Flux Signal: {cn}" # Set x and y axis labels according to units # if the units are unknown, we will default to "Flux" if y_units is None: y_units = "Flux" else: y_units = y_units # it is assumed this is a timeseries, default to "time" if x_units is None: x_units = "Time" else: x_units = x_units # Scatter Plot if type(signal) == np.array: series_index = list(range(len(signal))) converted_array = pd.Series(signal.ravel(), index=series_index) signal = converted_array plt.figure(figsize=figsize) plt.scatter( pd.Series([i for i in range(1, len(signal))]), signal[1:], marker=4, color=color, ) plt.ylabel(y_units) plt.xlabel(x_units) plt.title(title_scatter) plt.show() # Line Plot plt.figure(figsize=figsize) plt.plot(pd.Series([i for i in range(1, len(signal))]), signal[1:], color=color) plt.ylabel(y_units) plt.xlabel(x_units) plt.title(title_line) plt.show() @staticmethod def flux_specs( signal, Fs=2, NFFT=256, noverlap=128, mode="psd", cmap=None, units=None, colorbar=False, save_for_ML=False, fname=None, num=None, **kwargs, ): """generate and save spectographs of flux signal frequencies""" import matplotlib.pyplot as plt if cmap is None: cmap = "binary" # PIX: plots only the pixelgrids -ideal for image classification if save_for_ML is True: # turn off everything except pixel grid fig, ax = plt.subplots(figsize=(10, 10), frameon=False) fig, freqs, t, m = plt.specgram( signal, Fs=Fs, NFFT=NFFT, mode=mode, cmap=cmap ) ax.axis(False) ax.show() if fname is not None: try: if num: path = fname + num else: path = fname plt.savefig(path, **kwargs) except Exception as e: print("Something went wrong while saving the img file") print(e) else: fig, ax = plt.subplots(figsize=(13, 11)) fig, freqs, t, m = plt.specgram( signal, Fs=Fs, NFFT=NFFT, mode=mode, cmap=cmap ) plt.colorbar() if units is None: units = ["Wavelength (λ)", "Frequency (ν)"] plt.xlabel(units[0]) plt.ylabel(units[1]) if num: title = f"Spectrogram_{num}" else: title = "Spectrogram" plt.title(title) plt.show() return fig, freqs, t, m @staticmethod def singal_phase_folder(file_list, fmt="kepler.fits", error=False, snr=False): """plots phase-folded light curve of a signal returns dataframe of transit timestamps for each light curve planet_hunter(f=files[9], fmt='kepler.fits') args: - fits_files = takes array of files or single .fits file kwargs: - format : 'kepler.fits' or 'tess.fits' - error: include SAP flux error (residuals) if available - snr: apply signal-to-noise-ratio to periodogram autopower calculation """ from astropy.timeseries import TimeSeries import numpy as np from astropy import units as u from astropy.timeseries import BoxLeastSquares from astropy.stats import sigma_clipped_stats from astropy.timeseries import aggregate_downsample # read in file transits = {} for index, file in enumerate(file_list): res = {} if fmt == "kepler.fits": prefix = file.replace("ktwo", "") suffix = prefix.replace("_llc.fits", "") pair = suffix.split("-") obs_id = pair[0] campaign = pair[1] ts = TimeSeries.read(file, format=fmt) # read in timeseries # add to meta dict res["obs_id"] = obs_id res["campaign"] = campaign res["lc_start"] = ts.time.jd[0] res["lc_end"] = ts.time.jd[-1] # use box least squares to estimate period if error is True: # if error col data available periodogram = BoxLeastSquares.from_timeseries( ts, "sap_flux", "sap_flux_err" ) else: periodogram = BoxLeastSquares.from_timeseries(ts, "sap_flux") if snr is True: results = periodogram.autopower(0.2 * u.day, objective="snr") else: results = periodogram.autopower(0.2 * u.day) maxpower = np.argmax(results.power) period = results.period[maxpower] transit_time = results.transit_time[maxpower] res["maxpower"] = maxpower res["period"] = period res["transit"] = transit_time # res['ts'] = ts # fold the time series using the period ts_folded = ts.fold(period=period, epoch_time=transit_time) # folded time series plot # plt.plot(ts_folded.time.jd, ts_folded['sap_flux'], 'k.', markersize=1) # plt.xlabel('Time (days)') # plt.ylabel('SAP Flux (e-/s)') # normalize the flux by sigma-clipping the data to determine the baseline flux: mean, median, stddev = sigma_clipped_stats(ts_folded["sap_flux"]) ts_folded["sap_flux_norm"] = ts_folded["sap_flux"] / median res["mean"] = mean res["median"] = median res["stddev"] = stddev res["sap_flux_norm"] = ts_folded["sap_flux_norm"] # downsample the time series by binning the points into bins of equal time ts_binned = aggregate_downsample(ts_folded, time_bin_size=0.03 * u.day) # final result fig = plt.figure(figsize=(11, 5)) ax = fig.gca() ax.plot(ts_folded.time.jd, ts_folded["sap_flux_norm"], "k.", markersize=1) ax.plot( ts_binned.time_bin_start.jd, ts_binned["sap_flux_norm"], "r-", drawstyle="steps-post", ) ax.set_xlabel("Time (days)") ax.set_ylabel("Normalized flux") ax.set_title(obs_id) ax.legend([np.round(period, 3)]) plt.close() res["fig"] = fig transits[index] = res df = pd.DataFrame.from_dict(transits, orient="index") return df # testing if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("dataset", type=str, help="path to dataframe (csv file)") parser.add_argument("index", type=str, default="index", help="index column name") parser.add_argument( "-e", "--example", type=str, choices=["svm", "cal"], help="run example demo" ) args = parser.parse_args() dataset = args.dataset index = args.index example = args.example df = pd.read_csv(dataset, index_col=index) if example == "svm": # Drop extra columns in case raw / un-preprocessed dataset is loaded drops = ["category", "ra_targ", "dec_targ", "imgname"] df.drop([c for c in drops if c in df.columns], axis=1, inplace=True) svm = HstSvmPlots(df) else: print("More examples coming soon!")