Source code for spacekit.analyzer.explore

import os
import re
import warnings
import numpy as np
import pandas as pd
from scipy.stats import iqr
from spacekit.preprocessor.transform import PowerX
from spacekit.generator.augment import augment_image
from spacekit.logger.log import Logger

try:
    from keras.preprocessing.image import array_to_img
except ImportError:
    from tensorflow.keras.utils import array_to_img

try:
    import matplotlib as mpl
    import matplotlib.pyplot as plt

    font_dict = {"family": "monospace", "size": 16}
    mpl.rc("font", **font_dict)
    styles = ["seaborn-bright", "seaborn-v0_8-bright"]
    valid_styles = [s for s in styles if s in plt.style.available]
    if len(valid_styles) > 0:
        try:
            plt.style.use(valid_styles[0])
        except OSError:
            pass
except ImportError:
    mpl = None
    plt = None

try:
    import plotly.graph_objects as go
    from plotly import subplots
    import plotly.offline as pyo
    import plotly.figure_factory as ff
    import plotly.express as px
except ImportError:
    go = None
    subplots = None
    pyo = None
    ff = None
    px = None


try:
    from astropy.timeseries import TimeSeries, BoxLeastSquares, aggregate_downsample
    from astropy import units as u
    from astropy.stats import sigma_clipped_stats
    from astropy.io import fits
except ImportError:
    TimeSeries = None


def check_ast_imports():
    return TimeSeries is not None


def check_viz_imports():
    return go is not None


def check_mpl_imports():
    return mpl is not None and plt is not None


