AGB-Star-Deprojection

Python package for visualising the circumstellar envelopes of AGB stars.

Report

Code Documentation

   1"""
   2Python package for visualising the circumstellar envelopes of AGB stars.
   3"""
   4
   5
   6from astropy.io import fits
   7from astropy.io.fits import PrimaryHDU, Header
   8from astropy import units as u
   9import matplotlib.pyplot as plt
  10from matplotlib import colormaps
  11from matplotlib.patches import Ellipse
  12import numpy as np
  13from numpy import ndarray, dtype, float64
  14from scipy.interpolate import interpn
  15from scipy.optimize import curve_fit
  16from scipy.optimize.elementwise import find_root, bracket_root
  17from collections.abc import Callable
  18from typing import Literal
  19import plotly.graph_objects as go 
  20import warnings
  21import scipy.integrate as integrate
  22from dataclasses import dataclass
  23from matplotlib.path import Path
  24from matplotlib.widgets import LassoSelector
  25from matplotlib.image import AxesImage
  26from matplotlib.backend_bases import MouseEvent
  27import matplotlib as mpl
  28
  29
  30
  31
  32# custom type aliases
  33DataArray1D = ndarray[tuple[int], dtype[float64]]
  34DataArray3D = ndarray[tuple[int, int, int], dtype[float64]]
  35Matrix2x2 = ndarray[tuple[Literal[2], Literal[2]], dtype[float64]]
  36VelocityModel = Callable[[ndarray], ndarray]
  37
  38# custom errors
  39class FITSHeaderError(Exception):
  40    """
  41    Raised when the FITS file header is missing necessary information, or storing it as the wrong type.
  42    """
  43    pass
  44
  45@dataclass
  46class CondensedData:
  47    """
  48    Container for storing a condensed version of the data cube and its associated metadata.
  49
  50    **Attributes**
  51
  52    - `x_offsets` (`DataArray1D`): 1D array of x-coordinates (offsets) in AU.
  53    - `y_offsets` (`DataArray1D`): 1D array of y-coordinates (offsets) in AU.
  54    - `v_offsets` (`DataArray1D`): 1D array of velocity offsets in km/s.
  55    - `data` (`DataArray3D`): 3D data array (velocity, y, x) containing the intensity values.
  56    - `star_name` (`str`): Name of the star or object.
  57    - `distance_to_star` (`float`): Distance to the star in AU.
  58    - `v_exp` (`float`): Expansion velocity in km/s.
  59    - `v_sys` (`float`): Systemic velocity in km/s.
  60    - `beta` (`float`): Beta parameter of the velocity law.
  61    - `r_dust` (`float`): Dust formation radius in AU.
  62    - `beam_maj` (`float`): Major axis of the beam in degrees.
  63    - `beam_min` (`float`): Minor axis of the beam in degrees.
  64    - `beam_angle` (`float`): Beam position angle in degrees.
  65    - `header` (`Header`): FITS header containing metadata.
  66    - `mean` (`float` or `None`, optional): Mean intensity of the data (default is None).
  67    - `std` (`float` or `None`, optional): Standard deviation of the data (default is None).
  68    """
  69    x_offsets: DataArray1D
  70    y_offsets: DataArray1D
  71    v_offsets: DataArray1D
  72    data: DataArray3D
  73    star_name: str
  74    distance_to_star: float
  75    v_exp: float
  76    v_sys: float
  77    beta: float
  78    r_dust: float
  79    beam_maj: float
  80    beam_min: float
  81    beam_angle: float
  82    header: Header
  83    mean: float | None = None
  84    std: float | None = None
  85
  86
  87class StarData:
  88
  89    """
  90    Class for manipulating and analyzing astronomical data cubes of radially expanding circumstellar envelopes.
  91
  92    The StarData class provides a comprehensive interface for loading, processing, analyzing, and visualizing
  93    3D data cubes (typically from FITS files) representing the emission from expanding circumstellar shells.
  94    It supports both direct loading from FITS files and from preprocessed CondensedData objects, and manages
  95    all relevant metadata and derived quantities.
  96
  97    Key Features
  98    ------------
  99    - **Data Loading:** Supports initialization from FITS files or CondensedData objects.
 100    - **Metadata Management:** Stores and exposes all relevant observational and physical parameters, including
 101      beam properties, systemic and expansion velocities, beta velocity law parameters, and FITS header information.
 102    - **Noise Estimation:** Automatically computes mean and standard deviation of background noise for filtering.
 103    - **Filtering:** Provides methods to filter data by significance (standard deviations) and to remove small clumps
 104      of points that fit within the beam (beam filtering).
 105    - **Coordinate Transformations:** Handles conversion between velocity space and spatial (cartesian) coordinates,
 106      supporting both constant velocity models and general velocity laws.
 107    - **Time Evolution:** Can compute the expansion of the envelope over time, transforming the data cube accordingly.
 108    - **Visualization:** Includes a variety of plotting methods:
 109        - Channel maps (2D slices through velocity channels)
 110        - 3D volume rendering (with Plotly)
 111        - Diagnostic plots for velocity/intensity and radius/velocity relationships
 112    - **Interactive Masking:** Supports interactive creation of masks for manual data cleaning.
 113
 114    Attributes
 115    ------------
 116
 117    - `data` (`DataArray3D`): The main data cube (v, y, x) containing intensity values.
 118    - `X` (`DataArray1D`): 1D array of x-coordinates (offsets) in AU.
 119    - `Y` (`DataArray1D`): 1D array of y-coordinates (offsets) in AU.
 120    - `V` (`DataArray1D`): 1D array of velocity offsets in km/s.
 121    - `distance_to_star` (`float`): Distance to the star in AU.
 122    - `beam_maj` (`float`): Major axis of the beam in degrees.
 123    - `beam_min` (`float`): Minor axis of the beam in degrees.
 124    - `beam_angle` (`float`): Beam position angle in degrees.
 125    - `mean` (`float`): Mean intensity of the background noise.
 126    - `std` (`float`): Standard deviation of the background noise.
 127    - `v_sys` (`float`): Systemic velocity in km/s.
 128    - `v_exp` (`float`): Expansion velocity in km/s.
 129    - `beta` (`float`): Beta parameter of the velocity law.
 130    - `r_dust` (`float`): Dust formation radius in AU.
 131    - `radius` (`float`): Characteristic radius (e.g., maximum intensity change) in AU.
 132    - `beta_velocity_law` (`VelocityModel`): Callable implementing the beta velocity law with the current object's parameters.
 133    - `star_name` (`str`): Name of the star or object.
 134
 135    Methods
 136    -------
 137    - `export() -> CondensedData`: Export all defining attributes to a CondensedData object.
 138    - `get_filtered_data(stds=5)`: Return a copy of the data, with values below the specified number of standard deviations set to np.nan.
 139    - `beam_filter(filtered_data)`: Remove clumps of points that fit inside the beam, setting these values to np.nan.
 140    - `get_expansion(years, v_func, ...)`: Compute the expanded data cube after a given time interval.
 141    - `plot_channel_maps(...)`: Plot the data cube as a set of 2D channel maps.
 142    - `plot_3D(...)`: Plot a 3D volume rendering of the data cube using Plotly.
 143    - `plot_velocity_vs_intensity(...)`: Plot velocity vs. intensity at the center of the xy plane.
 144    - `plot_radius_vs_intensity()`: Plot radius vs. intensity at the center of the xy plane.
 145    - `plot_radius_vs_velocity(...)`: Plot radius vs. velocity at the center of the xy plane.
 146    - `create_mask(...)`: Launch an interactive mask creator for the data cube.
 147    """
 148
 149    _c = 299792.458  # speed of light, km/s
 150    v0 = 3  # km/s, speed of sound
 151
 152    def __init__(
 153        self,
 154        info_source: str | CondensedData,
 155        distance_to_star: float | None = None,
 156        rest_frequency: float | None = None,
 157        maskfile: str | None = None,
 158        beta_law_params: tuple[float, float] | None = None,
 159        v_exp: float | None = None,
 160        v_sys: float | None = None,
 161        absolute_star_pos: tuple[float, float] | None = None
 162    ) -> None:
 163        """
 164        Initialize a StarData object by reading data from a FITS file or a CondensedData object.
 165
 166        **Parameters**
 167
 168        - `info_source` (`str` or `CondensedData`): Path to a FITS file or a CondensedData object containing preprocessed data.
 169        - `distance_to_star` (`float` or `None`, optional): Distance to the star in AU (required if info_source is a FITS file).
 170        - `rest_frequency` (`float` or `None`, optional): Rest frequency in Hz (required if info_source is a FITS file).
 171        - `maskfile` (`str` or `None`, optional): Path to a .npy file containing a mask to apply to the data.
 172        - `beta_law_params` (`tuple` of `float` or `None`, optional): (r_dust (AU), beta) parameters for the beta velocity law. If None, will be fit from data.
 173        - `v_exp` (`float` or `None`, optional): Expansion velocity in km/s. If None, will be fit from data.
 174        - `v_sys` (`float` or `None`, optional): Systemic velocity in km/s. If None, will be fit from data.
 175        - `absolute_star_pos` (`tuple` of `float` or `None`, optional): Absolute (RA, Dec) position of the star in degrees. If None, taken to be the centre of the image.
 176
 177        **Raises**
 178
 179        - `ValueError`: If required parameters are missing when reading from a FITS file.
 180        - `FITSHeaderError`: If any attribute in the FITS file header is an incorrect type.
 181        """
 182        if isinstance(info_source, str):
 183            if distance_to_star is None or rest_frequency is None:
 184                raise ValueError("Distance to star and rest frequency required when reading from FITS file.")
 185            self.__load_from_fits_file(info_source, distance_to_star, rest_frequency, absolute_star_pos, v_sys = v_sys, v_exp = v_exp)
 186            if beta_law_params is None:
 187                self._r_dust, self._beta, self._radius = self.__get_beta_law()
 188            else:
 189                self._r_dust, self._beta = beta_law_params
 190
 191        else:
 192            # load from CondensedData
 193            self._X = info_source.x_offsets
 194            self._Y = info_source.y_offsets
 195            self._V = info_source.v_offsets
 196            self._data = info_source.data
 197            self.star_name = info_source.star_name
 198            self._distance_to_star = info_source.distance_to_star
 199            self._v_exp = info_source.v_exp if v_exp is None else v_exp
 200            self._v_sys = info_source.v_sys if v_sys is None else v_sys
 201            self._r_dust = info_source.r_dust if beta_law_params is None else beta_law_params[0]
 202            self._beta = info_source.beta if beta_law_params is None else beta_law_params[1]
 203            self._beam_maj = info_source.beam_maj
 204            self._beam_min = info_source.beam_min
 205            self._beam_angle = info_source.beam_angle
 206            self._header = info_source.header
 207            self.__process_beam()
 208
 209            # compute mean and standard deviation
 210            if info_source.mean is None or info_source.std is None:
 211                self._mean, self._std = self.__mean_and_std()
 212            else:
 213                self._mean = info_source.mean
 214                self._std = info_source.std
 215
 216
 217        if maskfile is not None:
 218            # mask data (permanent)
 219            mask = np.load(maskfile)
 220            self._data = self._data * mask
 221        
 222    # ---- READ ONLY ATTRIBUTES ----
 223
 224    @property
 225    def data(self) -> DataArray3D:
 226        """
 227        DataArray3D 
 228        
 229        Stores the intensity of light at each data point.
 230        
 231        Dimensions: k x m x n, where k is the number of frequency channels,
 232        m is the number of declination channels, and n is the number of right ascension channels.
 233        """
 234        return self._data
 235
 236    @property
 237    def X(self) -> DataArray1D:
 238        """
 239        DataArray1D 
 240        
 241        Stores the x-coordinates relative to the centre in AU.
 242        Obtained from right ascension coordinates.
 243        """
 244        return self._X
 245
 246    @property
 247    def Y(self) -> DataArray1D:
 248        """
 249        DataArray1D
 250         
 251        Stores the y-coordinates relative to the centre in AU.
 252        Obtained from declination coordinates.
 253        """
 254        return self._Y
 255
 256    @property
 257    def V(self) -> DataArray1D:
 258        """
 259        DataArray1D
 260        
 261        Stores the velocity offsets relative to the star velocity in km/s.
 262        Obtained from frequency channels.
 263        """
 264        return self._V
 265
 266    @property
 267    def distance_to_star(self) -> float:
 268        """
 269        Distance to star in AU.
 270        """
 271        return self._distance_to_star
 272
 273    @property
 274    def B(self) -> Matrix2x2:
 275        """
 276        Matrix2x2
 277
 278        Ellipse matrix of beam. For 1x2 vectors v, w with coordinates (ra, dec) in degrees,
 279        if (v-w)^T B (v-w) < 1, then v is within the beam centred at w.
 280        """
 281        return self._B
 282
 283    @property
 284    def beam_maj(self) -> float:
 285        """
 286        Major axis of the beam in degrees.
 287        """
 288        return self._beam_maj
 289
 290    @property
 291    def beam_min(self) -> float:
 292        """
 293        Minor axis of the beam in degrees.
 294        """
 295        return self._beam_min
 296
 297    @property
 298    def beam_angle(self) -> float:
 299        """
 300        Beam position angle in degrees.
 301        """
 302        return self._beam_angle
 303
 304    @property
 305    def mean(self) -> float:
 306        """
 307        The mean intensity of the light, taken over coordinates away from the centre.
 308        """
 309        return self._mean
 310
 311    @property
 312    def std(self) -> float:
 313        """
 314        The standard deviation of the intensity of the light, taken over coordinates away from the centre.
 315        """
 316        return self._std
 317
 318    @property
 319    def v_sys(self) -> float:
 320        """
 321        The systemic velocity of the star in km/s.
 322        """
 323        return self._v_sys
 324
 325    @property
 326    def v_exp(self) -> float:
 327        """
 328        The maximum radial expansion speed in km/s.
 329        """
 330        return self._v_exp
 331
 332    @property
 333    def beta(self) -> float:
 334        """
 335        Beta parameter of the velocity law.
 336        """
 337        return self._beta
 338
 339    @property
 340    def r_dust(self) -> float:
 341        """
 342        Dust formation radius in AU.
 343        """
 344        return self._r_dust
 345
 346    @property
 347    def radius(self) -> float:
 348        """
 349        Characteristic radius (e.g., maximum intensity change).
 350        """
 351        return self._radius
 352
 353    @property
 354    def beta_velocity_law(self) -> VelocityModel:
 355        """
 356        VelocityModel
 357
 358        Returns a callable implementing the beta velocity law with the current object's parameters.
 359        """
 360        def law(r):
 361            return self.__general_beta_velocity_law(r, self.r_dust, self.beta)
 362        return law
 363
 364    # ---- EXPORT ----
 365
 366    def export(self) -> CondensedData:
 367        """
 368        Export all defining attributes to a CondensedData object.
 369        """
 370        return CondensedData(
 371            self.X,
 372            self.Y,
 373            self.V,
 374            self.data, 
 375            self.star_name,
 376            self.distance_to_star,
 377            self.v_exp,
 378            self.v_sys,
 379            self.beta,
 380            self.r_dust,
 381            self.beam_maj,
 382            self.beam_min,
 383            self.beam_angle,
 384            self._header,
 385            self.mean,
 386            self.std 
 387        )
 388
 389    # ---- HELPER METHODS FOR INITIALISATION ----
 390
 391    @staticmethod
 392    def __header_check(header: Header) -> bool:
 393        """
 394        Check that the FITS header contains all required values with appropriate types.
 395
 396        Parameters
 397        ----------
 398        header : Header
 399            FITS header object to check.
 400
 401        Returns
 402        -------
 403        missing_beam : bool
 404            True if beam parameters are missing from the header, False otherwise.
 405
 406        Raises
 407        ------
 408        FITSHeaderError
 409            If any required attribute is present but has an incorrect type.
 410        """
 411        missing_beam = False
 412        types_to_check = {
 413            "BSCALE": float,
 414            "BZERO": float,
 415            "OBJECT": str,
 416            "BMAJ": float,
 417            "BMIN": float,
 418            "BPA": float,
 419            "BTYPE": str,
 420            "BUNIT": str
 421        }
 422        for num in range(1, 4):
 423            types_to_check["CTYPE" + str(num)] = str
 424            types_to_check["NAXIS" + str(num)] = int
 425            types_to_check["CRPIX" + str(num)] = float
 426            types_to_check["CRVAL" + str(num)] = float
 427            types_to_check["CDELT" + str(num)] = float
 428
 429        # check if beam is present in data
 430        if "BMAJ" not in header or "BMIN" not in header or "BPA" not in header:
 431            missing_beam = True
 432
 433        for attr in types_to_check:
 434            if attr in ["BMAJ", "BMIN", "BPA"] and missing_beam:
 435                continue
 436            attr_type = types_to_check[attr]
 437            if not type(header[attr]) is attr_type:
 438                raise FITSHeaderError(f"Header attribute {attr} should have type {attr_type}, instead is {type(attr)}")
 439        return missing_beam
 440
 441    def __load_from_fits_file(
 442        self,
 443        filename: str,
 444        distance_to_star: float,
 445        rest_freq: float,
 446        absolute_star_pos: tuple[float, float] | None = None,
 447        v_sys: float | None = None,
 448        v_exp: float | None = None
 449    ) -> None:
 450        """
 451        Load data and metadata from a FITS file and initialize StarData attributes.
 452
 453        Parameters
 454        ----------
 455        filename : str
 456            Path to the FITS file.
 457        distance_to_star : float
 458            Distance to the star in AU.
 459        rest_freq : float
 460            Rest frequency in Hz.
 461        absolute_star_pos : tuple of float or None, optional
 462            Absolute (RA, Dec) position of the star in degrees. If None, use image center.
 463        v_sys : float or None, optional
 464            Systemic velocity in km/s. If None, will be fit from data.
 465        v_exp : float or None, optional
 466            Expansion velocity in km/s. If None, will be fit from data.
 467
 468        Returns
 469        -------
 470        None
 471
 472        Raises
 473        ------
 474        FITSHeaderError
 475            If the FITS header is missing required attributes or has incorrect types.
 476        AssertionError
 477            If the FITS file does not contain data.
 478        """
 479        # read data from file
 480        with fits.open(filename) as hdul:
 481            hdu: PrimaryHDU = hdul[0] # type: ignore
 482
 483            missing_beam = self.__header_check(hdu.header)  # check that all the information is available before proceeding
 484            self._header: Header = hdu.header
 485            if missing_beam:  # data is in hdul[1].data instead
 486                beam_data = list(hdul[1].data)
 487
 488                str_to_conversion = {
 489                    "arcsec": 1/3600,
 490                    "deg": 1,
 491                    "degrees": 1,
 492                    "degree": 1
 493                }
 494                unit_maj = hdul[1].header["TUNIT1"]
 495                unit_min = hdul[1].header["TUNIT2"]
 496
 497                # 1 arcsec = 1/3600 degree
 498                self._beam_maj = np.mean(np.array([beam[0] for beam in beam_data]))*str_to_conversion[unit_maj]
 499                self._beam_min = np.mean(np.array([beam[1] for beam in beam_data]))*str_to_conversion[unit_min]
 500                self._beam_angle = np.mean(np.array([beam[2] for beam in beam_data]))
 501            else:
 502                self._beam_maj = self._header["BMAJ"]
 503                self._beam_min = self._header["BMIN"]
 504                self._beam_angle = self._header["BPA"]
 505                
 506            
 507
 508            brightness_scale: float = self._header["BSCALE"]  # type: ignore
 509            brightness_zero: float = self._header["BZERO"]  # type: ignore
 510
 511            # scale data to be in specified brightness units
 512            assert hdu.data is not None
 513            self._data: DataArray3D = np.array(hdu.data[0], dtype = float64)*brightness_scale+brightness_zero  #freq, dec, ra
 514
 515        self.star_name: str = self._header["OBJECT"]  # type: ignore
 516        self._distance_to_star = distance_to_star
 517
 518
 519        # get velocities from frequencies
 520        freq_range: DataArray1D = self.__get_header_array(3)
 521        vel_range: DataArray1D = (1/freq_range-1/rest_freq)*rest_freq*StarData._c  # velocity in km/s
 522        if vel_range[-1] < vel_range[0]:  # array is backwards
 523            vel_range = np.flip(vel_range)
 524
 525
 526        # get X and Y coordinates
 527        ra_vals = self.__get_header_array(1)
 528        dec_vals = self.__get_header_array(2)   # reverse !!
 529        if absolute_star_pos is None:
 530            ra_offsets = ra_vals - np.mean(ra_vals)
 531            dec_offsets = dec_vals - np.mean(dec_vals)
 532        else:
 533            ra_offsets = ra_vals - absolute_star_pos[0]
 534            dec_offsets = dec_vals - absolute_star_pos[1]
 535
 536        self._X = ra_offsets*self.distance_to_star*np.pi/180  # measured in AU
 537        self._Y = dec_offsets*self.distance_to_star*np.pi/180
 538
 539        self.__process_beam()
 540
 541        # get mean and standard deviation of intensity values of noise
 542        self._mean, self._std = self.__mean_and_std()
 543
 544        # get velocity offsets
 545        self._v_sys, self._v_exp = self.__get_star_and_exp_velocity(vel_range, v_sys = v_sys, v_exp = v_exp)
 546        self._V = vel_range - self.v_sys
 547
 548    def __process_beam(self) -> None:
 549        """
 550        Compute the beam ellipse matrix and pixel offsets for the beam and its boundary.
 551
 552        Returns
 553        -------
 554        None
 555
 556        Notes
 557        -----
 558        Sets the attributes `_B`, `_offset_in_beam`, and `_boundary_offset` for use in beam-related calculations.
 559        """
 560        # get matrix for elliptical distance corresponding to beam
 561        major: float = float(self.beam_maj)/2  # type: ignore
 562        minor: float = self.beam_min/2  # type: ignore
 563
 564        # degress -> radians
 565        theta: float = self.beam_angle*np.pi/180  # type: ignore
 566        R = np.array([
 567                [np.cos(theta), -np.sin(theta)],
 568                [np.sin(theta), np.cos(theta)]
 569        ])
 570        D = np.diag([1/minor, 1/major])
 571        self._B = R@D@D@R.T  # beam matrix: (v-w)^T B (v-w) < 1 means v is within beam centred at w
 572        self._offset_in_beam, self._boundary_offset = self.__pixels_in_beam(major)
 573
 574    def __pixels_in_beam(self, major: float) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]:
 575        """
 576        Determine which pixel offsets are inside or on the boundary of the beam ellipse.
 577
 578        Parameters
 579        ----------
 580        major : float
 581            Length of the major axis of the ellipse, in degrees.
 582
 583        Returns
 584        -------
 585        pixels_in_beam : list of tuple of int
 586            List of (x, y) offsets inside the beam ellipse, relative to the center.
 587        pixels_on_bdry : list of tuple of int
 588            List of (x, y) offsets on the boundary of the beam ellipse, relative to the center.
 589        """
 590        delta_x: float = self._header["CDELT1"]  # type: ignore
 591        delta_y: float = self._header["CDELT2"]  # type: ignore
 592
 593        bound_x: int = int(np.abs(major/delta_y) + 1)  # square to bound search in x direction
 594        bound_y: int = int(np.abs(major/delta_x) + 1)  # y direction
 595
 596        pixels_on_bdry = []
 597        pixels_in_beam = []
 598
 599        for x_offset in range(-bound_x, bound_x + 1):
 600            for y_offset in range(-bound_y, bound_y + 1):
 601
 602                # get position and elliptic distance from origin
 603                pos = np.array([delta_x*x_offset, delta_y*y_offset])
 604                dist = np.sqrt(pos.T @ self.B @ pos)
 605
 606                # determine if on boundary or inside ellipse
 607                if 1 <= dist <= 1.2:
 608                    pixels_on_bdry.append((x_offset, y_offset))
 609                elif dist < 1:
 610                    pixels_in_beam.append((x_offset, y_offset))
 611
 612        return pixels_in_beam, pixels_on_bdry
 613    
 614    def __get_header_array(self, num: Literal[1, 2, 3]) -> DataArray1D:
 615        """
 616        Get coordinate values from the FITS header for RA, DEC, or FREQ axes.
 617
 618        Parameters
 619        ----------
 620        num : {1, 2, 3}
 621            Axis number: 1 for RA, 2 for DEC, 3 for FREQ.
 622
 623        Returns
 624        -------
 625        vals : DataArray1D
 626            1-D array of coordinate values for the specified axis, computed from the header.
 627        """
 628        vals_length: int = self._header["NAXIS" + str(num)]  # type: ignore
 629        vals: np.ndarray[tuple[int], np.dtype[np.float64]] = np.zeros(vals_length)
 630        x0: float = self._header["CRPIX" + str(num)]  # type: ignore
 631        y0: float = self._header["CRVAL" + str(num)]  # type: ignore
 632        delta: float = self._header["CDELT" + str(num)]  # type: ignore
 633
 634        # vals[i] should be delta*(i - x0) + y0, for 1-indexing. but we are 0-indexed
 635        for i in range(len(vals)):
 636            vals[i] = delta*(i + 1 - x0) + y0
 637
 638        return vals
 639
 640    def __mean_and_std(self) -> tuple[float, float]:
 641        """
 642        Calculate the mean and standard deviation of the background noise, by looking at points away from the center.
 643
 644        Returns
 645        -------
 646        mean : float
 647            Mean intensity of the background noise.
 648        std : float
 649            Standard deviation of the background noise.
 650        """
 651        # trim, as edges can be unreliable
 652        outer_trim = 20  # remove outer 1/20th
 653
 654        frames, y_obs, x_obs = self.data.shape
 655        trimmed_data = self.data[:, y_obs//outer_trim:y_obs - y_obs//outer_trim, x_obs//outer_trim:x_obs - x_obs//outer_trim]
 656        frames, y_obs, x_obs = trimmed_data.shape
 657        
 658
 659        # take edges of trimmed data
 660        inner_trim = 5  # take outer 1/5 th
 661        left_close = trimmed_data[:frames//inner_trim, :y_obs//inner_trim, :].flatten()
 662        right_close = trimmed_data[:frames//inner_trim, y_obs - y_obs//inner_trim:, :].flatten()
 663        top_close = trimmed_data[:frames//inner_trim, y_obs//inner_trim:y_obs - y_obs//inner_trim, :x_obs//inner_trim].flatten()
 664        bottom_close = trimmed_data[:frames//inner_trim, y_obs//inner_trim:y_obs - y_obs//inner_trim, x_obs-x_obs//inner_trim:].flatten()
 665        
 666        left_far = trimmed_data[frames - frames//inner_trim:, :y_obs//inner_trim, :].flatten()
 667        right_far = trimmed_data[frames - frames//inner_trim:, y_obs - y_obs//inner_trim:, :].flatten()
 668        top_far = trimmed_data[frames - frames//inner_trim:, y_obs//inner_trim:y_obs - y_obs//inner_trim, :x_obs//inner_trim].flatten()
 669        bottom_far = trimmed_data[frames - frames//inner_trim:, y_obs//inner_trim:y_obs - y_obs//inner_trim, x_obs-x_obs//inner_trim:].flatten()
 670
 671
 672        ring = np.concatenate((left_close, right_close, top_close, bottom_close, left_far, right_far, top_far, bottom_far))
 673        ring = ring[~np.isnan(ring)]
 674
 675        if len(ring) < 20:
 676            warnings.warn("Only {len(ring)} data points selected for finding mean and standard deviation of noise.")
 677
 678        mean: float = float(np.mean(ring))
 679        std: float = float(np.std(ring))
 680
 681        return mean, std
 682
 683    def __multiply_beam(self, times: float | int) -> list[tuple[int, int]]:
 684        """
 685        Generate a list of pixel offsets that represent the beam, scaled by a factor.
 686
 687        Parameters
 688        ----------
 689        times : float or int
 690            Scaling factor for the beam size.
 691
 692        Returns
 693        -------
 694        insides : list of tuple of int
 695            List of (x, y) offsets inside the scaled beam.
 696        """
 697        if times <= 0.0001:
 698            return []
 699        insides = []
 700        beam_bound = max(max([x for x, y in self._offset_in_beam]), max([y for x, y in self._offset_in_beam]))
 701        
 702        # search a square around the origin
 703        for x in range(-int(beam_bound*times), int(beam_bound*times) + 1):
 704            for y in range(-int(beam_bound*times), int(beam_bound*times) + 1):
 705
 706                # check if point would shrink into beam
 707                pos = (int(x/times), int(y/times))
 708                if pos in self._offset_in_beam:
 709                    insides.append((x, y))
 710
 711        return insides
 712
 713    def __get_centre_intensities(self, beam_widths: float | int = 1) -> DataArray1D:
 714        """
 715        Calculate the mean intensity at the center of each velocity channel, averaged over a region the size of the beam.
 716
 717        Parameters
 718        ----------
 719        beam_widths : float or int, optional
 720            Scale factor of beam (default is 1).
 721
 722        Returns
 723        -------
 724        all_densities : DataArray1D
 725            Array of mean intensities for each velocity channel.
 726        """
 727        # centre index
 728        y_idx = np.argmin(self.Y**2)
 729        x_idx = np.argmin(self.X**2)
 730        beam_pixels = self.__multiply_beam(beam_widths)
 731        density_list = []
 732
 733        # compute average density
 734        for v in range(len(self.data)):
 735            total =  0
 736            for x_offset, y_offset in beam_pixels:
 737                if 0 <= y_idx + y_offset < self.data.shape[1] and 0 <= x_idx + x_offset < self.data.shape[2]:
 738                    intensity = self.data[v][y_idx + y_offset][x_idx + x_offset]
 739                if intensity > 0 and not np.isnan(intensity):
 740                    total += intensity
 741            density_list.append(total/len(beam_pixels))
 742
 743        all_densities = np.array(density_list)
 744        return all_densities
 745
 746    def __get_star_and_exp_velocity(self, vel_range: DataArray1D, plot: bool = False, fit_parabola: bool = False, v_sys: float | None= None, v_exp: float | None = None) -> tuple[float, float]:
 747        """
 748        Computes the systemic and expansion velocities, in km/s, by fitting a parabola to the centre intensities.
 749
 750        Parameters
 751        ----------
 752        vel_range : DataArray1D
 753            1-D array of channel velocities in km/s.
 754        plot : bool, optional
 755            If True, plot the iterative process (default is False).
 756        fit_parabola : bool, optional
 757            If True, plot a fitted parabola (default is False).
 758        v_sys : float or None, optional
 759            If provided, use this as the systemic velocity (default is None).
 760        v_exp : float or None, optional
 761            If provided, use this as the expansion velocity (default is None).
 762
 763        Returns
 764        -------
 765        v_sys : float
 766            Systemic velocity in km/s.
 767        v_exp : float
 768            Expansion velocity in km/s.
 769        """
 770        given_v_sys = v_sys
 771        given_v_exp = v_exp
 772        if given_v_exp is not None and given_v_sys is not None:
 773            return given_v_sys, given_v_exp
 774
 775        all_densities = self.__get_centre_intensities(2)
 776        
 777        
 778        v_exp_seen = np.array([])
 779        v_sys_seen = np.array([])
 780
 781        converged = False
 782        densities = all_densities.copy()
 783        i = 1
 784        while not converged:  # loop until computation converges
 785            densities /= np.sum(densities) # normalise
 786            if plot:
 787                plt.plot(vel_range, densities, label = f"iteration {i}")
 788            v_sys = np.dot(densities, vel_range) if given_v_sys is None else given_v_sys
 789
 790            v_exp = np.sqrt(5*np.dot(((vel_range - v_sys)**2), densities)) if given_v_exp is None else given_v_exp
 791            
 792            if any(np.isclose(v_exp_seen, v_exp)) and any(np.isclose(v_sys_seen, v_sys)):
 793                converged = True
 794            v_exp_seen = np.append(v_exp_seen, v_exp)
 795            v_sys_seen = np.append(v_sys_seen, v_sys)
 796
 797            densities = all_densities.copy()
 798            densities[vel_range < (v_sys - v_exp)] = 0
 799            densities[vel_range > (v_sys + v_exp)] = 0
 800            i += 1
 801
 802            if i >= 100:
 803                warnings.warn("Systemic and expansion velocity computation did not converge after 100 iterations.")
 804                break
 805
 806        if plot and fit_parabola:
 807            parabola = (1 -((vel_range - v_sys)**2/v_exp**2))
 808            parabola[parabola < 0] = 0
 809            parabola /= np.sum(parabola)
 810            plt.plot(vel_range, parabola, label = "parabola")
 811
 812        if plot:
 813            plt.title("Determining v_sys, v_exp")
 814            plt.xlabel("Relative velocity (km/s)")
 815            plt.ylabel(f"{self._header['BTYPE']} at centre point ({self._header['BUNIT']})")
 816            plt.legend()
 817            plt.show()
 818        
 819        return v_sys, v_exp
 820
 821    def __get_beta_law(self, plot_intensities = False, plot_velocities = False, plot_beta_law = False) -> tuple[float, float, float]:
 822        """
 823        Fit the beta velocity law to the data and return the dust formation radius, beta parameter, and the radius of maximum intensity change.
 824
 825        Parameters
 826        ----------
 827        plot_intensities : bool, optional
 828            If True, plot intensity vs. radius (default is False).
 829        plot_velocities : bool, optional
 830            If True, plot velocity vs. radius (default is False).
 831        plot_beta_law : bool, optional
 832            If True, plot the fitted beta law (default is False).
 833
 834        Returns
 835        -------
 836        r_dust : float
 837            Dust formation radius.
 838        beta : float
 839            Beta parameter of the velocity law.
 840        radius : float
 841            Radius of maximum intensity change.
 842        """
 843        intensities = self.__get_centre_intensities(0.5)
 844
 845        def v_from_i(i):
 846            return np.sqrt(1 - i/np.max(intensities))*self.v_exp
 847    
 848        centre_idx = np.argmin(self.V**2)
 849        frame = self.data[centre_idx]
 850        
 851        max_radius = np.minimum(np.max(self.X), np.max(self.Y))
 852        precision = min(len(self.X), len(self.Y))//2
 853        radii = np.linspace(0, max_radius, precision)
 854        
 855        X, Y = np.meshgrid(self.X, self.Y, indexing="ij")
 856
 857        deltas = np.array([])
 858        I = np.array([])  # average intensity in each ring
 859        for i in range(len(radii) - 1):
 860            ring = frame[(radii[i]**2 <= X**2 + Y**2)  &  (X**2 + Y**2 <= radii[i+1]**2)]
 861            avg_intensity = np.mean(ring[np.isfinite(ring)]) 
 862            I = np.append(I, avg_intensity)
 863
 864            inner = frame[X**2+Y**2 <= radii[i+1]**2]
 865            outer = frame[~(X**2+Y**2 <= radii[i+1]**2)]
 866            deltas = np.append(deltas,len(inner[(inner >= self.mean+5*self.std)])+len(outer[outer < self.mean+5*self.std]))
 867
 868        V = v_from_i(I)
 869        V[~np.isfinite(V)] = 0
 870        R = radii[1:]
 871        radius_index = np.argmax(deltas)
 872        radius = R[radius_index]
 873
 874        if plot_intensities:
 875            plt.plot(R, I, label = "intensities")
 876            plt.axvline(x = radius, label = "radius", color = "gray", linestyle = "dashed")
 877            plt.legend()
 878            plt.xlabel("Radius (AU)")
 879            plt.ylabel(f"Average {self._header['BTYPE']} ({self._header['BUNIT']})")
 880            plt.title("Average intensity at each radius")
 881            plt.show()
 882            return self.r_dust, self.beta, self.radius
 883
 884        v_fit = V[(V > 0)]
 885        r_fit = R[(V > 0)]
 886
 887        if plot_velocities:
 888            r_dust, beta = self.r_dust, self.beta
 889        else:
 890            params = curve_fit(self.__general_beta_velocity_law, r_fit, v_fit)[0]
 891            r_dust, beta =  params[0], params[1]
 892
 893        if plot_velocities:
 894            plt.plot(R, V, label = "velocities")
 895            if plot_beta_law:
 896                LAW = self.__general_beta_velocity_law(R, r_dust, beta)
 897                plt.plot(R, LAW, label = "beta law")
 898            plt.axvline(x = radius, label = "radius", color = "gray", linestyle = "dashed")
 899            plt.legend()
 900            plt.xlabel("Radius (AU)")
 901            plt.ylabel(f"Velocity (km/s)")
 902            plt.title("Velocity at each radius")
 903            plt.show()
 904        
 905        return r_dust, beta, radius
 906
 907    def __general_beta_velocity_law(self, r: ndarray, r_dust: float, beta: float) -> ndarray:
 908        """
 909        General beta velocity law.
 910
 911        Parameters
 912        ----------
 913        r : array_like
 914            Radius values.
 915        r_dust : float
 916            Dust formation radius.
 917        beta : float
 918            Beta parameter.
 919
 920        Returns
 921        -------
 922        v : array_like
 923            Velocity at each radius.
 924        """
 925        return self.v0+(self.v_exp-self.v0)*((1-r_dust/r)**beta)
 926    
 927    # ---- FILTERING DATA ----
 928
 929    def get_filtered_data(self, stds: float | int = 5) -> DataArray3D:
 930        """
 931        Return a copy of the data, with values below the specified number of standard deviations set to np.nan.
 932
 933        **Parameters**
 934
 935        - `stds` (`float` or `int`, optional): Number of standard deviations to filter by (default is 5).
 936
 937        **Returns**
 938
 939        - `filtered_data` (`DataArray3D`): Filtered data array.
 940        """
 941        filtered_data = self.data.copy()  # creates a deep copy
 942        filtered_data[filtered_data < stds*self.std] = np.nan
 943        return filtered_data
 944    
 945    def beam_filter(self, filtered_data: DataArray3D) -> DataArray3D:
 946        """
 947        Remove clumps of points that fit inside the beam, setting these values to np.nan.
 948
 949        **Parameters**
 950
 951        - `filtered_data` (`DataArray3D`): 3-D array with the same dimensions as the data array.
 952
 953        **Returns**
 954
 955        - `beam_filtered_data` (`DataArray3D`): 3-D array with small clumps of points removed.
 956        """
 957        beam_filtered_data = filtered_data.copy()
 958        for frame in range(len(filtered_data)):
 959            for y_idx in range(len(filtered_data[frame])):
 960                for x_idx in range(len(filtered_data[frame][y_idx])):
 961                    if np.isnan(filtered_data[frame][y_idx][x_idx]):  # ignore empty points
 962                        continue
 963                    
 964                    # filled point that we are searching around
 965                
 966                    erase = True
 967
 968                    for x_offset, y_offset in self._boundary_offset:
 969                        x_check = x_idx + x_offset
 970                        y_check = y_idx + y_offset
 971                        try:
 972                            if not np.isnan(filtered_data[frame][y_check][x_check]):
 973                                erase = False  # there is something present on the border - saved!
 974                                break
 975                        except IndexError:  # in case x_check, y_check are out of range
 976                            pass
 977
 978                    if erase:  # consider ellipse to be an anomaly
 979                        # erase entire inside of ellipse centred at w
 980                        for x_offset, y_offset in self._offset_in_beam:
 981                            x_check = x_idx + x_offset
 982                            y_check = y_idx + y_offset
 983                            try:
 984                                beam_filtered_data[frame][y_check][x_check] = np.nan  # erase
 985                            except IndexError:
 986                                pass
 987            
 988        return beam_filtered_data
 989
 990    # ---- HELPER METHODS FOR PLOTTING ----
 991
 992    def __crop_data(self, x: DataArray1D, y: DataArray1D, v: DataArray1D, data: DataArray3D, crop_leeway: int | float = 0, fill_data: DataArray3D | None = None, special_v = False) -> tuple[DataArray1D, DataArray1D, DataArray1D, DataArray3D]:
 993        """
 994        Crop the data arrays to the smallest region containing all valid (non-NaN) data.
 995
 996        Parameters
 997        ----------
 998        x, y, v : DataArray1D
 999            Small coordinate arrays.
1000        data : DataArray3D
1001            Data array to use for crop.
1002        crop_leeway : int or float, optional
1003            Fractional leeway to expand the crop region (default is 0).
1004        fill_data : DataArray3D or None, optional
1005            Data array to use for filling values (default is None).
1006
1007        Returns
1008        -------
1009        cropped_x, cropped_y, cropped_v : DataArray1D
1010            Cropped coordinate arrays.
1011        cropped_data : DataArray3D
1012            Cropped data array.
1013        """
1014        if fill_data is None:
1015            fill_data = self.data
1016        
1017        v_max, y_max, x_max = data.shape 
1018
1019        # gets indices
1020        v_indices = np.arange(v_max)
1021        y_indices = np.arange(y_max)
1022        x_indices = np.arange(x_max)
1023
1024        # turn into flat arrays
1025        V_IDX, Y_IDX, X_IDX = np.meshgrid(v_indices, y_indices, x_indices, indexing="ij")
1026
1027        # filter out nan data
1028        valid_V = V_IDX[~np.isnan(data)]
1029        valid_X = X_IDX[~np.isnan(data)]
1030        valid_Y = Y_IDX[~np.isnan(data)]
1031
1032        # indices to crop at
1033        v_mid = (np.min(valid_V) + np.max(valid_V))/2
1034        v_lo = max(int(v_mid - (1 + crop_leeway)*(v_mid - np.min(valid_V))), 0)
1035        v_hi = min(int(v_mid + (1 + crop_leeway)*(np.max(valid_V) - v_mid)), len(v) - 1) + 1
1036
1037        if special_v:
1038            assert v[v_lo] <= v[v_hi], "v should be increasing"
1039            if v[v_lo] > -self.v_exp:
1040                # get last time v < -self.v_exp:
1041                offsets = v - (-self.v_exp)
1042                if np.any(offsets < 0):
1043                    offsets[offsets > 0] = -np.inf
1044                    v_lo = np.argmax(offsets) # want greatest neg value
1045                else:
1046                    v_lo = 0
1047            if v[v_hi] < self.v_exp:
1048                offsets = v - (self.v_exp)
1049                if np.any(offsets > 0):
1050                    offsets[offsets < 0] = np.inf
1051                    v_hi = np.argmin(offsets) # want smallest pos value
1052                else:
1053                    v_hi = -1
1054
1055
1056
1057        x_mid = (np.min(valid_X) + np.max(valid_X))/2
1058        x_lo = max(int(x_mid - (1 + crop_leeway)*(x_mid - np.min(valid_X))), 0)
1059        x_hi = min(int(x_mid + (1 + crop_leeway)*(np.max(valid_X) - x_mid)), len(x) - 1) + 1
1060
1061        y_mid = (np.min(valid_Y) + np.max(valid_Y))/2
1062        y_lo = max(int(y_mid - (1 + crop_leeway)*(y_mid - np.min(valid_Y))), 0)
1063        y_hi = min(int(y_mid + (1 + crop_leeway)*(np.max(valid_Y) - y_mid)), len(y) - 1) + 1
1064
1065        # crop x, y, v, data
1066        cropped_data = fill_data[v_lo:v_hi, y_lo:y_hi, x_lo:x_hi]
1067        cropped_v = v[v_lo: v_hi]
1068        cropped_y = y[y_lo: y_hi]
1069        cropped_x = x[x_lo: x_hi]
1070
1071        return cropped_x, cropped_y, cropped_v, cropped_data
1072
1073    def __filter_and_crop(
1074            self, 
1075            filter_stds: float | int | None, 
1076            filter_beam: bool, 
1077            crop_leeway: float,
1078            verbose: bool
1079        ) -> tuple[DataArray3D, DataArray3D, DataArray3D, DataArray3D]:
1080        """
1081        Filter and crop data as specified.
1082
1083        Parameters
1084        ----------
1085        filter_stds : float, int, or None
1086            Number of standard deviations to filter by, or None.
1087        filter_beam : bool
1088            If True, apply beam filtering.
1089        crop_leeway : float
1090            Fractional leeway to expand the crop region.
1091        verbose : bool
1092            If True, print progress.
1093
1094        Returns
1095        -------
1096        X, Y, V : DataArray3D
1097            Meshgrids of velocity space coordinates.
1098        cropped_data : DataArray3D
1099            Cropped and filtered data array.
1100        """
1101        if filter_stds is not None:
1102            if verbose:
1103                print("Filtering data (to crop)...")
1104            data = self.get_filtered_data(filter_stds)
1105            if filter_beam:
1106                if verbose:
1107                    print("Applying beam filter (to crop)...")
1108                data = self.beam_filter(data)
1109        else:
1110            data = self.data
1111
1112        cropped_x, cropped_y, cropped_v, cropped_data = self.__crop_data(self.X,self.Y,self.V, data, crop_leeway=crop_leeway, special_v = True)
1113
1114        if verbose:  
1115            print(f"Data cropped to shape {cropped_data.shape}")
1116
1117
1118        V, Y, X = np.meshgrid(cropped_v, cropped_y, cropped_x, indexing = "ij")
1119        return X, Y, V, cropped_data
1120    
1121    def __fast_interpolation(self, points: tuple, x_bounds: tuple, y_bounds: tuple, z_bounds: tuple, data: DataArray3D, v_func: VelocityModel | None, num_points: int) -> tuple[DataArray1D, DataArray1D, DataArray1D, DataArray3D]:
1122        """
1123        Interpolate the data onto a regular grid in (X, Y, Z) space.
1124
1125        Parameters
1126        ----------
1127        points : tuple
1128            Tuple of (v, y, x) small coordinate arrays.
1129        x_bounds, y_bounds, z_bounds : tuple
1130            Bounds for the new grid in each dimension.
1131        data : DataArray3D
1132            Data array to interpolate, aligned with points.
1133        v_func : VelocityModel or None
1134            Velocity law function.
1135        num_points : int
1136            Number of points in each dimension for the new grid.
1137
1138        Returns
1139        -------
1140        X, Y, Z : DataArray1D
1141            Flattened meshgrids of new coordinates.
1142        interp_data : DataArray3D
1143            Interpolated data array.
1144        """
1145        # points = (V, X, Y)
1146        # new_shape = (data.shape[0] + 2, data.shape[1], data.shape[2])
1147        # new_data = np.full(new_shape, np.nan)
1148        # new_data[1:-1, :, :] = data.copy()  # padded on both sides with nan
1149
1150        # # extend v array to have out-of-range values
1151        # v_array = points[0]
1152        # delta_v = v_array[1] - v_array[0]
1153        # new_v_array = np.zeros(len(v_array) + 2)
1154        # new_v_array[1:-1] = v_array
1155        # new_v_array[0] = new_v_array[1] - delta_v
1156        # new_v_array[-1] = new_v_array[-2] + delta_v
1157        # new_points = (new_v_array, points[1], points[2])
1158
1159        # get points to interpolate at
1160        x_even = np.linspace(x_bounds[0], x_bounds[1], num=num_points)
1161        y_even = np.linspace(y_bounds[0], y_bounds[1], num=num_points)
1162        z_even = np.linspace(z_bounds[0], z_bounds[1], num=num_points)
1163
1164        X, Y, Z = np.meshgrid(x_even, y_even, z_even, indexing="ij")
1165
1166        # build velocities
1167        V = self.__to_velocity_space(X, Y, Z, v_func)
1168
1169        shape = V.shape
1170        # flatten arrays
1171        V = V.ravel()
1172        Y = Y.ravel()
1173        X = X.ravel()
1174        Z = Z.ravel()
1175
1176        bad_idxs = (V < np.min(points[0])) | (V > np.max(points[0])) |  ~np.isfinite(V)
1177        V[bad_idxs] = 0
1178
1179
1180        interp_points = np.column_stack((V, Y, X))
1181        interp_data = interpn(points, data, interp_points)
1182
1183        interp_data[bad_idxs] = np.nan
1184
1185        interp_data = interp_data.reshape(shape)
1186
1187        return X, Y, Z, interp_data
1188
1189    # ---- COORDINATE CHANGE ----
1190
1191    def __radius_from_vel_space_coords(self, x: ndarray, y: ndarray, v: ndarray, v_func: VelocityModel) -> ndarray:
1192        """
1193        Compute the physical radius corresponding to velocity-space coordinates (x, y, v)
1194        using the provided velocity law.
1195
1196        Parameters
1197        ----------
1198        x : ndarray
1199            Array of x-coordinates in AU.
1200        y : ndarray
1201            Array of y-coordinates in AU.
1202        v : ndarray
1203            Array of velocity coordinates in km/s.
1204        v_func : VelocityModel
1205            Callable velocity law, v_func(r), returning velocity in km/s for given radius in AU.
1206
1207        Returns
1208        -------
1209        r : ndarray
1210            Array of radii in AU corresponding to the input (x, y, v) coordinates.
1211
1212        Notes
1213        -----
1214        Solves for r in the equation:
1215            r**2 * (1 - v**2 / v_func(r)**2) = x**2 + y**2
1216        for each (x, y, v) triplet.
1217        """
1218        #
1219        assert x.shape == y.shape, f"Arrays x and y should have the same shape, instead {x.shape = }, {y.shape = }"
1220        assert x.shape == v.shape, f"Arrays x and v should have the same shape, instead {x.shape = }, {v.shape = }"
1221
1222        shape = x.shape
1223
1224        x, y, v = x.ravel(), y.ravel(), v.ravel()
1225
1226        def radius_eqn(r: ndarray, x: ndarray, y: ndarray, v: ndarray):
1227            return r**2 * (1 - v**2/(v_func(r)**2)) - x**2 - y**2
1228        
1229        # initial bracket guess
1230        r_lower = np.sqrt(x**2 + y**2)  # r must be greater than sqrt(x^2 + y^2)
1231
1232        # find bracket
1233        bracket_result = bracket_root(radius_eqn, r_lower, xmin = r_lower, args = (x, y, v))
1234        r_lower, r_upper = bracket_result.bracket
1235        result = find_root(radius_eqn, (r_lower, r_upper), args=(x, y, v))
1236
1237        r: ndarray = result.x
1238        
1239        return r.reshape(shape)
1240
1241    def __to_velocity_space(self, X: ndarray, Y: ndarray, Z: ndarray, v_func: VelocityModel | None) -> ndarray:
1242        """
1243        Convert spatial coordinates to velocity space using the velocity law.
1244
1245        Parameters
1246        ----------
1247        X, Y, Z : array_like
1248            Spatial coordinates.
1249        v_func : VelocityModel or None
1250            Velocity law function.
1251
1252        Returns
1253        -------
1254        V : array_like
1255            Velocity at each spatial coordinate.
1256        """
1257        # build velocities
1258        if v_func is None:
1259            V = self.v_exp*Z/np.sqrt(X**2 + Y**2 + Z**2)
1260        else:
1261            V = v_func(np.sqrt(X**2 + Y**2 + Z**2))*Z/np.sqrt(X**2 + Y**2 + Z**2)
1262
1263        return V
1264
1265    # ---- EXPANSION OVER TIME ----
1266    def __get_new_radius(self, curr_radius: ndarray, years: float | int, v_func: VelocityModel | None) -> ndarray:
1267        """
1268        Compute the new radius after a given time, using the velocity law.
1269
1270        Parameters
1271        ----------
1272        curr_radius : ndarray
1273            Initial radii in AU.
1274        years : float or int
1275            Time interval in years.
1276        v_func : VelocityModel or None
1277            Velocity law function, or None (uses constant velocity).
1278
1279        Returns
1280        -------
1281        new_rad : ndarray
1282            New radii after the specified time interval.
1283
1284        Notes
1285        -----
1286        Setting v_func = None speeds this up significantly, and is a close approximation.
1287        """
1288        t = u.yr.to(u.s, years)  # t is the time in seconds
1289        rad_km = u.au.to(u.km, curr_radius)
1290
1291        if v_func is None: 
1292            new_rad = u.km.to(u.au, rad_km + self.v_exp*t)
1293        else:
1294
1295            def dr_dt(t, r):
1296                # need to convert r back to AU
1297                vals = v_func(u.km.to(u.au, r))
1298                vals[~np.isfinite(vals)] = 0
1299
1300                return vals
1301            
1302            # flatten
1303            shape = rad_km.shape
1304            rad_km = rad_km.ravel()
1305
1306            # remove nans
1307            nan_idxs = ~(np.isfinite(rad_km) & np.isfinite(dr_dt(0, rad_km)) & (rad_km > 0))
1308            valid_rad = np.min(rad_km[~nan_idxs])
1309            rad_km[nan_idxs] = valid_rad
1310
1311            # solve
1312            solution = integrate.solve_ivp(dr_dt,(0,t), rad_km, vectorized=True)
1313            new_rad = u.km.to(u.au, solution.y[:,-1]) # r is evaluated at different time points
1314
1315            # include nans and unflatten
1316            new_rad[nan_idxs] = np.nan
1317            new_rad = new_rad.reshape(shape)
1318
1319        return new_rad
1320
1321    def __time_expansion_transform(self, x: DataArray1D, y: DataArray1D, v: DataArray1D, data: DataArray3D, years: float | int, v_func: VelocityModel | None, crop: bool = True) -> tuple[DataArray1D, DataArray1D, DataArray1D, DataArray3D, tuple]:
1322        """
1323        Transform coordinates and data to account for expansion over time.
1324
1325        Parameters
1326        ----------
1327        x, y, v : DataArray1D
1328            Small coordinate arrays.
1329        data : DataArray3D
1330            Data array.
1331        years : float or int
1332            Time interval in years.
1333        v_func : VelocityModel or None
1334            Velocity law function.
1335        crop : bool, optional
1336            If True, crop to finite values (default is True).
1337
1338        Returns
1339        -------
1340        new_X, new_Y, new_V : DataArray1D
1341            Transformed coordinate arrays (flattened meshgrid).
1342        data : DataArray3D
1343            Data array (possibly cropped). Remains unchanged if crop is False.
1344        points : tuple
1345            Tuple of original (v, y, x) arrays, cropped if crop is True.
1346        """
1347        if crop:
1348            x, y, v, data = self.__crop_data(x, y, v, data, fill_data=data)
1349        
1350        V, X, Y = np.meshgrid(v, y, x, indexing="ij")
1351
1352
1353        if v_func is None:
1354            def v_func_mod(r):
1355                return self.v_exp
1356        else:
1357            v_func_mod = v_func
1358
1359        # get R0 and R1 arrays
1360        R0 = self.__radius_from_vel_space_coords(X, Y, V, v_func_mod)
1361        R1 = self.__get_new_radius(R0, years, v_func)
1362
1363        R1[R1 < 0] = np.inf  # deal with neg radii
1364
1365        new_X = (X*R1/R0).ravel()
1366        new_Y = (Y*R1/R0).ravel()
1367        if v_func is None:
1368            new_V = V.ravel()
1369        else:
1370            new_V = (V*v_func(R1)/v_func(R0)).ravel()
1371
1372
1373        if crop:
1374            finite_idxs = np.isfinite(new_X) & np.isfinite(new_Y) & np.isfinite(new_V) & np.isfinite(data).ravel()
1375            new_X = new_X[finite_idxs]
1376            new_Y = new_Y[finite_idxs]
1377            new_V = new_V[finite_idxs]
1378
1379        return new_X, new_Y, new_V, data, (v, y, x)
1380
1381    def get_expansion(self, years: float | int, v_func: VelocityModel | None = None, remove_centre: float | int | None = 2, new_shape: tuple = (50, 250, 250), verbose: bool = False) -> CondensedData:
1382        """
1383        Compute the expanded data cube after a given time interval.
1384
1385        **Parameters**
1386
1387        - `years` (`float` or `int`): Time interval in years.
1388        - `v_func` (`VelocityModel` or `None`): Velocity law function or None (default is None, which uses constant expansion velocity).
1389        - `remove_centre` (`float` or `int` or `None`, optional): If not None, remove all points within this many beam widths of the centre (default is 2).
1390        - `new_shape` (`tuple`, optional): Shape of the output grid (default is (50, 250, 250)).
1391        - `verbose` (`bool`, optional): If True, print progress (default is False).
1392
1393        **Returns**
1394
1395        - `info` (`CondensedData`): CondensedData object containing the expanded data and metadata.
1396        """     
1397        
1398        use_data = self.data.copy()
1399        if remove_centre is not None:
1400            if verbose:
1401                print("Removing centre...")
1402            # get centre coords
1403            y_idx = np.argmin(self.Y**2)
1404            x_idx = np.argmin(self.X**2)
1405            
1406            # proportion of the radius that the beam takes up
1407            beam_rad_au = (((self.beam_maj + self.beam_min)/2)*np.pi/180)*self.distance_to_star
1408            beam_prop = beam_rad_au/self.radius
1409            v_axis_removal = remove_centre*beam_prop*self.v_exp
1410            v_idxs = np.arange(len(self.V))[np.abs(self.V) < v_axis_removal]
1411            relevant_vs = self.V[np.abs(self.V) < v_axis_removal]
1412            proportions = remove_centre * np.sqrt(1 - (relevant_vs/v_axis_removal)**2)
1413            
1414            # removing centre
1415            for i in range(len(v_idxs)):
1416                v_idx = v_idxs[i]
1417                prop = proportions[i]
1418                beam = self.__multiply_beam(prop)
1419                for x_offset, y_offset in beam:
1420                    use_data[v_idx][y_idx + y_offset][x_idx + x_offset] = np.nan
1421        
1422        if verbose:
1423            print("Transforming coordinates...")
1424        X, Y, V, data, points = self.__time_expansion_transform(self.X, self.Y, self.V, use_data, years, v_func)
1425        v_num, y_num, x_num = new_shape
1426
1427        if verbose:
1428            print("Generating grid for new object...")
1429        gridv, gridy, gridx = np.mgrid[
1430            np.min(V):np.max(V):v_num*1j, 
1431            np.min(Y):np.max(Y):y_num*1j,
1432            np.min(X):np.max(X):x_num*1j
1433        ] 
1434
1435
1436        # get preimage of grid
1437        small_gridx = gridx[0, 0, :]
1438        small_gridy = gridy[0, :, 0]
1439        small_gridv = gridv[:, 0, 0]
1440
1441        # go backwards with negative years
1442        if verbose:
1443            print("Shrinking grid to original data bounds...")
1444        prev_X, prev_Y, prev_V, _, _ = self.__time_expansion_transform(small_gridx, small_gridy, small_gridv, use_data,  -years, v_func, crop = False)
1445        
1446        
1447        bad_idxs = (prev_V < np.min(points[0])) | (prev_V > np.max(points[0])) | \
1448                (prev_Y < np.min(points[1])) | (prev_Y > np.max(points[1])) | \
1449                (prev_X < np.min(points[2])) | (prev_X > np.max(points[2])) | \
1450                np.isnan(prev_V) | np.isnan(prev_X) | np.isnan(prev_Y)
1451        
1452        prev_V[bad_idxs] = 0
1453        prev_X[bad_idxs] = 0
1454        prev_Y[bad_idxs] = 0
1455
1456        # interpolate regular data at these points
1457        if verbose:
1458            print("Interpolating...")
1459        interp_points = np.column_stack((prev_V, prev_Y, prev_X))
1460        interp_data = interpn(points, data, interp_points)
1461        interp_data[bad_idxs] = np.nan
1462        new_data = interp_data.reshape(gridx.shape)
1463
1464        non_nans = len(new_data[np.isfinite(new_data)])
1465        if verbose:
1466            print(f"{non_nans} non-nan values remaining out of {np.size(new_data)}")
1467
1468        info = CondensedData(
1469            small_gridx, 
1470            small_gridy, 
1471            small_gridv, 
1472            new_data, 
1473            self.star_name, 
1474            self.distance_to_star, 
1475            self.v_exp, 
1476            self.v_sys,
1477            self.beta,
1478            self.r_dust,
1479            self.beam_maj, 
1480            self.beam_min, 
1481            self.beam_angle,
1482            self._header,
1483            self.mean, 
1484            self.std
1485        )
1486
1487        if verbose:
1488            print("Time expansion process complete!")
1489        return info
1490
1491
1492    # ---- PLOTTING ----
1493
1494    def plot_velocity_vs_intensity(self, fit_parabola: bool = True) -> None:
1495        """
1496        Plot velocity vs. intensity at the center of the xy plane.
1497
1498        **Parameters**
1499
1500        - `fit_parabola` (`bool`, optional): If True, fit and plot a parabola (default is True).
1501
1502        **Notes**
1503
1504        The well-fittedness of the parabola can help you visually determine 
1505        the accuracy of the calculated systemic and expansion velocity.
1506        """
1507        self.__get_star_and_exp_velocity(self.V, plot=True, fit_parabola=fit_parabola)
1508
1509    def plot_radius_vs_intensity(self) -> None:
1510        """
1511        Plot radius vs. intensity at the center of the xy plane.
1512        """
1513        self.__get_beta_law(plot_intensities=True)
1514
1515    def plot_radius_vs_velocity(self, fit_beta_law: bool = True) -> None:
1516        """
1517        Plot radius vs. velocity at the center of the xy plane.
1518
1519        **Parameters**
1520
1521        - `fit_beta_law` (`bool`, optional): If True, fit and plot the beta law (default is True).
1522
1523        **Notes**
1524
1525        The well-fittedness of the beta law curve can help you visually determine 
1526        the accuracy of the calculated beta law parameters.
1527        """
1528        self.__get_beta_law(plot_velocities=True, plot_beta_law=fit_beta_law)
1529
1530    def plot_channel_maps(self, 
1531        filter_stds: float | int | None = None, 
1532        filter_beam: bool = False, 
1533        dimensions: None | tuple[int,int] = None, 
1534        start: int = 0, 
1535        end: None | int = None, 
1536        include_beam: bool = True, 
1537        text_pos: None | tuple[float,float] = None, 
1538        beam_pos: None | tuple[float,float] = None, 
1539        title: str | None = None, 
1540        cmap: str = "viridis"
1541    ) -> None:
1542        """
1543        Plot the data cube as a set of 2D channel maps.
1544
1545        **Parameters**
1546
1547        - `filter_stds` (`float` or `int` or `None`, optional): Number of standard deviations to filter by (default is None).
1548        - `filter_beam` (`bool`, optional): If True, apply beam filtering (default is False).
1549        - `dimensions` (`tuple` of `int` or `None`, optional): Grid dimensions (nrows, ncols) for subplots (default is None).
1550        - `start` (`int`, optional): Starting velocity channel index (default is 0).
1551        - `end` (`int` or `None`, optional): Ending velocity channel index (default is None).
1552        - `include_beam` (`bool`, optional): If True, plot the beam ellipse (default is True).
1553        - `text_pos` (`tuple` of `float` or `None`, optional): Position for velocity text annotation (default is None).
1554        - `beam_pos` (`tuple` of `float` or `None`, optional): Position for beam ellipse (default is None).
1555        - `title` (`str` or `None`, optional): Plot title (default is None, which uses the name of the star).
1556        - `cmap` (`str`, optional): Colormap for the plot (default is "viridis"). See https://matplotlib.org/stable/users/explain/colors/colormaps.html
1557        """
1558
1559        if end is None:
1560            end = self.data.shape[0]-1
1561        else:
1562            if end < start: #invalid so set to the end of the data
1563                end = self.data.shape[0]-1 
1564            if end >= self.data.shape[0]: #invalid so set to the end of the data
1565                end = self.data.shape[0]-1 
1566
1567        if filter_stds is not None:
1568            data = self.get_filtered_data(filter_stds)
1569            if filter_beam:
1570                data = self.beam_filter(data)
1571        else:
1572            data = self.data
1573        
1574        if start >= data.shape[0]: #invalid so set to the start of the data
1575            start = 0
1576
1577        if title is None:
1578            title = self.star_name + " channel maps"
1579
1580        if dimensions is not None:
1581            nrows, ncols = dimensions
1582        else:
1583            nrows = int(np.ceil(np.sqrt(end-start)))
1584            ncols = int(np.ceil((end-start+1)/nrows))
1585
1586        fig, axes = plt.subplots(nrows,ncols)
1587        viridis = colormaps[cmap]
1588        fig.suptitle(title)
1589        fig.supxlabel("Right Ascension (Arcseconds)")
1590        fig.supylabel("Declination (Arcseconds)")
1591
1592        i = start
1593        extents = u.radian.to(u.arcsec,np.array([np.min(self.X),np.max(self.X),np.min(self.Y),np.max(self.Y)])/self.distance_to_star)
1594        
1595        done = False
1596        for ax in axes.flat:
1597            if not done:
1598                im = ax.imshow(data[i],vmin = self.mean, vmax = np.max(data[~np.isnan(data)]), extent=extents, cmap = cmap)
1599
1600                ax.set_facecolor(viridis(0))
1601                ax.set_aspect("equal")
1602
1603                if text_pos is None:
1604                    ax.text(extents[0]*5/6,extents[3]*1/2,f"{self.V[i]:.1f} km/s",size="x-small",c="white")
1605                else:
1606                    ax.text(text_pos[0],text_pos[1],f"{self.V[i]:.1f} km/s",size="x-small",c="white")
1607
1608                if include_beam:
1609                    bmaj = u.deg.to(u.arcsec,self.beam_maj)
1610                    bmin = u.deg.to(u.arcsec,self.beam_min)
1611                    bpa = self.beam_angle
1612
1613                    if beam_pos is None:
1614                        ellipse_artist = Ellipse(xy=(extents[0]*1/2,extents[2]*1/2),width=bmaj,height=bmin,angle=bpa,color = "white")
1615                    else:
1616                        ellipse_artist = Ellipse(xy=(beam_pos[0],beam_pos[1]),width=bmaj,height=bmin,angle=bpa, color = "white")
1617                    ax.add_artist(ellipse_artist)
1618                i += 1
1619                if (i >= data.shape[0]) or (i > end): done = True
1620            else:
1621                ax.axis("off")
1622        cbar = fig.colorbar(im, ax=axes.ravel().tolist())
1623        cbar.set_label("Flux Density (Jy/Beam)")
1624        plt.show()
1625
1626    def create_mask(self, filter_stds: float | int | None = None, savefile: str | None = None, initial_crop: tuple | None = None):
1627        """
1628        Launch an interactive mask creator for the data cube.
1629
1630        **Parameters**
1631
1632        - `filter_stds` (`float` or `int` or `None`, optional): Number of standard deviations to initially filter by (default is None).
1633        - `savefile` (`str` or `None`, optional): Filename to save the selected mask (default is None). Should end in .npy.
1634        - `initial_crop` (`tuple` or `None`, optional): Initial crop region as (v_lo, v_hi, y_lo, y_hi, x_lo, x_hi), using channel indices (default is None).
1635        """
1636        if filter_stds is not None:
1637            data = self.get_filtered_data(filter_stds)
1638        else:
1639            data = self.data
1640
1641        if initial_crop is not None:
1642            v_lo, v_hi, y_lo, y_hi, x_lo, x_hi = initial_crop
1643            new_data = data[v_lo:v_hi, y_lo:y_hi, x_lo:x_hi]
1644            selector = _PointsSelector(new_data)
1645            plt.show()
1646            mask = np.full(data.shape, np.nan)
1647            mask[v_lo:v_hi, y_lo:y_hi, x_lo:x_hi] = selector.mask
1648
1649        else:
1650            selector = _PointsSelector(data)
1651            plt.show()
1652        
1653            # mask is complete now
1654            mask = selector.mask
1655        
1656        # save mask
1657        if savefile is not None:
1658            np.save(savefile, mask)
1659
1660    def plot_3D(
1661        self, 
1662        filter_stds: float | int | None = None, 
1663        filter_beam: bool = False, 
1664        z_cutoff: float = 1,
1665        crop_leeway: float = 0,
1666        num_points: int = 50,
1667        num_surfaces: int = 50,
1668        opacity: float | int = 0.5,
1669        opacityscale: list[list[float]] = [[0, 0], [1, 1]],
1670        colorscale: str = "Reds",
1671        v_func: VelocityModel | None = None, 
1672        verbose: bool = False,
1673        title: str | None = None,
1674        folder: str | None = None,
1675        num_angles: int = 24,
1676        camera_dist: float | int = 2
1677    ) -> None:
1678        """
1679        Plot a 3D volume rendering of the data cube using Plotly.
1680
1681        **Parameters**
1682
1683        - `filter_stds` (`float` or `int` or `None`, optional): Number of standard deviations to filter by (default is None).
1684        - `filter_beam` (`bool`, optional): If True, apply beam filtering (default is False).
1685        - `z_cutoff` (`float` or `None`, optional): Cutoff for z values as a proportion of the largest x, y values (default is 1).
1686        - `crop_leeway` (`int` or `float`, optional): Fractional leeway to expand the crop region (default is 0).
1687        - `num_points` (`int`, optional): Number of points in each dimension for the grid (default is 50).
1688        - `num_surfaces` (`int`, optional): Number of surfaces to draw when rendering the plot (default is 50).
1689        - `opacity` (`float` or `int`, optional): Opacity of the volume rendering (default is 0.5).
1690        - `opacityscale` (`list` of `list` of `float`, optional): Opacity scale, see https://plotly.com/python-api-reference/generated/plotly.graph_objects.Volume.html (default is [[0, 0], [1, 1]]).
1691        - `colorscale` (`str`, optional): Colormap for the plot, see https://plotly.com/python/builtin-colorscales/ (default is "Reds").
1692        - `v_func` (`VelocityModel` or `None`, optional): Velocity law function (default is None, which uses constant expansion velocity).
1693        - `verbose` (`bool`, optional): If True, print progress (default is False).
1694        - `title` (`str` or `None`, optional): Plot title (default is None, which uses the star name).
1695        - `folder` (`str` or `None`, optional): If provided, generates successive frames and saves as png files to this folder (default is None).
1696        - `num_angles` (`int`, optional): Number of angles for saving images (default is 24). Greater values give smoother animation.
1697        - `camera_dist` (`float` or `int`, optional): Camera radius to use if generating frames (default is 2).
1698        """
1699        if verbose:
1700            print("Initial filter and crop...")
1701        X, Y, V, data = self.__filter_and_crop(filter_stds, filter_beam, crop_leeway, verbose)
1702        
1703        if title is None:
1704            title = self.star_name
1705
1706        # get small 1d arrays
1707        x_small = X[0, 0, :]
1708        y_small = Y[0, :, 0]
1709        v_small = V[:, 0, 0]
1710
1711        # interpolate to grid
1712        if verbose:
1713            print("Interpolating to grid...")
1714
1715        z_bound = z_cutoff * np.maximum(np.max(np.abs(x_small)), np.max(np.abs(y_small)))
1716        
1717        x_lo, x_hi, y_lo, y_hi, z_lo, z_hi = np.min(x_small), np.max(x_small), np.min(y_small), np.max(y_small), -z_bound, z_bound
1718        gridx, gridy, gridz, out = self.__fast_interpolation((v_small, y_small, x_small), (x_lo, x_hi), (y_lo, y_hi), (z_lo, z_hi), data, v_func, num_points)
1719
1720
1721        # filter by standard deviation
1722        if filter_stds is not None:
1723            out[out < filter_stds*self.std] = np.nan
1724
1725        out[np.isnan(out)] = 0  # Plotly cant deal with nans
1726
1727        if verbose:
1728            print(f"Found {len(out[out > 0])} non-nan points.")
1729
1730        out = out.ravel()
1731        min_value = np.min(out[np.isfinite(out) & (out > 0)])
1732
1733        if verbose:
1734            print("Plotting figure...")
1735
1736        
1737        fig = go.Figure(
1738            data=go.Volume(
1739                x=gridx,
1740                y=gridy,
1741                z=gridz,
1742                value=out,
1743                isomin = min_value,
1744                colorscale= colorscale,
1745                colorbar = dict(title = "Flux Density (Jy/beam)"),
1746                opacityscale=opacityscale,
1747                opacity=opacity, # needs to be small to see through all surfaces
1748                surface_count=num_surfaces, # needs to be a large number for good volume rendering
1749                **kwargs
1750            ),
1751            layout=go.Layout(
1752                title = {
1753                    "text":title,
1754                    "x":0.5,
1755                    "y":0.95,
1756                    "xanchor":"center",
1757                    "font":{"size":24}
1758                },
1759                scene = dict(
1760                      xaxis=dict(
1761                          title=dict(
1762                              text='X (AU)'
1763                          )
1764                      ),
1765                      yaxis=dict(
1766                          title=dict(
1767                              text='Y (AU)'
1768                          )
1769                      ),
1770                      zaxis=dict(
1771                          title=dict(
1772                              text='Z (AU)'
1773                          )
1774                      ),
1775                      aspectmode = "cube"
1776                    ),
1777            )
1778        )
1779
1780
1781
1782        if folder is not None:
1783            if verbose:
1784                print("Generating frames...")
1785            angles = np.linspace(0,360,num_angles)
1786            for a in angles:
1787                b = a*np.pi/180
1788                eye = dict(x=camera_dist*np.cos(b),y=camera_dist*np.sin(b),z=1.25)
1789                fig.update_layout(scene_camera_eye = eye)
1790                if folder:
1791                    fig.write_image(f"{folder}/angle{int(a)}.png")
1792                else:
1793                    fig.write_image(f"{int(a)}.png")
1794                if verbose:
1795                    print(f"Generating frames: {a/360}% complete.")
1796        else:
1797            fig.show()
1798                
1799
1800
1801
1802
1803class _PointsSelector:
1804
1805    """
1806    Interactive tool for selecting and masking points in a 3D data cube using matplotlib.
1807    """
1808
1809    def __init__(self, data: DataArray3D):
1810        """
1811        Initialize the _PointsSelector with a data cube.
1812
1813        Parameters
1814        ----------
1815        data : DataArray3D
1816            3D data array to select points from.
1817        """
1818        self.idx: int | None = None
1819        self.data: DataArray3D = data
1820        self.mask: DataArray3D = np.ones(data.shape)
1821        self.mask[np.isnan(data)] = np.nan
1822        self.num_plots: int = self.data.shape[0]
1823        ys = np.arange(self.data.shape[1])
1824        xs = np.arange(self.data.shape[2])
1825        X, Y = np.meshgrid(xs, ys, indexing = "ij")
1826        self.xys = np.column_stack((X.ravel(), Y.ravel()))
1827        self.width = int(np.sqrt(self.num_plots) + 1)  # ceiling of square root
1828        self.fig, axs = plt.subplots(self.width, self.width)
1829
1830        axs = axs.flatten()
1831        for i in range(len(axs)):
1832            if i >= self.num_plots:
1833                axs[i].axis("off")  # turn off unnecessary axes
1834
1835        self.axs = axs[:self.num_plots]
1836
1837        cmap = mpl.colormaps.get_cmap('viridis').copy()
1838        cmap.set_bad(color='white')
1839
1840        self.collections: list[AxesImage] = [None]*self.num_plots  # type: ignore
1841        upper = np.max(self.data[np.isfinite(self.data)])
1842        lower = np.min(self.data[np.isfinite(self.data)])
1843        for i in range(self.num_plots):
1844            self.collections[i] = self.axs[i].imshow(self.data[i], vmin= lower, vmax = upper, cmap = cmap)
1845
1846        self.canvas = self.fig.canvas
1847
1848        self.ind = []
1849        self.awaiting_keypress = False
1850        self.fig.suptitle("Double click on a subplot to start lassoing points.")
1851        self.fig.canvas.mpl_connect('button_press_event', self.onclick)
1852        self.fig.canvas.mpl_connect("key_press_event", self.key_press)
1853
1854    def onclick(self, event: MouseEvent):
1855        """
1856        Handle double-click events on subplots to start lasso selection.
1857
1858        Parameters
1859        ----------
1860        event : MouseEvent
1861            Matplotlib mouse event.
1862        """
1863        if event.dblclick:
1864            if event.inaxes is None:
1865                return
1866            else:
1867                for idx in range(len(self.axs)):
1868                    if event.inaxes is self.axs[idx]:
1869                        self.idx = idx
1870                        self.create_lasso()
1871                        return
1872
1873    def key_press(self, event):
1874        """
1875        Handle key press events after lasso selection.
1876
1877        Parameters
1878        ----------
1879        event : KeyEvent
1880            Matplotlib key event.
1881        """
1882        if self.awaiting_keypress:
1883            self.awaiting_keypress = False
1884            if event.key == "a":
1885                self.unmask_selection()
1886            elif event.key == "r":
1887                self.mask_selection()
1888            else:
1889                self.disconnect()
1890
1891    def onselect(self, verts):
1892        """
1893        Callback for when the lasso selection is completed.
1894
1895        Parameters
1896        ----------
1897        verts : list of tuple
1898            Vertices of the lasso path.
1899        """
1900        path = Path(verts)
1901        
1902        self.ind = self.xys[path.contains_points(self.xys)]
1903        self.temp_mask = np.zeros(self.data[self.idx].shape)
1904        for idx_pair in self.ind:
1905            i, j = idx_pair[0], idx_pair[1]
1906            self.temp_mask[j][i] = 1
1907
1908        alpha = np.ones(self.temp_mask.shape)
1909        alpha *= 0.2
1910        alpha[self.temp_mask == 1] = 1
1911        self.collections[self.idx].set_alpha(alpha)
1912        self.fig.suptitle("Press R to mask points, A to mask all other points, any other key to escape.")
1913        self.awaiting_keypress = True
1914        self.fig.canvas.draw_idle()
1915
1916
1917    def create_lasso(self):
1918        """
1919        Start the lasso selector on the currently active subplot.
1920        """
1921        self.fig.suptitle("Selecting points...")
1922        self.fig.canvas.draw_idle()
1923        self.lasso = LassoSelector(self.axs[self.idx], onselect=self.onselect)
1924
1925    def mask_selection(self):
1926        """
1927        Mask (set to NaN) the selected points in the current subplot.
1928        """
1929        self.mask[self.idx][self.temp_mask == 1] = np.nan
1930        self.collections[self.idx].set(data = self.data[self.idx]*self.mask[self.idx])
1931        self.disconnect()
1932
1933    def unmask_selection(self):
1934        """
1935        Mask (set to NaN) all points except the selected points in the current subplot.
1936        """
1937        self.mask[self.idx][self.temp_mask != 1] = np.nan
1938        self.collections[self.idx].set(data = self.data[self.idx]*self.mask[self.idx])
1939        self.disconnect()
1940        
1941
1942    def disconnect(self):
1943        """
1944        Disconnect the lasso selector and reset the subplot state.
1945        """
1946        self.lasso.disconnect_events()
1947        self.fig.suptitle("Double click on a subplot to start lassoing points.")
1948        alpha = np.ones(self.data[self.idx].shape)
1949        self.collections[self.idx].set_alpha(alpha)
1950        self.idx = None
1951        self.ind = []
1952        self.canvas.draw_idle()
DataArray1D = numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]]
DataArray3D = numpy.ndarray[tuple[int, int, int], numpy.dtype[numpy.float64]]
Matrix2x2 = numpy.ndarray[tuple[typing.Literal[2], typing.Literal[2]], numpy.dtype[numpy.float64]]
VelocityModel = collections.abc.Callable[[numpy.ndarray], numpy.ndarray]
class FITSHeaderError(builtins.Exception):
40class FITSHeaderError(Exception):
41    """
42    Raised when the FITS file header is missing necessary information, or storing it as the wrong type.
43    """
44    pass

Raised when the FITS file header is missing necessary information, or storing it as the wrong type.

@dataclass
class CondensedData:
46@dataclass
47class CondensedData:
48    """
49    Container for storing a condensed version of the data cube and its associated metadata.
50
51    **Attributes**
52
53    - `x_offsets` (`DataArray1D`): 1D array of x-coordinates (offsets) in AU.
54    - `y_offsets` (`DataArray1D`): 1D array of y-coordinates (offsets) in AU.
55    - `v_offsets` (`DataArray1D`): 1D array of velocity offsets in km/s.
56    - `data` (`DataArray3D`): 3D data array (velocity, y, x) containing the intensity values.
57    - `star_name` (`str`): Name of the star or object.
58    - `distance_to_star` (`float`): Distance to the star in AU.
59    - `v_exp` (`float`): Expansion velocity in km/s.
60    - `v_sys` (`float`): Systemic velocity in km/s.
61    - `beta` (`float`): Beta parameter of the velocity law.
62    - `r_dust` (`float`): Dust formation radius in AU.
63    - `beam_maj` (`float`): Major axis of the beam in degrees.
64    - `beam_min` (`float`): Minor axis of the beam in degrees.
65    - `beam_angle` (`float`): Beam position angle in degrees.
66    - `header` (`Header`): FITS header containing metadata.
67    - `mean` (`float` or `None`, optional): Mean intensity of the data (default is None).
68    - `std` (`float` or `None`, optional): Standard deviation of the data (default is None).
69    """
70    x_offsets: DataArray1D
71    y_offsets: DataArray1D
72    v_offsets: DataArray1D
73    data: DataArray3D
74    star_name: str
75    distance_to_star: float
76    v_exp: float
77    v_sys: float
78    beta: float
79    r_dust: float
80    beam_maj: float
81    beam_min: float
82    beam_angle: float
83    header: Header
84    mean: float | None = None
85    std: float | None = None

Container for storing a condensed version of the data cube and its associated metadata.

Attributes

  • x_offsets (DataArray1D): 1D array of x-coordinates (offsets) in AU.
  • y_offsets (DataArray1D): 1D array of y-coordinates (offsets) in AU.
  • v_offsets (DataArray1D): 1D array of velocity offsets in km/s.
  • data (DataArray3D): 3D data array (velocity, y, x) containing the intensity values.
  • star_name (str): Name of the star or object.
  • distance_to_star (float): Distance to the star in AU.
  • v_exp (float): Expansion velocity in km/s.
  • v_sys (float): Systemic velocity in km/s.
  • beta (float): Beta parameter of the velocity law.
  • r_dust (float): Dust formation radius in AU.
  • beam_maj (float): Major axis of the beam in degrees.
  • beam_min (float): Minor axis of the beam in degrees.
  • beam_angle (float): Beam position angle in degrees.
  • header (Header): FITS header containing metadata.
  • mean (float or None, optional): Mean intensity of the data (default is None).
  • std (float or None, optional): Standard deviation of the data (default is None).
CondensedData( x_offsets: numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]], y_offsets: numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]], v_offsets: numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]], data: numpy.ndarray[tuple[int, int, int], numpy.dtype[numpy.float64]], star_name: str, distance_to_star: float, v_exp: float, v_sys: float, beta: float, r_dust: float, beam_maj: float, beam_min: float, beam_angle: float, header: astropy.io.fits.header.Header, mean: float | None = None, std: float | None = None)
x_offsets: numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]]
y_offsets: numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]]
v_offsets: numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]]
data: numpy.ndarray[tuple[int, int, int], numpy.dtype[numpy.float64]]
star_name: str
distance_to_star: float
v_exp: float
v_sys: float
beta: float
r_dust: float
beam_maj: float
beam_min: float
beam_angle: float
header: astropy.io.fits.header.Header
mean: float | None = None
std: float | None = None
class StarData:
  88class StarData:
  89
  90    """
  91    Class for manipulating and analyzing astronomical data cubes of radially expanding circumstellar envelopes.
  92
  93    The StarData class provides a comprehensive interface for loading, processing, analyzing, and visualizing
  94    3D data cubes (typically from FITS files) representing the emission from expanding circumstellar shells.
  95    It supports both direct loading from FITS files and from preprocessed CondensedData objects, and manages
  96    all relevant metadata and derived quantities.
  97
  98    Key Features
  99    ------------
 100    - **Data Loading:** Supports initialization from FITS files or CondensedData objects.
 101    - **Metadata Management:** Stores and exposes all relevant observational and physical parameters, including
 102      beam properties, systemic and expansion velocities, beta velocity law parameters, and FITS header information.
 103    - **Noise Estimation:** Automatically computes mean and standard deviation of background noise for filtering.
 104    - **Filtering:** Provides methods to filter data by significance (standard deviations) and to remove small clumps
 105      of points that fit within the beam (beam filtering).
 106    - **Coordinate Transformations:** Handles conversion between velocity space and spatial (cartesian) coordinates,
 107      supporting both constant velocity models and general velocity laws.
 108    - **Time Evolution:** Can compute the expansion of the envelope over time, transforming the data cube accordingly.
 109    - **Visualization:** Includes a variety of plotting methods:
 110        - Channel maps (2D slices through velocity channels)
 111        - 3D volume rendering (with Plotly)
 112        - Diagnostic plots for velocity/intensity and radius/velocity relationships
 113    - **Interactive Masking:** Supports interactive creation of masks for manual data cleaning.
 114
 115    Attributes
 116    ------------
 117
 118    - `data` (`DataArray3D`): The main data cube (v, y, x) containing intensity values.
 119    - `X` (`DataArray1D`): 1D array of x-coordinates (offsets) in AU.
 120    - `Y` (`DataArray1D`): 1D array of y-coordinates (offsets) in AU.
 121    - `V` (`DataArray1D`): 1D array of velocity offsets in km/s.
 122    - `distance_to_star` (`float`): Distance to the star in AU.
 123    - `beam_maj` (`float`): Major axis of the beam in degrees.
 124    - `beam_min` (`float`): Minor axis of the beam in degrees.
 125    - `beam_angle` (`float`): Beam position angle in degrees.
 126    - `mean` (`float`): Mean intensity of the background noise.
 127    - `std` (`float`): Standard deviation of the background noise.
 128    - `v_sys` (`float`): Systemic velocity in km/s.
 129    - `v_exp` (`float`): Expansion velocity in km/s.
 130    - `beta` (`float`): Beta parameter of the velocity law.
 131    - `r_dust` (`float`): Dust formation radius in AU.
 132    - `radius` (`float`): Characteristic radius (e.g., maximum intensity change) in AU.
 133    - `beta_velocity_law` (`VelocityModel`): Callable implementing the beta velocity law with the current object's parameters.
 134    - `star_name` (`str`): Name of the star or object.
 135
 136    Methods
 137    -------
 138    - `export() -> CondensedData`: Export all defining attributes to a CondensedData object.
 139    - `get_filtered_data(stds=5)`: Return a copy of the data, with values below the specified number of standard deviations set to np.nan.
 140    - `beam_filter(filtered_data)`: Remove clumps of points that fit inside the beam, setting these values to np.nan.
 141    - `get_expansion(years, v_func, ...)`: Compute the expanded data cube after a given time interval.
 142    - `plot_channel_maps(...)`: Plot the data cube as a set of 2D channel maps.
 143    - `plot_3D(...)`: Plot a 3D volume rendering of the data cube using Plotly.
 144    - `plot_velocity_vs_intensity(...)`: Plot velocity vs. intensity at the center of the xy plane.
 145    - `plot_radius_vs_intensity()`: Plot radius vs. intensity at the center of the xy plane.
 146    - `plot_radius_vs_velocity(...)`: Plot radius vs. velocity at the center of the xy plane.
 147    - `create_mask(...)`: Launch an interactive mask creator for the data cube.
 148    """
 149
 150    _c = 299792.458  # speed of light, km/s
 151    v0 = 3  # km/s, speed of sound
 152
 153    def __init__(
 154        self,
 155        info_source: str | CondensedData,
 156        distance_to_star: float | None = None,
 157        rest_frequency: float | None = None,
 158        maskfile: str | None = None,
 159        beta_law_params: tuple[float, float] | None = None,
 160        v_exp: float | None = None,
 161        v_sys: float | None = None,
 162        absolute_star_pos: tuple[float, float] | None = None
 163    ) -> None:
 164        """
 165        Initialize a StarData object by reading data from a FITS file or a CondensedData object.
 166
 167        **Parameters**
 168
 169        - `info_source` (`str` or `CondensedData`): Path to a FITS file or a CondensedData object containing preprocessed data.
 170        - `distance_to_star` (`float` or `None`, optional): Distance to the star in AU (required if info_source is a FITS file).
 171        - `rest_frequency` (`float` or `None`, optional): Rest frequency in Hz (required if info_source is a FITS file).
 172        - `maskfile` (`str` or `None`, optional): Path to a .npy file containing a mask to apply to the data.
 173        - `beta_law_params` (`tuple` of `float` or `None`, optional): (r_dust (AU), beta) parameters for the beta velocity law. If None, will be fit from data.
 174        - `v_exp` (`float` or `None`, optional): Expansion velocity in km/s. If None, will be fit from data.
 175        - `v_sys` (`float` or `None`, optional): Systemic velocity in km/s. If None, will be fit from data.
 176        - `absolute_star_pos` (`tuple` of `float` or `None`, optional): Absolute (RA, Dec) position of the star in degrees. If None, taken to be the centre of the image.
 177
 178        **Raises**
 179
 180        - `ValueError`: If required parameters are missing when reading from a FITS file.
 181        - `FITSHeaderError`: If any attribute in the FITS file header is an incorrect type.
 182        """
 183        if isinstance(info_source, str):
 184            if distance_to_star is None or rest_frequency is None:
 185                raise ValueError("Distance to star and rest frequency required when reading from FITS file.")
 186            self.__load_from_fits_file(info_source, distance_to_star, rest_frequency, absolute_star_pos, v_sys = v_sys, v_exp = v_exp)
 187            if beta_law_params is None:
 188                self._r_dust, self._beta, self._radius = self.__get_beta_law()
 189            else:
 190                self._r_dust, self._beta = beta_law_params
 191
 192        else:
 193            # load from CondensedData
 194            self._X = info_source.x_offsets
 195            self._Y = info_source.y_offsets
 196            self._V = info_source.v_offsets
 197            self._data = info_source.data
 198            self.star_name = info_source.star_name
 199            self._distance_to_star = info_source.distance_to_star
 200            self._v_exp = info_source.v_exp if v_exp is None else v_exp
 201            self._v_sys = info_source.v_sys if v_sys is None else v_sys
 202            self._r_dust = info_source.r_dust if beta_law_params is None else beta_law_params[0]
 203            self._beta = info_source.beta if beta_law_params is None else beta_law_params[1]
 204            self._beam_maj = info_source.beam_maj
 205            self._beam_min = info_source.beam_min
 206            self._beam_angle = info_source.beam_angle
 207            self._header = info_source.header
 208            self.__process_beam()
 209
 210            # compute mean and standard deviation
 211            if info_source.mean is None or info_source.std is None:
 212                self._mean, self._std = self.__mean_and_std()
 213            else:
 214                self._mean = info_source.mean
 215                self._std = info_source.std
 216
 217
 218        if maskfile is not None:
 219            # mask data (permanent)
 220            mask = np.load(maskfile)
 221            self._data = self._data * mask
 222        
 223    # ---- READ ONLY ATTRIBUTES ----
 224
 225    @property
 226    def data(self) -> DataArray3D:
 227        """
 228        DataArray3D 
 229        
 230        Stores the intensity of light at each data point.
 231        
 232        Dimensions: k x m x n, where k is the number of frequency channels,
 233        m is the number of declination channels, and n is the number of right ascension channels.
 234        """
 235        return self._data
 236
 237    @property
 238    def X(self) -> DataArray1D:
 239        """
 240        DataArray1D 
 241        
 242        Stores the x-coordinates relative to the centre in AU.
 243        Obtained from right ascension coordinates.
 244        """
 245        return self._X
 246
 247    @property
 248    def Y(self) -> DataArray1D:
 249        """
 250        DataArray1D
 251         
 252        Stores the y-coordinates relative to the centre in AU.
 253        Obtained from declination coordinates.
 254        """
 255        return self._Y
 256
 257    @property
 258    def V(self) -> DataArray1D:
 259        """
 260        DataArray1D
 261        
 262        Stores the velocity offsets relative to the star velocity in km/s.
 263        Obtained from frequency channels.
 264        """
 265        return self._V
 266
 267    @property
 268    def distance_to_star(self) -> float:
 269        """
 270        Distance to star in AU.
 271        """
 272        return self._distance_to_star
 273
 274    @property
 275    def B(self) -> Matrix2x2:
 276        """
 277        Matrix2x2
 278
 279        Ellipse matrix of beam. For 1x2 vectors v, w with coordinates (ra, dec) in degrees,
 280        if (v-w)^T B (v-w) < 1, then v is within the beam centred at w.
 281        """
 282        return self._B
 283
 284    @property
 285    def beam_maj(self) -> float:
 286        """
 287        Major axis of the beam in degrees.
 288        """
 289        return self._beam_maj
 290
 291    @property
 292    def beam_min(self) -> float:
 293        """
 294        Minor axis of the beam in degrees.
 295        """
 296        return self._beam_min
 297
 298    @property
 299    def beam_angle(self) -> float:
 300        """
 301        Beam position angle in degrees.
 302        """
 303        return self._beam_angle
 304
 305    @property
 306    def mean(self) -> float:
 307        """
 308        The mean intensity of the light, taken over coordinates away from the centre.
 309        """
 310        return self._mean
 311
 312    @property
 313    def std(self) -> float:
 314        """
 315        The standard deviation of the intensity of the light, taken over coordinates away from the centre.
 316        """
 317        return self._std
 318
 319    @property
 320    def v_sys(self) -> float:
 321        """
 322        The systemic velocity of the star in km/s.
 323        """
 324        return self._v_sys
 325
 326    @property
 327    def v_exp(self) -> float:
 328        """
 329        The maximum radial expansion speed in km/s.
 330        """
 331        return self._v_exp
 332
 333    @property
 334    def beta(self) -> float:
 335        """
 336        Beta parameter of the velocity law.
 337        """
 338        return self._beta
 339
 340    @property
 341    def r_dust(self) -> float:
 342        """
 343        Dust formation radius in AU.
 344        """
 345        return self._r_dust
 346
 347    @property
 348    def radius(self) -> float:
 349        """
 350        Characteristic radius (e.g., maximum intensity change).
 351        """
 352        return self._radius
 353
 354    @property
 355    def beta_velocity_law(self) -> VelocityModel:
 356        """
 357        VelocityModel
 358
 359        Returns a callable implementing the beta velocity law with the current object's parameters.
 360        """
 361        def law(r):
 362            return self.__general_beta_velocity_law(r, self.r_dust, self.beta)
 363        return law
 364
 365    # ---- EXPORT ----
 366
 367    def export(self) -> CondensedData:
 368        """
 369        Export all defining attributes to a CondensedData object.
 370        """
 371        return CondensedData(
 372            self.X,
 373            self.Y,
 374            self.V,
 375            self.data, 
 376            self.star_name,
 377            self.distance_to_star,
 378            self.v_exp,
 379            self.v_sys,
 380            self.beta,
 381            self.r_dust,
 382            self.beam_maj,
 383            self.beam_min,
 384            self.beam_angle,
 385            self._header,
 386            self.mean,
 387            self.std 
 388        )
 389
 390    # ---- HELPER METHODS FOR INITIALISATION ----
 391
 392    @staticmethod
 393    def __header_check(header: Header) -> bool:
 394        """
 395        Check that the FITS header contains all required values with appropriate types.
 396
 397        Parameters
 398        ----------
 399        header : Header
 400            FITS header object to check.
 401
 402        Returns
 403        -------
 404        missing_beam : bool
 405            True if beam parameters are missing from the header, False otherwise.
 406
 407        Raises
 408        ------
 409        FITSHeaderError
 410            If any required attribute is present but has an incorrect type.
 411        """
 412        missing_beam = False
 413        types_to_check = {
 414            "BSCALE": float,
 415            "BZERO": float,
 416            "OBJECT": str,
 417            "BMAJ": float,
 418            "BMIN": float,
 419            "BPA": float,
 420            "BTYPE": str,
 421            "BUNIT": str
 422        }
 423        for num in range(1, 4):
 424            types_to_check["CTYPE" + str(num)] = str
 425            types_to_check["NAXIS" + str(num)] = int
 426            types_to_check["CRPIX" + str(num)] = float
 427            types_to_check["CRVAL" + str(num)] = float
 428            types_to_check["CDELT" + str(num)] = float
 429
 430        # check if beam is present in data
 431        if "BMAJ" not in header or "BMIN" not in header or "BPA" not in header:
 432            missing_beam = True
 433
 434        for attr in types_to_check:
 435            if attr in ["BMAJ", "BMIN", "BPA"] and missing_beam:
 436                continue
 437            attr_type = types_to_check[attr]
 438            if not type(header[attr]) is attr_type:
 439                raise FITSHeaderError(f"Header attribute {attr} should have type {attr_type}, instead is {type(attr)}")
 440        return missing_beam
 441
 442    def __load_from_fits_file(
 443        self,
 444        filename: str,
 445        distance_to_star: float,
 446        rest_freq: float,
 447        absolute_star_pos: tuple[float, float] | None = None,
 448        v_sys: float | None = None,
 449        v_exp: float | None = None
 450    ) -> None:
 451        """
 452        Load data and metadata from a FITS file and initialize StarData attributes.
 453
 454        Parameters
 455        ----------
 456        filename : str
 457            Path to the FITS file.
 458        distance_to_star : float
 459            Distance to the star in AU.
 460        rest_freq : float
 461            Rest frequency in Hz.
 462        absolute_star_pos : tuple of float or None, optional
 463            Absolute (RA, Dec) position of the star in degrees. If None, use image center.
 464        v_sys : float or None, optional
 465            Systemic velocity in km/s. If None, will be fit from data.
 466        v_exp : float or None, optional
 467            Expansion velocity in km/s. If None, will be fit from data.
 468
 469        Returns
 470        -------
 471        None
 472
 473        Raises
 474        ------
 475        FITSHeaderError
 476            If the FITS header is missing required attributes or has incorrect types.
 477        AssertionError
 478            If the FITS file does not contain data.
 479        """
 480        # read data from file
 481        with fits.open(filename) as hdul:
 482            hdu: PrimaryHDU = hdul[0] # type: ignore
 483
 484            missing_beam = self.__header_check(hdu.header)  # check that all the information is available before proceeding
 485            self._header: Header = hdu.header
 486            if missing_beam:  # data is in hdul[1].data instead
 487                beam_data = list(hdul[1].data)
 488
 489                str_to_conversion = {
 490                    "arcsec": 1/3600,
 491                    "deg": 1,
 492                    "degrees": 1,
 493                    "degree": 1
 494                }
 495                unit_maj = hdul[1].header["TUNIT1"]
 496                unit_min = hdul[1].header["TUNIT2"]
 497
 498                # 1 arcsec = 1/3600 degree
 499                self._beam_maj = np.mean(np.array([beam[0] for beam in beam_data]))*str_to_conversion[unit_maj]
 500                self._beam_min = np.mean(np.array([beam[1] for beam in beam_data]))*str_to_conversion[unit_min]
 501                self._beam_angle = np.mean(np.array([beam[2] for beam in beam_data]))
 502            else:
 503                self._beam_maj = self._header["BMAJ"]
 504                self._beam_min = self._header["BMIN"]
 505                self._beam_angle = self._header["BPA"]
 506                
 507            
 508
 509            brightness_scale: float = self._header["BSCALE"]  # type: ignore
 510            brightness_zero: float = self._header["BZERO"]  # type: ignore
 511
 512            # scale data to be in specified brightness units
 513            assert hdu.data is not None
 514            self._data: DataArray3D = np.array(hdu.data[0], dtype = float64)*brightness_scale+brightness_zero  #freq, dec, ra
 515
 516        self.star_name: str = self._header["OBJECT"]  # type: ignore
 517        self._distance_to_star = distance_to_star
 518
 519
 520        # get velocities from frequencies
 521        freq_range: DataArray1D = self.__get_header_array(3)
 522        vel_range: DataArray1D = (1/freq_range-1/rest_freq)*rest_freq*StarData._c  # velocity in km/s
 523        if vel_range[-1] < vel_range[0]:  # array is backwards
 524            vel_range = np.flip(vel_range)
 525
 526
 527        # get X and Y coordinates
 528        ra_vals = self.__get_header_array(1)
 529        dec_vals = self.__get_header_array(2)   # reverse !!
 530        if absolute_star_pos is None:
 531            ra_offsets = ra_vals - np.mean(ra_vals)
 532            dec_offsets = dec_vals - np.mean(dec_vals)
 533        else:
 534            ra_offsets = ra_vals - absolute_star_pos[0]
 535            dec_offsets = dec_vals - absolute_star_pos[1]
 536
 537        self._X = ra_offsets*self.distance_to_star*np.pi/180  # measured in AU
 538        self._Y = dec_offsets*self.distance_to_star*np.pi/180
 539
 540        self.__process_beam()
 541
 542        # get mean and standard deviation of intensity values of noise
 543        self._mean, self._std = self.__mean_and_std()
 544
 545        # get velocity offsets
 546        self._v_sys, self._v_exp = self.__get_star_and_exp_velocity(vel_range, v_sys = v_sys, v_exp = v_exp)
 547        self._V = vel_range - self.v_sys
 548
 549    def __process_beam(self) -> None:
 550        """
 551        Compute the beam ellipse matrix and pixel offsets for the beam and its boundary.
 552
 553        Returns
 554        -------
 555        None
 556
 557        Notes
 558        -----
 559        Sets the attributes `_B`, `_offset_in_beam`, and `_boundary_offset` for use in beam-related calculations.
 560        """
 561        # get matrix for elliptical distance corresponding to beam
 562        major: float = float(self.beam_maj)/2  # type: ignore
 563        minor: float = self.beam_min/2  # type: ignore
 564
 565        # degress -> radians
 566        theta: float = self.beam_angle*np.pi/180  # type: ignore
 567        R = np.array([
 568                [np.cos(theta), -np.sin(theta)],
 569                [np.sin(theta), np.cos(theta)]
 570        ])
 571        D = np.diag([1/minor, 1/major])
 572        self._B = R@D@D@R.T  # beam matrix: (v-w)^T B (v-w) < 1 means v is within beam centred at w
 573        self._offset_in_beam, self._boundary_offset = self.__pixels_in_beam(major)
 574
 575    def __pixels_in_beam(self, major: float) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]:
 576        """
 577        Determine which pixel offsets are inside or on the boundary of the beam ellipse.
 578
 579        Parameters
 580        ----------
 581        major : float
 582            Length of the major axis of the ellipse, in degrees.
 583
 584        Returns
 585        -------
 586        pixels_in_beam : list of tuple of int
 587            List of (x, y) offsets inside the beam ellipse, relative to the center.
 588        pixels_on_bdry : list of tuple of int
 589            List of (x, y) offsets on the boundary of the beam ellipse, relative to the center.
 590        """
 591        delta_x: float = self._header["CDELT1"]  # type: ignore
 592        delta_y: float = self._header["CDELT2"]  # type: ignore
 593
 594        bound_x: int = int(np.abs(major/delta_y) + 1)  # square to bound search in x direction
 595        bound_y: int = int(np.abs(major/delta_x) + 1)  # y direction
 596
 597        pixels_on_bdry = []
 598        pixels_in_beam = []
 599
 600        for x_offset in range(-bound_x, bound_x + 1):
 601            for y_offset in range(-bound_y, bound_y + 1):
 602
 603                # get position and elliptic distance from origin
 604                pos = np.array([delta_x*x_offset, delta_y*y_offset])
 605                dist = np.sqrt(pos.T @ self.B @ pos)
 606
 607                # determine if on boundary or inside ellipse
 608                if 1 <= dist <= 1.2:
 609                    pixels_on_bdry.append((x_offset, y_offset))
 610                elif dist < 1:
 611                    pixels_in_beam.append((x_offset, y_offset))
 612
 613        return pixels_in_beam, pixels_on_bdry
 614    
 615    def __get_header_array(self, num: Literal[1, 2, 3]) -> DataArray1D:
 616        """
 617        Get coordinate values from the FITS header for RA, DEC, or FREQ axes.
 618
 619        Parameters
 620        ----------
 621        num : {1, 2, 3}
 622            Axis number: 1 for RA, 2 for DEC, 3 for FREQ.
 623
 624        Returns
 625        -------
 626        vals : DataArray1D
 627            1-D array of coordinate values for the specified axis, computed from the header.
 628        """
 629        vals_length: int = self._header["NAXIS" + str(num)]  # type: ignore
 630        vals: np.ndarray[tuple[int], np.dtype[np.float64]] = np.zeros(vals_length)
 631        x0: float = self._header["CRPIX" + str(num)]  # type: ignore
 632        y0: float = self._header["CRVAL" + str(num)]  # type: ignore
 633        delta: float = self._header["CDELT" + str(num)]  # type: ignore
 634
 635        # vals[i] should be delta*(i - x0) + y0, for 1-indexing. but we are 0-indexed
 636        for i in range(len(vals)):
 637            vals[i] = delta*(i + 1 - x0) + y0
 638
 639        return vals
 640
 641    def __mean_and_std(self) -> tuple[float, float]:
 642        """
 643        Calculate the mean and standard deviation of the background noise, by looking at points away from the center.
 644
 645        Returns
 646        -------
 647        mean : float
 648            Mean intensity of the background noise.
 649        std : float
 650            Standard deviation of the background noise.
 651        """
 652        # trim, as edges can be unreliable
 653        outer_trim = 20  # remove outer 1/20th
 654
 655        frames, y_obs, x_obs = self.data.shape
 656        trimmed_data = self.data[:, y_obs//outer_trim:y_obs - y_obs//outer_trim, x_obs//outer_trim:x_obs - x_obs//outer_trim]
 657        frames, y_obs, x_obs = trimmed_data.shape
 658        
 659
 660        # take edges of trimmed data
 661        inner_trim = 5  # take outer 1/5 th
 662        left_close = trimmed_data[:frames//inner_trim, :y_obs//inner_trim, :].flatten()
 663        right_close = trimmed_data[:frames//inner_trim, y_obs - y_obs//inner_trim:, :].flatten()
 664        top_close = trimmed_data[:frames//inner_trim, y_obs//inner_trim:y_obs - y_obs//inner_trim, :x_obs//inner_trim].flatten()
 665        bottom_close = trimmed_data[:frames//inner_trim, y_obs//inner_trim:y_obs - y_obs//inner_trim, x_obs-x_obs//inner_trim:].flatten()
 666        
 667        left_far = trimmed_data[frames - frames//inner_trim:, :y_obs//inner_trim, :].flatten()
 668        right_far = trimmed_data[frames - frames//inner_trim:, y_obs - y_obs//inner_trim:, :].flatten()
 669        top_far = trimmed_data[frames - frames//inner_trim:, y_obs//inner_trim:y_obs - y_obs//inner_trim, :x_obs//inner_trim].flatten()
 670        bottom_far = trimmed_data[frames - frames//inner_trim:, y_obs//inner_trim:y_obs - y_obs//inner_trim, x_obs-x_obs//inner_trim:].flatten()
 671
 672
 673        ring = np.concatenate((left_close, right_close, top_close, bottom_close, left_far, right_far, top_far, bottom_far))
 674        ring = ring[~np.isnan(ring)]
 675
 676        if len(ring) < 20:
 677            warnings.warn("Only {len(ring)} data points selected for finding mean and standard deviation of noise.")
 678
 679        mean: float = float(np.mean(ring))
 680        std: float = float(np.std(ring))
 681
 682        return mean, std
 683
 684    def __multiply_beam(self, times: float | int) -> list[tuple[int, int]]:
 685        """
 686        Generate a list of pixel offsets that represent the beam, scaled by a factor.
 687
 688        Parameters
 689        ----------
 690        times : float or int
 691            Scaling factor for the beam size.
 692
 693        Returns
 694        -------
 695        insides : list of tuple of int
 696            List of (x, y) offsets inside the scaled beam.
 697        """
 698        if times <= 0.0001:
 699            return []
 700        insides = []
 701        beam_bound = max(max([x for x, y in self._offset_in_beam]), max([y for x, y in self._offset_in_beam]))
 702        
 703        # search a square around the origin
 704        for x in range(-int(beam_bound*times), int(beam_bound*times) + 1):
 705            for y in range(-int(beam_bound*times), int(beam_bound*times) + 1):
 706
 707                # check if point would shrink into beam
 708                pos = (int(x/times), int(y/times))
 709                if pos in self._offset_in_beam:
 710                    insides.append((x, y))
 711
 712        return insides
 713
 714    def __get_centre_intensities(self, beam_widths: float | int = 1) -> DataArray1D:
 715        """
 716        Calculate the mean intensity at the center of each velocity channel, averaged over a region the size of the beam.
 717
 718        Parameters
 719        ----------
 720        beam_widths : float or int, optional
 721            Scale factor of beam (default is 1).
 722
 723        Returns
 724        -------
 725        all_densities : DataArray1D
 726            Array of mean intensities for each velocity channel.
 727        """
 728        # centre index
 729        y_idx = np.argmin(self.Y**2)
 730        x_idx = np.argmin(self.X**2)
 731        beam_pixels = self.__multiply_beam(beam_widths)
 732        density_list = []
 733
 734        # compute average density
 735        for v in range(len(self.data)):
 736            total =  0
 737            for x_offset, y_offset in beam_pixels:
 738                if 0 <= y_idx + y_offset < self.data.shape[1] and 0 <= x_idx + x_offset < self.data.shape[2]:
 739                    intensity = self.data[v][y_idx + y_offset][x_idx + x_offset]
 740                if intensity > 0 and not np.isnan(intensity):
 741                    total += intensity
 742            density_list.append(total/len(beam_pixels))
 743
 744        all_densities = np.array(density_list)
 745        return all_densities
 746
 747    def __get_star_and_exp_velocity(self, vel_range: DataArray1D, plot: bool = False, fit_parabola: bool = False, v_sys: float | None= None, v_exp: float | None = None) -> tuple[float, float]:
 748        """
 749        Computes the systemic and expansion velocities, in km/s, by fitting a parabola to the centre intensities.
 750
 751        Parameters
 752        ----------
 753        vel_range : DataArray1D
 754            1-D array of channel velocities in km/s.
 755        plot : bool, optional
 756            If True, plot the iterative process (default is False).
 757        fit_parabola : bool, optional
 758            If True, plot a fitted parabola (default is False).
 759        v_sys : float or None, optional
 760            If provided, use this as the systemic velocity (default is None).
 761        v_exp : float or None, optional
 762            If provided, use this as the expansion velocity (default is None).
 763
 764        Returns
 765        -------
 766        v_sys : float
 767            Systemic velocity in km/s.
 768        v_exp : float
 769            Expansion velocity in km/s.
 770        """
 771        given_v_sys = v_sys
 772        given_v_exp = v_exp
 773        if given_v_exp is not None and given_v_sys is not None:
 774            return given_v_sys, given_v_exp
 775
 776        all_densities = self.__get_centre_intensities(2)
 777        
 778        
 779        v_exp_seen = np.array([])
 780        v_sys_seen = np.array([])
 781
 782        converged = False
 783        densities = all_densities.copy()
 784        i = 1
 785        while not converged:  # loop until computation converges
 786            densities /= np.sum(densities) # normalise
 787            if plot:
 788                plt.plot(vel_range, densities, label = f"iteration {i}")
 789            v_sys = np.dot(densities, vel_range) if given_v_sys is None else given_v_sys
 790
 791            v_exp = np.sqrt(5*np.dot(((vel_range - v_sys)**2), densities)) if given_v_exp is None else given_v_exp
 792            
 793            if any(np.isclose(v_exp_seen, v_exp)) and any(np.isclose(v_sys_seen, v_sys)):
 794                converged = True
 795            v_exp_seen = np.append(v_exp_seen, v_exp)
 796            v_sys_seen = np.append(v_sys_seen, v_sys)
 797
 798            densities = all_densities.copy()
 799            densities[vel_range < (v_sys - v_exp)] = 0
 800            densities[vel_range > (v_sys + v_exp)] = 0
 801            i += 1
 802
 803            if i >= 100:
 804                warnings.warn("Systemic and expansion velocity computation did not converge after 100 iterations.")
 805                break
 806
 807        if plot and fit_parabola:
 808            parabola = (1 -((vel_range - v_sys)**2/v_exp**2))
 809            parabola[parabola < 0] = 0
 810            parabola /= np.sum(parabola)
 811            plt.plot(vel_range, parabola, label = "parabola")
 812
 813        if plot:
 814            plt.title("Determining v_sys, v_exp")
 815            plt.xlabel("Relative velocity (km/s)")
 816            plt.ylabel(f"{self._header['BTYPE']} at centre point ({self._header['BUNIT']})")
 817            plt.legend()
 818            plt.show()
 819        
 820        return v_sys, v_exp
 821
 822    def __get_beta_law(self, plot_intensities = False, plot_velocities = False, plot_beta_law = False) -> tuple[float, float, float]:
 823        """
 824        Fit the beta velocity law to the data and return the dust formation radius, beta parameter, and the radius of maximum intensity change.
 825
 826        Parameters
 827        ----------
 828        plot_intensities : bool, optional
 829            If True, plot intensity vs. radius (default is False).
 830        plot_velocities : bool, optional
 831            If True, plot velocity vs. radius (default is False).
 832        plot_beta_law : bool, optional
 833            If True, plot the fitted beta law (default is False).
 834
 835        Returns
 836        -------
 837        r_dust : float
 838            Dust formation radius.
 839        beta : float
 840            Beta parameter of the velocity law.
 841        radius : float
 842            Radius of maximum intensity change.
 843        """
 844        intensities = self.__get_centre_intensities(0.5)
 845
 846        def v_from_i(i):
 847            return np.sqrt(1 - i/np.max(intensities))*self.v_exp
 848    
 849        centre_idx = np.argmin(self.V**2)
 850        frame = self.data[centre_idx]
 851        
 852        max_radius = np.minimum(np.max(self.X), np.max(self.Y))
 853        precision = min(len(self.X), len(self.Y))//2
 854        radii = np.linspace(0, max_radius, precision)
 855        
 856        X, Y = np.meshgrid(self.X, self.Y, indexing="ij")
 857
 858        deltas = np.array([])
 859        I = np.array([])  # average intensity in each ring
 860        for i in range(len(radii) - 1):
 861            ring = frame[(radii[i]**2 <= X**2 + Y**2)  &  (X**2 + Y**2 <= radii[i+1]**2)]
 862            avg_intensity = np.mean(ring[np.isfinite(ring)]) 
 863            I = np.append(I, avg_intensity)
 864
 865            inner = frame[X**2+Y**2 <= radii[i+1]**2]
 866            outer = frame[~(X**2+Y**2 <= radii[i+1]**2)]
 867            deltas = np.append(deltas,len(inner[(inner >= self.mean+5*self.std)])+len(outer[outer < self.mean+5*self.std]))
 868
 869        V = v_from_i(I)
 870        V[~np.isfinite(V)] = 0
 871        R = radii[1:]
 872        radius_index = np.argmax(deltas)
 873        radius = R[radius_index]
 874
 875        if plot_intensities:
 876            plt.plot(R, I, label = "intensities")
 877            plt.axvline(x = radius, label = "radius", color = "gray", linestyle = "dashed")
 878            plt.legend()
 879            plt.xlabel("Radius (AU)")
 880            plt.ylabel(f"Average {self._header['BTYPE']} ({self._header['BUNIT']})")
 881            plt.title("Average intensity at each radius")
 882            plt.show()
 883            return self.r_dust, self.beta, self.radius
 884
 885        v_fit = V[(V > 0)]
 886        r_fit = R[(V > 0)]
 887
 888        if plot_velocities:
 889            r_dust, beta = self.r_dust, self.beta
 890        else:
 891            params = curve_fit(self.__general_beta_velocity_law, r_fit, v_fit)[0]
 892            r_dust, beta =  params[0], params[1]
 893
 894        if plot_velocities:
 895            plt.plot(R, V, label = "velocities")
 896            if plot_beta_law:
 897                LAW = self.__general_beta_velocity_law(R, r_dust, beta)
 898                plt.plot(R, LAW, label = "beta law")
 899            plt.axvline(x = radius, label = "radius", color = "gray", linestyle = "dashed")
 900            plt.legend()
 901            plt.xlabel("Radius (AU)")
 902            plt.ylabel(f"Velocity (km/s)")
 903            plt.title("Velocity at each radius")
 904            plt.show()
 905        
 906        return r_dust, beta, radius
 907
 908    def __general_beta_velocity_law(self, r: ndarray, r_dust: float, beta: float) -> ndarray:
 909        """
 910        General beta velocity law.
 911
 912        Parameters
 913        ----------
 914        r : array_like
 915            Radius values.
 916        r_dust : float
 917            Dust formation radius.
 918        beta : float
 919            Beta parameter.
 920
 921        Returns
 922        -------
 923        v : array_like
 924            Velocity at each radius.
 925        """
 926        return self.v0+(self.v_exp-self.v0)*((1-r_dust/r)**beta)
 927    
 928    # ---- FILTERING DATA ----
 929
 930    def get_filtered_data(self, stds: float | int = 5) -> DataArray3D:
 931        """
 932        Return a copy of the data, with values below the specified number of standard deviations set to np.nan.
 933
 934        **Parameters**
 935
 936        - `stds` (`float` or `int`, optional): Number of standard deviations to filter by (default is 5).
 937
 938        **Returns**
 939
 940        - `filtered_data` (`DataArray3D`): Filtered data array.
 941        """
 942        filtered_data = self.data.copy()  # creates a deep copy
 943        filtered_data[filtered_data < stds*self.std] = np.nan
 944        return filtered_data
 945    
 946    def beam_filter(self, filtered_data: DataArray3D) -> DataArray3D:
 947        """
 948        Remove clumps of points that fit inside the beam, setting these values to np.nan.
 949
 950        **Parameters**
 951
 952        - `filtered_data` (`DataArray3D`): 3-D array with the same dimensions as the data array.
 953
 954        **Returns**
 955
 956        - `beam_filtered_data` (`DataArray3D`): 3-D array with small clumps of points removed.
 957        """
 958        beam_filtered_data = filtered_data.copy()
 959        for frame in range(len(filtered_data)):
 960            for y_idx in range(len(filtered_data[frame])):
 961                for x_idx in range(len(filtered_data[frame][y_idx])):
 962                    if np.isnan(filtered_data[frame][y_idx][x_idx]):  # ignore empty points
 963                        continue
 964                    
 965                    # filled point that we are searching around
 966                
 967                    erase = True
 968
 969                    for x_offset, y_offset in self._boundary_offset:
 970                        x_check = x_idx + x_offset
 971                        y_check = y_idx + y_offset
 972                        try:
 973                            if not np.isnan(filtered_data[frame][y_check][x_check]):
 974                                erase = False  # there is something present on the border - saved!
 975                                break
 976                        except IndexError:  # in case x_check, y_check are out of range
 977                            pass
 978
 979                    if erase:  # consider ellipse to be an anomaly
 980                        # erase entire inside of ellipse centred at w
 981                        for x_offset, y_offset in self._offset_in_beam:
 982                            x_check = x_idx + x_offset
 983                            y_check = y_idx + y_offset
 984                            try:
 985                                beam_filtered_data[frame][y_check][x_check] = np.nan  # erase
 986                            except IndexError:
 987                                pass
 988            
 989        return beam_filtered_data
 990
 991    # ---- HELPER METHODS FOR PLOTTING ----
 992
 993    def __crop_data(self, x: DataArray1D, y: DataArray1D, v: DataArray1D, data: DataArray3D, crop_leeway: int | float = 0, fill_data: DataArray3D | None = None, special_v = False) -> tuple[DataArray1D, DataArray1D, DataArray1D, DataArray3D]:
 994        """
 995        Crop the data arrays to the smallest region containing all valid (non-NaN) data.
 996
 997        Parameters
 998        ----------
 999        x, y, v : DataArray1D
1000            Small coordinate arrays.
1001        data : DataArray3D
1002            Data array to use for crop.
1003        crop_leeway : int or float, optional
1004            Fractional leeway to expand the crop region (default is 0).
1005        fill_data : DataArray3D or None, optional
1006            Data array to use for filling values (default is None).
1007
1008        Returns
1009        -------
1010        cropped_x, cropped_y, cropped_v : DataArray1D
1011            Cropped coordinate arrays.
1012        cropped_data : DataArray3D
1013            Cropped data array.
1014        """
1015        if fill_data is None:
1016            fill_data = self.data
1017        
1018        v_max, y_max, x_max = data.shape 
1019
1020        # gets indices
1021        v_indices = np.arange(v_max)
1022        y_indices = np.arange(y_max)
1023        x_indices = np.arange(x_max)
1024
1025        # turn into flat arrays
1026        V_IDX, Y_IDX, X_IDX = np.meshgrid(v_indices, y_indices, x_indices, indexing="ij")
1027
1028        # filter out nan data
1029        valid_V = V_IDX[~np.isnan(data)]
1030        valid_X = X_IDX[~np.isnan(data)]
1031        valid_Y = Y_IDX[~np.isnan(data)]
1032
1033        # indices to crop at
1034        v_mid = (np.min(valid_V) + np.max(valid_V))/2
1035        v_lo = max(int(v_mid - (1 + crop_leeway)*(v_mid - np.min(valid_V))), 0)
1036        v_hi = min(int(v_mid + (1 + crop_leeway)*(np.max(valid_V) - v_mid)), len(v) - 1) + 1
1037
1038        if special_v:
1039            assert v[v_lo] <= v[v_hi], "v should be increasing"
1040            if v[v_lo] > -self.v_exp:
1041                # get last time v < -self.v_exp:
1042                offsets = v - (-self.v_exp)
1043                if np.any(offsets < 0):
1044                    offsets[offsets > 0] = -np.inf
1045                    v_lo = np.argmax(offsets) # want greatest neg value
1046                else:
1047                    v_lo = 0
1048            if v[v_hi] < self.v_exp:
1049                offsets = v - (self.v_exp)
1050                if np.any(offsets > 0):
1051                    offsets[offsets < 0] = np.inf
1052                    v_hi = np.argmin(offsets) # want smallest pos value
1053                else:
1054                    v_hi = -1
1055
1056
1057
1058        x_mid = (np.min(valid_X) + np.max(valid_X))/2
1059        x_lo = max(int(x_mid - (1 + crop_leeway)*(x_mid - np.min(valid_X))), 0)
1060        x_hi = min(int(x_mid + (1 + crop_leeway)*(np.max(valid_X) - x_mid)), len(x) - 1) + 1
1061
1062        y_mid = (np.min(valid_Y) + np.max(valid_Y))/2
1063        y_lo = max(int(y_mid - (1 + crop_leeway)*(y_mid - np.min(valid_Y))), 0)
1064        y_hi = min(int(y_mid + (1 + crop_leeway)*(np.max(valid_Y) - y_mid)), len(y) - 1) + 1
1065
1066        # crop x, y, v, data
1067        cropped_data = fill_data[v_lo:v_hi, y_lo:y_hi, x_lo:x_hi]
1068        cropped_v = v[v_lo: v_hi]
1069        cropped_y = y[y_lo: y_hi]
1070        cropped_x = x[x_lo: x_hi]
1071
1072        return cropped_x, cropped_y, cropped_v, cropped_data
1073
1074    def __filter_and_crop(
1075            self, 
1076            filter_stds: float | int | None, 
1077            filter_beam: bool, 
1078            crop_leeway: float,
1079            verbose: bool
1080        ) -> tuple[DataArray3D, DataArray3D, DataArray3D, DataArray3D]:
1081        """
1082        Filter and crop data as specified.
1083
1084        Parameters
1085        ----------
1086        filter_stds : float, int, or None
1087            Number of standard deviations to filter by, or None.
1088        filter_beam : bool
1089            If True, apply beam filtering.
1090        crop_leeway : float
1091            Fractional leeway to expand the crop region.
1092        verbose : bool
1093            If True, print progress.
1094
1095        Returns
1096        -------
1097        X, Y, V : DataArray3D
1098            Meshgrids of velocity space coordinates.
1099        cropped_data : DataArray3D
1100            Cropped and filtered data array.
1101        """
1102        if filter_stds is not None:
1103            if verbose:
1104                print("Filtering data (to crop)...")
1105            data = self.get_filtered_data(filter_stds)
1106            if filter_beam:
1107                if verbose:
1108                    print("Applying beam filter (to crop)...")
1109                data = self.beam_filter(data)
1110        else:
1111            data = self.data
1112
1113        cropped_x, cropped_y, cropped_v, cropped_data = self.__crop_data(self.X,self.Y,self.V, data, crop_leeway=crop_leeway, special_v = True)
1114
1115        if verbose:  
1116            print(f"Data cropped to shape {cropped_data.shape}")
1117
1118
1119        V, Y, X = np.meshgrid(cropped_v, cropped_y, cropped_x, indexing = "ij")
1120        return X, Y, V, cropped_data
1121    
1122    def __fast_interpolation(self, points: tuple, x_bounds: tuple, y_bounds: tuple, z_bounds: tuple, data: DataArray3D, v_func: VelocityModel | None, num_points: int) -> tuple[DataArray1D, DataArray1D, DataArray1D, DataArray3D]:
1123        """
1124        Interpolate the data onto a regular grid in (X, Y, Z) space.
1125
1126        Parameters
1127        ----------
1128        points : tuple
1129            Tuple of (v, y, x) small coordinate arrays.
1130        x_bounds, y_bounds, z_bounds : tuple
1131            Bounds for the new grid in each dimension.
1132        data : DataArray3D
1133            Data array to interpolate, aligned with points.
1134        v_func : VelocityModel or None
1135            Velocity law function.
1136        num_points : int
1137            Number of points in each dimension for the new grid.
1138
1139        Returns
1140        -------
1141        X, Y, Z : DataArray1D
1142            Flattened meshgrids of new coordinates.
1143        interp_data : DataArray3D
1144            Interpolated data array.
1145        """
1146        # points = (V, X, Y)
1147        # new_shape = (data.shape[0] + 2, data.shape[1], data.shape[2])
1148        # new_data = np.full(new_shape, np.nan)
1149        # new_data[1:-1, :, :] = data.copy()  # padded on both sides with nan
1150
1151        # # extend v array to have out-of-range values
1152        # v_array = points[0]
1153        # delta_v = v_array[1] - v_array[0]
1154        # new_v_array = np.zeros(len(v_array) + 2)
1155        # new_v_array[1:-1] = v_array
1156        # new_v_array[0] = new_v_array[1] - delta_v
1157        # new_v_array[-1] = new_v_array[-2] + delta_v
1158        # new_points = (new_v_array, points[1], points[2])
1159
1160        # get points to interpolate at
1161        x_even = np.linspace(x_bounds[0], x_bounds[1], num=num_points)
1162        y_even = np.linspace(y_bounds[0], y_bounds[1], num=num_points)
1163        z_even = np.linspace(z_bounds[0], z_bounds[1], num=num_points)
1164
1165        X, Y, Z = np.meshgrid(x_even, y_even, z_even, indexing="ij")
1166
1167        # build velocities
1168        V = self.__to_velocity_space(X, Y, Z, v_func)
1169
1170        shape = V.shape
1171        # flatten arrays
1172        V = V.ravel()
1173        Y = Y.ravel()
1174        X = X.ravel()
1175        Z = Z.ravel()
1176
1177        bad_idxs = (V < np.min(points[0])) | (V > np.max(points[0])) |  ~np.isfinite(V)
1178        V[bad_idxs] = 0
1179
1180
1181        interp_points = np.column_stack((V, Y, X))
1182        interp_data = interpn(points, data, interp_points)
1183
1184        interp_data[bad_idxs] = np.nan
1185
1186        interp_data = interp_data.reshape(shape)
1187
1188        return X, Y, Z, interp_data
1189
1190    # ---- COORDINATE CHANGE ----
1191
1192    def __radius_from_vel_space_coords(self, x: ndarray, y: ndarray, v: ndarray, v_func: VelocityModel) -> ndarray:
1193        """
1194        Compute the physical radius corresponding to velocity-space coordinates (x, y, v)
1195        using the provided velocity law.
1196
1197        Parameters
1198        ----------
1199        x : ndarray
1200            Array of x-coordinates in AU.
1201        y : ndarray
1202            Array of y-coordinates in AU.
1203        v : ndarray
1204            Array of velocity coordinates in km/s.
1205        v_func : VelocityModel
1206            Callable velocity law, v_func(r), returning velocity in km/s for given radius in AU.
1207
1208        Returns
1209        -------
1210        r : ndarray
1211            Array of radii in AU corresponding to the input (x, y, v) coordinates.
1212
1213        Notes
1214        -----
1215        Solves for r in the equation:
1216            r**2 * (1 - v**2 / v_func(r)**2) = x**2 + y**2
1217        for each (x, y, v) triplet.
1218        """
1219        #
1220        assert x.shape == y.shape, f"Arrays x and y should have the same shape, instead {x.shape = }, {y.shape = }"
1221        assert x.shape == v.shape, f"Arrays x and v should have the same shape, instead {x.shape = }, {v.shape = }"
1222
1223        shape = x.shape
1224
1225        x, y, v = x.ravel(), y.ravel(), v.ravel()
1226
1227        def radius_eqn(r: ndarray, x: ndarray, y: ndarray, v: ndarray):
1228            return r**2 * (1 - v**2/(v_func(r)**2)) - x**2 - y**2
1229        
1230        # initial bracket guess
1231        r_lower = np.sqrt(x**2 + y**2)  # r must be greater than sqrt(x^2 + y^2)
1232
1233        # find bracket
1234        bracket_result = bracket_root(radius_eqn, r_lower, xmin = r_lower, args = (x, y, v))
1235        r_lower, r_upper = bracket_result.bracket
1236        result = find_root(radius_eqn, (r_lower, r_upper), args=(x, y, v))
1237
1238        r: ndarray = result.x
1239        
1240        return r.reshape(shape)
1241
1242    def __to_velocity_space(self, X: ndarray, Y: ndarray, Z: ndarray, v_func: VelocityModel | None) -> ndarray:
1243        """
1244        Convert spatial coordinates to velocity space using the velocity law.
1245
1246        Parameters
1247        ----------
1248        X, Y, Z : array_like
1249            Spatial coordinates.
1250        v_func : VelocityModel or None
1251            Velocity law function.
1252
1253        Returns
1254        -------
1255        V : array_like
1256            Velocity at each spatial coordinate.
1257        """
1258        # build velocities
1259        if v_func is None:
1260            V = self.v_exp*Z/np.sqrt(X**2 + Y**2 + Z**2)
1261        else:
1262            V = v_func(np.sqrt(X**2 + Y**2 + Z**2))*Z/np.sqrt(X**2 + Y**2 + Z**2)
1263
1264        return V
1265
1266    # ---- EXPANSION OVER TIME ----
1267    def __get_new_radius(self, curr_radius: ndarray, years: float | int, v_func: VelocityModel | None) -> ndarray:
1268        """
1269        Compute the new radius after a given time, using the velocity law.
1270
1271        Parameters
1272        ----------
1273        curr_radius : ndarray
1274            Initial radii in AU.
1275        years : float or int
1276            Time interval in years.
1277        v_func : VelocityModel or None
1278            Velocity law function, or None (uses constant velocity).
1279
1280        Returns
1281        -------
1282        new_rad : ndarray
1283            New radii after the specified time interval.
1284
1285        Notes
1286        -----
1287        Setting v_func = None speeds this up significantly, and is a close approximation.
1288        """
1289        t = u.yr.to(u.s, years)  # t is the time in seconds
1290        rad_km = u.au.to(u.km, curr_radius)
1291
1292        if v_func is None: 
1293            new_rad = u.km.to(u.au, rad_km + self.v_exp*t)
1294        else:
1295
1296            def dr_dt(t, r):
1297                # need to convert r back to AU
1298                vals = v_func(u.km.to(u.au, r))
1299                vals[~np.isfinite(vals)] = 0
1300
1301                return vals
1302            
1303            # flatten
1304            shape = rad_km.shape
1305            rad_km = rad_km.ravel()
1306
1307            # remove nans
1308            nan_idxs = ~(np.isfinite(rad_km) & np.isfinite(dr_dt(0, rad_km)) & (rad_km > 0))
1309            valid_rad = np.min(rad_km[~nan_idxs])
1310            rad_km[nan_idxs] = valid_rad
1311
1312            # solve
1313            solution = integrate.solve_ivp(dr_dt,(0,t), rad_km, vectorized=True)
1314            new_rad = u.km.to(u.au, solution.y[:,-1]) # r is evaluated at different time points
1315
1316            # include nans and unflatten
1317            new_rad[nan_idxs] = np.nan
1318            new_rad = new_rad.reshape(shape)
1319
1320        return new_rad
1321
1322    def __time_expansion_transform(self, x: DataArray1D, y: DataArray1D, v: DataArray1D, data: DataArray3D, years: float | int, v_func: VelocityModel | None, crop: bool = True) -> tuple[DataArray1D, DataArray1D, DataArray1D, DataArray3D, tuple]:
1323        """
1324        Transform coordinates and data to account for expansion over time.
1325
1326        Parameters
1327        ----------
1328        x, y, v : DataArray1D
1329            Small coordinate arrays.
1330        data : DataArray3D
1331            Data array.
1332        years : float or int
1333            Time interval in years.
1334        v_func : VelocityModel or None
1335            Velocity law function.
1336        crop : bool, optional
1337            If True, crop to finite values (default is True).
1338
1339        Returns
1340        -------
1341        new_X, new_Y, new_V : DataArray1D
1342            Transformed coordinate arrays (flattened meshgrid).
1343        data : DataArray3D
1344            Data array (possibly cropped). Remains unchanged if crop is False.
1345        points : tuple
1346            Tuple of original (v, y, x) arrays, cropped if crop is True.
1347        """
1348        if crop:
1349            x, y, v, data = self.__crop_data(x, y, v, data, fill_data=data)
1350        
1351        V, X, Y = np.meshgrid(v, y, x, indexing="ij")
1352
1353
1354        if v_func is None:
1355            def v_func_mod(r):
1356                return self.v_exp
1357        else:
1358            v_func_mod = v_func
1359
1360        # get R0 and R1 arrays
1361        R0 = self.__radius_from_vel_space_coords(X, Y, V, v_func_mod)
1362        R1 = self.__get_new_radius(R0, years, v_func)
1363
1364        R1[R1 < 0] = np.inf  # deal with neg radii
1365
1366        new_X = (X*R1/R0).ravel()
1367        new_Y = (Y*R1/R0).ravel()
1368        if v_func is None:
1369            new_V = V.ravel()
1370        else:
1371            new_V = (V*v_func(R1)/v_func(R0)).ravel()
1372
1373
1374        if crop:
1375            finite_idxs = np.isfinite(new_X) & np.isfinite(new_Y) & np.isfinite(new_V) & np.isfinite(data).ravel()
1376            new_X = new_X[finite_idxs]
1377            new_Y = new_Y[finite_idxs]
1378            new_V = new_V[finite_idxs]
1379
1380        return new_X, new_Y, new_V, data, (v, y, x)
1381
1382    def get_expansion(self, years: float | int, v_func: VelocityModel | None = None, remove_centre: float | int | None = 2, new_shape: tuple = (50, 250, 250), verbose: bool = False) -> CondensedData:
1383        """
1384        Compute the expanded data cube after a given time interval.
1385
1386        **Parameters**
1387
1388        - `years` (`float` or `int`): Time interval in years.
1389        - `v_func` (`VelocityModel` or `None`): Velocity law function or None (default is None, which uses constant expansion velocity).
1390        - `remove_centre` (`float` or `int` or `None`, optional): If not None, remove all points within this many beam widths of the centre (default is 2).
1391        - `new_shape` (`tuple`, optional): Shape of the output grid (default is (50, 250, 250)).
1392        - `verbose` (`bool`, optional): If True, print progress (default is False).
1393
1394        **Returns**
1395
1396        - `info` (`CondensedData`): CondensedData object containing the expanded data and metadata.
1397        """     
1398        
1399        use_data = self.data.copy()
1400        if remove_centre is not None:
1401            if verbose:
1402                print("Removing centre...")
1403            # get centre coords
1404            y_idx = np.argmin(self.Y**2)
1405            x_idx = np.argmin(self.X**2)
1406            
1407            # proportion of the radius that the beam takes up
1408            beam_rad_au = (((self.beam_maj + self.beam_min)/2)*np.pi/180)*self.distance_to_star
1409            beam_prop = beam_rad_au/self.radius
1410            v_axis_removal = remove_centre*beam_prop*self.v_exp
1411            v_idxs = np.arange(len(self.V))[np.abs(self.V) < v_axis_removal]
1412            relevant_vs = self.V[np.abs(self.V) < v_axis_removal]
1413            proportions = remove_centre * np.sqrt(1 - (relevant_vs/v_axis_removal)**2)
1414            
1415            # removing centre
1416            for i in range(len(v_idxs)):
1417                v_idx = v_idxs[i]
1418                prop = proportions[i]
1419                beam = self.__multiply_beam(prop)
1420                for x_offset, y_offset in beam:
1421                    use_data[v_idx][y_idx + y_offset][x_idx + x_offset] = np.nan
1422        
1423        if verbose:
1424            print("Transforming coordinates...")
1425        X, Y, V, data, points = self.__time_expansion_transform(self.X, self.Y, self.V, use_data, years, v_func)
1426        v_num, y_num, x_num = new_shape
1427
1428        if verbose:
1429            print("Generating grid for new object...")
1430        gridv, gridy, gridx = np.mgrid[
1431            np.min(V):np.max(V):v_num*1j, 
1432            np.min(Y):np.max(Y):y_num*1j,
1433            np.min(X):np.max(X):x_num*1j
1434        ] 
1435
1436
1437        # get preimage of grid
1438        small_gridx = gridx[0, 0, :]
1439        small_gridy = gridy[0, :, 0]
1440        small_gridv = gridv[:, 0, 0]
1441
1442        # go backwards with negative years
1443        if verbose:
1444            print("Shrinking grid to original data bounds...")
1445        prev_X, prev_Y, prev_V, _, _ = self.__time_expansion_transform(small_gridx, small_gridy, small_gridv, use_data,  -years, v_func, crop = False)
1446        
1447        
1448        bad_idxs = (prev_V < np.min(points[0])) | (prev_V > np.max(points[0])) | \
1449                (prev_Y < np.min(points[1])) | (prev_Y > np.max(points[1])) | \
1450                (prev_X < np.min(points[2])) | (prev_X > np.max(points[2])) | \
1451                np.isnan(prev_V) | np.isnan(prev_X) | np.isnan(prev_Y)
1452        
1453        prev_V[bad_idxs] = 0
1454        prev_X[bad_idxs] = 0
1455        prev_Y[bad_idxs] = 0
1456
1457        # interpolate regular data at these points
1458        if verbose:
1459            print("Interpolating...")
1460        interp_points = np.column_stack((prev_V, prev_Y, prev_X))
1461        interp_data = interpn(points, data, interp_points)
1462        interp_data[bad_idxs] = np.nan
1463        new_data = interp_data.reshape(gridx.shape)
1464
1465        non_nans = len(new_data[np.isfinite(new_data)])
1466        if verbose:
1467            print(f"{non_nans} non-nan values remaining out of {np.size(new_data)}")
1468
1469        info = CondensedData(
1470            small_gridx, 
1471            small_gridy, 
1472            small_gridv, 
1473            new_data, 
1474            self.star_name, 
1475            self.distance_to_star, 
1476            self.v_exp, 
1477            self.v_sys,
1478            self.beta,
1479            self.r_dust,
1480            self.beam_maj, 
1481            self.beam_min, 
1482            self.beam_angle,
1483            self._header,
1484            self.mean, 
1485            self.std
1486        )
1487
1488        if verbose:
1489            print("Time expansion process complete!")
1490        return info
1491
1492
1493    # ---- PLOTTING ----
1494
1495    def plot_velocity_vs_intensity(self, fit_parabola: bool = True) -> None:
1496        """
1497        Plot velocity vs. intensity at the center of the xy plane.
1498
1499        **Parameters**
1500
1501        - `fit_parabola` (`bool`, optional): If True, fit and plot a parabola (default is True).
1502
1503        **Notes**
1504
1505        The well-fittedness of the parabola can help you visually determine 
1506        the accuracy of the calculated systemic and expansion velocity.
1507        """
1508        self.__get_star_and_exp_velocity(self.V, plot=True, fit_parabola=fit_parabola)
1509
1510    def plot_radius_vs_intensity(self) -> None:
1511        """
1512        Plot radius vs. intensity at the center of the xy plane.
1513        """
1514        self.__get_beta_law(plot_intensities=True)
1515
1516    def plot_radius_vs_velocity(self, fit_beta_law: bool = True) -> None:
1517        """
1518        Plot radius vs. velocity at the center of the xy plane.
1519
1520        **Parameters**
1521
1522        - `fit_beta_law` (`bool`, optional): If True, fit and plot the beta law (default is True).
1523
1524        **Notes**
1525
1526        The well-fittedness of the beta law curve can help you visually determine 
1527        the accuracy of the calculated beta law parameters.
1528        """
1529        self.__get_beta_law(plot_velocities=True, plot_beta_law=fit_beta_law)
1530
1531    def plot_channel_maps(self, 
1532        filter_stds: float | int | None = None, 
1533        filter_beam: bool = False, 
1534        dimensions: None | tuple[int,int] = None, 
1535        start: int = 0, 
1536        end: None | int = None, 
1537        include_beam: bool = True, 
1538        text_pos: None | tuple[float,float] = None, 
1539        beam_pos: None | tuple[float,float] = None, 
1540        title: str | None = None, 
1541        cmap: str = "viridis"
1542    ) -> None:
1543        """
1544        Plot the data cube as a set of 2D channel maps.
1545
1546        **Parameters**
1547
1548        - `filter_stds` (`float` or `int` or `None`, optional): Number of standard deviations to filter by (default is None).
1549        - `filter_beam` (`bool`, optional): If True, apply beam filtering (default is False).
1550        - `dimensions` (`tuple` of `int` or `None`, optional): Grid dimensions (nrows, ncols) for subplots (default is None).
1551        - `start` (`int`, optional): Starting velocity channel index (default is 0).
1552        - `end` (`int` or `None`, optional): Ending velocity channel index (default is None).
1553        - `include_beam` (`bool`, optional): If True, plot the beam ellipse (default is True).
1554        - `text_pos` (`tuple` of `float` or `None`, optional): Position for velocity text annotation (default is None).
1555        - `beam_pos` (`tuple` of `float` or `None`, optional): Position for beam ellipse (default is None).
1556        - `title` (`str` or `None`, optional): Plot title (default is None, which uses the name of the star).
1557        - `cmap` (`str`, optional): Colormap for the plot (default is "viridis"). See https://matplotlib.org/stable/users/explain/colors/colormaps.html
1558        """
1559
1560        if end is None:
1561            end = self.data.shape[0]-1
1562        else:
1563            if end < start: #invalid so set to the end of the data
1564                end = self.data.shape[0]-1 
1565            if end >= self.data.shape[0]: #invalid so set to the end of the data
1566                end = self.data.shape[0]-1 
1567
1568        if filter_stds is not None:
1569            data = self.get_filtered_data(filter_stds)
1570            if filter_beam:
1571                data = self.beam_filter(data)
1572        else:
1573            data = self.data
1574        
1575        if start >= data.shape[0]: #invalid so set to the start of the data
1576            start = 0
1577
1578        if title is None:
1579            title = self.star_name + " channel maps"
1580
1581        if dimensions is not None:
1582            nrows, ncols = dimensions
1583        else:
1584            nrows = int(np.ceil(np.sqrt(end-start)))
1585            ncols = int(np.ceil((end-start+1)/nrows))
1586
1587        fig, axes = plt.subplots(nrows,ncols)
1588        viridis = colormaps[cmap]
1589        fig.suptitle(title)
1590        fig.supxlabel("Right Ascension (Arcseconds)")
1591        fig.supylabel("Declination (Arcseconds)")
1592
1593        i = start
1594        extents = u.radian.to(u.arcsec,np.array([np.min(self.X),np.max(self.X),np.min(self.Y),np.max(self.Y)])/self.distance_to_star)
1595        
1596        done = False
1597        for ax in axes.flat:
1598            if not done:
1599                im = ax.imshow(data[i],vmin = self.mean, vmax = np.max(data[~np.isnan(data)]), extent=extents, cmap = cmap)
1600
1601                ax.set_facecolor(viridis(0))
1602                ax.set_aspect("equal")
1603
1604                if text_pos is None:
1605                    ax.text(extents[0]*5/6,extents[3]*1/2,f"{self.V[i]:.1f} km/s",size="x-small",c="white")
1606                else:
1607                    ax.text(text_pos[0],text_pos[1],f"{self.V[i]:.1f} km/s",size="x-small",c="white")
1608
1609                if include_beam:
1610                    bmaj = u.deg.to(u.arcsec,self.beam_maj)
1611                    bmin = u.deg.to(u.arcsec,self.beam_min)
1612                    bpa = self.beam_angle
1613
1614                    if beam_pos is None:
1615                        ellipse_artist = Ellipse(xy=(extents[0]*1/2,extents[2]*1/2),width=bmaj,height=bmin,angle=bpa,color = "white")
1616                    else:
1617                        ellipse_artist = Ellipse(xy=(beam_pos[0],beam_pos[1]),width=bmaj,height=bmin,angle=bpa, color = "white")
1618                    ax.add_artist(ellipse_artist)
1619                i += 1
1620                if (i >= data.shape[0]) or (i > end): done = True
1621            else:
1622                ax.axis("off")
1623        cbar = fig.colorbar(im, ax=axes.ravel().tolist())
1624        cbar.set_label("Flux Density (Jy/Beam)")
1625        plt.show()
1626
1627    def create_mask(self, filter_stds: float | int | None = None, savefile: str | None = None, initial_crop: tuple | None = None):
1628        """
1629        Launch an interactive mask creator for the data cube.
1630
1631        **Parameters**
1632
1633        - `filter_stds` (`float` or `int` or `None`, optional): Number of standard deviations to initially filter by (default is None).
1634        - `savefile` (`str` or `None`, optional): Filename to save the selected mask (default is None). Should end in .npy.
1635        - `initial_crop` (`tuple` or `None`, optional): Initial crop region as (v_lo, v_hi, y_lo, y_hi, x_lo, x_hi), using channel indices (default is None).
1636        """
1637        if filter_stds is not None:
1638            data = self.get_filtered_data(filter_stds)
1639        else:
1640            data = self.data
1641
1642        if initial_crop is not None:
1643            v_lo, v_hi, y_lo, y_hi, x_lo, x_hi = initial_crop
1644            new_data = data[v_lo:v_hi, y_lo:y_hi, x_lo:x_hi]
1645            selector = _PointsSelector(new_data)
1646            plt.show()
1647            mask = np.full(data.shape, np.nan)
1648            mask[v_lo:v_hi, y_lo:y_hi, x_lo:x_hi] = selector.mask
1649
1650        else:
1651            selector = _PointsSelector(data)
1652            plt.show()
1653        
1654            # mask is complete now
1655            mask = selector.mask
1656        
1657        # save mask
1658        if savefile is not None:
1659            np.save(savefile, mask)
1660
1661    def plot_3D(
1662        self, 
1663        filter_stds: float | int | None = None, 
1664        filter_beam: bool = False, 
1665        z_cutoff: float = 1,
1666        crop_leeway: float = 0,
1667        num_points: int = 50,
1668        num_surfaces: int = 50,
1669        opacity: float | int = 0.5,
1670        opacityscale: list[list[float]] = [[0, 0], [1, 1]],
1671        colorscale: str = "Reds",
1672        v_func: VelocityModel | None = None, 
1673        verbose: bool = False,
1674        title: str | None = None,
1675        folder: str | None = None,
1676        num_angles: int = 24,
1677        camera_dist: float | int = 2
1678    ) -> None:
1679        """
1680        Plot a 3D volume rendering of the data cube using Plotly.
1681
1682        **Parameters**
1683
1684        - `filter_stds` (`float` or `int` or `None`, optional): Number of standard deviations to filter by (default is None).
1685        - `filter_beam` (`bool`, optional): If True, apply beam filtering (default is False).
1686        - `z_cutoff` (`float` or `None`, optional): Cutoff for z values as a proportion of the largest x, y values (default is 1).
1687        - `crop_leeway` (`int` or `float`, optional): Fractional leeway to expand the crop region (default is 0).
1688        - `num_points` (`int`, optional): Number of points in each dimension for the grid (default is 50).
1689        - `num_surfaces` (`int`, optional): Number of surfaces to draw when rendering the plot (default is 50).
1690        - `opacity` (`float` or `int`, optional): Opacity of the volume rendering (default is 0.5).
1691        - `opacityscale` (`list` of `list` of `float`, optional): Opacity scale, see https://plotly.com/python-api-reference/generated/plotly.graph_objects.Volume.html (default is [[0, 0], [1, 1]]).
1692        - `colorscale` (`str`, optional): Colormap for the plot, see https://plotly.com/python/builtin-colorscales/ (default is "Reds").
1693        - `v_func` (`VelocityModel` or `None`, optional): Velocity law function (default is None, which uses constant expansion velocity).
1694        - `verbose` (`bool`, optional): If True, print progress (default is False).
1695        - `title` (`str` or `None`, optional): Plot title (default is None, which uses the star name).
1696        - `folder` (`str` or `None`, optional): If provided, generates successive frames and saves as png files to this folder (default is None).
1697        - `num_angles` (`int`, optional): Number of angles for saving images (default is 24). Greater values give smoother animation.
1698        - `camera_dist` (`float` or `int`, optional): Camera radius to use if generating frames (default is 2).
1699        """
1700        if verbose:
1701            print("Initial filter and crop...")
1702        X, Y, V, data = self.__filter_and_crop(filter_stds, filter_beam, crop_leeway, verbose)
1703        
1704        if title is None:
1705            title = self.star_name
1706
1707        # get small 1d arrays
1708        x_small = X[0, 0, :]
1709        y_small = Y[0, :, 0]
1710        v_small = V[:, 0, 0]
1711
1712        # interpolate to grid
1713        if verbose:
1714            print("Interpolating to grid...")
1715
1716        z_bound = z_cutoff * np.maximum(np.max(np.abs(x_small)), np.max(np.abs(y_small)))
1717        
1718        x_lo, x_hi, y_lo, y_hi, z_lo, z_hi = np.min(x_small), np.max(x_small), np.min(y_small), np.max(y_small), -z_bound, z_bound
1719        gridx, gridy, gridz, out = self.__fast_interpolation((v_small, y_small, x_small), (x_lo, x_hi), (y_lo, y_hi), (z_lo, z_hi), data, v_func, num_points)
1720
1721
1722        # filter by standard deviation
1723        if filter_stds is not None:
1724            out[out < filter_stds*self.std] = np.nan
1725
1726        out[np.isnan(out)] = 0  # Plotly cant deal with nans
1727
1728        if verbose:
1729            print(f"Found {len(out[out > 0])} non-nan points.")
1730
1731        out = out.ravel()
1732        min_value = np.min(out[np.isfinite(out) & (out > 0)])
1733
1734        if verbose:
1735            print("Plotting figure...")
1736
1737        
1738        fig = go.Figure(
1739            data=go.Volume(
1740                x=gridx,
1741                y=gridy,
1742                z=gridz,
1743                value=out,
1744                isomin = min_value,
1745                colorscale= colorscale,
1746                colorbar = dict(title = "Flux Density (Jy/beam)"),
1747                opacityscale=opacityscale,
1748                opacity=opacity, # needs to be small to see through all surfaces
1749                surface_count=num_surfaces, # needs to be a large number for good volume rendering
1750                **kwargs
1751            ),
1752            layout=go.Layout(
1753                title = {
1754                    "text":title,
1755                    "x":0.5,
1756                    "y":0.95,
1757                    "xanchor":"center",
1758                    "font":{"size":24}
1759                },
1760                scene = dict(
1761                      xaxis=dict(
1762                          title=dict(
1763                              text='X (AU)'
1764                          )
1765                      ),
1766                      yaxis=dict(
1767                          title=dict(
1768                              text='Y (AU)'
1769                          )
1770                      ),
1771                      zaxis=dict(
1772                          title=dict(
1773                              text='Z (AU)'
1774                          )
1775                      ),
1776                      aspectmode = "cube"
1777                    ),
1778            )
1779        )
1780
1781
1782
1783        if folder is not None:
1784            if verbose:
1785                print("Generating frames...")
1786            angles = np.linspace(0,360,num_angles)
1787            for a in angles:
1788                b = a*np.pi/180
1789                eye = dict(x=camera_dist*np.cos(b),y=camera_dist*np.sin(b),z=1.25)
1790                fig.update_layout(scene_camera_eye = eye)
1791                if folder:
1792                    fig.write_image(f"{folder}/angle{int(a)}.png")
1793                else:
1794                    fig.write_image(f"{int(a)}.png")
1795                if verbose:
1796                    print(f"Generating frames: {a/360}% complete.")
1797        else:
1798            fig.show()

Class for manipulating and analyzing astronomical data cubes of radially expanding circumstellar envelopes.

The StarData class provides a comprehensive interface for loading, processing, analyzing, and visualizing 3D data cubes (typically from FITS files) representing the emission from expanding circumstellar shells. It supports both direct loading from FITS files and from preprocessed CondensedData objects, and manages all relevant metadata and derived quantities.

Key Features

  • Data Loading: Supports initialization from FITS files or CondensedData objects.
  • Metadata Management: Stores and exposes all relevant observational and physical parameters, including beam properties, systemic and expansion velocities, beta velocity law parameters, and FITS header information.
  • Noise Estimation: Automatically computes mean and standard deviation of background noise for filtering.
  • Filtering: Provides methods to filter data by significance (standard deviations) and to remove small clumps of points that fit within the beam (beam filtering).
  • Coordinate Transformations: Handles conversion between velocity space and spatial (cartesian) coordinates, supporting both constant velocity models and general velocity laws.
  • Time Evolution: Can compute the expansion of the envelope over time, transforming the data cube accordingly.
  • Visualization: Includes a variety of plotting methods:
    • Channel maps (2D slices through velocity channels)
    • 3D volume rendering (with Plotly)
    • Diagnostic plots for velocity/intensity and radius/velocity relationships
  • Interactive Masking: Supports interactive creation of masks for manual data cleaning.

Attributes

  • data (DataArray3D): The main data cube (v, y, x) containing intensity values.
  • X (DataArray1D): 1D array of x-coordinates (offsets) in AU.
  • Y (DataArray1D): 1D array of y-coordinates (offsets) in AU.
  • V (DataArray1D): 1D array of velocity offsets in km/s.
  • distance_to_star (float): Distance to the star in AU.
  • beam_maj (float): Major axis of the beam in degrees.
  • beam_min (float): Minor axis of the beam in degrees.
  • beam_angle (float): Beam position angle in degrees.
  • mean (float): Mean intensity of the background noise.
  • std (float): Standard deviation of the background noise.
  • v_sys (float): Systemic velocity in km/s.
  • v_exp (float): Expansion velocity in km/s.
  • beta (float): Beta parameter of the velocity law.
  • r_dust (float): Dust formation radius in AU.
  • radius (float): Characteristic radius (e.g., maximum intensity change) in AU.
  • beta_velocity_law (VelocityModel): Callable implementing the beta velocity law with the current object's parameters.
  • star_name (str): Name of the star or object.

Methods

  • export() -> CondensedData: Export all defining attributes to a CondensedData object.
  • get_filtered_data(stds=5): Return a copy of the data, with values below the specified number of standard deviations set to np.nan.
  • beam_filter(filtered_data): Remove clumps of points that fit inside the beam, setting these values to np.nan.
  • get_expansion(years, v_func, ...): Compute the expanded data cube after a given time interval.
  • plot_channel_maps(...): Plot the data cube as a set of 2D channel maps.
  • plot_3D(...): Plot a 3D volume rendering of the data cube using Plotly.
  • plot_velocity_vs_intensity(...): Plot velocity vs. intensity at the center of the xy plane.
  • plot_radius_vs_intensity(): Plot radius vs. intensity at the center of the xy plane.
  • plot_radius_vs_velocity(...): Plot radius vs. velocity at the center of the xy plane.
  • create_mask(...): Launch an interactive mask creator for the data cube.
StarData( info_source: str | AGB-Star-Deprojection.CondensedData, distance_to_star: float | None = None, rest_frequency: float | None = None, maskfile: str | None = None, beta_law_params: tuple[float, float] | None = None, v_exp: float | None = None, v_sys: float | None = None, absolute_star_pos: tuple[float, float] | None = None)
153    def __init__(
154        self,
155        info_source: str | CondensedData,
156        distance_to_star: float | None = None,
157        rest_frequency: float | None = None,
158        maskfile: str | None = None,
159        beta_law_params: tuple[float, float] | None = None,
160        v_exp: float | None = None,
161        v_sys: float | None = None,
162        absolute_star_pos: tuple[float, float] | None = None
163    ) -> None:
164        """
165        Initialize a StarData object by reading data from a FITS file or a CondensedData object.
166
167        **Parameters**
168
169        - `info_source` (`str` or `CondensedData`): Path to a FITS file or a CondensedData object containing preprocessed data.
170        - `distance_to_star` (`float` or `None`, optional): Distance to the star in AU (required if info_source is a FITS file).
171        - `rest_frequency` (`float` or `None`, optional): Rest frequency in Hz (required if info_source is a FITS file).
172        - `maskfile` (`str` or `None`, optional): Path to a .npy file containing a mask to apply to the data.
173        - `beta_law_params` (`tuple` of `float` or `None`, optional): (r_dust (AU), beta) parameters for the beta velocity law. If None, will be fit from data.
174        - `v_exp` (`float` or `None`, optional): Expansion velocity in km/s. If None, will be fit from data.
175        - `v_sys` (`float` or `None`, optional): Systemic velocity in km/s. If None, will be fit from data.
176        - `absolute_star_pos` (`tuple` of `float` or `None`, optional): Absolute (RA, Dec) position of the star in degrees. If None, taken to be the centre of the image.
177
178        **Raises**
179
180        - `ValueError`: If required parameters are missing when reading from a FITS file.
181        - `FITSHeaderError`: If any attribute in the FITS file header is an incorrect type.
182        """
183        if isinstance(info_source, str):
184            if distance_to_star is None or rest_frequency is None:
185                raise ValueError("Distance to star and rest frequency required when reading from FITS file.")
186            self.__load_from_fits_file(info_source, distance_to_star, rest_frequency, absolute_star_pos, v_sys = v_sys, v_exp = v_exp)
187            if beta_law_params is None:
188                self._r_dust, self._beta, self._radius = self.__get_beta_law()
189            else:
190                self._r_dust, self._beta = beta_law_params
191
192        else:
193            # load from CondensedData
194            self._X = info_source.x_offsets
195            self._Y = info_source.y_offsets
196            self._V = info_source.v_offsets
197            self._data = info_source.data
198            self.star_name = info_source.star_name
199            self._distance_to_star = info_source.distance_to_star
200            self._v_exp = info_source.v_exp if v_exp is None else v_exp
201            self._v_sys = info_source.v_sys if v_sys is None else v_sys
202            self._r_dust = info_source.r_dust if beta_law_params is None else beta_law_params[0]
203            self._beta = info_source.beta if beta_law_params is None else beta_law_params[1]
204            self._beam_maj = info_source.beam_maj
205            self._beam_min = info_source.beam_min
206            self._beam_angle = info_source.beam_angle
207            self._header = info_source.header
208            self.__process_beam()
209
210            # compute mean and standard deviation
211            if info_source.mean is None or info_source.std is None:
212                self._mean, self._std = self.__mean_and_std()
213            else:
214                self._mean = info_source.mean
215                self._std = info_source.std
216
217
218        if maskfile is not None:
219            # mask data (permanent)
220            mask = np.load(maskfile)
221            self._data = self._data * mask

Initialize a StarData object by reading data from a FITS file or a CondensedData object.

Parameters

  • info_source (str or CondensedData): Path to a FITS file or a CondensedData object containing preprocessed data.
  • distance_to_star (float or None, optional): Distance to the star in AU (required if info_source is a FITS file).
  • rest_frequency (float or None, optional): Rest frequency in Hz (required if info_source is a FITS file).
  • maskfile (str or None, optional): Path to a .npy file containing a mask to apply to the data.
  • beta_law_params (tuple of float or None, optional): (r_dust (AU), beta) parameters for the beta velocity law. If None, will be fit from data.
  • v_exp (float or None, optional): Expansion velocity in km/s. If None, will be fit from data.
  • v_sys (float or None, optional): Systemic velocity in km/s. If None, will be fit from data.
  • absolute_star_pos (tuple of float or None, optional): Absolute (RA, Dec) position of the star in degrees. If None, taken to be the centre of the image.

Raises

  • ValueError: If required parameters are missing when reading from a FITS file.
  • FITSHeaderError: If any attribute in the FITS file header is an incorrect type.
v0 = 3
data: numpy.ndarray[tuple[int, int, int], numpy.dtype[numpy.float64]]
225    @property
226    def data(self) -> DataArray3D:
227        """
228        DataArray3D 
229        
230        Stores the intensity of light at each data point.
231        
232        Dimensions: k x m x n, where k is the number of frequency channels,
233        m is the number of declination channels, and n is the number of right ascension channels.
234        """
235        return self._data

DataArray3D

Stores the intensity of light at each data point.

Dimensions: k x m x n, where k is the number of frequency channels, m is the number of declination channels, and n is the number of right ascension channels.

X: numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]]
237    @property
238    def X(self) -> DataArray1D:
239        """
240        DataArray1D 
241        
242        Stores the x-coordinates relative to the centre in AU.
243        Obtained from right ascension coordinates.
244        """
245        return self._X

DataArray1D

Stores the x-coordinates relative to the centre in AU. Obtained from right ascension coordinates.

Y: numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]]
247    @property
248    def Y(self) -> DataArray1D:
249        """
250        DataArray1D
251         
252        Stores the y-coordinates relative to the centre in AU.
253        Obtained from declination coordinates.
254        """
255        return self._Y

DataArray1D

Stores the y-coordinates relative to the centre in AU. Obtained from declination coordinates.

V: numpy.ndarray[tuple[int], numpy.dtype[numpy.float64]]
257    @property
258    def V(self) -> DataArray1D:
259        """
260        DataArray1D
261        
262        Stores the velocity offsets relative to the star velocity in km/s.
263        Obtained from frequency channels.
264        """
265        return self._V

DataArray1D

Stores the velocity offsets relative to the star velocity in km/s. Obtained from frequency channels.

distance_to_star: float
267    @property
268    def distance_to_star(self) -> float:
269        """
270        Distance to star in AU.
271        """
272        return self._distance_to_star

Distance to star in AU.

B: numpy.ndarray[tuple[typing.Literal[2], typing.Literal[2]], numpy.dtype[numpy.float64]]
274    @property
275    def B(self) -> Matrix2x2:
276        """
277        Matrix2x2
278
279        Ellipse matrix of beam. For 1x2 vectors v, w with coordinates (ra, dec) in degrees,
280        if (v-w)^T B (v-w) < 1, then v is within the beam centred at w.
281        """
282        return self._B

Matrix2x2

Ellipse matrix of beam. For 1x2 vectors v, w with coordinates (ra, dec) in degrees, if (v-w)^T B (v-w) < 1, then v is within the beam centred at w.

beam_maj: float
284    @property
285    def beam_maj(self) -> float:
286        """
287        Major axis of the beam in degrees.
288        """
289        return self._beam_maj

Major axis of the beam in degrees.

beam_min: float
291    @property
292    def beam_min(self) -> float:
293        """
294        Minor axis of the beam in degrees.
295        """
296        return self._beam_min

Minor axis of the beam in degrees.

beam_angle: float
298    @property
299    def beam_angle(self) -> float:
300        """
301        Beam position angle in degrees.
302        """
303        return self._beam_angle

Beam position angle in degrees.

mean: float
305    @property
306    def mean(self) -> float:
307        """
308        The mean intensity of the light, taken over coordinates away from the centre.
309        """
310        return self._mean

The mean intensity of the light, taken over coordinates away from the centre.

std: float
312    @property
313    def std(self) -> float:
314        """
315        The standard deviation of the intensity of the light, taken over coordinates away from the centre.
316        """
317        return self._std

The standard deviation of the intensity of the light, taken over coordinates away from the centre.

v_sys: float
319    @property
320    def v_sys(self) -> float:
321        """
322        The systemic velocity of the star in km/s.
323        """
324        return self._v_sys

The systemic velocity of the star in km/s.

v_exp: float
326    @property
327    def v_exp(self) -> float:
328        """
329        The maximum radial expansion speed in km/s.
330        """
331        return self._v_exp

The maximum radial expansion speed in km/s.

beta: float
333    @property
334    def beta(self) -> float:
335        """
336        Beta parameter of the velocity law.
337        """
338        return self._beta

Beta parameter of the velocity law.

r_dust: float
340    @property
341    def r_dust(self) -> float:
342        """
343        Dust formation radius in AU.
344        """
345        return self._r_dust

Dust formation radius in AU.

radius: float
347    @property
348    def radius(self) -> float:
349        """
350        Characteristic radius (e.g., maximum intensity change).
351        """
352        return self._radius

Characteristic radius (e.g., maximum intensity change).

beta_velocity_law: Callable[[numpy.ndarray], numpy.ndarray]
354    @property
355    def beta_velocity_law(self) -> VelocityModel:
356        """
357        VelocityModel
358
359        Returns a callable implementing the beta velocity law with the current object's parameters.
360        """
361        def law(r):
362            return self.__general_beta_velocity_law(r, self.r_dust, self.beta)
363        return law

VelocityModel

Returns a callable implementing the beta velocity law with the current object's parameters.

def export(self) -> AGB-Star-Deprojection.CondensedData:
367    def export(self) -> CondensedData:
368        """
369        Export all defining attributes to a CondensedData object.
370        """
371        return CondensedData(
372            self.X,
373            self.Y,
374            self.V,
375            self.data, 
376            self.star_name,
377            self.distance_to_star,
378            self.v_exp,
379            self.v_sys,
380            self.beta,
381            self.r_dust,
382            self.beam_maj,
383            self.beam_min,
384            self.beam_angle,
385            self._header,
386            self.mean,
387            self.std 
388        )

Export all defining attributes to a CondensedData object.

def get_filtered_data( self, stds: float | int = 5) -> numpy.ndarray[tuple[int, int, int], numpy.dtype[numpy.float64]]:
930    def get_filtered_data(self, stds: float | int = 5) -> DataArray3D:
931        """
932        Return a copy of the data, with values below the specified number of standard deviations set to np.nan.
933
934        **Parameters**
935
936        - `stds` (`float` or `int`, optional): Number of standard deviations to filter by (default is 5).
937
938        **Returns**
939
940        - `filtered_data` (`DataArray3D`): Filtered data array.
941        """
942        filtered_data = self.data.copy()  # creates a deep copy
943        filtered_data[filtered_data < stds*self.std] = np.nan
944        return filtered_data

Return a copy of the data, with values below the specified number of standard deviations set to np.nan.

Parameters

  • stds (float or int, optional): Number of standard deviations to filter by (default is 5).

Returns

def beam_filter( self, filtered_data: numpy.ndarray[tuple[int, int, int], numpy.dtype[numpy.float64]]) -> numpy.ndarray[tuple[int, int, int], numpy.dtype[numpy.float64]]:
946    def beam_filter(self, filtered_data: DataArray3D) -> DataArray3D:
947        """
948        Remove clumps of points that fit inside the beam, setting these values to np.nan.
949
950        **Parameters**
951
952        - `filtered_data` (`DataArray3D`): 3-D array with the same dimensions as the data array.
953
954        **Returns**
955
956        - `beam_filtered_data` (`DataArray3D`): 3-D array with small clumps of points removed.
957        """
958        beam_filtered_data = filtered_data.copy()
959        for frame in range(len(filtered_data)):
960            for y_idx in range(len(filtered_data[frame])):
961                for x_idx in range(len(filtered_data[frame][y_idx])):
962                    if np.isnan(filtered_data[frame][y_idx][x_idx]):  # ignore empty points
963                        continue
964                    
965                    # filled point that we are searching around
966                
967                    erase = True
968
969                    for x_offset, y_offset in self._boundary_offset:
970                        x_check = x_idx + x_offset
971                        y_check = y_idx + y_offset
972                        try:
973                            if not np.isnan(filtered_data[frame][y_check][x_check]):
974                                erase = False  # there is something present on the border - saved!
975                                break
976                        except IndexError:  # in case x_check, y_check are out of range
977                            pass
978
979                    if erase:  # consider ellipse to be an anomaly
980                        # erase entire inside of ellipse centred at w
981                        for x_offset, y_offset in self._offset_in_beam:
982                            x_check = x_idx + x_offset
983                            y_check = y_idx + y_offset
984                            try:
985                                beam_filtered_data[frame][y_check][x_check] = np.nan  # erase
986                            except IndexError:
987                                pass
988            
989        return beam_filtered_data

Remove clumps of points that fit inside the beam, setting these values to np.nan.

Parameters

  • filtered_data (DataArray3D): 3-D array with the same dimensions as the data array.

Returns

  • beam_filtered_data (DataArray3D): 3-D array with small clumps of points removed.
def get_expansion( self, years: float | int, v_func: Callable[[numpy.ndarray], numpy.ndarray] | None = None, remove_centre: float | int | None = 2, new_shape: tuple = (50, 250, 250), verbose: bool = False) -> AGB-Star-Deprojection.CondensedData:
1382    def get_expansion(self, years: float | int, v_func: VelocityModel | None = None, remove_centre: float | int | None = 2, new_shape: tuple = (50, 250, 250), verbose: bool = False) -> CondensedData:
1383        """
1384        Compute the expanded data cube after a given time interval.
1385
1386        **Parameters**
1387
1388        - `years` (`float` or `int`): Time interval in years.
1389        - `v_func` (`VelocityModel` or `None`): Velocity law function or None (default is None, which uses constant expansion velocity).
1390        - `remove_centre` (`float` or `int` or `None`, optional): If not None, remove all points within this many beam widths of the centre (default is 2).
1391        - `new_shape` (`tuple`, optional): Shape of the output grid (default is (50, 250, 250)).
1392        - `verbose` (`bool`, optional): If True, print progress (default is False).
1393
1394        **Returns**
1395
1396        - `info` (`CondensedData`): CondensedData object containing the expanded data and metadata.
1397        """     
1398        
1399        use_data = self.data.copy()
1400        if remove_centre is not None:
1401            if verbose:
1402                print("Removing centre...")
1403            # get centre coords
1404            y_idx = np.argmin(self.Y**2)
1405            x_idx = np.argmin(self.X**2)
1406            
1407            # proportion of the radius that the beam takes up
1408            beam_rad_au = (((self.beam_maj + self.beam_min)/2)*np.pi/180)*self.distance_to_star
1409            beam_prop = beam_rad_au/self.radius
1410            v_axis_removal = remove_centre*beam_prop*self.v_exp
1411            v_idxs = np.arange(len(self.V))[np.abs(self.V) < v_axis_removal]
1412            relevant_vs = self.V[np.abs(self.V) < v_axis_removal]
1413            proportions = remove_centre * np.sqrt(1 - (relevant_vs/v_axis_removal)**2)
1414            
1415            # removing centre
1416            for i in range(len(v_idxs)):
1417                v_idx = v_idxs[i]
1418                prop = proportions[i]
1419                beam = self.__multiply_beam(prop)
1420                for x_offset, y_offset in beam:
1421                    use_data[v_idx][y_idx + y_offset][x_idx + x_offset] = np.nan
1422        
1423        if verbose:
1424            print("Transforming coordinates...")
1425        X, Y, V, data, points = self.__time_expansion_transform(self.X, self.Y, self.V, use_data, years, v_func)
1426        v_num, y_num, x_num = new_shape
1427
1428        if verbose:
1429            print("Generating grid for new object...")
1430        gridv, gridy, gridx = np.mgrid[
1431            np.min(V):np.max(V):v_num*1j, 
1432            np.min(Y):np.max(Y):y_num*1j,
1433            np.min(X):np.max(X):x_num*1j
1434        ] 
1435
1436
1437        # get preimage of grid
1438        small_gridx = gridx[0, 0, :]
1439        small_gridy = gridy[0, :, 0]
1440        small_gridv = gridv[:, 0, 0]
1441
1442        # go backwards with negative years
1443        if verbose:
1444            print("Shrinking grid to original data bounds...")
1445        prev_X, prev_Y, prev_V, _, _ = self.__time_expansion_transform(small_gridx, small_gridy, small_gridv, use_data,  -years, v_func, crop = False)
1446        
1447        
1448        bad_idxs = (prev_V < np.min(points[0])) | (prev_V > np.max(points[0])) | \
1449                (prev_Y < np.min(points[1])) | (prev_Y > np.max(points[1])) | \
1450                (prev_X < np.min(points[2])) | (prev_X > np.max(points[2])) | \
1451                np.isnan(prev_V) | np.isnan(prev_X) | np.isnan(prev_Y)
1452        
1453        prev_V[bad_idxs] = 0
1454        prev_X[bad_idxs] = 0
1455        prev_Y[bad_idxs] = 0
1456
1457        # interpolate regular data at these points
1458        if verbose:
1459            print("Interpolating...")
1460        interp_points = np.column_stack((prev_V, prev_Y, prev_X))
1461        interp_data = interpn(points, data, interp_points)
1462        interp_data[bad_idxs] = np.nan
1463        new_data = interp_data.reshape(gridx.shape)
1464
1465        non_nans = len(new_data[np.isfinite(new_data)])
1466        if verbose:
1467            print(f"{non_nans} non-nan values remaining out of {np.size(new_data)}")
1468
1469        info = CondensedData(
1470            small_gridx, 
1471            small_gridy, 
1472            small_gridv, 
1473            new_data, 
1474            self.star_name, 
1475            self.distance_to_star, 
1476            self.v_exp, 
1477            self.v_sys,
1478            self.beta,
1479            self.r_dust,
1480            self.beam_maj, 
1481            self.beam_min, 
1482            self.beam_angle,
1483            self._header,
1484            self.mean, 
1485            self.std
1486        )
1487
1488        if verbose:
1489            print("Time expansion process complete!")
1490        return info

Compute the expanded data cube after a given time interval.

Parameters

  • years (float or int): Time interval in years.
  • v_func (VelocityModel or None): Velocity law function or None (default is None, which uses constant expansion velocity).
  • remove_centre (float or int or None, optional): If not None, remove all points within this many beam widths of the centre (default is 2).
  • new_shape (tuple, optional): Shape of the output grid (default is (50, 250, 250)).
  • verbose (bool, optional): If True, print progress (default is False).

Returns

  • info (CondensedData): CondensedData object containing the expanded data and metadata.
def plot_velocity_vs_intensity(self, fit_parabola: bool = True) -> None:
1495    def plot_velocity_vs_intensity(self, fit_parabola: bool = True) -> None:
1496        """
1497        Plot velocity vs. intensity at the center of the xy plane.
1498
1499        **Parameters**
1500
1501        - `fit_parabola` (`bool`, optional): If True, fit and plot a parabola (default is True).
1502
1503        **Notes**
1504
1505        The well-fittedness of the parabola can help you visually determine 
1506        the accuracy of the calculated systemic and expansion velocity.
1507        """
1508        self.__get_star_and_exp_velocity(self.V, plot=True, fit_parabola=fit_parabola)

Plot velocity vs. intensity at the center of the xy plane.

Parameters

  • fit_parabola (bool, optional): If True, fit and plot a parabola (default is True).

Notes

The well-fittedness of the parabola can help you visually determine the accuracy of the calculated systemic and expansion velocity.

def plot_radius_vs_intensity(self) -> None:
1510    def plot_radius_vs_intensity(self) -> None:
1511        """
1512        Plot radius vs. intensity at the center of the xy plane.
1513        """
1514        self.__get_beta_law(plot_intensities=True)

Plot radius vs. intensity at the center of the xy plane.

def plot_radius_vs_velocity(self, fit_beta_law: bool = True) -> None:
1516    def plot_radius_vs_velocity(self, fit_beta_law: bool = True) -> None:
1517        """
1518        Plot radius vs. velocity at the center of the xy plane.
1519
1520        **Parameters**
1521
1522        - `fit_beta_law` (`bool`, optional): If True, fit and plot the beta law (default is True).
1523
1524        **Notes**
1525
1526        The well-fittedness of the beta law curve can help you visually determine 
1527        the accuracy of the calculated beta law parameters.
1528        """
1529        self.__get_beta_law(plot_velocities=True, plot_beta_law=fit_beta_law)

Plot radius vs. velocity at the center of the xy plane.

Parameters

  • fit_beta_law (bool, optional): If True, fit and plot the beta law (default is True).

Notes

The well-fittedness of the beta law curve can help you visually determine the accuracy of the calculated beta law parameters.

def plot_channel_maps( self, filter_stds: float | int | None = None, filter_beam: bool = False, dimensions: None | tuple[int, int] = None, start: int = 0, end: None | int = None, include_beam: bool = True, text_pos: None | tuple[float, float] = None, beam_pos: None | tuple[float, float] = None, title: str | None = None, cmap: str = 'viridis') -> None:
1531    def plot_channel_maps(self, 
1532        filter_stds: float | int | None = None, 
1533        filter_beam: bool = False, 
1534        dimensions: None | tuple[int,int] = None, 
1535        start: int = 0, 
1536        end: None | int = None, 
1537        include_beam: bool = True, 
1538        text_pos: None | tuple[float,float] = None, 
1539        beam_pos: None | tuple[float,float] = None, 
1540        title: str | None = None, 
1541        cmap: str = "viridis"
1542    ) -> None:
1543        """
1544        Plot the data cube as a set of 2D channel maps.
1545
1546        **Parameters**
1547
1548        - `filter_stds` (`float` or `int` or `None`, optional): Number of standard deviations to filter by (default is None).
1549        - `filter_beam` (`bool`, optional): If True, apply beam filtering (default is False).
1550        - `dimensions` (`tuple` of `int` or `None`, optional): Grid dimensions (nrows, ncols) for subplots (default is None).
1551        - `start` (`int`, optional): Starting velocity channel index (default is 0).
1552        - `end` (`int` or `None`, optional): Ending velocity channel index (default is None).
1553        - `include_beam` (`bool`, optional): If True, plot the beam ellipse (default is True).
1554        - `text_pos` (`tuple` of `float` or `None`, optional): Position for velocity text annotation (default is None).
1555        - `beam_pos` (`tuple` of `float` or `None`, optional): Position for beam ellipse (default is None).
1556        - `title` (`str` or `None`, optional): Plot title (default is None, which uses the name of the star).
1557        - `cmap` (`str`, optional): Colormap for the plot (default is "viridis"). See https://matplotlib.org/stable/users/explain/colors/colormaps.html
1558        """
1559
1560        if end is None:
1561            end = self.data.shape[0]-1
1562        else:
1563            if end < start: #invalid so set to the end of the data
1564                end = self.data.shape[0]-1 
1565            if end >= self.data.shape[0]: #invalid so set to the end of the data
1566                end = self.data.shape[0]-1 
1567
1568        if filter_stds is not None:
1569            data = self.get_filtered_data(filter_stds)
1570            if filter_beam:
1571                data = self.beam_filter(data)
1572        else:
1573            data = self.data
1574        
1575        if start >= data.shape[0]: #invalid so set to the start of the data
1576            start = 0
1577
1578        if title is None:
1579            title = self.star_name + " channel maps"
1580
1581        if dimensions is not None:
1582            nrows, ncols = dimensions
1583        else:
1584            nrows = int(np.ceil(np.sqrt(end-start)))
1585            ncols = int(np.ceil((end-start+1)/nrows))
1586
1587        fig, axes = plt.subplots(nrows,ncols)
1588        viridis = colormaps[cmap]
1589        fig.suptitle(title)
1590        fig.supxlabel("Right Ascension (Arcseconds)")
1591        fig.supylabel("Declination (Arcseconds)")
1592
1593        i = start
1594        extents = u.radian.to(u.arcsec,np.array([np.min(self.X),np.max(self.X),np.min(self.Y),np.max(self.Y)])/self.distance_to_star)
1595        
1596        done = False
1597        for ax in axes.flat:
1598            if not done:
1599                im = ax.imshow(data[i],vmin = self.mean, vmax = np.max(data[~np.isnan(data)]), extent=extents, cmap = cmap)
1600
1601                ax.set_facecolor(viridis(0))
1602                ax.set_aspect("equal")
1603
1604                if text_pos is None:
1605                    ax.text(extents[0]*5/6,extents[3]*1/2,f"{self.V[i]:.1f} km/s",size="x-small",c="white")
1606                else:
1607                    ax.text(text_pos[0],text_pos[1],f"{self.V[i]:.1f} km/s",size="x-small",c="white")
1608
1609                if include_beam:
1610                    bmaj = u.deg.to(u.arcsec,self.beam_maj)
1611                    bmin = u.deg.to(u.arcsec,self.beam_min)
1612                    bpa = self.beam_angle
1613
1614                    if beam_pos is None:
1615                        ellipse_artist = Ellipse(xy=(extents[0]*1/2,extents[2]*1/2),width=bmaj,height=bmin,angle=bpa,color = "white")
1616                    else:
1617                        ellipse_artist = Ellipse(xy=(beam_pos[0],beam_pos[1]),width=bmaj,height=bmin,angle=bpa, color = "white")
1618                    ax.add_artist(ellipse_artist)
1619                i += 1
1620                if (i >= data.shape[0]) or (i > end): done = True
1621            else:
1622                ax.axis("off")
1623        cbar = fig.colorbar(im, ax=axes.ravel().tolist())
1624        cbar.set_label("Flux Density (Jy/Beam)")
1625        plt.show()

Plot the data cube as a set of 2D channel maps.

Parameters

  • filter_stds (float or int or None, optional): Number of standard deviations to filter by (default is None).
  • filter_beam (bool, optional): If True, apply beam filtering (default is False).
  • dimensions (tuple of int or None, optional): Grid dimensions (nrows, ncols) for subplots (default is None).
  • start (int, optional): Starting velocity channel index (default is 0).
  • end (int or None, optional): Ending velocity channel index (default is None).
  • include_beam (bool, optional): If True, plot the beam ellipse (default is True).
  • text_pos (tuple of float or None, optional): Position for velocity text annotation (default is None).
  • beam_pos (tuple of float or None, optional): Position for beam ellipse (default is None).
  • title (str or None, optional): Plot title (default is None, which uses the name of the star).
  • cmap (str, optional): Colormap for the plot (default is "viridis"). See https://matplotlib.org/stable/users/explain/colors/colormaps.html
def create_mask( self, filter_stds: float | int | None = None, savefile: str | None = None, initial_crop: tuple | None = None):
1627    def create_mask(self, filter_stds: float | int | None = None, savefile: str | None = None, initial_crop: tuple | None = None):
1628        """
1629        Launch an interactive mask creator for the data cube.
1630
1631        **Parameters**
1632
1633        - `filter_stds` (`float` or `int` or `None`, optional): Number of standard deviations to initially filter by (default is None).
1634        - `savefile` (`str` or `None`, optional): Filename to save the selected mask (default is None). Should end in .npy.
1635        - `initial_crop` (`tuple` or `None`, optional): Initial crop region as (v_lo, v_hi, y_lo, y_hi, x_lo, x_hi), using channel indices (default is None).
1636        """
1637        if filter_stds is not None:
1638            data = self.get_filtered_data(filter_stds)
1639        else:
1640            data = self.data
1641
1642        if initial_crop is not None:
1643            v_lo, v_hi, y_lo, y_hi, x_lo, x_hi = initial_crop
1644            new_data = data[v_lo:v_hi, y_lo:y_hi, x_lo:x_hi]
1645            selector = _PointsSelector(new_data)
1646            plt.show()
1647            mask = np.full(data.shape, np.nan)
1648            mask[v_lo:v_hi, y_lo:y_hi, x_lo:x_hi] = selector.mask
1649
1650        else:
1651            selector = _PointsSelector(data)
1652            plt.show()
1653        
1654            # mask is complete now
1655            mask = selector.mask
1656        
1657        # save mask
1658        if savefile is not None:
1659            np.save(savefile, mask)

Launch an interactive mask creator for the data cube.

Parameters

  • filter_stds (float or int or None, optional): Number of standard deviations to initially filter by (default is None).
  • savefile (str or None, optional): Filename to save the selected mask (default is None). Should end in .npy.
  • initial_crop (tuple or None, optional): Initial crop region as (v_lo, v_hi, y_lo, y_hi, x_lo, x_hi), using channel indices (default is None).
def plot_3D( self, filter_stds: float | int | None = None, filter_beam: bool = False, z_cutoff: float = 1, crop_leeway: float = 0, num_points: int = 50, num_surfaces: int = 50, opacity: float | int = 0.5, opacityscale: list[list[float]] = [[0, 0], [1, 1]], colorscale: str = 'Reds', v_func: Callable[[numpy.ndarray], numpy.ndarray] | None = None, verbose: bool = False, title: str | None = None, folder: str | None = None, num_angles: int = 24, camera_dist: float | int = 2) -> None:
1661    def plot_3D(
1662        self, 
1663        filter_stds: float | int | None = None, 
1664        filter_beam: bool = False, 
1665        z_cutoff: float = 1,
1666        crop_leeway: float = 0,
1667        num_points: int = 50,
1668        num_surfaces: int = 50,
1669        opacity: float | int = 0.5,
1670        opacityscale: list[list[float]] = [[0, 0], [1, 1]],
1671        colorscale: str = "Reds",
1672        v_func: VelocityModel | None = None, 
1673        verbose: bool = False,
1674        title: str | None = None,
1675        folder: str | None = None,
1676        num_angles: int = 24,
1677        camera_dist: float | int = 2
1678    ) -> None:
1679        """
1680        Plot a 3D volume rendering of the data cube using Plotly.
1681
1682        **Parameters**
1683
1684        - `filter_stds` (`float` or `int` or `None`, optional): Number of standard deviations to filter by (default is None).
1685        - `filter_beam` (`bool`, optional): If True, apply beam filtering (default is False).
1686        - `z_cutoff` (`float` or `None`, optional): Cutoff for z values as a proportion of the largest x, y values (default is 1).
1687        - `crop_leeway` (`int` or `float`, optional): Fractional leeway to expand the crop region (default is 0).
1688        - `num_points` (`int`, optional): Number of points in each dimension for the grid (default is 50).
1689        - `num_surfaces` (`int`, optional): Number of surfaces to draw when rendering the plot (default is 50).
1690        - `opacity` (`float` or `int`, optional): Opacity of the volume rendering (default is 0.5).
1691        - `opacityscale` (`list` of `list` of `float`, optional): Opacity scale, see https://plotly.com/python-api-reference/generated/plotly.graph_objects.Volume.html (default is [[0, 0], [1, 1]]).
1692        - `colorscale` (`str`, optional): Colormap for the plot, see https://plotly.com/python/builtin-colorscales/ (default is "Reds").
1693        - `v_func` (`VelocityModel` or `None`, optional): Velocity law function (default is None, which uses constant expansion velocity).
1694        - `verbose` (`bool`, optional): If True, print progress (default is False).
1695        - `title` (`str` or `None`, optional): Plot title (default is None, which uses the star name).
1696        - `folder` (`str` or `None`, optional): If provided, generates successive frames and saves as png files to this folder (default is None).
1697        - `num_angles` (`int`, optional): Number of angles for saving images (default is 24). Greater values give smoother animation.
1698        - `camera_dist` (`float` or `int`, optional): Camera radius to use if generating frames (default is 2).
1699        """
1700        if verbose:
1701            print("Initial filter and crop...")
1702        X, Y, V, data = self.__filter_and_crop(filter_stds, filter_beam, crop_leeway, verbose)
1703        
1704        if title is None:
1705            title = self.star_name
1706
1707        # get small 1d arrays
1708        x_small = X[0, 0, :]
1709        y_small = Y[0, :, 0]
1710        v_small = V[:, 0, 0]
1711
1712        # interpolate to grid
1713        if verbose:
1714            print("Interpolating to grid...")
1715
1716        z_bound = z_cutoff * np.maximum(np.max(np.abs(x_small)), np.max(np.abs(y_small)))
1717        
1718        x_lo, x_hi, y_lo, y_hi, z_lo, z_hi = np.min(x_small), np.max(x_small), np.min(y_small), np.max(y_small), -z_bound, z_bound
1719        gridx, gridy, gridz, out = self.__fast_interpolation((v_small, y_small, x_small), (x_lo, x_hi), (y_lo, y_hi), (z_lo, z_hi), data, v_func, num_points)
1720
1721
1722        # filter by standard deviation
1723        if filter_stds is not None:
1724            out[out < filter_stds*self.std] = np.nan
1725
1726        out[np.isnan(out)] = 0  # Plotly cant deal with nans
1727
1728        if verbose:
1729            print(f"Found {len(out[out > 0])} non-nan points.")
1730
1731        out = out.ravel()
1732        min_value = np.min(out[np.isfinite(out) & (out > 0)])
1733
1734        if verbose:
1735            print("Plotting figure...")
1736
1737        
1738        fig = go.Figure(
1739            data=go.Volume(
1740                x=gridx,
1741                y=gridy,
1742                z=gridz,
1743                value=out,
1744                isomin = min_value,
1745                colorscale= colorscale,
1746                colorbar = dict(title = "Flux Density (Jy/beam)"),
1747                opacityscale=opacityscale,
1748                opacity=opacity, # needs to be small to see through all surfaces
1749                surface_count=num_surfaces, # needs to be a large number for good volume rendering
1750                **kwargs
1751            ),
1752            layout=go.Layout(
1753                title = {
1754                    "text":title,
1755                    "x":0.5,
1756                    "y":0.95,
1757                    "xanchor":"center",
1758                    "font":{"size":24}
1759                },
1760                scene = dict(
1761                      xaxis=dict(
1762                          title=dict(
1763                              text='X (AU)'
1764                          )
1765                      ),
1766                      yaxis=dict(
1767                          title=dict(
1768                              text='Y (AU)'
1769                          )
1770                      ),
1771                      zaxis=dict(
1772                          title=dict(
1773                              text='Z (AU)'
1774                          )
1775                      ),
1776                      aspectmode = "cube"
1777                    ),
1778            )
1779        )
1780
1781
1782
1783        if folder is not None:
1784            if verbose:
1785                print("Generating frames...")
1786            angles = np.linspace(0,360,num_angles)
1787            for a in angles:
1788                b = a*np.pi/180
1789                eye = dict(x=camera_dist*np.cos(b),y=camera_dist*np.sin(b),z=1.25)
1790                fig.update_layout(scene_camera_eye = eye)
1791                if folder:
1792                    fig.write_image(f"{folder}/angle{int(a)}.png")
1793                else:
1794                    fig.write_image(f"{int(a)}.png")
1795                if verbose:
1796                    print(f"Generating frames: {a/360}% complete.")
1797        else:
1798            fig.show()

Plot a 3D volume rendering of the data cube using Plotly.

Parameters

  • filter_stds (float or int or None, optional): Number of standard deviations to filter by (default is None).
  • filter_beam (bool, optional): If True, apply beam filtering (default is False).
  • z_cutoff (float or None, optional): Cutoff for z values as a proportion of the largest x, y values (default is 1).
  • crop_leeway (int or float, optional): Fractional leeway to expand the crop region (default is 0).
  • num_points (int, optional): Number of points in each dimension for the grid (default is 50).
  • num_surfaces (int, optional): Number of surfaces to draw when rendering the plot (default is 50).
  • opacity (float or int, optional): Opacity of the volume rendering (default is 0.5).
  • opacityscale (list of list of float, optional): Opacity scale, see https://plotly.com/python-api-reference/generated/plotly.graph_objects.Volume.html (default is [[0, 0], [1, 1]]).
  • colorscale (str, optional): Colormap for the plot, see https://plotly.com/python/builtin-colorscales/ (default is "Reds").
  • v_func (VelocityModel or None, optional): Velocity law function (default is None, which uses constant expansion velocity).
  • verbose (bool, optional): If True, print progress (default is False).
  • title (str or None, optional): Plot title (default is None, which uses the star name).
  • folder (str or None, optional): If provided, generates successive frames and saves as png files to this folder (default is None).
  • num_angles (int, optional): Number of angles for saving images (default is 24). Greater values give smoother animation.
  • camera_dist (float or int, optional): Camera radius to use if generating frames (default is 2).