Source code for BuildYourOwnEmbedding.responses

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
import numpy.typing as npt
from typing import Union, Dict, Tuple, Any, Type, Callable, List
import inspect
import itertools
import mplcursors
from matplotlib.widgets import Slider, Button
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import warnings

from .parameters import Parameter
from .functional import inverse_correlation, plot_rdm, mds, \
    pca_variance_explained, pca, fisher_information as fi

##! =====================================
##!             Responses:
##! =====================================


[docs] class ResponseFunction(ABC): """ Base class for defining neural responses. The `ResponseFunction` class serves as an abstract base class for creating neural response functions that can be evaluated based on input data. This class provides a standardised interface for defining neural response curves, allowing users to implement custom responses by extending this class and implementing the `evaluate` method. It also supports mathematical operations to combine responses, creating composite response functions. Attributes: params (Dict): A dictionary containing response-specific parameters used to define the behavior of the response function. Methods: evaluate(x: npt.NDArray) -> npt.NDArray: Abstract method to evaluate the neural response for a given input. __call__(x: npt.NDArray) -> npt.NDArray: Calls the `evaluate` method to compute the response for the given input. __add__(other: ResponseFunction) -> CompositeResponse: Combines this response with another response using addition. __sub__(other: ResponseFunction) -> CompositeResponse: Combines this response with another response using subtraction. __mul__(other: ResponseFunction) -> CompositeResponse: Combines this response with another response using multiplication. __str__() -> str: Returns a string representation of the response function. """ def __init__(self, **kwargs: Any) -> None: """ Initialises the response function with specific parameters. This constructor accepts named arguments for the parameters that define the behavior of the neural response function. These parameters are stored in a dictionary for later use by the response function. Args: **kwargs (Any): Named arguments representing the parameters of the response function. Notes: Subclasses should provide specific parameter requirements, which will be documented in their respective class definitions. """ self.params: Dict = kwargs # Store all parameters for general use
[docs] @abstractmethod def evaluate(self, x: npt.NDArray) -> npt.NDArray: """ Evaluate the response of the neuron for a given input. This abstract method must be implemented by subclasses to define how the response function computes output values based on input data. Args: x (npt.NDArray): The input data for which the response is evaluated. This should be a Numpy array representing the stimulus or feature input to the neuron. Returns: npt.NDArray: A Numpy array representing the evaluated response values corresponding to the input data. Raises: NotImplementedError: This method must be overridden by subclasses, otherwise, calling it will raise this error. """ pass
def __call__( self, x: npt.NDArray, noiseLevel: npt.number = 0, noiseType: str = "Gaussian" ) -> npt.NDArray: """ Compute the response by calling the evaluate method. This method allows the object to be used as a callable function, directly invoking the `evaluate` method when the instance is called with input data. Args: x (npt.NDArray): The input data for which the response is evaluated. noiseLevel (npt.number, optional): The standard deviation of Gaussian noise to be added to the response. Defaults to 0 (no noise). noiseType (str, optional): The type of noise to add. Defaults to 'Gaussian'. Returns: npt.NDArray: The response values computed by the `evaluate` method. """ response = self.evaluate(x) if noiseLevel == 0: return response if noiseType == "Gaussian": return response + np.random.normal(0, noiseLevel, response.shape) # Unknown noise type warnings.warn(f"Unknown noise type {noiseType}. Defaulting to Gaussian noise.") return response + np.random.normal(0, noiseLevel, response.shape) def __add__(self, other: ResponseFunction) -> CompositeResponse: """ Add another response function to this response function. This method allows the creation of a composite response by adding the outputs of two different response functions. The resulting composite response is represented by a new instance of the `CompositeResponse` class. Args: other (ResponseFunction): The other response function to be added to this response. Returns: CompositeResponse: A new `CompositeResponse` instance representing the combined response of this response and the other response. """ return CompositeResponse(self, other, np.add) def __sub__(self, other: ResponseFunction) -> CompositeResponse: """ Subtract another response function from this response function. This method allows the creation of a composite response by subtracting the outputs of another response function from this response. The resulting composite response is represented by a new instance of the `CompositeResponse` class. Args: other (ResponseFunction): The other response function to be subtracted from this response. Returns: CompositeResponse: A new `CompositeResponse` instance representing the difference between this response and the other response. """ return CompositeResponse(self, other, np.subtract) def __mul__(self, other: ResponseFunction) -> CompositeResponse: """ Multiply this response function by another response function. This method allows the creation of a composite response by multiplying the outputs of two different response functions. The resulting composite response is represented by a new instance of the `CompositeResponse` class. Args: other (ResponseFunction): The other response function to multiply with this response. Returns: CompositeResponse: A new `CompositeResponse` instance representing the product of this response and the other response. """ return CompositeResponse(self, other, np.multiply) def __str__(self) -> str: """ Return a string representation of the response function. This method provides a human-readable representation of the response function, including its class name and the parameters used to define its behavior. Returns: str: A string representation of the response function, displaying the class name and its parameter values. """ return f"{type(self).__name__}, {self.params}"
[docs] class CompositeResponse(ResponseFunction): """ Composite response class that combines two responses using a specified operation. """ def __init__( self, response1: ResponseFunction, response2: ResponseFunction, operation: Callable[[npt.NDArray, npt.NDArray], npt.NDArray], ) -> None: """ Initialises a composite response function from two individual responses and an operation. Args: response1 (ResponseFunction): First response response2 (ResponseFunction): Second response operation (Callable[[npt.NDArray, npt.NDArray], npt.NDArray]): An operation applied to the outputs of the individual responses """ super().__init__() self.response1: ResponseFunction = response1 self.response2: ResponseFunction = response2 self.operation = operation
[docs] def evaluate(self, x: npt.NDArray) -> npt.NDArray: """ Evaluates the CompositeResponse for a given input stimuli. Args: x (npt.NDArray): The input stimuli for which to generate a response. Returns: npt.NDArray: The response of the CompositeResponse function to the stimuli `x`. """ return self.operation(self.response1(x), self.response2(x))
def __str__(self) -> str: return ( f"{self.operation.__name__}({str(self.response1)}, {str(self.response2)})" )
[docs] class GaussianResponse(ResponseFunction): """GaussianResponse class. Is the response function for a 1-dimensional gaussian response model. """ def __init__(self, mean: npt.number, std: npt.number) -> None: """Constructor method for the GaussianResponse response function. Args: mean (npt.number): The mean value for the desired Gaussian response function. std (npt.number): The standard deviation of the desired Gaussian response function. """ super().__init__(mean=mean, std=std)
[docs] def evaluate( self, x: npt.NDArray, dtype: npt.DTypeLike = np.float64 ) -> npt.NDArray: """Evaluates the GaussianResponse response function for a given stimuli. Args: x (npt.NDArray): The stimuli for which to generate a response to. dtype (npt.DTypeLike, optional): The desired data type of the output. Defaults to np.float64. Returns: npt.NDArray: A numpy array containing the values of the Response to the stimulus `x`. """ mean = self.params["mean"] std = self.params["std"] return dtype( np.exp(-0.5 * ((x - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi)) )
[docs] class SigmoidResponse(ResponseFunction): """SigmoidResponse class. Is the response function for a 1-dimensional sigmoid response model. """ def __init__(self, alpha: npt.number, beta: npt.number) -> None: """Constructor for SigmoidResponse response function. Args: alpha (npt.number): The alpha value (steepness of slope) of the desired sigmoid response function. beta (npt.number): The beta value (pivot point) of the desired sigmoid response function. """ super().__init__(alpha=alpha, beta=beta)
[docs] def evaluate( self, x: npt.NDArray, dtype: npt.DTypeLike = np.float64 ) -> npt.NDArray: """Evaluates the SigmoidResponse response function for a given stimuli. Args: x (npt.NDArray): The stimuli for which to generate a response to. dtype (npt.DTypeLike, optional): The desired data type of the output. Defaults to np.float64. Returns: npt.NDArray: A numpy array containing the values of the Response to the stimulus `x`. """ alpha = self.params["alpha"] beta = self.params["beta"] return dtype(1 / (1 + np.exp(-alpha * (x - beta))))
[docs] class VonMisesResponse(ResponseFunction): """VonMisesResponse class. Is the response function for a 1-dimensional Von-mises response model. """ def __init__(self, kappa: npt.number, theta: npt.number) -> None: """Constructor for the VonMisesResponse response function. Args: kappa (npt.number): The kappa value for the desired von-mises response. theta (npt.number): The theta value for the desired von-mises response. """ super().__init__(kappa=kappa, theta=theta) def _I0(self, kappa: npt.number, numTerms: int = 50) -> npt.number: # modified bessel function of the first kind of order 0, calculated with a Maclaurin series expansion result = 0 cumulativeFactorial = 1 for i in range(1, numTerms + 1): cumulativeFactorial *= i result += (kappa / 2) ** (2 * i) / (cumulativeFactorial**2) return result
[docs] def evaluate( self, x: npt.NDArray, dtype: npt.DTypeLike = np.float64 ) -> npt.NDArray: """Evaluates a Von-mises response for a given stimuli Args: x (npt.NDArray): The stimuli for which to generate a response to. dtype (npt.DTypeLike, optional): The desired data type of the output. Defaults to np.float64. Returns: npt.NDArray: A numpy array containing the values of the Response to the stimulus `x`. """ kappa = self.params["kappa"] theta = self.params["theta"] return np.exp(kappa * np.cos(x - theta)) / (2 * np.pi * self._I0(kappa))
[docs] class GaussianResponse2D(ResponseFunction): def __init__( self, xMean: npt.number, yMean: npt.number, xStd: npt.number, yStd: npt.number ) -> None: """Constructor for 2D Gaussian response function. Args: xMean (npt.number): The mean of the desired Gaussian response in the x-axis. yMean (npt.number): The standard deviation of the desired Gaussian response in the x-axis. xStd (npt.number): The mean of the desired Gaussian response in the y-axis. yStd (npt.number): The standard deviation of the desired Gaussian response in the y-axis. """ super().__init__(xMean=xMean, yMean=yMean, xStd=xStd, yStd=yStd)
[docs] def evaluate( self, x: npt.NDArray, dtype: npt.DTypeLike = np.float64 ) -> npt.NDArray: """Evaluates the GaussianResponse2D response function for a given stimuli. Args: x (npt.NDArray): The stimuli for which to generate a response to. This stimuli should be a 2-dimensional numpy array. dtype (npt.DTypeLike, optional): The desired data type of the output. Defaults to np.float64. Returns: npt.NDArray: A numpy array containing the values of the Response to the stimulus `x`. """ xVals, yVals = x # unpack return dtype( np.exp( -( ((xVals - self.params["xMean"]) ** 2) / (2 * self.params["xStd"] ** 2) + ((yVals - self.params["yMean"]) ** 2) / (2 * self.params["yStd"] ** 2) ) ) )
##! ===================================== ##! Response Manager: ##! ===================================== # this class will generate sets of responses:
[docs] class ResponseManager: def __init__( self, responseClasses: Union[ Type[ResponseFunction], Tuple[Type[ResponseFunction], ...] ], **parameters: Parameter, ) -> None: """ Initialises response manager class with specific response type and parameters. Args: responseClasses (Type[ResponseFunction]): A single response class or a tuple of response classes to generate responses for. **param_strategies (Dict[str, Parameter]): Parameter strategies for each response class. """ self.responseClasses: Type[ResponseFunction] = responseClasses self.parameters: Dict[str, Parameter] = parameters def _get_response_parameters(self) -> List[str]: """ Gets the names of the parameters from the response class' __init__ method Returns: List[str]: A list of parameter names """ return list(inspect.signature(self.responseClasses.__init__).parameters.keys())[ 1: ] def _get_combinations(self) -> List[Dict[str, Any]]: """ Gets all combinations of parameter values based on the input parameters Returns: List[Dict[str, Any]]: A list of dictionaries containing all combinations of parameter values """ # paramValuesDict = {str(i): param.get_values() for i, param in enumerate(self.parameters)} #! paramValuesDict = self.parameters paramNames = self._get_response_parameters() paramCombinations = itertools.product( *[p.get_values() for p in [*paramValuesDict.values()]] ) return [dict(zip(paramNames, paramCombo)) for paramCombo in paramCombinations] def _generate_response( self, x: npt.NDArray, responseParams: Dict[str, Any], noiseLevel: npt.number = 0 ) -> ResponseData: """ Generates a single response based on the provided parameter values Args: x (npt.NDArray): The input data for which the response is generated responseParams (Dict[str, Any]): Parameter values for the response noiseLevel (npt.number, optional): The standard deviation of Gaussian noise to be added to the response. Defaults to 0 (no noise). Returns: ResponseData: A ResponseData object containing the response values and the corresponding parameters. """ response = self.responseClasses(**responseParams) responseValues = response(x, noiseLevel) return ResponseData( params=responseParams, response=responseValues, responseFunction=self.responseClasses.__name__, x=x, )
[docs] def generate_responses( self, x: npt.NDArray, noiseLevel: npt.number = 0, numSamples: int = 1 ) -> ResponseSet: """ Generates responses for each combination of parameter values Args: x (npt.NDArray): The input data for which the response is generated noiseLevel (npt.number, optional): The standard deviation of Gaussian noise to be added to the response. Defaults to 0 (no noise). numSamples (int, optional): The number of times to sample each combination of parameters. Defaults to 1. Returns: ResponseSet: A ResponseSet object containing each response and the corresponding parameters. """ paramCombinations = self._get_combinations() responses = [] for params in paramCombinations: for _ in range(numSamples): responses.append(self._generate_response(x, params, noiseLevel)) return ResponseSet(responses=responses)
[docs] @dataclass class ResponseData: """ Data class for storing a single response output and the parameters that created it. Attributes: params (Dict[str, np.number]): The parameters used for generating the response. response (np.ndarray): The response values. responseFunction (str): The name of the Response Function that generated this Response """ params: Dict[str, npt.number] response: npt.NDArray responseFunction: str x: npt.NDArray def __str__(self) -> str: return f"{self.responseFunction}:\n" + "\n".join( f"{paramName} = {self.params[paramName]}" for paramName in self.params.keys() )
# stores a set of responses
[docs] class ResponseSet: """ A class for storing and analysing a set of neural responses. The ResponseSet class provides methods to compute representational dissimilarity matrices (RDM) and representational geodesic topological matrices (RGTM), as well as functions for visualising and analysing responses. It supports both 1D and 2D response visualisations and also allows interactive plots. Attributes: responses (List[ResponseData]): A list of ResponseData objects that store the neural responses and their associated parameters. """ def __init__(self, responses: Union[None, List[ResponseData]] = None) -> None: """ Initialise the ResponseSet with a list of response objects. Args: responses (Union[None, List[ResponseData]], optional): A list of ResponseData objects representing neural responses. Defaults to None. """ if responses: self.responses = responses else: self.responses = []
[docs] def compute_rdm( self, dissimilarityMetric: Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ] = inverse_correlation, ) -> npt.ArrayLike: """ Computes the Representational Dissimilarity Matrix (RDM) for the stored responses Args: dissimilarityMetric (Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ], optional): A function to compute dissimilarity between two responses. Defaults to inverse_correlation. Returns: npt.ArrayLike: A square RDM where each entry (i, j) represents the dissimilarity between the i-th and j-th responses. """ n: int = len(self.responses) rdm: npt.NDArray = np.zeros((n, n)) for i in range(n): for j in range(i + 1, n): rdm[i, j] = rdm[j, i] = dissimilarityMetric( self.responses[i].response, self.responses[j].response ) return rdm
[docs] def compute_rgtm( self, lowerBound: npt.number, upperBound: npt.number, dissimilarityMetric: Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ] = inverse_correlation, ) -> npt.ArrayLike: """ Computes the Representational Geodesic Topological Matrix (RGTM). The RGTM is computed by applying a piecewise linear transformation on the dissimilarity matrix based on the provided lower and upper bounds. Args: lowerBound (npt.number): The lower bound for the transformation. upperBound (npt.number): The upper bound for the transformation. dissimilarityMetric (Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ], optional): A function to compute dissimilarity between two responses. Defaults to inverse_correlation. Returns: npt.ArrayLike: The computed RGTM matrix. """ assert lowerBound <= upperBound n: int = len(self.responses) rgtm: npt.NDArray = np.zeros((n, n)) for i in range(n): for j in range(i + 1, n): # compute dissimilarity: dissimilarity = dissimilarityMetric( self.responses[i].response, self.responses[j].response ) # compute geo-topological transform: if dissimilarity <= lowerBound: transformedDissimilarity = 0 elif dissimilarity >= upperBound: transformedDissimilarity = 1 else: transformedDissimilarity = (dissimilarity - lowerBound) / ( upperBound - lowerBound ) rgtm[i, j] = rgtm[j, i] = transformedDissimilarity return rgtm
[docs] def plot_rgtm( self, lowerBound: npt.number = 0, upperBound: npt.number = 1, dissimilarityMetric: Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ] = inverse_correlation, interactive: bool = True, labels: Union[List[str], None] = None, cmap: str = "viridis", title: str = "Representational Geo-Topological Matrix", figsize: Tuple[int] = (7, 7), dissimilarityLabel: str = "Dissimilarity", ) -> None: """ Plots the Representational Geo-Topological Matrix (RGTM) with optional interactivity. Args: lowerBound (npt.number, optional): The lower bound for the RGTM transformation. Defaults to 0. upperBound (npt.number, optional): The upper bound for the RGTM transformation. Defaults to 1. dissimilarityMetric (Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ], optional): Function to compute dissimilarity between two responses. Defaults to inverse_correlation. interactive (bool, optional): Whether to include sliders for adjusting the bounds interactively. Defaults to True. labels (Union[List[str], None], optional): List of labels for the axes. Defaults to None. cmap (str, optional): Colourmap used for plotting. Defaults to "viridis". title (str, optional): Title of the plot. Defaults to "Representational Geo-Topological Matrix". figsize (Tuple[int], optional): Size of the figure. Defaults to (7, 7). dissimilarityLabel (str, optional): Label for the colour bar. Defaults to "Dissimilarity". """ rgtm: npt.NDArray = self.compute_rgtm( lowerBound, upperBound, dissimilarityMetric ) # create the plot: fig, ax = plt.subplots(figsize=figsize) # display initial RGTM: heatmap = ax.imshow(rgtm, cmap=cmap, vmin=0, vmax=1) # ? plt.colorbar(heatmap, ax=ax, label=dissimilarityLabel) ax.set_title(title) if labels: ax.set_xticks(labels) ax.set_yticks(labels) # if interactive, add interactivity behaviour: if interactive: # reposition plot to account for sliders: plt.subplots_adjust(left=0.25, bottom=0.25) # add axes for the sliders: lowerBoundSliderAx = plt.axes( [0.25, 0.1, 0.65, 0.03], facecolor="lightgoldenrodyellow" ) upperBoundSliderAx = plt.axes( [0.25, 0.15, 0.65, 0.03], facecolor="lightgoldenrodyellow" ) # add the sliders: lowerBoundSlider = Slider( ax=lowerBoundSliderAx, label="Lower Bound (l)", valmin=0, valmax=1, valinit=lowerBound, ) upperBoundSlider = Slider( ax=upperBoundSliderAx, label="Upper Bound (u)", valmin=0, valmax=1, valinit=upperBound, ) # define update function for slider on change event: def onChange(*args, **kwargs): # re-compute the RGTM: newLowerBound = lowerBoundSlider.val newUpperBound = upperBoundSlider.val if newLowerBound > newUpperBound: # display error message: ax.set_title( "Error: Upper bound must be greater than or equal to the lower bound", color="red", ) return # don't plot if bounds are invalid # compute new RGTM: newRgtm = self.compute_rgtm( newLowerBound, newUpperBound, dissimilarityMetric ) # update plot: heatmap.set_data(newRgtm) ax.set_title(title, color="black") fig.canvas.draw_idle() # add on change function to both sliders lowerBoundSlider.on_changed(onChange) upperBoundSlider.on_changed(onChange) # show the plot: plt.show()
[docs] def plot_rdm( self, dissimilarityMetric: Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ] = inverse_correlation, labels: Union[List[str], None] = None, cmap: str = "viridis", title: str = None, figsize: Tuple[int] = (7, 7), dissimilarityLabel: str = "Dissimilarity", ) -> None: """ Plots the Representational Dissimilarity Matrix (RDM) for the stored responses. Args: dissimilarityMetric (Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ], optional): Function to compute dissimilarity between two responses. Defaults to inverse_correlation. labels (Union[List[str], None], optional): Labels for the plot's x and y axes. Defaults to None. cmap (str, optional): Colourmap for the plot. Defaults to "viridis". title (str, optional): Title of the plot. Defaults to None. figsize (Tuple[int], optional): Size of the figure. Defaults to (7, 7). dissimilarityLabel (str, optional): Label for the colour bar. Defaults to "Dissimilarity". """ rdm = self.compute_rdm(dissimilarityMetric) plot_rdm(rdm, labels, cmap, title, figsize, dissimilarityLabel)
[docs] def plot_responses( self, figsize: Tuple[int] = (7, 7), xlabel: str = "Stimuli", ylabel: str = "Response", title: str = "Responses", grid: bool = False, hoverEffects: bool = True, # if false will not add hover effect to plot *args, **kwargs, ) -> None: """ Plots the responses in 1D or 2D format. Higher dimensional plots are not currently supported. Args: figsize (Tuple[int], optional): Size of the figure. Defaults to (7, 7). xlabel (str, optional): Label for the x-axis. Defaults to "Stimuli". ylabel (str, optional): Label for the y-axis. Defaults to "Response". title (str, optional): Title of the plot. Defaults to "Responses". grid (bool, optional): Whether to display a grid on the plot. Defaults to False. hoverEffects (bool, optional): If True, adds hover effects to the plot. Defaults to True. Raises: NotImplementedError: If the ResponseData represents 3-dimensional or higher curves. """ responseShape = list(self.responses[0].response.shape) # plot 1-d responses: if len(responseShape) == 1: self._plot_responses_1d( figsize, xlabel, ylabel, title, grid, hoverEffects, *args, **kwargs ) # plot 2-d responses: elif len(responseShape) == 2: self._plot_responses_2d( figsize=figsize, xlabel=xlabel, ylabel=ylabel, title=title, *args, **kwargs, ) else: raise NotImplementedError( "plotting responses higher than two dimensions is not currently supported" )
def _plot_responses_1d( self, figsize: Tuple[int] = (7, 7), xlabel: str = "Stimuli", ylabel: str = "Response", title: str = "Responses", grid: bool = False, hoverEffects: bool = True, # if false will not add hover effect to plot ) -> None: # plot the responses: plt.figure(figsize=figsize) plottedResponses = [] for i in range(len(self.responses)): (plottedResponse,) = plt.plot( self.responses[i].x, self.responses[i].response ) plottedResponses.append(plottedResponse) plt.xlabel(xlabel=xlabel) plt.ylabel(ylabel=ylabel) plt.title(label=title) plt.grid(visible=grid) if hoverEffects: cursor = mplcursors.cursor(hover=True) @cursor.connect("add") def on_add(selectedResponse) -> None: curve = selectedResponse.artist idx = plottedResponses.index(curve) # set annotation to show formatted parameters: # todo selectedResponse.annotation.set(text=f"{str(self.responses[idx])}") selectedResponse.annotation.get_bbox_patch().set( fc="white", alpha=0.8 ) # todo: make these parameters variable plt.show() def _plot_responses_2d( self, figsize: Tuple[int] = (7, 7), xlabel: str = "Stimuli 1", zlabel: str = "Stimuli 2", ylabel: str = "Response", title: str = "2D Tuning Curve Visualisation", cmap: str = "viridis", type: str = "heatmap", # can be "heatmap" or "surfaceplot" ) -> None: if type == "surfaceplot": # plot the responses: fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection="3d") self.currentResponseIndex = 0 # function to click through plot: def update_plot(responseIndex: int) -> None: ax.clear() response = self.responses[responseIndex].response x, y = self.responses[responseIndex].x ax.plot_surface(x, y, response, cmap=cmap) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_zlabel(zlabel) ax.set_title(title) plt.draw() update_plot(self.currentResponseIndex) # function for when next button is clicked: def next_response(e) -> None: self.currentResponseIndex = (self.currentResponseIndex + 1) % len( self.responses ) update_plot(self.currentResponseIndex) # function for when previous button is clicked: def prev_response(e) -> None: self.currentResponseIndex = (self.currentResponseIndex - 1) % len( self.responses ) update_plot(self.currentResponseIndex) # add buttons for interactivity: axNext = plt.axes([0.8, 0.05, 0.1, 0.075]) btnNext = Button(axNext, "Next") btnNext.on_clicked(next_response) axPrev = plt.axes([0.7, 0.05, 0.1, 0.075]) btnPrev = Button(axPrev, "Previous") btnPrev.on_clicked(prev_response) plt.show() elif type == "heatmap": # Plot the responses as a heatmap fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) self.currentResponseIndex = 0 from matplotlib.colorbar import Colorbar colourbar: Colorbar = None # function to update heatmap plot def update_plot(responseIndex: int) -> None: nonlocal colourbar # fig.clear() # ax = fig.add_subplot(111) response = self.responses[responseIndex].response x, y = self.responses[responseIndex].x c = ax.imshow( response, cmap=cmap, origin="lower", extent=[x.min(), x.max(), y.min(), y.max()], ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) # colourbar = fig.colorbar(c, ax=ax) plt.draw() update_plot(self.currentResponseIndex) # function for when next button is clicked: def next_response(e) -> None: self.currentResponseIndex = (self.currentResponseIndex + 1) % len( self.responses ) update_plot(self.currentResponseIndex) # function for when previous button is clicked: def prev_response(e) -> None: self.currentResponseIndex = (self.currentResponseIndex - 1) % len( self.responses ) update_plot(self.currentResponseIndex) # add buttons for interactivity: axNext = plt.axes([0.8, 0, 0.1, 0.075]) btnNext = Button(axNext, "Next") btnNext.on_clicked(next_response) axPrev = plt.axes([0.7, 0, 0.1, 0.075]) btnPrev = Button(axPrev, "Previous") btnPrev.on_clicked(prev_response) plt.show()
[docs] def plot_3D_mds( self, dissimilarityMetric: Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ] = inverse_correlation, xlabel: str = "MDS Dimension 1", ylabel: str = "MDS Dimension 2", zlabel: str = "MDS Dimension 3", cmap: str = "viridis", cbarLabel: Union[str, None] = None, title: str = None, figsize: Tuple[int] = (7, 7), responseToColour: Union[Callable[[ResponseData], npt.number], str, None] = None, ) -> None: """ Plots a 3D Multidimensional Scaling (MDS) scatter plot for the stored responses. This method computes a 3D MDS embedding from the dissimilarity matrix of the responses and visualises the responses in a 3D scatter plot. The color of each point in the scatter plot can be customized based on response parameters or a user-defined function. Args: dissimilarityMetric (Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ], optional): A function to compute the dissimilarity between two responses. Defaults to inverse_correlation. xlabel (str, optional): Label for the x-axis of the MDS plot. Defaults to "MDS Dimension 1". ylabel (str, optional): Label for the y-axis of the MDS plot. Defaults to "MDS Dimension 2". zlabel (str, optional): Label for the z-axis of the MDS plot. Defaults to "MDS Dimension 3". cmap (str, optional): The colourmap used for coloring the scatter points. Defaults to "viridis". cbarLabel (Union[str, None], optional): Label for the colour bar. Defaults to None. title (str, optional): Title of the plot. Defaults to None. figsize (Tuple[int], optional): The size of the figure. Defaults to (7, 7). responseToColour (Union[Callable[[ResponseData], npt.number], str, None], optional): A function or string to define the color mapping for the scatter points. If a string is passed, it should be the name of a parameter in the responses. If a callable is passed, it should be a function that maps a `ResponseData` object to a numeric value for coloring the points. Defaults to None. """ # compute dissimilarity rdm = self.compute_rdm(dissimilarityMetric) # compute 3D MDS: mdsCoords = mds(rdm, nComponents=3) # display: fig = plt.figure(figsize=figsize) ax: Axes3D = fig.add_subplot(projection="3d") # get colours / values based on colorFunction (if it exists) if responseToColour: if type(responseToColour) == str: colours = [ response.params[responseToColour] for response in self.responses ] elif callable(responseToColour): colours = [responseToColour(response) for response in self.responses] else: # use first parameter value as default if no colour value is provided colours = [ next(iter(response.params.values())) for response in self.responses ] # create the scatter plot: scat = ax.scatter( mdsCoords[:, 0], mdsCoords[:, 1], mdsCoords[:, 2], c=colours, cmap=cmap, marker="o", s=100, ) # set labels and title: ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_zlabel(zlabel) ax.set_title(title) # add colour bar: cb = fig.colorbar(scat, ax=ax, shrink=0.5, aspect=10) # ? if cbarLabel: cb.set_label(cbarLabel) plt.show()
[docs] def plot_2D_mds( self, dissimilarityMetric: Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ] = inverse_correlation, xlabel: str = "MDS Dimension 1", ylabel: str = "MDS Dimension 2", cmap: str = "viridis", cbarLabel: Union[str, None] = None, title: str = None, figsize: Tuple[int] = (7, 7), responseToColour: Union[Callable[[ResponseData], npt.number], str, None] = None, ) -> None: """ Plots a 2D Multidimensional Scaling (MDS) scatter plot for the stored responses. This method computes a 2D MDS embedding from the dissimilarity matrix of the responses and visualises the responses in a 2D scatter plot. The color of each point in the scatter plot can be customized based on response parameters or a user-defined function. Args: dissimilarityMetric (Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ], optional): A function to compute the dissimilarity between two responses. Defaults to inverse_correlation. xlabel (str, optional): Label for the x-axis of the MDS plot. Defaults to "MDS Dimension 1". ylabel (str, optional): Label for the y-axis of the MDS plot. Defaults to "MDS Dimension 2". cmap (str, optional): The colourmap used for coloring the scatter points. Defaults to "viridis". cbarLabel (Union[str, None], optional): Label for the colour bar. Defaults to None. title (str, optional): Title of the plot. Defaults to None. figsize (Tuple[int], optional): The size of the figure. Defaults to (7, 7). responseToColour (Union[Callable[[ResponseData], npt.number], str, None], optional): A function or string to define the color mapping for the scatter points. If a string is passed, it should be the name of a parameter in the responses. If a callable is passed, it should be a function that maps a `ResponseData` object to a numeric value for coloring the points. Defaults to None. """ # compute dissimilarity rdm = self.compute_rdm(dissimilarityMetric) # compute 3D MDS: mdsCoords = mds(rdm, nComponents=2) # display: fig = plt.figure(figsize=figsize) ax = fig.figure.subplots(1, 1) # get colours / values based on colorFunction (if it exists) if responseToColour: if type(responseToColour) == str: colours = [ response.params[responseToColour] for response in self.responses ] elif callable(responseToColour): colours = [responseToColour(response) for response in self.responses] else: # use first parameter value as default if no colour value is provided colours = [ next(iter(response.params.values())) for response in self.responses ] # create the scatter plot: scat = ax.scatter( mdsCoords[:, 0], mdsCoords[:, 1], c=colours, cmap=cmap, marker="o", s=100, ) # set labels and title: ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) # add colour bar: cb = fig.colorbar(scat, ax=ax, shrink=0.5, aspect=10) # ? if cbarLabel: cb.set_label(cbarLabel) plt.show()
[docs] def plot_PCA_variance_explained( self, dissimilarityMetric: Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ] = inverse_correlation, xlabel: str = "Number of Principal Components", ylabel: str = "Variance Explained", title: str = "Variance Explained vs. Number of Principal Components", figsize: Tuple[int] = (7, 7), grid: bool = False, ) -> None: """ Plots the variance explained by each principal component from PCA on the dissimilarity matrix. This method computes the Principal Component Analysis (PCA) of the dissimilarity matrix and visualises the variance explained by each principal component in a line plot. This provides insight into how much variance each component captures, helping to determine how many components are needed for sufficient variance representation. Args: dissimilarityMetric (Callable[ [npt.ArrayLike, npt.ArrayLike], npt.number ], optional): A function to compute the dissimilarity between two responses. Defaults to inverse_correlation. xlabel (str, optional): Label for the x-axis. Defaults to 'Number of Principal Components'. ylabel (str, optional): Label for the y-axis. Defaults to 'Variance Explained'. title (str, optional): Title of the plot. Defaults to 'Variance Explained vs. Number of Principal Components'. figsize (Tuple[int], optional): Size of the figure. Defaults to (7, 7). grid (bool, optional): Whether to show grid lines on the plot. Defaults to False. """ # compute dissimilarity rdm = self.compute_rdm(dissimilarityMetric) # compute PCA variance explained: pcaVarianceExplained = pca_variance_explained(rdm) plt.figure(figsize=figsize) plt.plot( range(1, len(pcaVarianceExplained) + 1), pcaVarianceExplained, marker="o" ) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.title(title) plt.grid(grid) plt.show()
[docs] def plot_2D_PCA( self, xlabel: str = "PC 1", ylabel: str = "PC 2", title: str = "2D PCA of Embedding", figsize: Tuple[int] = (7, 7), ) -> None: curves = np.array([r.response for r in self.responses]) pcs = pca(curves, nComponents=2) plt.figure(figsize=figsize) plt.scatter(pcs[:, 0], pcs[:, 1]) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.title(title) plt.show()
[docs] def fisher_information(self) -> npt.NDArray: """Computes the fisher information of the set of responses Returns: npt.NDArray: A numpy array containing the fisher information of the stored responses. """ curves = np.array([r.response for r in self.responses]) return fi(curves)