[docs] class ImagePreviews: """Base parent class for rendering and displaying images as plots""" def __init__(self, X, labels, name="ImagePreviews", **log_kws): self.__name__ = name self.log = Logger(self.__name__, **log_kws).spacekit_logger() self.X = X self.y = labels if not check_viz_imports(): self.log.error("plotly and/or matplotlib not installed.") raise ImportError( "You must install plotly (`pip install plotly`) " "and matplotlib<4 (`pip install matplotlib<4`) " "for the compute module to work." "\n\nInstall extra deps via `pip install spacekit[x]`" )
[docs] class SVMPreviews(ImagePreviews): """ImagePreviews subclass for previewing SVM images. Primarily can be used to compare original with augmented versions. Parameters ---------- ImagePlots : class spacekit.analyzer.explore.ImagePreviews parent class """ def __init__( self, X, labels=None, names=None, ndims=3, channels=3, w=128, h=128, figsize=(10, 10), **log_kws, ): """Instantiates an SVMPreviews class object. Parameters ---------- X : ndarray ndimensional array of image pixel values labels : ndarray, optional target class labels for each image ndims : int, optional number of dimensions (frames) per image, by default 3 channels : int, optional channels per image frame (rgb color is 3, gray/bw is 1), by default 3 w : int, optional width of images, by default 128 h : int, optional height of images, by default 128 """ super().__init__(X, labels, name="SVMPreviews", **log_kws) self.names = names self.n_images = len(X) self.ndims = ndims self.channels = channels self.w = w self.h = h self.figsize = figsize
[docs] def select_image_from_array(self, i=None): if i is None: return self.X else: return self.X[i]
[docs] def check_dimensions(self, Xi): if Xi.shape != (self.ndims, self.w, self.h, self.channels): try: Xi = Xi.reshape(self.ndims, self.w, self.h, self.channels) return Xi except Exception as e: print(e)
[docs] def preview_image(self, Xi, dim=3, aug=False, show=False): if aug is True: # reshape handled by augment if needed Xi = augment_image(Xi) title = "Augmented" else: Xi = self.check_dimensions(Xi) title = "Original" frames = ["orig", "pt-seg", "gaia"] fig = px.imshow( Xi, facet_col=0, binary_string=True, labels={"facet_col": "frame"}, facet_col_wrap=3, ) for i, frame in enumerate(frames): fig.layout.annotations[i]["text"] = "%s" % frame fig.update_layout( title_text=f"{title} Image Slices", margin=dict(t=100), width=990, height=500, showlegend=False, paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={ "color": "#ffffff", }, ) fig.update_xaxes(visible=False) fig.update_yaxes(visible=False) if show is True: fig.show() else: return fig
[docs] def preview_image_mpl(self, Xi, dim=3, aug=False, show=False): if aug is True: # reshape handled by augment if needed Xi = augment_image(Xi) else: Xi = self.check_dimensions(Xi) fig = plt.figure(figsize=self.figsize) for n in range(dim): xi = array_to_img(Xi[n]) # xi = image.array_to_img(Xi[n]) ax = plt.subplot(dim, dim, n + 1) ax.imshow(xi) plt.axis("off") if show is True: plt.show() else: plt.close() return fig
[docs] def get_synthetic_image(self, img_name, show=False, dim=3, aug=False): pairs = [i for i in self.names if img_name in i] if len(pairs) > 1: synth_name = pairs[np.argmax([len(p.split("_")) for p in pairs])] synth_num = np.where(self.names == synth_name) synth_img = self.select_image_from_array(synth_num) if show is True: self.preview_image(synth_img, dim=dim, aug=aug) return synth_name, synth_num, synth_img else: print("Synthetic version not found for the selected image") return None
[docs] def preview_og_aug_pair(self, i=None, dim=3): """Plot frames of both original and augmented versions of n-dimensional images Parameters ---------- i : int, optional index of image selected from array X, by default None dim : int, optional dimensions (number of frames per image), by default 3 """ Xi = self.select_image_from_array(i=i) self.preview_image(Xi, dim=dim, aug=False) self.preview_image(Xi, dim=dim, aug=True)
[docs] def preview_og_syn_pair(self, img_name): pairs = [i for i in self.X if img_name in i] self.preview_image(pairs[0]) self.preview_image(pairs[1])
# def preview_corrupted_pairs(self): # """Finds the matching positive class images from both image sets and displays them in a grid.""" # posA = self.X[-self.X_prime.shape[0] :][self.y[-self.X_prime.shape[0] :] == 1] # posB = self.X_prime[self.y_prime == 1] # plt.figure(figsize=(10, 10)) # for n in range(5): # x = image.array_to_img(posA[n][0]) # ax = plt.subplot(5, 5, n + 1) # ax.imshow(x) # plt.axis("off") # plt.show() # plt.figure(figsize=(10, 10)) # for n in range(5): # x = image.array_to_img(posB[n][0]) # ax = plt.subplot(5, 5, n + 1) # ax.imshow(x) # plt.axis("off") # plt.show()
[docs] class DataPlots: """Base class for drawing exploratory data analysis plots from a dataframe.""" def __init__( self, df, width=1300, height=700, show=False, save_html=None, telescope=None, name="DataPlots", **log_kws, ): self.__name__ = name self.log = Logger(self.__name__, **log_kws).spacekit_logger() self.df = df self.width = width self.height = height self.show = show self.save_html = save_html self.telescope = telescope self.target = None # target (y) name e.g. "label", "memory", "wallclock" self.labels = None # self.classes = None # target classes e.g. [0,1] or [0,1,2,3] self.n_classes = None self.group = None # e.g. "detector", "instr", "cat" self.gkeys = None self.group_dict = None self.categories = None self.cmap = ["dodgerblue", "gold", "fuchsia", "lime"] self.continuous = None self.categorical = None self.feature_list = None self.figures = None self.scatter = None self.bar = None self.box = None self.groupedbar = None self.kde = None if not check_viz_imports(): self.log.error("plotly and/or matplotlib not installed.") raise ImportError( "You must install plotly (`pip install plotly`) " "and matplotlib<4 (`pip install matplotlib<4`) " "for the compute module to work." "\n\nInstall extra deps via `pip install spacekit[x]`" )
[docs] def group_keys(self): """Generates numerically ordered key-pairs for each unique value of self.group found in the dataframe Returns ------- dict enumerated dictionary of unique values for each group """ if not self.group: self.log.error( "Cannot generate group keys if no grouping feature specified. Set the `group` attribute then try again." ) if self.group.startswith("instr"): return dict(enumerate(self.instr_keys())) elif self.group.startswith("det"): return dict(enumerate(self.det_keys())) elif self.group.startswith("cat"): return dict(enumerate(self.targ_class_keys())) else: return dict(enumerate(sorted(list(self.df[self.group].unique()))))
[docs] def instr_keys(self): """Generates a list of intruments based on self.telescope Returns ------- list list of instrument keys for the specified telescope """ if self.telescope not in ["hst", "jwst"]: return [] return dict(hst=["acs", "wfc3", "cos", "stis"], jwst=["fgs", "miri", "nircam", "niriss", "nirspec"])[ self.telescope.lower() ]
[docs] def det_keys(self): """Creates a list of detectors based on self.telescope Returns ------- list list of detector keys for the specified telescope """ keys = sorted(list(self.df[self.group].unique())) if self.telescope.lower() == "hst": if len(keys) == 2: return ["wfc-uvis", "other"] if not isinstance(keys[0], str) and len(keys) == 5: return ["hrc", "ir", "sbc", "uvis", "wfc"] return keys
[docs] def targ_class_keys(self): """List of standard astronomical target classification categories Returns ------- list standard target classification categories """ return [ "calibration", "galaxy", "galaxy_cluster", "ISM", "star", "stellar_cluster", "unidentified", ]
[docs] def map_df_by_group(self): """Instantiates `group_dict` as a dictionary of grouped dataframes and color map""" self.group_dict = {} for k, v in self.gkeys.items(): self.group_dict[v] = [self.df.groupby(self.group).get_group(k), self.cmap[k]]
[docs] def map_data(self): """Instantiates `data_map` as a dictionary of grouped dataframes and color maps for each category in `categories` attribute.""" cmap = ["#119dff", "salmon", "#66c2a5", "fuchsia", "#f4d365"] if self.cmap is None else self.cmap if not self.categories: self.feature_subset() self.data_map = {} for key, name in self.gkeys.items(): data = self.categories[name] self.data_map[name] = dict(data=data, color=cmap[key])
[docs] def feature_subset(self): """Create a set of groups from a categorical feature (dataframe column). Used for plotting multiple traces on a figure Returns ------- dictionary self.categories attribute containing key-value pairs: groups of observations (values) for each category (keys) """ self.categories = {} feature_groups = self.df.groupby(self.group) for i in list(range(len(feature_groups))): k = self.gkeys[i] self.categories[k] = feature_groups.get_group(i)
[docs] def feature_stats_by_target(self, feature): """Calculates statistical info (mean and standard deviation) for a feature within each target class. Parameters ---------- feature : str dataframe column to get statistical calculations on Returns ------- nested lists list of means and list of standard deviations for a feature, subdivided for each target class. """ means, errs = [], [] for c in self.classes: mu, ste = [], [] for k in list(self.gkeys.keys()): data = self.df[(self.df[self.target] == c) & (self.df[self.group] == k)][feature] mu.append(np.mean(data)) ste.append(np.std(data) / np.sqrt(len(data))) means.append(mu) errs.append(ste) return means, errs
[docs] def make_subplots(self, figtype, xtitle, ytitle, data1, data2, name1, name2): """Generates figure with multiple subplots for two sets of data using previously generated figures. Parameters ---------- figtype : str type of figure being generated (used for saving html file) xtitle : str title for the x-axis ytitle : str title for the y-axis data1 : go.Figure figure object for the first set of data data2 : go.Figure figure object for the second set of data name1 : str name for the first subplot name2 : str name for the second subplot Returns ------- go.Figure figure object containing the subplots """ fig = subplots.make_subplots( rows=1, cols=2, subplot_titles=(name1, name2), shared_yaxes=False, x_title=xtitle, y_title=ytitle, ) fig.add_trace(data1.data[0], 1, 1) fig.add_trace(data1.data[1], 1, 1) fig.add_trace(data2.data[0], 1, 2) fig.add_trace(data2.data[1], 1, 2) fig.update_layout( title_text=f"{name1} vs {name2}", margin=dict(t=50, l=80), width=self.width, height=self.height, paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={ "color": "#ffffff", }, ) if self.show: fig.show() if self.save_html: if not os.path.exists(self.save_html): os.makedirs(self.save_html, exist_ok=True) pyo.plot(fig, filename=f"{self.save_html}/{figtype}_{self.name1}_vs_{self.name2}") return fig
[docs] def make_target_scatter_figs( self, xaxis_name, yaxis_name, marker_size=15, cmap=["cyan", "fuchsia"], categories=None, target=None, ): """Generates scatterplots for two features in the dataframe, grouped by target classes. Parameters ---------- xaxis_name : str column name in dataframe to plot on x-axis yaxis_name : str column name in dataframe to plot on y-axis marker_size : int, optional marker size for scatter plot points, by default 15 cmap : list, optional list of colors for different target classes, by default ["cyan", "fuchsia"] categories : dict, optional dictionary of categories to group data by, by default None target : str, optional name of target column in dataframe, by default None Returns ------- list list of scatterplot figures for each category """ if categories is None: categories = {"all": self.df} if target is None: target = self.target scatter_figs = [] for key, data in categories.items(): target_groups = data.groupby(target) traces = [] for i in list(range(len(target_groups))): dx = target_groups.get_group(i) trace = go.Scatter( x=dx[xaxis_name], y=dx[yaxis_name], text=dx.index, mode="markers", opacity=0.7, marker={"size": marker_size, "color": cmap[i]}, name=self.labels[i], # "aligned", ) traces.append(trace) layout = go.Layout( xaxis={"title": xaxis_name}, yaxis={"title": yaxis_name}, title=key, # margin={'l': 40, 'b': 40, 't': 10, 'r': 0}, hovermode="closest", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, width=700, height=500, ) fig = go.Figure(data=traces, layout=layout) if self.show: fig.show() if self.save_html: if not os.path.exists(self.save_html): os.makedirs(self.save_html, exist_ok=True) pyo.plot( fig, filename=f"{self.save_html}/{key}-{xaxis_name}-{yaxis_name}-{target}-scatter.html", ) scatter_figs.append(fig) return scatter_figs
[docs] def make_feature_scatter_figs(self, xaxis_name, yaxis_name): """Generates scatterplots for two features in the dataframe, grouped by the `group` attribute. Parameters ---------- xaxis_name : str name of column in dataframe to plot on x-axis yaxis_name : str name of column in dataframe to plot on y-axis Returns ------- list scatterplot figures for each group in self.group attribute """ if self.data_map is None: self.map_data() scatter_figs = [] for key, datacolor in self.data_map.items(): data = datacolor["data"] color = datacolor["color"] trace = go.Scatter( x=data[xaxis_name], y=data[yaxis_name], text=data.index, mode="markers", opacity=0.7, marker={"size": 15, "color": color}, name=key, ) layout = go.Layout( xaxis={"title": xaxis_name}, yaxis={"title": yaxis_name}, title=key, hovermode="closest", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, ) fig = go.Figure(data=trace, layout=layout) scatter_figs.append(fig) return scatter_figs
[docs] def make_target_scatter(self, target=None): """Generates target vs feature scatterplot for a given target (by default self.target) for each feature in self.feature_list. Parameters ---------- target : str, optional target column name, by default None Returns ------- list target-feature scatterplot figures for each feature in self.feature_list """ if target is None: target = self.target target_figs = {} for f in self.feature_list: target_figs[f] = self.make_target_scatter_figs(f, target) return target_figs
[docs] def bar_plots( self, X, Y, feature, y_err=[None, None], width=700, height=500, cmap=["dodgerblue", "fuchsia"], ): """Draws a bar plot for a feature, grouped by the `group` attribute. Parameters ---------- X : array-like X-axis values Y : array-like Y-axis values feature : str Feature name y_err : list, optional Y-axis error values, by default [None, None] width : int, optional Width of the plot, by default 700 height : int, optional Height of the plot, by default 500 cmap : list, optional List of colors for the plot, by default ["dodgerblue", "fuchsia"] Returns ------- go.Figure Plotly Figure object representing the bar plot """ traces = [] for i in self.classes: i = int(i) trace = go.Bar( x=X, y=Y[i], error_y=dict(type="data", array=y_err[i], color="white", thickness=0.5), name=self.labels[i], text=sorted(list(self.group_keys().values())), marker=dict(color=cmap[i]), ) traces.append(trace) layout = go.Layout( title=f"{feature.upper()} average by {self.group.capitalize()}", xaxis={"title": self.group}, yaxis={"title": f"{feature} (mean)"}, paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, width=width, height=height, ) fig = go.Figure(data=traces, layout=layout) if self.save_html: pyo.plot(fig, filename=f"{self.save_html}/{feature}-barplot.html") if self.show: fig.show() else: return fig
[docs] def kde_plots( self, cols, norm=False, targets=False, hist=True, curve=True, binsize=0.2, # [0.3, 0.2, 0.1] width=700, height=500, cmap=["#F66095", "#2BCDC1"], ): """Generates KDE plots for specified columns in the dataframe. Parameters ---------- cols : list of str List of column names to generate KDE plots for norm : bool, optional Whether to normalize the data, by default False targets : bool, optional Whether to group data by target classes, by default False hist : bool, optional Whether to show histogram, by default True curve : bool, optional Whether to show KDE curve, by default True binsize : float, optional Bin size for the histogram, by default 0.2 height : int, optional Height of the plot, by default 500 cmap : list, optional List of colors for the plot, by default ["#F66095", "#2BCDC1"] Returns ------- go.Figure Plotly Figure object representing the KDE plot """ if norm is True: df = PowerX(self.df, cols=cols, join_data=True).Xt cols = [c + "_scl" for c in cols] tag = "-norm" else: df = self.df tag = "" if targets is True: hist_data = [df.loc[df[self.target] == c][cols[0]] for c in self.classes] group_labels = self.labels # [f"{cols[0]}={i}" for i in self.labels] title = f"KDE {cols[0]} by target class ({self.target})" name = f"kde-targets-{cols[0]}{tag}.html" else: hist_data = [df[c] for c in cols] group_labels = cols title = f"KDE {group_labels[0]} vs {group_labels[1]}" name = f"kde-{group_labels[0]}-{group_labels[1]}{tag}.html" fig = ff.create_distplot( hist_data, group_labels, colors=cmap, bin_size=binsize, show_hist=hist, show_curve=curve, ) fig.update_layout( title_text=title, paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, width=width, height=height, ) if self.save_html: if not os.path.exists(self.save_html): os.makedirs(self.save_html, exist_ok=True) pyo.plot(fig, filename=f"{self.save_html}/{name}") if self.show: fig.show() return fig
[docs] def scatter3d(self, x, y, z, mask=None, target=None): """Generates a 3D scatterplot for three features in the dataframe. Parameters ---------- x : str feature column name for x-axis y : str feature column name for y-axis z : str feature column name for z-axis mask : pd.DataFrame, optional DataFrame to use as a mask/filter, by default None target : str, optional target column name, by default None Returns ------- go.Figure Plotly Figure object representing the 3D scatterplot """ if mask is None: df = self.df else: df = mask if target is None: target = self.target traces = [] for targ, group in df.groupby(target): trace = go.Scatter3d( x=group[x], y=group[y], z=group[z], name=targ, mode="markers", marker=dict(size=7, color=targ, colorscale="Plasma", opacity=0.8), ) traces.append(trace) layout = go.Layout( title=f"3D Scatterplot: {x} - {y} - {z}", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, legend_title_text=target, ) fig = go.Figure(data=traces, layout=layout) fig.update_layout(scene=dict(xaxis_title=x, yaxis_title=y, zaxis_title=z)) if self.save_html: pyo.plot(fig, filename=f"{self.save_html}/scatter3d.html") if self.show: fig.show() else: return fig
[docs] def remove_outliers(self, y_data): """Removes outliers from a given pandas Series using the IQR method. Parameters ---------- y_data : pd.Series The data from which to remove outliers. Returns ------- pd.Series The data with outliers removed via IQR filtering. """ q = y_data.quantile([0.25, 0.75]).values q1, q3 = q[0], q[1] lower_fence = q1 - 1.5 * iqr(y_data) upper_fence = q3 + 1.5 * iqr(y_data) y = y_data.loc[(y_data > lower_fence) & (y_data < upper_fence)] return y
[docs] def box_plots(self, cols=None, outliers=True): """Generates multi-trace box plots for each feature in cols param, with or without outliers Parameters ---------- cols : list, optional features to plot from dataframe, by default None (uses self.continuous attribute) outliers : bool, optional whether to include outliers in the box plots, by default True Returns ------- dict dictionary of plotly box plot figures for each feature in cols parameter """ box = {} title_sfx = "" features = cols or self.continuous for f in features: traces = [] for i, name in enumerate(self.gkeys.values()): y_data = self.categories[name][f] if outliers is False: y_data = self.remove_outliers(y_data) title_sfx = "- no outliers" trace = go.Box(y=y_data, name=name, marker=dict(color=self.cmap[i])) traces.append(trace) layout = go.Layout( title=f"{f} by {self.group}{title_sfx}", hovermode="closest", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, ) fig = go.Figure(data=traces, layout=layout) box[f] = fig return box
[docs] def make_box_figs(self, vars: list): """Generates single trace box plots, one plot for each var where `vars` is a list of columns in df Parameters ---------- vars : list column names in dataframe to plot Returns ------- list list of plotly box plot figures for each variable in vars parameter """ box_figs = [] if not self.group_dict: self.map_df_by_group() for v in vars: data = [go.Box(y=j[0][v], name=i) for i, j in self.group_dict.items()] layout = go.Layout( title=f"{v} by {self.group}", hovermode="closest", paper_bgcolor="#242a44", plot_bgcolor="#242a44", font={"color": "#ffffff"}, ) fig = go.Figure(data=data, layout=layout) box_figs.append(fig) return box_figs
[docs] def grouped_barplot(self, target="label", cmap=None): """Draws a grouped bar plot for a target column, grouped by the `group` attribute. Parameters ---------- target : str, optional target column to plot, by default "label" cmap : list, optional list of colors for the bars, by default None Returns ------- go.Figure plotly figure object for the grouped bar plot """ df = self.df if cmap is None: cmap = self.cmap or ["red", "orange", "yellow", "purple", "blue"] groups = df.groupby([self.group])[target] traces = [] for key, value in self.gkeys.items(): dx = groups.get_group(key).value_counts() trace = go.Bar(x=dx.index, y=dx, name=value.upper(), marker=dict(color=cmap[key])) traces.append(trace) layout = go.Layout(title=f"{target.title()} by {self.group.title()}") fig = go.Figure(data=traces, layout=layout) if self.save_html: pyo.plot(fig, filename=f"{self.save_html}/grouped-bar.html") if self.show: fig.show() else: return fig
[docs] class HstSvmPlots(DataPlots): """Instantiates an HstSvmPlots class Parameters ---------- DataPlots : class spacekit.analyzer.explore.DataPlots parent class """ def __init__(self, df, group="det", width=1300, height=700, show=False, save_html=None, **log_kws): super().__init__( df, width=width, height=height, show=show, save_html=save_html, telescope="hst", name="HstSvmPlots", **log_kws, ) self.group = group self.target = "label" self.classes = list(set(df[self.target].values)) # [0, 1] self.labels = ["aligned", "misaligned"] self.n_classes = len(set(self.labels)) self.gkeys = self.group_keys() self.cmap = ["#119dff", "salmon", "#66c2a5", "fuchsia", "#f4d365"] self.feature_subset() self.continuous = ["rms_ra", "rms_dec", "gaia", "nmatches", "numexp"] self.categorical = ["det", "wcs", "cat"] self.feature_list = self.continuous + self.categorical self.map_df_by_group()
[docs] def draw_plots(self): self.alignment_bars() self.alignment_scatters() self.alignment_kde()
[docs] def alignment_bars(self): self.bar = {} X = sorted(list(self.gkeys.keys())) for f in self.continuous: means, errs = self.feature_stats_by_target(f) bar = self.bar_plots(X, means, f, y_err=errs) self.bar[f] = bar
[docs] def alignment_scatters(self): rms_scatter = self.make_target_scatter_figs("rms_ra", "rms_dec", categories=self.categories) source_scatter = self.make_target_scatter_figs("point", "segment", categories=self.categories) self.scatter = {"rms_ra_dec": rms_scatter, "point_segment": source_scatter}
[docs] def alignment_kde(self): cols = self.continuous self.kde = dict(rms=self.kde_plots(["rms_ra", "rms_dec"]), targ={}, norm={}) targ = [self.kde_plots([c], targets=True) for c in cols] norm = [self.kde_plots([c], norm=True, targets=True) for c in cols] for i, c in enumerate(cols): self.kde["targ"][c] = targ[i] self.kde["norm"][c] = norm[i]
[docs] class HstCalPlots(DataPlots): def __init__(self, df, group="instr", width=1300, height=700, show=False, save_html=None, **log_kws): super().__init__( df, width=width, height=height, show=show, save_html=save_html, telescope="hst", name="HstCalPlots", **log_kws, ) self.target = "mem_bin" self.classes = [0, 1, 2, 3] self.group = group self.labels = ["2g", "8g", "16g", "64g"] self.gkeys = self.group_keys() self.group_dict = {} self.cmap = ["dodgerblue", "gold", "fuchsia", "lime"] self.data_map = None self.feature_subset() self.instruments = list(self.df["instr_key"].unique()) self.continuous = ["n_files", "total_mb", "x_files", "x_size"] self.categorical = [ "drizcorr", "pctecorr", "crsplit", "subarray", "detector", "dtype", "instr", ] self.feature_list = self.continuous + self.categorical self.scatter3 = None
[docs] def draw_plots(self): self.make_cal_scatterplots() self.box = self.box_plots() box_target = self.box_plots(cols=["memory", "wallclock"]) box_fenced = self.box_plots(cols=["memory", "wallclock"], outliers=False) self.box["memory"] = box_target["memory"] self.box["wallclock"] = box_target["wallclock"] self.box["mem_fence"] = box_fenced["memory"] self.box["wall_fence"] = box_fenced["wallclock"]
# self.scatter3 = self.make_cal_scatter3d() # self.bar # self.kde
[docs] def make_cal_scatterplots(self): memory_figs, wallclock_figs = {}, {} for f in self.feature_list: memory_figs[f] = self.make_feature_scatter_figs(f, "memory") wallclock_figs[f] = self.make_feature_scatter_figs(f, "wallclock") self.scatter = dict(memory=memory_figs, wallclock=wallclock_figs)
[docs] def make_cal_scatter3d(self): x, y = "memory", "wallclock" self.scatter3 = {} for z in self.continuous: data = self.df[[x, y, z, "instr_key"]] self.scatter3[z] = super().scatter3d(x, y, z, mask=data, target="instr_key", width=700, height=700)
class SignalPlots: """Class for plotting time series signals and their spectrograms.""" def __init__( self, show=False, save_png=False, target_cns={}, color_map={}, output_dir=None, name="SignalPlots", **log_kws, ): """Class for manipulating and plotting time series signals and frequency spectrograms. Parameters ---------- show : bool, optional display plot, by default False save_png : str, optional save plot as PNG file, by default False target_cns : dict, optional target label and string keypairs, by default {} color_map : dict, optional target label and color keypairs, by default {} """ self.__name__ = name self.log = Logger(self.__name__, **log_kws).spacekit_logger() self.show = show self.save_png = save_png self.target_cns = target_cns self.color_map = color_map self.flux_col = "pdcsap_flux" self.extra_cols = ["lc_start", "lc_end", "maxpower", "transit", "mean", "median", "stddev"] self.output_dir = os.getcwd() if output_dir is None else output_dir self.check_dependencies() warnings.filterwarnings(action="ignore") # ignore astropy warnings def check_dependencies(self): if not check_ast_imports() or not check_mpl_imports(): self.log.error("astropy and/or matplotlib not installed.") raise ImportError( "You must have astropy and matplotlib installed " f"for the {self.__name__} class to work." "\n\nInstall extra deps via `pip install spacekit[x]`" ) def parse_filename(self, fname, fmt="kepler.fits"): """Extracts target information from FITS light curve file name. Parameters ---------- fname : str path to FITS light curve file (llc or lc) fmt : str, optional 'kepler.fits' or 'tess.fits', by default "kepler.fits" Returns ------- tuple target id (str), campaign/sector id (str) """ fname = os.path.basename(fname) if fmt == "kepler.fits": # r"ktwo{obs_id}-c{campaign}_llc.fits" patt = r"ktwo(\d{9,15})-c(\d{2})_llc\.fits" m = re.match(patt, fname) if m: return (m.group(1), m.group(2)) # tid, campaign elif fmt == "tess.fits": # r"tess{date-time}-s{sctr}-{tid}-{scid}-{cr}_lc.fits" patt = r"^tess(\d{13})-s(\d{4})-(\d{16,20})-(\d{4})-s_lc\.fits$" m = re.match(patt, fname) if m: return (m.group(2), m.group(1)) # tid, sector else: raise ValueError("fmt must be 'kepler.fits' or 'tess.fits'") raise ValueError("Filename does not match expected pattern") @staticmethod def read_ts_signal(fits_file, signal_col="pdcsap_flux", fmt="kepler.fits", offset=False, remove_nans=True): """Reads time series signal data from a FITS light curve file (_llc.fits or _lc.fits for kepler and fits respectively). Optionally can apply telescope-specific BJD offset as determined by `fmt` kwarg (most light curve files already have this applied) and remove NaN values from both signal and corresponding timestamp arrays. Regarding the `signal_col` defaults: "sap_flux" is Simple Aperture Photometry flux, the flux after summing the calibrated pixels within the telescope's optimal photometric aperture; the default (recommended) is "pdcsap_flux" (Pre-search Data Conditioned Simple Aperture Photometry, the SAP flux values nominally corrected for instrumental variations - these are the mission's best estimate of the intrinsic variability of the target.). Parameters ---------- fits_file : str path to FITS light curve file (llc or lc) signal_col : str, optional header column name containing the data, by default "pdcsap_flux" fmt : str, optional 'kepler.fits' or 'tess.fits', by default "kepler.fits" offset : bool, optional apply telescope-specifc BJD offset to timestamps, by default False remove_nans : bool, optional remove NaN values from signal and timestamps, by default True Returns ------- np.ndarray time series signal data as a numpy array """ if fmt not in ["kepler.fits", "tess.fits"]: raise ValueError("fmt must be 'kepler.fits' or 'tess.fits'") ts = TimeSeries.read(fits_file, format=fmt) flux = np.asarray(ts[signal_col], dtype="float64") timestamps = ts.time.jd if offset is True: bjd = dict(kepler=2454833.0, tess=2457000.0)[fmt.split(".")[0]] timestamps -= bjd # convert to KBJD/TBJD if remove_nans is True: not_nan_mask = ~np.isnan(flux) flux = flux[not_nan_mask] timestamps = timestamps[not_nan_mask] return timestamps, flux def atomic_vector_plotter( self, signal, timestamps=None, label=None, y_units="PDCSAP Flux (e-/s)", # aperture photometry flux x_units="Time (BJD)", # Barycentric Julian Date figsize=(15, 10), fname="flux_signal.png", title_pfx="Flux Signal", ): """Plots scatter and line plots of time series signal values. Parameters ---------- signal : np.ndarray or pandas Series time series signal data y_units : str, optional y-axis label, by default "PDCSAP Flux (e-/s)" x_units : str, optional x-axis label, by default "Time (BJD)" """ cn = self.target_cns.get(label, "") color = self.color_map.get(label, "black") title = title_pfx + f": {cn}" if cn != "" else title_pfx if timestamps is None: timestamps = list(range(len(signal))) x_units = "Time Cadence Index" fig, axs = plt.subplots(nrows=2, ncols=1, figsize=figsize, sharex=True) axs[0].plot( timestamps, signal, color=color, ) axs[0].set_ylabel(y_units) axs[1].scatter( timestamps, signal, marker=4, color=color, ) axs[1].set_ylabel(y_units) plt.xlabel(x_units) plt.suptitle(title) fig.tight_layout() if self.save_png: fpath = str(os.path.join(self.output_dir, fname)) + ".png" fig.savefig(fpath, dpi=300) if self.show: plt.show() else: plt.close() def signal_phase_folder(self, file_list, fmt="kepler.fits", error=True, snr=True, include_extra=False): """Generates phase-folded light curves from LLC/LCF flux signals Parameters ---------- file_list : list list of FITS file path(s) containing time series data flux_col : str, optional header column name containing the data, by default "pdcsap_flux" fmt : str, optional 'kepler.fits' or 'tess.fits', by default "kepler.fits" error : bool, optional include SAP flux error (residuals) if available, by default True snr : bool, optional apply signal-to-noise-ratio to periodogram autopower calculation, by default True Returns ------- pd.DataFrame transit timestamps and phase folded flux values for each light curve """ # req_cols = ["obs_id", "campaign", "time_jd", "sap_flux_norm", "time_bin_start", "sap_flux_norm_binned", "period"] transits = {} for index, file in enumerate(file_list): res = {} fname = os.path.basename(file) (tid, sc) = self.parse_filename(fname, fmt=fmt) ts = TimeSeries.read(file, format=fmt) # read in timeseries # add to meta dict res["tid"] = tid res["sc"] = sc # use box least squares to estimate period if error is True and f"{self.flux_col}_err" in ts.columns: periodogram = BoxLeastSquares.from_timeseries(ts, self.flux_col, f"{self.flux_col}_err") else: periodogram = BoxLeastSquares.from_timeseries(ts, self.flux_col) if snr is True: results = periodogram.autopower(0.2 * u.day, objective="snr") else: results = periodogram.autopower(0.2 * u.day) maxpower = np.argmax(results.power) period = results.period[maxpower] res["period"] = period transit_time = results.transit_time[maxpower] # fold the time series using the period ts_folded = ts.fold(period=period, epoch_time=transit_time) res["time_jd"] = ts_folded.time.jd # normalize the flux by sigma-clipping the data to determine the baseline flux: mean, median, stddev = sigma_clipped_stats(ts_folded[self.flux_col]) ts_folded["flux_norm"] = ts_folded[self.flux_col] / median res["flux_norm"] = ts_folded["flux_norm"] # downsample the time series by binning the points into bins of equal time ts_binned = aggregate_downsample(ts_folded, time_bin_size=0.03 * u.day) res["time_bin_start"] = ts_binned.time_bin_start.jd res["flux_norm_binned"] = ts_binned["flux_norm"] if include_extra: res["lc_start"] = ts.time.jd[0] res["lc_end"] = ts.time.jd[-1] res["transit"] = transit_time res["maxpower"] = maxpower res["mean"] = mean res["median"] = median res["stddev"] = stddev res["fname"] = fname transits[index] = res df = pd.DataFrame.from_dict(transits, orient="index") return df def plot_phase_signals(self, ts, title_pfx="Phase-folded Light Curve: ", figsize=(11, 5)): """Plots a phase-folded light curve from timeseries flux signal data. Requires a dataframe row containing the following columns: "time_jd", "flux_norm", "time_bin_start", "flux_norm_binned", "tid", "sc", "period" e.g., df = SignalPlots.signal_phase_folder(file_list) ts = df.iloc[index] signal_plots.plot_phase_signals(ts) Parameters ---------- ts : ArrayLike timeseries flux signal data title_pfx : str, optional Plot title prefix, by default "Phase-folded Light Curve: " figsize : tuple, optional figure size, by default (11,5) """ fig = plt.figure(figsize=figsize) ax = fig.gca() ax.plot(ts["time_jd"], ts["flux_norm"], "k.", markersize=1) ax.plot( ts["time_bin_start"], ts["flux_norm_binned"], "r-", drawstyle="steps-post", ) ax.set_xlabel("Time (days)") ax.set_ylabel("Normalized flux") ax.set_title(title_pfx + ts["tid"]) ax.legend([np.round(ts["period"], 3)]) if self.save_png: fpath = os.path.join(self.output_dir, f"{ts['sc']}-{ts['tid']}_phase_folded.png") fig.savefig(fpath, dpi=300) if self.show: plt.show() else: plt.close() def set_spec_kwargs(self, Fs=2, NFFT=256, noverlap=128, mode="psd", cmap="binary"): """returns dict of default spectrogram kwargs Returns ------- dict default spectrogram kwargs """ spec_kwargs = { "Fs": Fs, "NFFT": NFFT, "noverlap": noverlap, "mode": mode, "cmap": cmap, } return spec_kwargs def flux_specs( self, signal, units=["Wavelength (λ)", "Frequency (ν)"], colorbar=True, save_for_ml=False, fname="specgram", title="Spectrogram", **kwargs, ): """generate and save spectrograms of flux signal frequencies. By default uses kwargs in `set_spec_kwargs` method. Parameters ---------- signal : ArrayLike 1D array-like signal data units : list of strings, optional x and y units respectively, by default N["Wavelength (λ)", "Frequency (ν)"] colorbar : bool, optional include colorbar in plot, by default True save_for_ml : bool, optional plots pixel grid only (no axes, colorbar or labels), by default False fname : str, optional filename without extension for saving png, by default 'specgram' title : str, optional plot title, by default "Spectrogram" **kwargs : dict matplotlib.pyplot.specgram keyword arguments Returns ------- tuple periodogram, freqs, t, m - see matplotlib.pyplot.specgram """ fpath = os.path.join(self.output_dir, fname) spec_kwargs = self.set_spec_kwargs(**kwargs) if save_for_ml is True: fig, ax = plt.subplots(figsize=(10, 10), frameon=False) ax.axis(False) else: fig, ax = plt.subplots(figsize=(13, 11)) if colorbar: plt.colorbar() units = ["Wavelength (λ)", "Frequency (ν)"] if units is None or len(units) < 2 else units plt.xlabel(units[0]) plt.ylabel(units[1]) plt.title(title) fig, freqs, t, m = plt.specgram( signal, **spec_kwargs, ) if self.save_png: plt.savefig(fpath, dpi=300) if self.show: plt.show() else: plt.close() return fig, freqs, t, m class K2SignalPlots(SignalPlots): """Class for plotting K2 time series signals and their spectrograms.""" def __init__( self, flux_col="pdcsap_flux", show=False, save_png=True, target_cns={1: "No Planet", 2: "Planet"}, color_map={1: "red", 2: "blue"}, **log_kws, ): """_summary_ Parameters ---------- show : bool, optional display plot, by default False save_png : bool, optional save plot as PNG file, by default True target_cns : dict, optional target label and string keypairs, by default {1: "No Planet", 2: "Planet"} color_map : dict, optional target label and color keypairs, by default {1: "red", 2: "blue"} """ super().__init__( show=show, save_png=save_png, flux_col=flux_col, target_cns=target_cns, color_map=color_map, name="K2SignalPlots", **log_kws, ) self.df = None self.files = [] def generate_dataframe(self): """Generates dataframe of K2 light curve signal properties from list of FITS files""" if len(self.files) == 0: raise ValueError("No files provided. Set `self.files` to a list of K2 FITS light curve file paths.") self.df = self.signal_phase_folder(self.files, fmt="kepler.fits", error=True, snr=True, include_extra=True) def generate_raw_flux_df(self, flux_col="SAP_FLUX", add_label=None, ffillna=True): """Generates dataframe of raw flux signals from list of K2 FITS files""" if len(self.files) == 0: raise ValueError("No files provided. Set `self.files` to a list of K2 FITS light curve file paths.") records = {} for index, file in enumerate(self.files): with fits.open(file) as hdulist: signal = hdulist[1].data[flux_col] records[index] = np.asarray(signal, dtype="float64") df = pd.DataFrame.from_dict(records, orient="index") if ffillna is True: df.ffill(axis=1, inplace=True) df.columns = ["FLUX." + str(c + 1) for c in df.columns] if isinstance(add_label, int): cols = list(df.columns) df["LABEL"] = add_label df = df[["LABEL"] + cols] return df def generate_specs(self, ml_ready=False, rgb=True): """Generates spectrograms for each light curve signal in dataframe""" if self.df is None: self.generate_dataframe(self.files) if rgb is True: kwargs = self.set_spec_kwargs(cmap="plasma") for _, row in self.df.iterrows(): fname = row["fname"].replace(".fits", "_specgram") _, flux = self.read_ts_signal(row["fname"], fmt="kepler.fits", offset=True, remove_nans=True) self.flux_specs( flux, save_for_ml=ml_ready, fname=fname, title=f"Spectrogram: {row['sc']}-{row['tid']}", **kwargs, ) def generate_phase_signal_plots(self): """Generates phase-folded light curve plots for each signal in dataframe""" if self.df is None: self.generate_dataframe(self.files) for i in list(range(len(self.df))): ts = df.iloc[i] self.plot_phase_signals(ts, title_pfx="K2 Phase-folded Light Curve: ", figsize=(11, 5)) def generate_flux_signal_plots(self): """Generates atomic vector plots for each signal in dataframe""" if self.df is None: self.generate_dataframe(self.files) for _, row in self.df.iterrows(): fname = row["fname"].replace(".fits", "_flux_signal") timestamps, flux = self.read_ts_signal(row["fname"], fmt="kepler.fits", offset=True, remove_nans=True) self.atomic_vector_plotter( flux, timestamps=timestamps, y_units="PDCSAP Flux (e-/s)", x_units="Time (BJD)", figsize=(15, 10), fname=fname, title_pfx=f"K2 Flux Signal: {row['sc']}-{row['tid']}", ) # testing if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("dataset", type=str, help="path to dataframe (csv file)") parser.add_argument("index", type=str, default="index", help="index column name") parser.add_argument("-e", "--example", type=str, choices=["svm", "cal"], help="run example demo") args = parser.parse_args() dataset = args.dataset index = args.index example = args.example df = pd.read_csv(dataset, index_col=index) if example == "svm": # Drop extra columns in case raw / un-preprocessed dataset is loaded drops = ["category", "ra_targ", "dec_targ", "imgname"] df.drop([c for c in drops if c in df.columns], axis=1, inplace=True) svm = HstSvmPlots(df) else: print("More examples coming soon!")