Source code for diempy.plots

from __future__ import annotations


# ---- stdlib ----
from collections import defaultdict
from itertools import groupby, chain


# ---- numpy / pandas ----
import numpy as np
import pandas as pd


# ---- matplotlib ----
import matplotlib.pyplot as plt
from matplotlib.patches import Wedge, Polygon, Rectangle
import matplotlib.colors as mcolors
from matplotlib.widgets import Button, Slider
from matplotlib.colors import to_rgb, LinearSegmentedColormap

import time
from dataclasses import dataclass
from typing import Iterable, Callable, Dict, Tuple, List, Optional

# ---- diem internals ----
from . import smooth


# explicitly used smoothing entry point
from .smooth import laplace_smooth_multiple_haplotypes

# more explicit imports
from fractions import Fraction
#from bisect import bisect_left
import bisect
from joblib import Parallel, delayed # for parallel computation of 'pwdmatrix' pairwise distance matrix rows

"""********************************************************
Keyword 'DMBCtoucher' is used to label the 
points where plots.py touches diemtypes.DMBC states

This is for future-proofing against the day 
where DMBC is allowed to store states other than {0,1,2,3}->{U,0,1,2} 
(c.f. vcf2diem output)

plots.py _currently_ has potential access to extra vcf2diem states through
    read_diem_bed_4_plots
    -> diemIrisFromPlotPrep
    -> diemLongFromPlotPrep

TODO:
    Chr_Nickname
    Ind_Nickname   can be linked with meta data to allow eg coloured name plotting (eg showing a-priori clusters)

    Stop using files bypass to get at extra states.
        Collab with Derek to store extra state internally in DMBC

    Rename the summaries functions more transparently - and move them out into an analytics section
********************************************************"""


""" _____________________ START Mathematica2Python _____________________"""
##############################
#### Mathematica2Python
### Author: Stuart J.E. Baird
###############################


def Split(seq, same_test=lambda a, b: a == b):
    '''
    Function to split a sequence into sublists based on a test function.
    ChatGPT 5.2 provided Mathematica Split equivalent in Python.
    
    Args:
        seq (list): The input sequence to be split.
        same_test (function): A function that takes two arguments and returns True if they are considered the same.
    Returns:
        list: A list of sublists, where each sublist contains consecutive elements that are considered the same.
    '''
    if not seq:
        return []
    out = [[seq[0]]]
    for x in seq[1:]:
        (out[-1] if same_test(out[-1][-1], x) else out.append([x])) and out[-1].append(x)
    return out


def RichRLE(lst):
    """
    Function generating a Rich Run Length Encoding of a list.
    
    Args:
        lst (list): The input list to be encoded.
    Returns:
        lists: A list containing four lists: states, lengths, starts, and ends of each run.
    """
    slst = Split(lst, lambda x, y: y == x)
    cumstart = 0
    states = []
    lengths = []
    starts = []
    ends = []
    for i in slst:
        leni = len(i)
        states.append(i[0])
        lengths.append(leni)
        starts.append(cumstart)
        cumstart += leni
        ends.append(cumstart - 1)
    return [states, lengths, starts, ends]


def Map(f, lst): 
    """
    equivalent to Mathematica Map 
    """
    return list(map(f, lst))



def ParallelMap(f, lst):
    """ 
    equivalent to Mathematica ParallelMap 
    """
    pool = mp.Pool()
    return list(pool.map(f, lst))


def Flatten(lstOlists): 
    """ 
    equivalent to Mathematica Flatten
    """
    return list(chain.from_iterable(lstOlists)) #itertools



def StringJoin(slst):
    """ 
    equivalent to Mathematica StringJoin 
    """
    separator = ''
    return separator.join(slst)


def Transpose(mat): 
    """ 
    equivalent to Mathematica Transpose
    """
    return list(np.array(mat).T)  # care here - hidden type casting on heterogeneous 'mat'rices


def StringTranspose(slst): 
    """ 
    equivalent to Mathematica StringTranspose 
    """
    return Map(StringJoin, Transpose(Map(Characters, slst)))


def Tally(lst):  # single pass so in principle fast ( O(n) ) but answers unsorted
    """ 
    equivalent to Mathematica Tally 
    """
    states = []
    tally = []
    for x in lst:
        p = FirstPosition(states, x)
        if p == []:
            states.append(x)
            tally.append([x, 1])
        else:
            tally[p[0]][1] += 1
    return tally

def Second(lst): 
    """ equivalent to Mathematica Second """
    return lst[1]


def Total(lst): 
    """"""
    return sum(lst)


def Join(lst1, lst2): 
    """ equivalent to Mathematica Join """
    return lst1 + lst2


def Take(lst, n):
    """ 
    equivalent to Mathematica Take 
    """
    if n > 0:
        ans = lst[:n]
    elif n == 0:
        ans = lst
    else:
        ans = lst[n:]
    return ans


def Drop(lst, n):
    """ equivalent to Mathematica Drop """
    if n > 0:
        ans = lst[n:]
    elif n == 0:
        ans = lst
    else:
        ans = lst[:n]
    return ans

def FirstPosition(lst, elem):
    """ equivalent to Mathematica FirstPosition """
    i = -1
    pos = []
    for l in lst:
        i += 1
        if l == elem:
            pos.append(i)
            break
    return pos

def Characters(s): 
    """ equivalent to Mathematica Characters """
    return [*s]

def StringTakeList(string, lengths):
    """ equivalent to Mathematica StringTakeList """
    substrings = []
    current_index = 0
    for length in lengths:
        substrings.append(string[current_index:current_index + length])
        current_index += length
    return substrings

""" _____________________ END Mathematica2Python _____________________"""


""" _____________________ START DIEMPy 2023 snippets _____________________"""

##############################
#### From DIEMPy
### Author: Stuart J.E. Baird
###############################


StringReplace20_dict = str.maketrans('02', '20')
""" simultaneous 2<->0 replacement dictionary """

def StringReplace20(text):
    """will _!simultaneously!_ replace 2->0 and 0->2"""
    return text.translate(StringReplace20_dict)


def sStateCount(s):
    """ 
    counts diem States as chars
    Args:   astring
    Returns: list of counts [nU,n0,n1,n2]
    """
    counts = Map(Second, Tally(Join(["0", "1", "2"], Characters(s))))
    nU = Total(
        Drop(counts, 3)
    )  # only the three 'call' chars above are not U encodings!
    counts = list(np.array(Take(counts, 3)) - 1)
    return Join([nU], counts)


def pHetErrOnString(s):
    """
    Calculates state frequency,heterozygosity and error rate from a string of diem states.
    Args:   astring
    Returns: tuple of (pHetErr, pHet, pErr)
    """
    sCount = sStateCount(s)
    callTotal = Total(Drop(sCount, 1))
    if callTotal > 0:
        ans = (
            Total(np.array(sCount) * [0, 0, 1, 2]) / (2 * callTotal),
            sCount[2] / callTotal,
            sCount[0] / Total(sCount),
        )
    else:  # no calls... are there any Us?
        if sCount[0] > 0:
            pErr = 1
        else:
            pErr = "NA"
        ans = ("NA", "NA", pErr)
    return ans


""" _____________________ START DIEMPy 2023 snippets _____________________"""



""" ______________new support by SJEB STARTING_______________________________"""

def Chr_Nickname(chr_name):
    """
    Shorten chromosome names for plotting.
    E.g., 'chromosome_1' -> 'Chr 1'
    Args:
        chr_name (str): Full chromosome name.
    Returns:
        str: Shortened chromosome name.
    """
    if 'scaffold_' in chr_name:
        return chr_name.replace('scaffold_', 'Scaf ')[-10:]
    elif 'scaffold' in chr_name:
        return chr_name.replace('scaffold', 'Scaf ')[-10:]
    elif 'chromosome_' in chr_name:
        return chr_name.replace('chromosome_', 'Chr ')[-9:]
    elif 'chromosome' in chr_name:
        return chr_name.replace('chromosome', 'Chr ')[-9:]
    else:
        return chr_name[-15:]
    
def Ind_Nickname(ind_name):
    """
    Shorten Ind names for plotting.
    E.g., 'chromosome_1' -> 'Chr 1'
    Args:
        ind_name (str): Full individual name.
    Returns:
        str: Shortened individual name.
    """
    shortlength = 11

    if '_NA' in ind_name:
        return ind_name.replace('_NA', '')[-shortlength:]
    else:
        return ind_name[-shortlength:]
    

def read_diem_bed_4_plots(bed_file_path, meta_file_path):
    """
    Reads a diem BED file and meta file for use in plots.py
    code copy from Derek Setter's read_diem_bed with additions by SJEB
    
    Derek comments:
    Fast version of read_diem_bed with significant performance improvements.
    
    Args:
    bed_file_path (str): Path to the diem BED file.
    meta_file_path (str): Path to the diem metadata file.

    Returns:
    A tuple of a DiemType object containing the diem BED data 
    (POLARISED (if hasPolarity)) and an IndsName file.
    """
    
    # Read metadata - no changes needed here as it's already fast
    df_meta = pd.read_csv(meta_file_path, sep='\t')
    chrNames = np.array(df_meta['#Chrom'].values)
    #chrRelativeRecRates = np.array(df_meta['relativeRecRates'].values)
    if "relativeRecRates" in df_meta.columns:
        chrRelativeRecRates = np.asarray(df_meta["relativeRecRates"].values)
    else:
        # Default: no recombination adjustment
        # Use chromosome count (24 in your case)
        chrRelativeRecRates = np.ones(len(chrNames), dtype=float)
    chrLengths = np.array(df_meta['RefEnd0'].values) - np.array(df_meta['RefStart0'].values)
    sampleNames = np.array(df_meta.columns[7:]) # Skip everything in title line up to relativeRecRates
    
    ploidyByChr = []
    for chr in chrNames:
        row = df_meta[df_meta['#Chrom'] == chr]
        ploidy = np.array(row.iloc[0,6:].values, dtype=int)
        ploidyByChr.append(ploidy)
    
    # Fast preamble reading - same as before
    preamble = []
    nSkipLines = 0
    individualsMasked = None
    with open(bed_file_path, 'r') as f:
        for line in f:
            if line.startswith('##'):
                preamble.append(line.strip())
                if line.startswith('##IndividualsMasked='):
                    clean_line = line.strip().removeprefix('##IndividualsMasked=')
                    if clean_line == 'None':
                        individualsMasked = None
                    else:
                        individualsMasked = clean_line.split(',')
                nSkipLines += 1
            else:
                break
    
    # Determine column names
    if len(preamble) > 0:
        hasPolarity = True
        column_names = [
            'chrom', 'start', 'end', 'qual', 'ref', 
            'SeqAlleles', 'SNV', 'nVNTs', 
            'exclusion_criterion', 'diem_genotype','nullPolarity','polarity',
            'DI','Support','masked'
        ]
    else:
        hasPolarity = False
        column_names = [
            'chrom', 'start', 'end', 'qual', 'ref', 
            'SeqAlleles', 'SNV', 'nVNTs', 
            'exclusion_criterion', 'diem_genotype'
        ]
    
    # Read the entire BED file at once
    df_bed = pd.read_csv(bed_file_path, sep='\t', names=column_names, skiprows=nSkipLines+1)

    # Polarise - these last lines SJEB
                
    if hasPolarity:
        #print('updating genotype polarities')
        mask = df_bed['polarity'] == 1
        df_bed.loc[mask, 'diem_genotype'] = df_bed.loc[mask, 'diem_genotype'].apply(StringReplace20)
        
    return df_bed,sampleNames,chrLengths,chrRelativeRecRates


def get_DI_span(aDT):
    """
    Get the min and max DI values across all chromosomes.

    Args:   aDT : DiemType
    """
    minDI=float('inf')
    maxDI=float('-inf')
    for idx, chr  in enumerate(aDT.DIByChr):
        minDI=min(minDI,min(aDT.DIByChr[idx]))
        maxDI=max(maxDI,max(aDT.DIByChr[idx]))
    return [minDI,maxDI]




"""______________________START statewise_genomes_summary_given_DI______________________________"""



def _statewise_summary_one_chr(SM, DI, ploidies, DIthreshold):
    """
    Compute statewise summary for ONE chromosome.
    Returns (counts_dict, (RetainedNumer, RetainedDenom))

    DMBCtoucher

    This is a helper for PAR_statewise_genomes_summary_given_DI.
    """

    nInds = SM.shape[0]

    DIfilter = DI >= DIthreshold
    RetainedNumer = int(np.count_nonzero(DIfilter))
    RetainedDenom = int(DIfilter.size)

    if RetainedNumer == 0:
        return (
            {
                "counts0": np.zeros(nInds, dtype=float),
                "counts1": np.zeros(nInds, dtype=float),
                "counts2": np.zeros(nInds, dtype=float),
                "counts3": np.zeros(nInds, dtype=float),
            },
            (RetainedNumer, RetainedDenom),
        )

    SMf = SM[:, DIfilter]

    # Vectorised masks
    is0 = (SMf == 0)
    is1 = (SMf == 1)
    is2 = (SMf == 2)
    is3 = ~(is0 | is1 | is2)

    counts0 = is0.sum(axis=1)
    counts1 = is1.sum(axis=1)
    counts2 = is2.sum(axis=1)
    counts3 = is3.sum(axis=1)

    w = ploidies.astype(float)

    return (
        {
            "counts0": w * counts0,
            "counts1": w * counts1,
            "counts2": w * counts2,
            "counts3": w * counts3,
        },
        (RetainedNumer, RetainedDenom),
    )


def PAR_statewise_genomes_summary_given_DI(
    aDT,
    DIthreshold: float,
    n_jobs=-1,
    backend="loky",
):
    """
    Parallel version of statewise_genomes_summary_given_DI.

    DMBCtoucher

    Parallelised over chromosomes.
    """

    nChr = len(aDT.DMBC)

    results = Parallel(
        n_jobs=n_jobs,
        backend=backend,
        prefer="processes",
        batch_size=1,
    )(
        delayed(_statewise_summary_one_chr)(
            aDT.DMBC[i],
            aDT.DIByChr[i],
            aDT.chrPloidies[i],
            DIthreshold,
        )
        for i in range(nChr)
    )

    chrom_counts = [r[0] for r in results]
    chrom_retained = [r[1] for r in results]

    return chrom_counts, chrom_retained



def SER_statewise_genomes_summary_given_DI(aDT, DIthreshold: float):
    """
    Statewise summary of genomes under a DI threshold.

    DMBCtoucher

    Refinements over genomes_summary_given_DI:
      1) counts3 = count of states NOT in {0,1,2}
      2) returns per-chromosome per-individual state counts
      3) returns per-chromosome retained counts

    Args
    ----
    aDT : DiemType
        Must provide:
          - DMBC : list of arrays, each shape (nInds, nSites_chr)
          - DIByChr : list of 1D arrays, per chromosome
          - chrPloidies : list of per-individual ploidies
    DIthreshold : float
        DI filter threshold

    Returns
    -------
    chrom_counts : list of dicts, length nChr
        Each dict has keys:
          'counts0', 'counts1', 'counts2', 'counts3'
        Each value is a float array of shape (nInds,)
        Counts are ploidy-weighted.
    chrom_retained : list of tuples, length nChr
        Each element is (RetainedNumer_chr, RetainedDenom_chr)
    """

    nChr = len(aDT.DMBC)
    nInds = aDT.DMBC[0].shape[0]

    chrom_counts = []
    chrom_retained = []

    for chr_idx in range(nChr):
        SM = aDT.DMBC[chr_idx]               # (nInds, nSites)
        DI = aDT.DIByChr[chr_idx]            # (nSites,)
        ploidies = aDT.chrPloidies[chr_idx]  # (nInds,)

        DIfilter = DI >= DIthreshold

        RetainedNumer = int(np.count_nonzero(DIfilter))
        RetainedDenom = int(DIfilter.size)
        chrom_retained.append((RetainedNumer, RetainedDenom))

        if RetainedNumer == 0:
            chrom_counts.append({
                "counts0": np.zeros(nInds, dtype=float),
                "counts1": np.zeros(nInds, dtype=float),
                "counts2": np.zeros(nInds, dtype=float),
                "counts3": np.zeros(nInds, dtype=float),
            })
            continue

        SMf = SM[:, DIfilter]  # (nInds, nRetained)

        # Vectorised state masks
        is0 = (SMf == 0)
        is1 = (SMf == 1)
        is2 = (SMf == 2)

        # counts3 = NOT in {0,1,2}
        is3 = ~(is0 | is1 | is2)

        # Per-individual counts
        counts0 = is0.sum(axis=1)
        counts1 = is1.sum(axis=1)
        counts2 = is2.sum(axis=1)
        counts3 = is3.sum(axis=1)

        # Apply ploidy weights once
        w = ploidies.astype(float)

        chrom_counts.append({
            "counts0": w * counts0,
            "counts1": w * counts1,
            "counts2": w * counts2,
            "counts3": w * counts3,
        })

    return chrom_counts, chrom_retained


def statewise_genomes_summary_given_DI(aDT, DIthreshold: float):
    if len(aDT.DMBC) >= 6:
        return PAR_statewise_genomes_summary_given_DI(aDT, DIthreshold)
    else:
        return SER_statewise_genomes_summary_given_DI(aDT, DIthreshold)
"""______________________END statewise_genomes_summary_given_DI______________________________"""


def summaries_from_statewise_counts(chrom_counts):
    """
    Compute [HI, HOM1, HET, HOM2, U] from statewise chrom_counts.

    Args: chrom_counts: iterable of dicts with keys
        'counts0', 'counts1', 'counts2', 'counts3'
        (arrays of shape nInds)

    Returns:
        summaries = [HI, HOM1, HET, HOM2, U]
    """

    A0 = sum(c["counts0"] for c in chrom_counts)
    A1 = sum(c["counts1"] for c in chrom_counts)
    A2 = sum(c["counts2"] for c in chrom_counts)
    A3 = sum(c["counts3"] for c in chrom_counts)

    dipDenom   = A1 + A2 + A3
    hapDenom   = 2 * dipDenom
    stateDenom = dipDenom + A0

    HI = np.divide((A2 + 2*A3), hapDenom, out=np.zeros_like(A2, dtype=float), where=(hapDenom != 0))
    HOM1 = np.divide(A1, dipDenom, out=np.zeros_like(A1, dtype=float), where=(dipDenom != 0))
    HET  = np.divide(A2, dipDenom, out=np.zeros_like(A2, dtype=float), where=(dipDenom != 0))
    HOM2 = np.divide(A3, dipDenom, out=np.zeros_like(A3, dtype=float), where=(dipDenom != 0))
    U    = np.divide(A0, stateDenom, out=np.zeros_like(A0, dtype=float), where=(stateDenom != 0))

    return [HI, HOM1, HET, HOM2, U]



def genomes_summary_given_DI(aDT, DIthreshold: float):
    """
    Summarises a diemType.DMBC across chromosomes, applying a DI threshold filter.
    Uses statewise_genomes_summary_given_DI for efficiency and single-point-of-entry to DMBC.

    Args:
    aDT: a DiemType
    DIthreshold: DI threshold filter
    Returns:
    summaries: list of per-individual summary arrays [HI, HOM1, HET, HOM2, U]
    RetainedNumer: number of retained sites after DI filtering
    RetainedDenom: total number of sites before DI filtering
    """
    chrom_counts, chrom_retained = \
        statewise_genomes_summary_given_DI(aDT, DIthreshold)

    summaries = summaries_from_statewise_counts(chrom_counts)

    DInumer = sum(n for n, _ in chrom_retained)
    DIdenom = sum(d for _, d in chrom_retained)

    return summaries, DInumer, DIdenom



def fractional_positions_of_multiples(A, delta):
    """
    Calculate fractional positions of multiples of delta in a sorted array A.

    Used as an inverse linear interpolation from reference (physical) positions to site indices;
    for example, to place ticks at regular physical intervals on a plot of site indices.

    Args:
        A (array-like): Sorted array of values (e.g. physical positions of DI filtered SNVs).
        delta (float): The interval for multiples (e.g. 1000 for kb ticks).
    Returns:
        np.ndarray: Array of (value (tick label), position (tick placement on SNV metric)) pairs.
    """

    A = np.asarray(A)
    n = len(A)

    values = []
    positions = []

    max_k = A[-1] // delta

    for k in range(1, max_k + 1):
        x = k * delta
        i = bisect.bisect_left(A, x)

        # skip values below the first element
        if i == 0:
            continue

        # exact match
        if i < n and A[i] == x:
            pos = float(i)
        else:
            left, right = A[i - 1], A[i]
            pos = (i - 1) + (x - left) / (right - left)

        values.append(x)
        positions.append(pos)
    tick_values = np.array(values)/delta
    tick_positions = np.array(positions)+1
    return np.column_stack((tick_values, tick_positions))
""" new support by SJEB ENDING"""


diemColours = [
    'white',
    mcolors.to_hex((128/255, 0, 128/255)),  # RGBColor[128/255, 0, 128/255] - Purple
    mcolors.to_hex((255/255, 229/255, 0)),  # RGBColor[255/255, 229/255, 0] - Yellow
    mcolors.to_hex((0, 128/255, 128/255))   # RGBColor[0, 128/255, 128/255] - Teal
]


"""________________________________________ cache-ing helpers ___________________"""


# Types matching existing API
ChromCounts = List[dict]                 # length nChr, dict with counts0..counts3 arrays (nInds,)
ChromRetained = List[Tuple[int, int]]    # length nChr, (kept, total)


@dataclass
class _ChrIncrementalState:
    """Internal mutable state for one chromosome during cache prefill."""
    order_asc: np.ndarray          # (nSites,) indices sorting DI ascending
    di_sorted: np.ndarray          # (nSites,) DI in ascending order
    k: int                         # current start index of retained suffix in di_sorted
    c0: np.ndarray                 # (nInds,) int counts (unweighted)
    c1: np.ndarray                 # (nInds,)
    c2: np.ndarray                 # (nInds,)
    c3: np.ndarray                 # (nInds,)
    ploidy_w: np.ndarray           # (nInds,) float weights


class StatewiseDIIncrementalCache:
    """
    Prefills snapshots for statewise_genomes_summary_given_DI for a fixed DI grid.

    Key property:
      - Across the whole prefill, each site is incorporated at most once (incremental).
      - No giant per-site prefix arrays are stored.
      - Snapshots are stored as ploidy-weighted float arrays (same as your API).
    """

    def __init__(
        self,
        aDT,
        di_grid: np.ndarray,
        progress: Optional[str] = None,   # None | "text"
        label: str = "Statewise cache prefill",
    ):
        self.aDT = aDT
        self.di_grid = np.asarray(di_grid, dtype=float)
        self.progress = progress
        self.label = label

        # Snapshots keyed by exact grid value (float)
        self._snapshots: Dict[float, Tuple[ChromCounts, ChromRetained]] = {}

        # Precompute per-chromosome incremental state
        self._chr_states: List[_ChrIncrementalState] = self._init_chr_states()

    def _init_chr_states(self) -> List[_ChrIncrementalState]:
        nChr = len(self.aDT.DMBC)
        nInds = self.aDT.DMBC[0].shape[0]
        states: List[_ChrIncrementalState] = []

        for chr_idx in range(nChr):
            DI = np.asarray(self.aDT.DIByChr[chr_idx], dtype=float)
            order = np.argsort(DI, kind="mergesort")  # stable, asc
            di_sorted = DI[order]

            ploidy_w = np.asarray(self.aDT.chrPloidies[chr_idx], dtype=float)

            states.append(
                _ChrIncrementalState(
                    order_asc=order,
                    di_sorted=di_sorted,
                    k=di_sorted.size,  # start with none retained (threshold > max)
                    c0=np.zeros(nInds, dtype=np.int64),
                    c1=np.zeros(nInds, dtype=np.int64),
                    c2=np.zeros(nInds, dtype=np.int64),
                    c3=np.zeros(nInds, dtype=np.int64),
                    ploidy_w=ploidy_w,
                )
            )
        return states

    def prefill(self):
        """
        Fill snapshots for all DI values in self.di_grid.

        We iterate thresholds from HIGH → LOW so the retained set only grows,
        letting us add newly-included sites once.
        """
        t0 = time.time()

        # We prefill in descending DI so retained grows monotonically.
        grid_desc = np.array(sorted(self.di_grid, reverse=True), dtype=float)
        n_steps = grid_desc.size

        for step_i, thr in enumerate(grid_desc, start=1):
            chrom_counts: ChromCounts = []
            chrom_retained: ChromRetained = []

            for chr_idx, st in enumerate(self._chr_states):
                nSites = st.di_sorted.size
                nInds = st.c0.size

                # New retained suffix start index for DI >= thr
                k_new = int(np.searchsorted(st.di_sorted, thr, side="left"))

                # Add sites that become newly retained: [k_new : st.k)
                if k_new < st.k:
                    # indices in original site order
                    idxs = st.order_asc[k_new:st.k]
                    SM = self.aDT.DMBC[chr_idx]  # (nInds, nSites_chr)

                    # Take the block (this allocates a temporary block array)
                    block = np.take(SM, idxs, axis=1)

                    # Count states in the block
                    b0 = (block == 0).sum(axis=1)
                    b1 = (block == 1).sum(axis=1)
                    b2 = (block == 2).sum(axis=1)
                    # everything else -> state3
                    b3 = block.shape[1] - (b0 + b1 + b2)

                    st.c0 += b0
                    st.c1 += b1
                    st.c2 += b2
                    st.c3 += b3

                    st.k = k_new  # retained suffix starts earlier now

                kept = nSites - st.k
                chrom_retained.append((int(kept), int(nSites)))

                w = st.ploidy_w
                chrom_counts.append({
                    "counts0": w * st.c0.astype(float),
                    "counts1": w * st.c1.astype(float),
                    "counts2": w * st.c2.astype(float),
                    "counts3": w * st.c3.astype(float),
                })

            # Store snapshot under exact threshold value
            self._snapshots[float(thr)] = (chrom_counts, chrom_retained)

            if self.progress == "text":
                elapsed = time.time() - t0
                pct = int(round(100.0 * step_i / n_steps))
                print(f"{self.label}: {step_i}/{n_steps} ({pct}%)  elapsed {elapsed:.1f}s")

        return self

    def get(self, DIthreshold: float) -> Tuple[ChromCounts, ChromRetained]:
        """
        Retrieve nearest snapshot for arbitrary DIthreshold.
        """
        thr = float(DIthreshold)

        # Find nearest grid value
        grid = self.di_grid
        j = int(np.argmin(np.abs(grid - thr)))
        key = float(grid[j])

        # NOTE: snapshots stored by exact float of grid values; should exist if prefilling done
        return self._snapshots[key]


@dataclass
class _ChrPWState:
    order_asc: np.ndarray   # indices sorting DI ascending
    di_sorted: np.ndarray   # DI ascending
    k: int                  # current suffix start index (retained = [k:])
    SM: np.ndarray          # (n_ind, n_sites_chr) numeric codes, clamped to 0..3


def _pw_weight_matrix():
    W = np.zeros((4, 4), dtype=np.float32)
    W[1, 2] = W[2, 1] = 1
    W[1, 3] = W[3, 1] = 2
    W[2, 2] = 1
    W[2, 3] = W[3, 2] = 1
    return W


class PairwiseDIIncrementalCache:
    """
    Incremental prefill cache for PARApwmatrixFromDiemType-like distance matrices
    over a fixed DI grid.

    - Prefill iterates DI thresholds HIGH → LOW so retained set grows.
    - Each site is incorporated at most once across the entire prefill.
    - Snapshots store M (num/den) for each DI grid point.
    """

    def __init__(
        self,
        aDT,
        di_grid: np.ndarray,
        *,
        chrom_indices: Optional[Sequence[int]] = None,
        progress: Optional[str] = None,      # None | "text"
        label: str = "Pairwise cache prefill",
        snapshot_dtype=np.float32,           # store snapshots as float32 to reduce memory
    ):
        self.aDT = aDT
        self.di_grid = np.asarray(di_grid, dtype=float)
        self.progress = progress
        self.label = label
        self.snapshot_dtype = snapshot_dtype

        self.n_ind = int(aDT.DMBC[0].shape[0])
        self.W = _pw_weight_matrix()

        if chrom_indices is None:
            chrom_indices = range(len(aDT.DMBC))
        self.chrom_indices = [int(i) for i in chrom_indices]

        self._chr_states = self._init_chr_states()
        self._snapshots: Dict[float, np.ndarray] = {}

        # Global incremental accumulators across ALL chosen chromosomes
        self._num = np.zeros((self.n_ind, self.n_ind), dtype=np.float64)
        self._den = np.zeros((self.n_ind, self.n_ind), dtype=np.int64)
        self._diag_num = np.zeros(self.n_ind, dtype=np.float64)
        self._diag_den = np.zeros(self.n_ind, dtype=np.int64)

    def _init_chr_states(self):
        states = []
        for chr_idx in self.chrom_indices:
            DI = np.asarray(self.aDT.DIByChr[chr_idx], dtype=float)
            order = np.argsort(DI, kind="mergesort")   # asc
            di_sorted = DI[order]

            # Clamp states: >2 -> 3 (future-proofing) ; keep 0 as missing
            #SM = np.minimum(self.aDT.DMBC[chr_idx], 3).astype(np.int8)
            # the clamp is wrong. Fix it: we want 0 to stay 0, and everything else >2 to become 3
            SM0 = self.aDT.DMBC[chr_idx]
            SM = np.where((SM0 >= 0) & (SM0 <= 3), SM0, 0).astype(np.int8)

            states.append(_ChrPWState(
                order_asc=order,
                di_sorted=di_sorted,
                k=di_sorted.size,  # start with none retained (thr > max)
                SM=SM,
            ))
        return states

    def _add_site_block(self, SM, idxs):
        W = self.W

        for s in idxs:
            col = SM[:, s]          # (n_ind,)
            valid = (col != 0)
            idx = np.nonzero(valid)[0]
            m = idx.size
            if m == 0:
                continue

            vals = col[idx].astype(np.int64)

            # ---- diagonal accumulators (within-individual) ----
            # valid site contributes to denominator
            self._diag_den[idx] += 1
            # HET contributes 1 unit (later /2 -> 0.5)
            self._diag_num[idx] += (vals == 2)

            # ---- off-diagonal accumulators ----
            if m < 2:
                continue

            ww = W[vals[:, None], vals[None, :]].astype(np.float64)

            ones = np.ones((m, m), dtype=np.int64)
            np.fill_diagonal(ones, 0)

            ix = np.ix_(idx, idx)
            self._num[ix] += ww
            self._den[ix] += ones


    def prefill(self):
        t0 = time.time()
        grid_desc = np.array(sorted(self.di_grid, reverse=True), dtype=float)
        n_steps = grid_desc.size

        for step_i, thr in enumerate(grid_desc, start=1):
            # Update each chromosome incremental state, and add newly-retained sites
            for st in self._chr_states:
                # retained set is DI >= thr => suffix start k_new
                k_new = int(np.searchsorted(st.di_sorted, thr, side="left"))

                if k_new < st.k:
                    # newly included sites are [k_new : st.k) in DI-sorted order
                    idxs = st.order_asc[k_new:st.k]
                    self._add_site_block(st.SM, idxs)
                    st.k = k_new

            # Snapshot matrix for this threshold
            M = np.full((self.n_ind, self.n_ind), np.nan, dtype=np.float64)

            mask = self._den > 0
            M[mask] = self._num[mask] / (2.0 * self._den[mask])

            # fill diagonal with within-individual heterozygosity-like distance
            dmask = self._diag_den > 0
            diag = np.full(self.n_ind, np.nan, dtype=np.float64)
            diag[dmask] = self._diag_num[dmask] / (2.0 * self._diag_den[dmask])
            np.fill_diagonal(M, diag)

            self._snapshots[float(thr)] = M.astype(self.snapshot_dtype, copy=False)


            if self.progress == "text":
                elapsed = time.time() - t0
                pct = int(round(100.0 * step_i / n_steps))
                print(f"{self.label}: {step_i}/{n_steps} ({pct}%)  elapsed {elapsed:.1f}s")

        return self

    def get(self, DIthreshold: float) -> np.ndarray:
        """
        Return nearest snapshot matrix to DIthreshold.
        """
        thr = float(DIthreshold)
        grid = self.di_grid
        j = int(np.argmin(np.abs(grid - thr)))
        key = float(grid[j])
        return self._snapshots[key]

"""________________________________________ START GenomeSummariesPlot ___________________"""


[docs] class GenomeSummaryPlot: """ Plots genome summaries with DI filtering and interactive widgets. These summaries include HI, HOM1, HET, HOM2, and U proportions per individual. Cursor hover displays individual IDs. Reorder button sorts individuals by HI given current DI filter. Args: dPol: DiemType object containing genomic data. Drop-in extension: - optional cache prefill over a DI grid (text progress) - DI slider uses cached results (nearest match within tol) - PREFILL uses StatewiseDIIncrementalCache (incremental, sorted-DI sweep) """ def __init__( self, dPol, *, prefill_cache=False, # NEW prefill_step=None, # NEW cache_tol=None, # NEW progress=None, # NEW: "text" | "none" ): self.dPol = dPol # ---- initial state ---- self.IndNickNames = [Ind_Nickname(name) for name in dPol.indNames] self.indNameFont = 6 self.indHIorder = np.arange(len(dPol.indNames)) # ---- cache (per instance) ---- # cache maps DI_value(float) -> (summaries, chrom_retained) self._cache = {} self._cache_keys_sorted = [] def _cache_set(di, value): di = float(di) if di not in self._cache: bisect.insort(self._cache_keys_sorted, di) self._cache[di] = value def _cache_get_nearest(di, tol): """Return cached payload for nearest DI within tol, else None.""" if not self._cache_keys_sorted: return None di = float(di) keys = self._cache_keys_sorted j = bisect.bisect_left(keys, di) candidates = [] if 0 <= j < len(keys): candidates.append(keys[j]) if 0 <= j - 1 < len(keys): candidates.append(keys[j - 1]) best = None best_dist = None for k in candidates: dist = abs(k - di) if best_dist is None or dist < best_dist: best = k best_dist = dist if best is None or best_dist is None or best_dist > tol: return None return self._cache[best] self._cache_set = _cache_set self._cache_get_nearest = _cache_get_nearest # ---- initial summaries (no filter) ---- self.chrom_counts, self.chrom_retained = statewise_genomes_summary_given_DI( self.dPol, float("-inf") ) self.summaries = summaries_from_statewise_counts(self.chrom_counts) self.DInumer = sum(n for n, _ in self.chrom_retained) self.DIdenom = sum(d for _, d in self.chrom_retained) self.prop = self.DInumer / self.DIdenom if self.DIdenom else 0.0 # ---- figure & axes ---- self.fig, self.ax = plt.subplots(figsize=(11, 4)) colours = Flatten([['red'], diemColours[1:], ['gray']]) self.lines = [] for summary, colour in zip(self.summaries, colours): line, = self.ax.plot(summary, color=colour, marker='.') self.lines.append(line) self.ax.legend(['HI', 'HOM1', 'HET', 'HOM2', 'U']) self.ax.set_ylim(0, 1) self.ax.set_title('Genomes summaries; no DI filter') self.ax.tick_params(axis='x', rotation=55) self._update_xticks() # ---- widgets ---- self._init_widgets() # ---- OPTIONAL: prefill cache grid (INCREMENTAL helper) ---- if prefill_cache: DI_span = get_DI_span(self.dPol) di_min, di_max = float(DI_span[0]), float(DI_span[1]) if prefill_step is None: span = di_max - di_min prefill_step = span / 200.0 if span > 0 else 1.0 if cache_tol is None: cache_tol = float(prefill_step) / 2.0 # Build DI grid including endpoints if di_max > di_min: n_steps = int(np.floor((di_max - di_min) / prefill_step)) di_values = [di_min + k * prefill_step for k in range(n_steps + 1)] if not di_values or di_values[-1] < di_max: di_values.append(di_max) else: di_values = [di_min] di_grid = np.asarray(di_values, dtype=float) # --- incremental cache prefill --- inc = StatewiseDIIncrementalCache( self.dPol, di_grid=di_grid, progress=("text" if progress == "text" else None), label="GenomeSummary cache prefill", ).prefill() # Convert chrom_counts snapshots into summaries once, and store # NOTE: inc._snapshots keys are floats from the grid (descending fill, but dict unordered) for di_key, (chrom_counts, chrom_retained) in inc._snapshots.items(): summaries = summaries_from_statewise_counts(chrom_counts) self._cache_set(di_key, (summaries, chrom_retained)) self._cache_tol = float(cache_tol) else: self._cache_tol = 0.0 # ---- coordinate display ---- self._install_format_coord() plt.show() # ---------------- helpers ---------------- def _update_xticks(self): self.ax.set_xticks( np.arange(len(self.IndNickNames)), np.array(self.IndNickNames)[self.indHIorder], rotation=55, fontsize=self.indNameFont, horizontalalignment='right' ) def _install_format_coord(self): n = len(self.dPol.indNames) tolerance = 0.03 # vertical proximity in y-units def format_coord(x, y): fallback = "\u2007" * 30 i = int(round(x)) if i < 0 or i >= n: return fallback for summary in self.summaries: y0 = summary[self.indHIorder][i] if abs(y - y0) < tolerance: return f"IndID: {self.dPol.indNames[self.indHIorder[i]]}" return fallback self.ax.format_coord = format_coord # ---------------- widgets ---------------- def _init_widgets(self): DI_span = get_DI_span(self.dPol) self.fig.subplots_adjust(bottom=0.3) # DI slider DI_box = self.fig.add_axes([0.2, 0.1, 0.65, 0.03]) self.DI_slider = Slider( ax=DI_box, label='DI', valmin=DI_span[0], valmax=DI_span[1], valinit=DI_span[0], ) self.DI_slider.on_changed(self.DIupdate) # Font slider FONT_box = self.fig.add_axes([0.25, 0.025, 0.1, 0.04]) self.FONT_slider = Slider( ax=FONT_box, label='IndLabels font', valmin=1, valmax=16, valinit=self.indNameFont, ) self.FONT_slider.on_changed(self.FONTupdate) # Reorder button reorderBox = self.fig.add_axes([0.8, 0.025, 0.1, 0.04]) self.reo_button = Button( reorderBox, 'Reorder by HI', hovercolor='0.975', color='red' ) self.reo_button.on_clicked(self.reorder) # ---------------- callbacks ---------------- def DIupdate(self, val): payload = None if self._cache_tol > 0: payload = self._cache_get_nearest(val, self._cache_tol) if payload is None: # lazy compute + store chrom_counts, chrom_retained = statewise_genomes_summary_given_DI(self.dPol, val) summaries = summaries_from_statewise_counts(chrom_counts) payload = (summaries, chrom_retained) self._cache_set(val, payload) self.summaries, self.chrom_retained = payload self.DInumer = sum(n for n, _ in self.chrom_retained) self.DIdenom = sum(d for _, d in self.chrom_retained) self.prop = self.DInumer / self.DIdenom if self.DIdenom else 0.0 self.ax.set_title( "Genomes summaries DI ≥ {:.2f} {} SNVs ({:.1f}% divergent across barrier)" .format(val, self.DInumer, 100 * self.prop) ) for line, summary in zip(self.lines, self.summaries): line.set_ydata(summary[self.indHIorder]) self.fig.canvas.draw_idle() def reorder(self, event): self.indHIorder = np.argsort(self.summaries[0]) # HI self._update_xticks() for line, summary in zip(self.lines, self.summaries): line.set_ydata(summary[self.indHIorder]) self.fig.canvas.draw_idle() def FONTupdate(self, val): self.indNameFont = int(val) self._update_xticks() self.fig.canvas.draw_idle()
#________________________________________ END GenomeSummariesPlot ___________________ #________________________________________ START GenomeMultiSummaryPlot ___________________
[docs] class GenomeMultiSummaryPlot: """ Plots genome summaries per chromosome with DI filtering and interactive widgets. These summaries include HI, HOM1, HET, HOM2, and U proportions per individual. Cursor hover displays individual IDs. Reorder button sorts individuals by global HI given current DI filter. Drop-in extension: - optional incremental cache prefill using StatewiseDIIncrementalCache - DI slider uses cached results (nearest match within tol) - keeps existing hover behaviour and plot style Args: dPol: DiemType object containing genomic data. chrom_indices: List of chromosome indices to plot. max_cols: max subplot columns. prefill_cache: precompute incremental cache over DI grid. prefill_step: DI step for grid (defaults to span/200). cache_tol: nearest-cache tolerance (defaults to prefill_step/2). progress: "text" | "none" """ def __init__( self, dPol, chrom_indices, max_cols=3, *, prefill_cache=False, prefill_step=None, cache_tol=None, progress=None,# "text" | None ): self.dPol = dPol self.IndNickNames = [Ind_Nickname(name) for name in dPol.indNames] self.ChrNickNames = [Chr_Nickname(name) for name in dPol.chrNames] self.max_cols = max_cols # ---- validate chromosomes ---- self.chrom_indices = self._validate_chrom_indices(chrom_indices) # ---- ordering state ---- self.indNameFont = 6 # ---- cache (per instance) ---- # cache maps DI_value(float) -> payload # payload = (global_summaries, chrom_retained, per_chr_summaries_dict) self._cache = {} self._cache_keys_sorted = [] def _cache_set(di, payload): di = float(di) if di not in self._cache: bisect.insort(self._cache_keys_sorted, di) self._cache[di] = payload def _cache_get_nearest(di, tol): if not self._cache_keys_sorted: return None di = float(di) keys = self._cache_keys_sorted j = bisect.bisect_left(keys, di) candidates = [] if 0 <= j < len(keys): candidates.append(keys[j]) if 0 <= j - 1 < len(keys): candidates.append(keys[j - 1]) best = None best_dist = None for k in candidates: dist = abs(k - di) if best_dist is None or dist < best_dist: best = k best_dist = dist if best is None or best_dist is None or best_dist > tol: return None return self._cache[best] self._cache_set = _cache_set self._cache_get_nearest = _cache_get_nearest # ---- initial DI snapshot (no filter) ---- self.chrom_counts, self.chrom_retained = statewise_genomes_summary_given_DI( self.dPol, float("-inf") ) # authoritative whole-genome summaries (from statewise counts) global_summaries = summaries_from_statewise_counts(self.chrom_counts) self.global_HI = global_summaries[0] self.indHIorder = np.argsort(self.global_HI) # build per-chromosome summaries once (initial DI) self.chrom_summaries = {} for idx in self.chrom_indices: self.chrom_summaries[idx] = summaries_from_statewise_counts([self.chrom_counts[idx]]) # ---- grid layout ---- n_plots = len(self.chrom_indices) n_cols = min(self.max_cols, n_plots) n_rows = (n_plots + n_cols - 1) // n_cols fig_w = 4.5 * n_cols fig_h = 3.5 * n_rows self.fig, self.axes = plt.subplots( n_rows, n_cols, figsize=(fig_w, fig_h), squeeze=False, sharey=True ) self.fig.subplots_adjust( left=0.06, right=0.98, top=0.92, bottom=0.45, hspace=0.60, wspace=0.25 ) # ---- draw plots ---- self.lines = {} axes_flat = self.axes.flatten() colours = Flatten([['red'], diemColours[1:], ['gray']]) global_hi_colour = "cyan" self.chrom_axes = {} self.global_hi_lines = {} for ax, chrom_idx in zip(axes_flat, self.chrom_indices): self.chrom_axes[chrom_idx] = ax summaries = self.chrom_summaries[chrom_idx] chrom_lines = [] for summary, colour in zip(summaries, colours): line, = ax.plot( summary[self.indHIorder], color=colour, marker='.', linewidth=0.8 ) chrom_lines.append(line) self.lines[chrom_idx] = chrom_lines # global HI overlay global_hi_line, = ax.plot( self.global_HI[self.indHIorder], color=global_hi_colour, linestyle="-", linewidth=1.5, alpha=0.8, ) self.global_hi_lines[chrom_idx] = global_hi_line ax.set_ylim(0, 1) num, denom = self.chrom_retained[chrom_idx] ax.set_title(f"{self.ChrNickNames[chrom_idx]} | {num:,}/{denom:,} sites", fontsize=10) ax.tick_params(axis='x', rotation=55) ax.set_xticks( np.arange(len(self.IndNickNames)), np.array(self.IndNickNames)[self.indHIorder], fontsize=self.indNameFont, ha='right' ) # hide unused axes for ax in axes_flat[len(self.chrom_indices):]: ax.axis("off") # legend once axes_flat[0].legend( ['HIc', 'HOM1', 'HET', 'HOM2', 'U', 'HIg'], fontsize=8, frameon=False ) # ---- widgets ---- self._init_widgets() # Force an initial DI computation at the *max* endpoint (or whatever value you want) EURG! #DI_span = get_DI_span(self.dPol) #self._on_DI_change(float(DI_span[1])) # ---- OPTIONAL: incremental cache prefill ---- if prefill_cache: DI_span = get_DI_span(self.dPol) di_min, di_max = float(DI_span[0]), float(DI_span[1]) if prefill_step is None: span = di_max - di_min prefill_step = span / 200.0 if span > 0 else 1.0 if cache_tol is None: cache_tol = float(prefill_step) / 2.0 # Build DI grid including endpoints if di_max > di_min: n_steps = int(np.floor((di_max - di_min) / prefill_step)) di_grid = di_min + prefill_step * np.arange(n_steps + 1, dtype=float) di_grid = np.clip(di_grid, di_min, di_max) di_grid[-1] = di_max else: di_grid = np.array([di_min], dtype=float) #di_grid = np.asarray(di_values, dtype=float) inc = StatewiseDIIncrementalCache( self.dPol, di_grid=di_grid, progress=("text" if progress == "text" else None), label="GenomeMultiSummary cache prefill", ).prefill() # Store only what this plot needs: # - global summaries (all chr) # - chrom_retained (all chr) # - per-chr summaries for selected chrom_indices for di_key, (chrom_counts, chrom_retained) in inc._snapshots.items(): global_summaries = summaries_from_statewise_counts(chrom_counts) per_chr = {} for idx in self.chrom_indices: per_chr[idx] = summaries_from_statewise_counts([chrom_counts[idx]]) self._cache_set(di_key, (global_summaries, chrom_retained, per_chr)) self._cache_tol = float(cache_tol) else: self._cache_tol = 0.0 # ---- coordinate display ---- self._install_format_coord() plt.show() # ================================================== # Validation # ================================================== def _validate_chrom_indices(self, chrom_indices): max_idx = len(self.dPol.chrLengths) - 1 valid, rejected = [], [] for idx in chrom_indices: if isinstance(idx, (int, np.integer)) and 0 <= int(idx) <= max_idx: valid.append(int(idx)) else: rejected.append(idx) if rejected: print("GenomeMultiSummaryPlot: rejected chromosome indices:", rejected) if not valid: raise ValueError("GenomeMultiSummaryPlot: no valid chromosome indices.") return valid # ================================================== # Hover # ================================================== def _install_format_coord(self): n = len(self.dPol.indNames) tolerance = 0.03 axes_flat = self.axes.flatten()[:len(self.chrom_indices)] for ax, chrom_idx in zip(axes_flat, self.chrom_indices): chrom_lines = self.lines[chrom_idx] def make_format_coord(chrom_lines_local): def format_coord(x, y): fallback = "\u2007" * 30 i = int(round(x)) if i < 0 or i >= n: return fallback for line in chrom_lines_local: ydata = line.get_ydata() if abs(y - ydata[i]) < tolerance: return f"IndID: {self.dPol.indNames[self.indHIorder[i]]}" return fallback return format_coord ax.format_coord = make_format_coord(chrom_lines) # ================================================== # Widgets # ================================================== def _init_widgets(self): DI_span = get_DI_span(self.dPol) ax_DI = self.fig.add_axes([0.15, 0.18, 0.7, 0.03]) self.DI_slider = Slider(ax_DI, "DI", DI_span[0], DI_span[1], valinit=DI_span[0]) self.DI_slider.on_changed(self._on_DI_change) ax_FS = self.fig.add_axes([0.25, 0.12, 0.1, 0.03]) self.FONT_slider = Slider(ax_FS, "IndLabel font", 4, 16, valinit=self.indNameFont, valstep=1) self.FONT_slider.on_changed(self._on_font_change) ax_RE = self.fig.add_axes([0.75, 0.115, 0.15, 0.045]) self.reorder_button = Button( ax_RE, "Reorder by global HI", hovercolor="0.95", color="cyan" ) self.reorder_button.on_clicked(self._on_reorder) # ================================================== # Callbacks # ================================================== def _on_DI_change(self, val): payload = None if self._cache_tol > 0: payload = self._cache_get_nearest(val, self._cache_tol) if payload is None: # Lazy compute full statewise summary once chrom_counts, chrom_retained = statewise_genomes_summary_given_DI(self.dPol, val) global_summaries = summaries_from_statewise_counts(chrom_counts) per_chr = {} for idx in self.chrom_indices: per_chr[idx] = summaries_from_statewise_counts([chrom_counts[idx]]) payload = (global_summaries, chrom_retained, per_chr) self._cache_set(val, payload) global_summaries, self.chrom_retained, per_chr = payload self.global_HI = global_summaries[0] # Update plotted data for each chromosome for idx in self.chrom_indices: summaries = per_chr[idx] self.chrom_summaries[idx] = summaries for line, summary in zip(self.lines[idx], summaries): line.set_ydata(summary[self.indHIorder]) self.global_hi_lines[idx].set_ydata(self.global_HI[self.indHIorder]) num, denom = self.chrom_retained[idx] self.chrom_axes[idx].set_title( f"{self.ChrNickNames[idx]} | {num:,}/{denom:,} sites", fontsize=10 ) self.fig.canvas.draw_idle() def _on_font_change(self, val): self.indNameFont = int(val) labels = np.array(self.IndNickNames)[self.indHIorder] for ax in self.axes.flatten()[:len(self.chrom_indices)]: ax.set_xticklabels(labels, fontsize=self.indNameFont) self.fig.canvas.draw_idle() def _on_reorder(self, event=None): """ Reorder individuals by *current* whole-genome HI (statewise), without recomputing anything expensive. """ self.indHIorder = np.argsort(self.global_HI) labels = np.array(self.IndNickNames)[self.indHIorder] for idx in self.chrom_indices: # global HI overlay self.global_hi_lines[idx].set_ydata(self.global_HI[self.indHIorder]) # chromosome lines summaries = self.chrom_summaries[idx] for line, summary in zip(self.lines[idx], summaries): line.set_ydata(summary[self.indHIorder]) for ax in self.axes.flatten()[:len(self.chrom_indices)]: ax.set_xticks( np.arange(len(labels)), labels, fontsize=self.indNameFont, ha="right" ) self.fig.canvas.draw_idle()
"""________________________________________ END GenomeMultiSummaryPlot ___________________""" """________________________________________ START GenomicDeFinettiPlot ___________________"""
[docs] class GenomicDeFinettiPlot: """ Plots a genomic de Finetti plot with DI filtering and interactive widgets. Cursor hover displays individual IDs. c.f. Figure 2, Figure 4: Petružela, J., Nürnberger, B., Ribas, A., Koutsovoulos, G., Čížková, D., Fornůsková, A., Aghová, T., Blaxter, M., de Bellocq, J.G. and Baird, S.J.E. (2025), Comparative Genomic Analysis of Co-Occurring Hybrid Zones of House Mouse Parasites Pneumocystis murina and Syphacia obvelata Using Genome Polarisation. Mol Ecol, 34: e70044. https://doi.org/10.1111/mec.70044 Figure 4: Ebdon, S., Laetsch, D. R., Vila, R., Baird, S. J. E., & Lohse, K. (2025). Genomic regions of current low hybridisation mark long-term barriers to gene flow in scarce swallowtail butterflies. PLoS Genetics, 21(4), 30. doi:https://doi.org/10.1371/journal.pgen.1011655 Drop-in extension: - optional incremental cache prefill using StatewiseDIIncrementalCache - DI slider uses cached results (nearest match within tol) - keeps output + hover behaviour the same Uses: - summaries_from_statewise_counts(statewise counts) - StatewiseDIIncrementalCache (fast prefill) Args: dPol: DiemType object containing genomic data. """ def __init__( self, dPol, *, prefill_cache=False, # NEW prefill_step=None, # NEW cache_tol=None, # NEW progress=None, # NEW: "text" | "none" ): self.dPol = dPol # ---- initial state ---- self.marker_size = 60 self.indHIorder = np.arange(len(dPol.indNames)) # ---- cache (per instance) ---- # maps DI_value(float) -> (summaries, DInumer, DIdenom) self._cache = {} self._cache_keys_sorted = [] def _cache_set(di, payload): di = float(di) if di not in self._cache: bisect.insort(self._cache_keys_sorted, di) self._cache[di] = payload def _cache_get_nearest(di, tol): if not self._cache_keys_sorted: return None di = float(di) keys = self._cache_keys_sorted j = bisect.bisect_left(keys, di) candidates = [] if 0 <= j < len(keys): candidates.append(keys[j]) if 0 <= j - 1 < len(keys): candidates.append(keys[j - 1]) best = None best_dist = None for k in candidates: dist = abs(k - di) if best_dist is None or dist < best_dist: best = k best_dist = dist if best is None or best_dist is None or best_dist > tol: return None return self._cache[best] self._cache_set = _cache_set self._cache_get_nearest = _cache_get_nearest # ---- initial summaries (no DI filter) ---- # Keep your original semantics, but compute via statewise so cache path matches exactly. chrom_counts, chrom_retained = statewise_genomes_summary_given_DI(self.dPol, float("-inf")) self.summaries = summaries_from_statewise_counts(chrom_counts) self.DInumer = sum(n for n, _ in chrom_retained) self.DIdenom = sum(d for _, d in chrom_retained) # unpack summaries self.HOM1 = self.summaries[1] self.HET = self.summaries[2] self.HOM2 = self.summaries[3] self.U = self.summaries[4] # ---- figure & axes ---- self.fig, self.ax = plt.subplots(figsize=(10, 10)) self._setup_axes() self.ax.set_title('Genomic de Finetti; no DI filter') # background self._draw_triangle() self._draw_hwe_curve() # points self.scatter = self._draw_points() # widgets self._init_widgets() # ---- OPTIONAL: prefill incremental cache ---- if prefill_cache: DI_span = get_DI_span(self.dPol) di_min, di_max = float(DI_span[0]), float(DI_span[1]) if prefill_step is None: span = di_max - di_min prefill_step = span / 200.0 if span > 0 else 1.0 if cache_tol is None: cache_tol = float(prefill_step) / 2.0 # Build DI grid including endpoints if di_max > di_min: n_steps = int(np.floor((di_max - di_min) / prefill_step)) di_values = [di_min + k * prefill_step for k in range(n_steps + 1)] if not di_values or di_values[-1] < di_max: di_values.append(di_max) else: di_values = [di_min] di_grid = np.asarray(di_values, dtype=float) inc = StatewiseDIIncrementalCache( self.dPol, di_grid=di_grid, progress=("text" if progress == "text" else None), label="GenomicDeFinetti cache prefill", ).prefill() # Store summaries snapshots only (small memory) for di_key, (chrom_counts, chrom_retained) in inc._snapshots.items(): summaries = summaries_from_statewise_counts(chrom_counts) DInumer = sum(n for n, _ in chrom_retained) DIdenom = sum(d for _, d in chrom_retained) self._cache_set(di_key, (summaries, DInumer, DIdenom)) self._cache_tol = float(cache_tol) else: self._cache_tol = 0.0 # coordinate display self._install_format_coord() plt.show() # -------------------------------------------------- # Helpers # -------------------------------------------------- @staticmethod def _to_triangle_coords(hom1, het, hom2): x = hom2 + 0.5 * het y = (np.sqrt(3) / 2) * het return x, y def _update_title(self, DIval): prop = self.DInumer / self.DIdenom if self.DIdenom > 0 else 0.0 self.ax.set_title( "Genomic de Finetti plot DI ≥ {:.2f} {} SNVs ({:.1f}% divergent across barrier)" .format(DIval, self.DInumer, 100 * prop), fontsize=12, pad=12 ) # -------------------------------------------------- # Axes / background # -------------------------------------------------- def _setup_axes(self): self.ax.set_aspect("equal") self.ax.set_xlim(-0.05, 1.05) self.ax.set_ylim(-0.05, np.sqrt(3) / 2 + 0.05) self.ax.set_xticks([]) self.ax.set_yticks([]) for spine in self.ax.spines.values(): spine.set_visible(False) self.ax.set_title("Genomic de Finetti plot") def _draw_triangle(self): h = np.sqrt(3) / 2 triangle = np.array([[0, 0], [1, 0], [0.5, h]]) self.ax.add_patch( Polygon(triangle, closed=True, fill=False, lw=1.2, color="black") ) self.ax.text(0, -0.04, "HOM1", ha="center", va="top", fontsize=9) self.ax.text(1, -0.04, "HOM2", ha="center", va="top", fontsize=9) self.ax.text(0.5, h + 0.03, "HET", ha="center", va="bottom", fontsize=9) def _draw_hwe_curve(self): p = np.linspace(0, 1, 400) hom1 = p**2 het = 2 * p * (1 - p) hom2 = (1 - p)**2 x, y = self._to_triangle_coords(hom1, het, hom2) self.ax.plot(x, y, color="black", lw=0.8, alpha=0.5) # -------------------------------------------------- # Points # -------------------------------------------------- def _blend_colours(self): weights = np.column_stack([self.HOM1, self.HET, self.HOM2, self.U]) base_colours = np.array([ to_rgb(diemColours[1]), # HOM1 to_rgb(diemColours[2]), # HET to_rgb(diemColours[3]), # HOM2 to_rgb(diemColours[0]), # U ]) # (4,3) rgb = weights @ base_colours rgb = np.clip(rgb, 0.0, 1.0) return rgb def _draw_points(self): x, y = self._to_triangle_coords( self.HOM1[self.indHIorder], self.HET[self.indHIorder], self.HOM2[self.indHIorder], ) colours = self._blend_colours()[self.indHIorder] return self.ax.scatter( x, y, s=self.marker_size, c=colours, edgecolor="black", linewidth=0.3, ) def _update_points(self): self.HOM1 = self.summaries[1] self.HET = self.summaries[2] self.HOM2 = self.summaries[3] self.U = self.summaries[4] x, y = self._to_triangle_coords( self.HOM1[self.indHIorder], self.HET[self.indHIorder], self.HOM2[self.indHIorder], ) self.scatter.set_offsets(np.column_stack([x, y])) self.scatter.set_facecolors(self._blend_colours()[self.indHIorder]) # -------------------------------------------------- # Widgets # -------------------------------------------------- def _init_widgets(self): DI_span = get_DI_span(self.dPol) self.fig.subplots_adjust(bottom=0.25) ax_DI = self.fig.add_axes([0.15, 0.12, 0.7, 0.03]) self.DI_slider = Slider( ax_DI, "DI", DI_span[0], DI_span[1], valinit=DI_span[0] ) self.DI_slider.on_changed(self.DIupdate) ax_SZ = self.fig.add_axes([0.25, 0.10, 0.1, 0.03]) self.size_slider = Slider( ax_SZ, "Symbol size", 10, 300, valinit=self.marker_size ) self.size_slider.on_changed(self.SIZEupdate) # -------------------------------------------------- # Callbacks # -------------------------------------------------- def DIupdate(self, val): payload = None if self._cache_tol > 0: payload = self._cache_get_nearest(val, self._cache_tol) if payload is None: # Lazy compute via statewise (consistent with cached path) chrom_counts, chrom_retained = statewise_genomes_summary_given_DI(self.dPol, val) summaries = summaries_from_statewise_counts(chrom_counts) DInumer = sum(n for n, _ in chrom_retained) DIdenom = sum(d for _, d in chrom_retained) payload = (summaries, DInumer, DIdenom) self._cache_set(val, payload) self.summaries, self.DInumer, self.DIdenom = payload # unpack (important) self.HOM1 = self.summaries[1] self.HET = self.summaries[2] self.HOM2 = self.summaries[3] self.U = self.summaries[4] self._update_points() self._update_title(val) self.fig.canvas.draw_idle() def SIZEupdate(self, val): self.marker_size = val self.scatter.set_sizes(np.full(len(self.dPol.indNames), val)) self.fig.canvas.draw_idle() # -------------------------------------------------- # Coordinate display # -------------------------------------------------- def _install_format_coord(self): tol = 0.03 def format_coord(x, y): fallback = "\u2007" * 30 pts = self.scatter.get_offsets() d = np.hypot(pts[:, 0] - x, pts[:, 1] - y) i = np.argmin(d) if d[i] < tol: return f"IndID: {self.dPol.indNames[self.indHIorder[i]]}" return fallback self.ax.format_coord = format_coord
"""________________________________________ END GenomicDeFinettiPlot ___________________""" """________________________________________ START GenomicMultiDeFinettiPlot ___________________"""
[docs] class GenomicMultiDeFinettiPlot: """ Multiple de Finetti plots, one per chromosome, all controlled by a shared DI slider and size slider. Uses statewise_genomes_summary_given_DI c.f. Figure 2, Figure 4: Petružela, J., Nürnberger, B., Ribas, A., Koutsovoulos, G., Čížková, D., Fornůsková, A., Aghová, T., Blaxter, M., de Bellocq, J.G. and Baird, S.J.E. (2025), Comparative Genomic Analysis of Co-Occurring Hybrid Zones of House Mouse Parasites Pneumocystis murina and Syphacia obvelata Using Genome Polarisation. Mol Ecol, 34: e70044. https://doi.org/10.1111/mec.70044 Figure 4: Ebdon, S., Laetsch, D. R., Vila, R., Baird, S. J. E., & Lohse, K. (2025). Genomic regions of current low hybridisation mark long-term barriers to gene flow in scarce swallowtail butterflies. PLoS Genetics, 21(4), 30. doi:https://doi.org/10.1371/journal.pgen.1011655 Drop-in extension: - optional incremental cache prefill using StatewiseDIIncrementalCache - DI slider uses cached results (nearest match within tol) - output + hover semantics unchanged Uses statewise_genomes_summary_given_DI + summaries_from_statewise_counts. Args: dPol: DiemType object containing genomic data. chrom_indices: List of chromosome indices to plot. """ def __init__( self, dPol, chrom_indices, max_cols=3, *, prefill_cache=False, # NEW prefill_step=None, # NEW cache_tol=None, # NEW progress=None, # NEW: "text" | "none" ): self.dPol = dPol self.chrom_indices = self._validate_chrom_indices(chrom_indices) self.ChrNickNames = [Chr_Nickname(name) for name in dPol.chrNames] self.max_cols = max_cols self.marker_size = 60 self.n_ind = len(dPol.indNames) self.indHIorder = np.arange(self.n_ind) # ---------- cache (per instance) ---------- # maps DI(float) -> (chrom_counts, chrom_retained) self._cache = {} self._cache_keys_sorted = [] self._cache_tol = 0.0 def _cache_set(di, payload): di = float(di) if di not in self._cache: bisect.insort(self._cache_keys_sorted, di) self._cache[di] = payload def _cache_get_nearest(di, tol): if not self._cache_keys_sorted: return None di = float(di) keys = self._cache_keys_sorted j = bisect.bisect_left(keys, di) candidates = [] if 0 <= j < len(keys): candidates.append(keys[j]) if 0 <= j - 1 < len(keys): candidates.append(keys[j - 1]) best = None best_dist = None for k in candidates: dist = abs(k - di) if best_dist is None or dist < best_dist: best = k best_dist = dist if best is None or best_dist is None or best_dist > tol: return None return self._cache[best] self._cache_set = _cache_set self._cache_get_nearest = _cache_get_nearest # ---------- initial statewise computation ---------- self.chrom_counts, self.chrom_retained = \ statewise_genomes_summary_given_DI(self.dPol, float("-inf")) # global summaries (authoritative ordering) self.global_summaries = summaries_from_statewise_counts(self.chrom_counts) self.global_HI = self.global_summaries[0] self.indHIorder = np.argsort(self.global_HI) # ---------- layout ---------- n_plots = len(self.chrom_indices) n_cols = min(self.max_cols, n_plots) n_rows = int(np.ceil(n_plots / n_cols)) self.fig, self.axes = plt.subplots( n_rows, n_cols, figsize=(4.8 * n_cols, 4.6 * n_rows), squeeze=False ) self.fig.subplots_adjust( left=0.06, right=0.98, top=0.92, bottom=0.32, hspace=0.45, wspace=0.25 ) # ---------- draw ---------- self.scatters = {} self.chrom_axes = {} axes_flat = self.axes.flatten() for ax, idx in zip(axes_flat, self.chrom_indices): self.chrom_axes[idx] = ax self._setup_axes(ax) self._draw_triangle(ax) self._draw_hwe_curve(ax) summaries = summaries_from_statewise_counts([self.chrom_counts[idx]]) _, HOM1, HET, HOM2, U = summaries x, y = self._to_triangle_coords( HOM1[self.indHIorder], HET[self.indHIorder], HOM2[self.indHIorder] ) colours = self._blend_colours(HOM1, HET, HOM2, U) sc = ax.scatter( x, y, s=self.marker_size, c=colours[self.indHIorder], edgecolor="black", linewidth=0.3 ) num, denom = self.chrom_retained[idx] ax.set_title(f"{self.ChrNickNames[idx]} | {num:,}/{denom:,} sites", fontsize=10) self.scatters[idx] = sc for ax in axes_flat[len(self.chrom_indices):]: ax.axis("off") # ---------- widgets ---------- self._init_widgets() # ---------- OPTIONAL: prefill incremental cache ---------- if prefill_cache: DI_span = get_DI_span(self.dPol) di_min, di_max = float(DI_span[0]), float(DI_span[1]) if prefill_step is None: span = di_max - di_min prefill_step = span / 200.0 if span > 0 else 1.0 if cache_tol is None: cache_tol = float(prefill_step) / 2.0 # Build DI grid including endpoints if di_max > di_min: n_steps = int(np.floor((di_max - di_min) / prefill_step)) di_values = [di_min + k * prefill_step for k in range(n_steps + 1)] if not di_values or di_values[-1] < di_max: di_values.append(di_max) else: di_values = [di_min] di_grid = np.asarray(di_values, dtype=float) inc = StatewiseDIIncrementalCache( self.dPol, di_grid=di_grid, progress=("text" if progress == "text" else None), label="GenomicMultiDeFinetti cache prefill", ).prefill() # Store snapshots (small memory compared to per-site) for di_key, payload in inc._snapshots.items(): self._cache_set(di_key, payload) self._cache_tol = float(cache_tol) # ---------- hover ---------- self._install_format_coord() plt.show() # ====================================================== # Helpers # ====================================================== @staticmethod def _to_triangle_coords(hom1, het, hom2): x = hom2 + 0.5 * het y = (np.sqrt(3) / 2) * het return x, y def _blend_colours(self, HOM1, HET, HOM2, U): weights = np.column_stack([HOM1, HET, HOM2, U]) base = np.array([ to_rgb(diemColours[1]), to_rgb(diemColours[2]), to_rgb(diemColours[3]), to_rgb(diemColours[0]), ]) return np.clip(weights @ base, 0, 1) def _setup_axes(self, ax): ax.set_aspect("equal") ax.set_xlim(-0.05, 1.05) ax.set_ylim(-0.05, np.sqrt(3)/2 + 0.05) ax.set_xticks([]) ax.set_yticks([]) for s in ax.spines.values(): s.set_visible(False) def _draw_triangle(self, ax): h = np.sqrt(3)/2 ax.add_patch(Polygon([[0,0],[1,0],[0.5,h]], fill=False, lw=1.2)) ax.text(0, -0.04, "HOM1", ha="center", va="top", fontsize=8) ax.text(1, -0.04, "HOM2", ha="center", va="top", fontsize=8) ax.text(0.5, h + 0.03, "HET", ha="center", va="bottom", fontsize=8) def _draw_hwe_curve(self, ax): p = np.linspace(0,1,400) x, y = self._to_triangle_coords(p*p, 2*p*(1-p), (1-p)**2) ax.plot(x, y, color="black", lw=0.8, alpha=0.5) # ====================================================== # Widgets # ====================================================== def _init_widgets(self): DI_span = get_DI_span(self.dPol) ax_DI = self.fig.add_axes([0.15, 0.20, 0.70, 0.035]) self.DI_slider = Slider(ax_DI, "DI", *DI_span, valinit=DI_span[0]) self.DI_slider.on_changed(self._on_DI_change) ax_SZ = self.fig.add_axes([0.25, 0.16, 0.1, 0.03]) self.size_slider = Slider(ax_SZ, "Symbol size", 10, 300, valinit=self.marker_size) self.size_slider.on_changed(self._on_size_change) # ====================================================== # Callbacks # ====================================================== def _on_DI_change(self, val): payload = None if self._cache_tol > 0: payload = self._cache_get_nearest(val, self._cache_tol) if payload is None: # fallback (no prefill or outside tol): compute and (optionally) cache chrom_counts, chrom_retained = statewise_genomes_summary_given_DI(self.dPol, val) payload = (chrom_counts, chrom_retained) # opportunistic cache self._cache_set(val, payload) self.chrom_counts, self.chrom_retained = payload # global summaries (authoritative ordering) self.global_summaries = summaries_from_statewise_counts(self.chrom_counts) self.indHIorder = np.argsort(self.global_summaries[0]) totNumer = 0 totDenom = 0 for idx in self.chrom_indices: summaries = summaries_from_statewise_counts([self.chrom_counts[idx]]) _, H1, Ht, H2, U = summaries x, y = self._to_triangle_coords( H1[self.indHIorder], Ht[self.indHIorder], H2[self.indHIorder] ) sc = self.scatters[idx] sc.set_offsets(np.column_stack([x, y])) sc.set_facecolors(self._blend_colours(H1, Ht, H2, U)[self.indHIorder]) num, denom = self.chrom_retained[idx] self.chrom_axes[idx].set_title( f"{self.ChrNickNames[idx]} | {num:,}/{denom:,} sites", fontsize=10 ) totNumer += num totDenom += denom # (you previously computed prop but didn't display it; keep behaviour unchanged) self.fig.canvas.draw_idle() def _on_size_change(self, val): self.marker_size = int(val) for sc in self.scatters.values(): sc.set_sizes(np.full(self.n_ind, self.marker_size)) self.fig.canvas.draw_idle() # ====================================================== # Hover # ====================================================== def _install_format_coord(self): tol = 0.03 names = self.dPol.indNames for idx, sc in self.scatters.items(): ax = self.chrom_axes[idx] def make_fmt(scatter): def fmt(x, y): pts = scatter.get_offsets() d = np.hypot(pts[:, 0] - x, pts[:, 1] - y) i = np.argmin(d) if d[i] < tol: return f"IndID: {names[self.indHIorder[i]]}" return "\u2007" * 30 return fmt ax.format_coord = make_fmt(sc) # ====================================================== # Validation # ====================================================== def _validate_chrom_indices(self, chrom_indices): max_idx = len(self.dPol.chrLengths) - 1 valid = [] rejected = [] for i in chrom_indices: try: ii = int(i) if 0 <= ii <= max_idx: valid.append(ii) else: rejected.append(i) except Exception: rejected.append(i) if rejected: print("GenomicMultiDeFinettiPlot: rejected chromosome indices:", rejected) if not valid: raise ValueError("No valid chromosome indices") return valid
"""________________________________________ END GenomicMultiDeFinettiPlot ___________________""" """________________________________________ START GenomicContributionsPlot ___________________"""
[docs] class GenomicContributionsPlot: """ Plots per-chromosome genomic contributions (HOM1, HET, HOM2, U, excluded) with DI filtering and interactive widgets. Drop-in extension: - optional incremental cache prefill using StatewiseDIIncrementalCache - DI slider uses cached statewise snapshots (nearest within tol) - output unchanged Uses statewise_genomes_summary_given_DI (or cached equivalent). Args: dPol: DiemType object containing genomic data. """ def __init__( self, dPol, chrom_indices=None, *, prefill_cache=False, # NEW prefill_step=None, # NEW cache_tol=None, # NEW progress=None, # NEW: "text" | "none" ): self.dPol = dPol self.chrom_indices = chrom_indices self.fontsize = 8 # ---- cache (per instance) ---- # maps DI(float) -> (chrom_counts, chrom_retained) self._cache = {} self._cache_keys_sorted = [] self._cache_tol = 0.0 def _cache_set(di, payload): di = float(di) if di not in self._cache: bisect.insort(self._cache_keys_sorted, di) self._cache[di] = payload def _cache_get_nearest(di, tol): if not self._cache_keys_sorted: return None di = float(di) keys = self._cache_keys_sorted j = bisect.bisect_left(keys, di) candidates = [] if 0 <= j < len(keys): candidates.append(keys[j]) if 0 <= j - 1 < len(keys): candidates.append(keys[j - 1]) best = None best_dist = None for k in candidates: dist = abs(k - di) if best_dist is None or dist < best_dist: best = k best_dist = dist if best is None or best_dist is None or best_dist > tol: return None return self._cache[best] self._cache_set = _cache_set self._cache_get_nearest = _cache_get_nearest # -------------------------------------------------- # Initial compute (no filter) # -------------------------------------------------- self.DInumer = 0 self.DIdenom = 0 self._compute_contributions(float("-inf")) # ---- figure & axes ---- self.fig, self.ax = plt.subplots(figsize=(10, 5)) self.ax.format_coord = None self.fig.subplots_adjust(bottom=0.40, right=0.85) self._draw_bars() self._init_widgets() # -------------------------------------------------- # OPTIONAL: prefill incremental cache # -------------------------------------------------- if prefill_cache: DI_span = get_DI_span(self.dPol) di_min, di_max = float(DI_span[0]), float(DI_span[1]) if prefill_step is None: span = di_max - di_min prefill_step = span / 200.0 if span > 0 else 1.0 if cache_tol is None: cache_tol = float(prefill_step) / 2.0 # Build DI grid including endpoints if di_max > di_min: n_steps = int(np.floor((di_max - di_min) / prefill_step)) di_values = [di_min + k * prefill_step for k in range(n_steps + 1)] if not di_values or di_values[-1] < di_max: di_values.append(di_max) else: di_values = [di_min] di_grid = np.asarray(di_values, dtype=float) inc = StatewiseDIIncrementalCache( self.dPol, di_grid=di_grid, progress=("text" if progress == "text" else None), label="GenomicContributions cache prefill", ).prefill() # Store snapshots for di_key, payload in inc._snapshots.items(): self._cache_set(di_key, payload) self._cache_tol = float(cache_tol) plt.show() # -------------------------------------------------- # Core computation # -------------------------------------------------- def _compute_contributions(self, DIval): """ Compute contributions at DIval. Uses cache if available (nearest within tol), otherwise computes directly. """ payload = None if self._cache_tol > 0: payload = self._cache_get_nearest(DIval, self._cache_tol) if payload is None: chrom_counts, chrom_retained = statewise_genomes_summary_given_DI(self.dPol, DIval) payload = (chrom_counts, chrom_retained) # opportunistic cache (even without prefill, harmless) self._cache_set(DIval, payload) else: chrom_counts, chrom_retained = payload # -------------------------------------------------- # Restrict chromosomes if requested (same as your code) # -------------------------------------------------- if self.chrom_indices is not None: kept = [] n_chr = len(chrom_counts) for ci in self.chrom_indices: if isinstance(ci, (int, np.integer)) and 0 <= int(ci) < n_chr: kept.append(int(ci)) if not kept: raise ValueError("GenomicContributionsPlot: no valid chromosome indices") else: kept = range(len(chrom_counts)) # -------------------------------------------------- # Aggregate over kept chromosomes # -------------------------------------------------- self.DInumer = 0 self.DIdenom = 0 kept_list = list(kept) self.chrom_labels = [] self.props = np.zeros((len(kept_list), 5)) # HOM1, HET, HOM2, U, excluded for out_i, chr_i in enumerate(kept_list): chr_name = Chr_Nickname(self.dPol.chrNames[chr_i]) self.chrom_labels.append(chr_name) counts = chrom_counts[chr_i] kept_sites, total_sites = chrom_retained[chr_i] self.DInumer += kept_sites self.DIdenom += total_sites c0 = float(np.sum(counts["counts0"])) c1 = float(np.sum(counts["counts1"])) c2 = float(np.sum(counts["counts2"])) c3 = float(np.sum(counts["counts3"])) total_alleles = float(total_sites) * float(np.sum(self.dPol.chrPloidies[chr_i])) if total_alleles == 0: continue self.props[out_i, :] = [ c1 / total_alleles, # HOM1 c2 / total_alleles, # HET c3 / total_alleles, # HOM2 c0 / total_alleles, # U (1.0 - kept_sites / total_sites) if total_sites else 0 ] self.current_DI = float(DIval) # -------------------------------------------------- # Drawing # -------------------------------------------------- def _draw_bars(self): self.ax.clear() x = np.arange(len(self.chrom_labels)) bottoms = np.zeros(len(x)) colours = [ diemColours[1], # HOM1 diemColours[2], # HET diemColours[3], # HOM2 "lightgray", # U "white", # excluded ] labels = ["HOM1", "HET", "HOM2", "U", "<DI"] for i in range(5): self.ax.bar( x, self.props[:, i], bottom=bottoms, color=colours[i], edgecolor="black" if i == 4 else None, linewidth=0.4 if i == 4 else 0, label=labels[i], ) bottoms += self.props[:, i] self.ax.set_xlim(-0.5, len(x) - 0.5) self.ax.set_ylim(0, 1) self.ax.set_xticks(x) self.ax.set_xticklabels( self.chrom_labels, rotation=90, fontsize=self.fontsize, ha="center", ) self.ax.set_ylabel("Proportion of SNVs") prop = self.DInumer / self.DIdenom if self.DIdenom > 0 else 0.0 self.ax.set_title( "Genomic contributions plot DI ≥ {:.2f} {} SNVs ({:.1f}% divergent across barrier)" .format(self.current_DI, self.DInumer, 100 * prop), fontsize=12, pad=12 ) self.ax.legend( loc="upper left", bbox_to_anchor=(1.01, 1.0), fontsize=8, frameon=False, ) self.fig.canvas.draw_idle() # -------------------------------------------------- # Widgets # -------------------------------------------------- def _init_widgets(self): DI_span = get_DI_span(self.dPol) ax_DI = self.fig.add_axes([0.15, 0.20, 0.70, 0.03]) self.DI_slider = Slider( ax_DI, "DI", DI_span[0], DI_span[1], valinit=DI_span[0], ) self.DI_slider.on_changed(self.DIupdate) ax_FS = self.fig.add_axes([0.15, 0.13, 0.15, 0.03]) self.font_slider = Slider( ax_FS, "Label font size", 4, 16, valinit=self.fontsize, valstep=1, ) self.font_slider.on_changed(self.FONTupdate) # -------------------------------------------------- # Callbacks # -------------------------------------------------- def DIupdate(self, val): self._compute_contributions(val) self._draw_bars() def FONTupdate(self, val): self.fontsize = int(val) self._draw_bars()
"""________________________________________ END GenomicContributions__________________""" """________________________________________ START IndGenomicContributions ___________________""" class IndGenomicContributionsPlot: """ Stacked-bar genomic contributions per INDIVIDUAL (HOM1, HET, HOM2, U, excluded), with DI filtering and widgets. - bars correspond to individuals (like GenomeSummaryPlot focus) - visuals correspond to GenomicContributionsPlot - reorder-by-HI button (HI computed at current DI) - optional incremental cache prefill via StatewiseDIIncrementalCache Uses: statewise_genomes_summary_given_DI c.f. Figure 2, Figure 4: Petružela, J., Nürnberger, B., Ribas, A., Koutsovoulos, G., Čížková, D., Fornůsková, A., Aghová, T., Blaxter, M., de Bellocq, J.G. and Baird, S.J.E. (2025), Comparative Genomic Analysis of Co-Occurring Hybrid Zones of House Mouse Parasites Pneumocystis murina and Syphacia obvelata Using Genome Polarisation. Mol Ecol, 34: e70044. https://doi.org/10.1111/mec.70044 """ def __init__( self, dPol, *, prefill_cache=False, # NEW prefill_step=None, # NEW cache_tol=None, # NEW progress=None, # NEW: "text" | "none" ): self.dPol = dPol # labels + ordering self.IndNickNames = [Ind_Nickname(name) for name in dPol.indNames] self.indNameFont = 6 self.indHIorder = np.arange(len(dPol.indNames), dtype=int) # plotted state self.current_DI = float("-inf") self.DInumer = 0 self.DIdenom = 0 self.global_HI = None self.props = None # (nInd, 5) # ---- cache (per instance) ---- # maps DI(float) -> (chrom_counts, chrom_retained) self._cache = {} self._cache_keys_sorted = [] self._cache_tol = 0.0 def _cache_set(di, payload): di = float(di) if di not in self._cache: bisect.insort(self._cache_keys_sorted, di) self._cache[di] = payload def _cache_get_nearest(di, tol): if not self._cache_keys_sorted: return None di = float(di) keys = self._cache_keys_sorted j = bisect.bisect_left(keys, di) candidates = [] if 0 <= j < len(keys): candidates.append(keys[j]) if 0 <= j - 1 < len(keys): candidates.append(keys[j - 1]) best = None best_dist = None for k in candidates: dist = abs(k - di) if best_dist is None or dist < best_dist: best = k best_dist = dist if best is None or best_dist is None or best_dist > tol: return None return self._cache[best] self._cache_set = _cache_set self._cache_get_nearest = _cache_get_nearest # ---- figure & axes ---- self.fig, self.ax = plt.subplots(figsize=(11, 4.8)) self.fig.subplots_adjust(bottom=0.35, right=0.88) self.ax.format_coord = None # no hover # ---- initial compute ---- self._compute_props(float("-inf")) # ---- INITIAL ORDER BY HI (NEW) ---- if self.global_HI is not None: self.indHIorder = np.argsort(self.global_HI) # ---- draw ---- self._draw_bars() # ---- widgets ---- self._init_widgets() # ---- OPTIONAL: incremental cache prefill ---- if prefill_cache: DI_span = get_DI_span(self.dPol) di_min, di_max = float(DI_span[0]), float(DI_span[1]) if prefill_step is None: span = di_max - di_min prefill_step = span / 200.0 if span > 0 else 1.0 if cache_tol is None: cache_tol = float(prefill_step) / 2.0 # Build DI grid including endpoints if di_max > di_min: n_steps = int(np.floor((di_max - di_min) / prefill_step)) di_values = [di_min + k * prefill_step for k in range(n_steps + 1)] if not di_values or di_values[-1] < di_max: di_values.append(di_max) else: di_values = [di_min] di_grid = np.asarray(di_values, dtype=float) inc = StatewiseDIIncrementalCache( self.dPol, di_grid=di_grid, progress=("text" if progress == "text" else None), label="IndGenomicContrib cache prefill", ).prefill() # Store snapshots for di_key, payload in inc._snapshots.items(): self._cache_set(di_key, payload) self._cache_tol = float(cache_tol) self.fig.canvas.draw() # force first render NOW (fix for issue 1) plt.show() # -------------------------------------------------- # Core computation # -------------------------------------------------- def _get_statewise_payload(self, DIval): """ Return (chrom_counts, chrom_retained) using cache if available. """ payload = None used_cache = False # <-- ALWAYS define this if self._cache_tol > 0: payload = self._cache_get_nearest(DIval, self._cache_tol) if payload is None: chrom_counts, chrom_retained = statewise_genomes_summary_given_DI(self.dPol, DIval) payload = (chrom_counts, chrom_retained) # opportunistic cache (safe even without prefill) self._cache_set(DIval, payload) return payload def _compute_props(self, DIval): """ Compute per-individual stacked proportions: HOM1, HET, HOM2, U, excluded Denominator is TOTAL alleles (including excluded DI-filtered-out sites), so excluded portion is meaningful per individual even under variable ploidy. """ chrom_counts, chrom_retained = self._get_statewise_payload(DIval) nChr = len(chrom_counts) nInd = chrom_counts[0]["counts0"].shape[0] # totals of sites retained/total (SNVs, not alleles) for title self.DInumer = sum(n for (n, _) in chrom_retained) self.DIdenom = sum(d for (_, d) in chrom_retained) self.current_DI = float(DIval) # Sum ploidy-weighted counts over chromosomes for each individual sum0 = np.zeros(nInd, dtype=float) sum1 = np.zeros(nInd, dtype=float) sum2 = np.zeros(nInd, dtype=float) sum3 = np.zeros(nInd, dtype=float) # Total alleles per individual across ALL sites (retained + excluded) total_alleles_all = np.zeros(nInd, dtype=float) for chr_i in range(nChr): c = chrom_counts[chr_i] sum0 += c["counts0"] sum1 += c["counts1"] sum2 += c["counts2"] sum3 += c["counts3"] # total alleles for this chromosome for each individual total_sites_chr = chrom_retained[chr_i][1] # denom w = np.asarray(self.dPol.chrPloidies[chr_i], dtype=float) # (nInd,) total_alleles_all += float(total_sites_chr) * w retained_alleles = sum0 + sum1 + sum2 + sum3 # Avoid division by 0 (rare but possible) denom = np.where(total_alleles_all > 0, total_alleles_all, 1.0) props = np.zeros((nInd, 5), dtype=float) props[:, 0] = sum1 / denom # HOM1 props[:, 1] = sum2 / denom # HET props[:, 2] = sum3 / denom # HOM2 props[:, 3] = sum0 / denom # U props[:, 4] = 1.0 - (retained_alleles / denom) # excluded (DI-filtered-out) # Clamp minor floating error props = np.clip(props, 0.0, 1.0) self.props = props # Compute global HI for reorder-by-HI (authoritative from statewise counts) global_summaries = summaries_from_statewise_counts(chrom_counts) self.global_HI = global_summaries[0] # HI is first summary # -------------------------------------------------- # Drawing # -------------------------------------------------- def _draw_bars(self): self.ax.clear() nInd = len(self.dPol.indNames) order = self.indHIorder x = np.arange(nInd) bottoms = np.zeros(nInd, dtype=float) colours = [ diemColours[1], # HOM1 diemColours[2], # HET diemColours[3], # HOM2 "lightgray", # U "white", # excluded ] labels = ["HOM1", "HET", "HOM2", "U", "<DI"] for k in range(5): y = np.asarray(self.props[order, k], dtype=float) self.ax.bar( x, y, bottom=bottoms.copy(), # freeze for this layer color=colours[k], edgecolor="black" if k == 4 else None, linewidth=0.4 if k == 4 else 0, label=labels[k], ) bottoms = bottoms + y # NOT in-place self.ax.set_xlim(-0.5, nInd - 0.5) self.ax.set_ylim(0, 1) self.ax.set_ylabel("Proportion of genotypes") # xticks labels = np.array(self.IndNickNames)[order] self.ax.set_xticks(x) self.ax.set_xticklabels( labels, rotation=90, fontsize=self.indNameFont, ha="center", ) # title prop_sites = self.DInumer / self.DIdenom if self.DIdenom > 0 else 0.0 self.ax.set_title( "Individual genomic contributions DI ≥ {:.2f} {} SNVs ({:.1f}% divergent across barrier)" .format(self.current_DI, self.DInumer, 100 * prop_sites), fontsize=12, pad=12 ) # legend outside self.ax.legend( loc="upper left", bbox_to_anchor=(1.01, 1.0), fontsize=8, frameon=False, ) self.fig.canvas.draw_idle() # -------------------------------------------------- # Widgets # -------------------------------------------------- def _init_widgets(self): DI_span = get_DI_span(self.dPol) # DI slider ax_DI = self.fig.add_axes([0.18, 0.18, 0.64, 0.03]) self.DI_slider = Slider( ax_DI, "DI", DI_span[0], DI_span[1], valinit=DI_span[0], ) self.DI_slider.on_changed(self.DIupdate) # font size slider ax_FS = self.fig.add_axes([0.18, 0.11, 0.14, 0.03]) self.font_slider = Slider( ax_FS, "Label font", 4, 16, valinit=self.indNameFont, valstep=1, ) self.font_slider.on_changed(self.FONTupdate) # reorder button ax_RE = self.fig.add_axes([0.8, 0.025, 0.1, 0.04])#[0.70, 0.105, 0.15, 0.045]) # EURG self.reorder_button = Button( ax_RE, "Reorder by HI", hovercolor="0.95", color="red", ) self.reorder_button.on_clicked(self.reorder) # -------------------------------------------------- # Callbacks # -------------------------------------------------- def DIupdate(self, val): # recompute props + HI at this DI self._compute_props(val) # keep current ordering unless user presses reorder self._draw_bars() def FONTupdate(self, val): self.indNameFont = int(val) self._draw_bars() def reorder(self, event=None): """ Reorder individuals by HI given the CURRENT DI threshold. """ if self.global_HI is None: print(" global_HI is None -> returning") return self.indHIorder = np.argsort(self.global_HI) print(" order AFTER reorder head:", self.indHIorder[:12]) print(" HI head (new order):", np.round(self.global_HI[self.indHIorder][:10], 4)) self._draw_bars() """________________________________________ END IndGenomicContributions ___________________""" """________________________________________ START diemPairsPlot ___________________""" def pwmatrixFromDiemType(aDT, DIthreshold=float("-inf")): """ Compute a DI-filtered pairwise distance matrix from a DiemType object. DMBCtoucher Args: aDT : DiemType DIthreshold : float Only sites with DI >= DIthreshold are retained. Returns: M : (N, N) numpy array Symmetric pairwise distance matrix. """ # ------------------------------------------------- # Dimensions # ------------------------------------------------- n_ind = aDT.DMBC[0].shape[0] # ------------------------------------------------- # Pairwise weight matrix (codes 0..3) # ------------------------------------------------- W = np.zeros((4, 4), dtype=float) W[1, 2] = W[2, 1] = 1 W[1, 3] = W[3, 1] = 2 W[2, 2] = 1 W[2, 3] = W[3, 2] = 1 # (1,1) and (3,3) remain 0 # ------------------------------------------------- # Accumulators # ------------------------------------------------- num = np.zeros((n_ind, n_ind), dtype=float) den = np.zeros((n_ind, n_ind), dtype=float) # ------------------------------------------------- # Single pass over chromosomes # ------------------------------------------------- for chr_idx, SM in enumerate(aDT.DMBC): # SM shape: (n_ind, n_sites) DIvals = aDT.DIByChr[chr_idx] keep = DIvals >= DIthreshold if not np.any(keep): continue SMf = SM[:, keep] # iterate retained sites for s in range(SMf.shape[1]): col = SMf[:, s] valid = col != 0 idx = np.where(valid)[0] if idx.size < 2: continue vals = col[idx] # pairwise contribution for ii, i in enumerate(idx): ai = vals[ii] for jj in range(ii + 1, len(idx)): j = idx[jj] aj = vals[jj] w = W[ai, aj] num[i, j] += w num[j, i] += w den[i, j] += 1 den[j, i] += 1 # ------------------------------------------------- # Final matrix # ------------------------------------------------- M = np.full((n_ind, n_ind), np.nan) mask = den > 0 M[mask] = num[mask] / den[mask] #np.fill_diagonal(M, 0.0) an example of ChatGPT 5.2 suggestion that is not biologically valid return M # for parallel computation of pairwise distance matrix rows def _pwmatrix_row(i, G, W): """ Compute one row of the pairwise distance matrix. """ n = G.shape[0] row = np.zeros(n, dtype=float) ai = G[i] for j in range(n): aj = G[j] valid = (ai != 0) & (aj != 0) denom = valid.sum() if denom == 0: row[j] = np.nan else: row[j] = W[ai[valid], aj[valid]].sum() / denom return i, row #------------------version without char sidestep------------------------------- def _pwmatrix_row_numeric(i, G, W): """ Compute one row of the pairwise distance matrix from numeric codes. Distance definition (see also diem2fasta.py): - exclude sites where either is Unencodable (0) - per-site contributions via W - normalize by 2 so HET×HET -> 0.5 and HOM1×HOM2 -> 1.0 - diagonal is within-individual: 0.5 per HET site, 0 otherwise, excluding Unencodable """ n = G.shape[0] row = np.full(n, np.nan, dtype=float) # default NaN if no valid sites ai = G[i] ai_ok = (ai != 0) for j in range(n): aj = G[j] if i == j: # within-individual: exclude unencodable sites valid = ai_ok else: # pairwise: exclude sites where either is unencodable valid = ai_ok & (aj != 0) denom = int(valid.sum()) if denom: # divide by 2 to implement "average over phases" (partial credit) row[j] = W[ai[valid], aj[valid]].sum() / (2.0 * denom) return i, row def PARApwmatrixFromDiemType( aDT, DIthreshold=float("-inf"), chrom_indices=None, n_jobs=-1, backend="loky", ): """ Parallel computation of pairwise distance matrix. Args ---- aDT : DiemType DIthreshold : float n_jobs : int Number of cores (-1 = all) backend : str joblib backend ("loky" recommended) Returns ------- M : (n_ind, n_ind) numpy array """ # ------------------------------------------------- # DI-filtered genomes, optionally chromosome-restricted # ------------------------------------------------- chunks = [] n_ind = aDT.DMBC[0].shape[0] # Default: all chromosomes if chrom_indices is None: chrom_indices = range(len(aDT.DMBC)) for chr_idx in chrom_indices: SM = aDT.DMBC[chr_idx] DI = aDT.DIByChr[chr_idx] keep = DI >= DIthreshold if not np.any(keep): continue # DMBC toucher: any int8 greater than 3 goes lto Derek state 0: U SMk = SM[:, keep] SMf = np.where((SMk >= 0) & (SMk <= 3), SMk, 0).astype(np.int8) chunks.append(SMf) if not chunks: return np.full((n_ind, n_ind), np.nan) # Concatenate retained chromosomes G = np.concatenate(chunks, axis=1) n = G.shape[0] # ------------------------------------------------- # Distance weight matrix # ------------------------------------------------- W = np.zeros((4, 4), dtype=float) W[1, 2] = W[2, 1] = 1 W[1, 3] = W[3, 1] = 2 W[2, 2] = 1 W[2, 3] = W[3, 2] = 1 # (0,* and *,0 excluded by valid mask) # ------------------------------------------------- # Parallel row computation # ------------------------------------------------- results = Parallel( n_jobs=n_jobs, backend=backend, prefer="processes", batch_size=1, )( delayed(_pwmatrix_row_numeric)(i, G, W) for i in range(n) ) # ------------------------------------------------- # Assemble matrix # ------------------------------------------------- M = np.zeros((n, n), dtype=float) for i, row in results: M[i, :] = row return M #-------------------------------------------------
[docs] class diemPairsPlot: """ Pairwise distance plot using BRICK rectangles (no imshow), now with: - DI slider - Reorder by HI button - optional incremental cache prefill (pairwise + optional HI/statewise) """ """ Pairwise distance plot using brickDiagram semantics. Uses genomes_summary_given_DI and PARApwmatrixFromDiemType c.f. Figure 2, Figure 4: Petružela, J., Nürnberger, B., Ribas, A., Koutsovoulos, G., Čížková, D., Fornůsková, A., Aghová, T., Blaxter, M., de Bellocq, J.G. and Baird, S.J.E. (2025), Comparative Genomic Analysis of Co-Occurring Hybrid Zones of House Mouse Parasites Pneumocystis murina and Syphacia obvelata Using Genome Polarisation. Mol Ecol, 34: e70044. https://doi.org/10.1111/mec.70044 Coding co-pilot: ChatGPT 5.2 Left panel: Square heatmap of pairwise distances (brick rectangles), ordered by Hybrid Index at the specified DI threshold. Right panel: Vertical colour key. Hover: Shows "IndA × IndB : distance". New features: - DI slider - Reorder by HI button - incremental cache prefill (pairwise + optional HI/statewise) """ def __init__( self, dPol, DIthreshold=float("-inf"), figsize=(9, 6), chrom_indices=None, # caching options prefill_cache=True, prefill_step=None, cache_tol=None, progress=None, # None | "text" cache_statewise_for_HI=True, # True recommended if you already have StatewiseDIIncrementalCache ): self.dPol = dPol self.chrom_indices = chrom_indices self.current_DI = float(DIthreshold) # ---------- DI span / grid ---------- DI_span = get_DI_span(self.dPol) self.di_min, self.di_max = float(DI_span[0]), float(DI_span[1]) if prefill_step is None: span = self.di_max - self.di_min prefill_step = span / 200.0 if span > 0 else 1.0 if cache_tol is None: cache_tol = float(prefill_step) / 2.0 self._cache_tol = float(cache_tol) self.di_grid = self._build_di_grid(self.di_min, self.di_max, float(prefill_step)) # ---------- caches ---------- self._inc_pairwise = None self._inc_statewise = None if prefill_cache: # Pairwise incremental cache self._inc_pairwise = PairwiseDIIncrementalCache( self.dPol, di_grid=self.di_grid, chrom_indices=self.chrom_indices, progress=("text" if progress == "text" else None), label="PairsPlot matrix cache prefill", snapshot_dtype=np.float32, ).prefill() # Optional HI ordering cache (fast ordering under slider motion) if cache_statewise_for_HI: self._inc_statewise = StatewiseDIIncrementalCache( self.dPol, di_grid=self.di_grid, progress=("text" if progress == "text" else None), label="PairsPlot HI cache prefill", ).prefill() # ---------- figure layout ---------- self.fig = plt.figure(figsize=figsize) gs = self.fig.add_gridspec( nrows=1, ncols=2, width_ratios=[20, 1], wspace=0.08 ) self.ax = self.fig.add_subplot(gs[0, 0]) self.cax = self.fig.add_subplot(gs[0, 1]) # room at bottom for widgets self.fig.subplots_adjust(bottom=0.22) # ---------- colormap ---------- self.cmap = LinearSegmentedColormap.from_list( "soft_coolwarm", ["#1e90ff", "white", "#fff266", "#ff1a1a"] ) # ---------- initial compute ---------- self._compute_for_DI(self.current_DI, force_reorder=True) # ---------- draw once ---------- self._setup_axes() self._init_bricks() # main matrix bricks (persistent) self._init_colour_key() # colour key bricks (persistent) # ---------- widgets ---------- self._init_widgets() # ---------- hover ---------- self._install_format_coord() plt.show() # ================================================= # DI grid helpers # ================================================= @staticmethod def _build_di_grid(di_min, di_max, step): if di_max > di_min: n_steps = int(np.floor((di_max - di_min) / step)) vals = [di_min + k * step for k in range(n_steps + 1)] if not vals or vals[-1] < di_max: vals.append(di_max) else: vals = [di_min] return np.asarray(vals, dtype=float) def _nearest_grid_value(self, di): di = float(di) grid = self.di_grid j = int(np.argmin(np.abs(grid - di))) return float(grid[j]) # ================================================= # Computation # ================================================= def _get_HI_and_retained_for_DI(self, DIthreshold): """ Return (HI, DInumer, DIdenom) at DIthreshold. DInumer/DIdenom are SNV counts retained/total across all chromosomes. """ if self._inc_statewise is not None: di_key = self._nearest_grid_value(DIthreshold) chrom_counts, chrom_retained = self._inc_statewise._snapshots[di_key] HI = summaries_from_statewise_counts(chrom_counts)[0] DInumer = sum(n for (n, _) in chrom_retained) DIdenom = sum(d for (_, d) in chrom_retained) return HI, DInumer, DIdenom # fallback (no statewise cache) # Use the same statewise function IndGenomicContributionsPlot uses chrom_counts, chrom_retained = statewise_genomes_summary_given_DI(self.dPol, float(DIthreshold)) HI = summaries_from_statewise_counts(chrom_counts)[0] DInumer = sum(n for (n, _) in chrom_retained) DIdenom = sum(d for (_, d) in chrom_retained) return HI, DInumer, DIdenom def _get_M_for_DI(self, DIthreshold): if self._inc_pairwise is not None: return self._inc_pairwise.get(DIthreshold) # Fallback: the existing parallel matrix builder return PARApwmatrixFromDiemType( self.dPol, DIthreshold=float(DIthreshold), chrom_indices=self.chrom_indices, ) def _compute_for_DI(self, DIthreshold, *, force_reorder=False): """ Compute matrix + (optionally) compute HI ordering. If force_reorder=False, keep existing order and only update M accordingly. """ self.current_DI = float(DIthreshold) # matrix (unordered) M_raw = self._get_M_for_DI(self.current_DI) # ordering HI, DInumer, DIdenom = self._get_HI_and_retained_for_DI(self.current_DI) self.DInumer = int(DInumer) self.DIdenom = int(DIdenom) if force_reorder or not hasattr(self, "ind_order") or self.ind_order is None: self.ind_order = np.argsort(HI) # apply ordering self.indNames = np.array(self.dPol.indNames)[self.ind_order] self.n = len(self.indNames) self.M = M_raw[self.ind_order][:, self.ind_order] # range self.vmin = float(np.nanmin(self.M)) if np.any(np.isfinite(self.M)) else 0.0 self.vmax = float(np.nanmax(self.M)) if np.any(np.isfinite(self.M)) else 1.0 # ================================================= # Drawing setup # ================================================= def _setup_axes(self): self.ax.set_xlim(0, self.n) self.ax.set_ylim(0, self.n) self.ax.set_aspect("equal") centers = np.arange(self.n) + 0.5 self.ax.set_xticks(centers) self.ax.set_yticks(centers) # initial label fontsize (linked to slider) self.label_fontsize = 4 self.ax.set_xticklabels(self.indNames, rotation=90, fontsize=self.label_fontsize) self.ax.set_yticklabels(self.indNames, fontsize=self.label_fontsize) prop_sites = self.DInumer / self.DIdenom if self.DIdenom > 0 else 0.0 self.ax.set_title( "Pairwise distances DI ≥ {:.2f} {} SNVs ({:.1f}% divergent across barrier)" .format(self.current_DI, self.DInumer, 100 * prop_sites), pad=10, ) def _color_for_val(self, val, norm): if not np.isfinite(val): return "black" return self.cmap(norm(val)) # ================================================= # Main bricks: create once, update facecolors later # ================================================= def _init_bricks(self): # Remove only previously created bricks (do NOT use self.ax.patches.clear()) if hasattr(self, "_bricks"): for r in self._bricks: try: r.remove() except Exception: pass norm = plt.Normalize(self.vmin, self.vmax) self._norm = norm self._bricks = [] # IMPORTANT: preserve your original orientation: val = M[j, i] drawn at (i, j) for i in range(self.n): for j in range(self.n): val = self.M[j, i] rect = Rectangle( (i, j), 1, 1, facecolor=self._color_for_val(val, norm), edgecolor="none" ) self.ax.add_patch(rect) self._bricks.append(rect) self.fig.canvas.draw_idle() def _update_bricks(self): # update normalisation if needed norm = plt.Normalize(self.vmin, self.vmax) self._norm = norm # update all facecolors k = 0 for i in range(self.n): for j in range(self.n): val = self.M[j, i] self._bricks[k].set_facecolor(self._color_for_val(val, norm)) k += 1 # update title #self.ax.set_title(f"Pairwise distances (DI ≥ {self.current_DI:.2f})", pad=10) prop_sites = self.DInumer / self.DIdenom if self.DIdenom > 0 else 0.0 self.ax.set_title( "Pairwise distances DI ≥ {:.2f} {} SNVs ({:.1f}% divergent across barrier)" .format(self.current_DI, self.DInumer, 100 * prop_sites), pad=10, ) self.fig.canvas.draw_idle() # ================================================= # Colour key: bricks, not imshow # ================================================= def _init_colour_key(self): self.cax.clear() # key as a vertical stack of small bricks self._key_bins = 256 self._key_rects = [] # axis coordinates: # x in [0,1], y in [0, key_bins] self.cax.set_xlim(0, 1) self.cax.set_ylim(0, self._key_bins) norm = plt.Normalize(self.vmin, self.vmax) for y in range(self._key_bins): # map y to value in [vmin, vmax] frac = y / (self._key_bins - 1) val = self.vmin + frac * (self.vmax - self.vmin) rect = Rectangle( (0, y), 1, 1, facecolor=self.cmap(norm(val)), edgecolor="none" ) self.cax.add_patch(rect) self._key_rects.append(rect) # ticks + labels self.cax.set_xticks([]) self.cax.set_yticks([0, self._key_bins - 1]) self.cax.set_yticklabels([f"{self.vmin:.2f}", f"{self.vmax:.2f}"], fontsize=8) self.cax.set_title("Distance", fontsize=9) self.fig.canvas.draw_idle() def _update_colour_key(self): # recolour bricks according to updated vmin/vmax norm = plt.Normalize(self.vmin, self.vmax) for y in range(self._key_bins): frac = y / (self._key_bins - 1) val = self.vmin + frac * (self.vmax - self.vmin) self._key_rects[y].set_facecolor(self.cmap(norm(val))) self.cax.set_yticklabels([f"{self.vmin:.2f}", f"{self.vmax:.2f}"], fontsize=8) self.fig.canvas.draw_idle() # ================================================= # Widgets # ================================================= def _init_widgets(self): # DI slider ax_DI = self.fig.add_axes([0.18, 0.10, 0.52, 0.03]) self.DI_slider = Slider(ax_DI, "DI", self.di_min, self.di_max, valinit=self.di_min) self.DI_slider.on_changed(self._on_DI_change) # reorder button ax_RE = self.fig.add_axes([0.82, 0.10, 0.14, 0.035]) self.reorder_button = Button(ax_RE, "Reorder by HI", hovercolor="0.95", color="red") self.reorder_button.on_clicked(self._on_reorder) # font slider: beneath colour key (your original style) pos = self.cax.get_position() ax_FS = self.fig.add_axes([pos.x0, 0.04, pos.width, 0.03]) self.font_slider = Slider(ax_FS, "Labels", 0, 8, valinit=self.label_fontsize, valstep=1) self.font_slider.on_changed(self._on_fontsize_change) # ================================================= # Callbacks # ================================================= def _on_DI_change(self, val): # If cached, snap to nearest grid DI if self._inc_pairwise is not None or self._inc_statewise is not None: di_eff = self._nearest_grid_value(val) else: di_eff = float(val) # IMPORTANT: keep current ordering during slider motion self._compute_for_DI(di_eff, force_reorder=False) # update bricks + key self._update_bricks() self._update_colour_key() def _on_reorder(self, event=None): # recompute ordering at current DI, then redraw labels + bricks self._compute_for_DI(self.current_DI, force_reorder=True) # update axis tick labels (same ticks, new labels) centers = np.arange(self.n) + 0.5 self.ax.set_xticks(centers) self.ax.set_yticks(centers) self.ax.set_xticklabels(self.indNames, rotation=90, fontsize=self.label_fontsize) self.ax.set_yticklabels(self.indNames, fontsize=self.label_fontsize) self._update_bricks() self._update_colour_key() def _on_fontsize_change(self, val): fs = int(val) self.label_fontsize = fs self.ax.set_xticklabels(self.indNames, rotation=90, fontsize=fs) self.ax.set_yticklabels(self.indNames, fontsize=fs) self.fig.canvas.draw_idle() # ================================================= # Hover # ================================================= def _install_format_coord(self): n = self.n def format_coord(x, y): fallback = " " * 40 i = int(np.floor(x)) j = int(np.floor(y)) if 0 <= i < n and 0 <= j < n: a = self.indNames[j] b = self.indNames[i] d = self.M[j, i] if np.isfinite(d): return f"{a} × {b} : {d:.3f}" return f"{a} × {b} : NA" return fallback self.ax.format_coord = format_coord
"""________________________________________ END diemPairsPlot ___________________""" """________________________________________ START DiemMultiPairsPlot__________________"""
[docs] class diemMultiPairsPlot: """ Multi-chromosome version of diemPairsPlot. One brick heatmap per chromosome, ordered by global Hybrid Index, arranged in a grid. The top-right grid cell contains the shared colour key. Widgets: - DI slider (updates matrices, keeps current order) - Reorder by HI button (recomputes order at current DI) - Label font size slider (all subplots) Optional: - prefill caching (pairwise per chromosome + optional HI/statewise) """ def __init__( self, dPol, chrom_indices, DIthreshold=float("-inf"), max_cols=3, figsize=(12, 8), # caching options prefill_cache=True, prefill_step=None, cache_tol=None, progress=None, # None | "text" cache_statewise_for_HI=True, # True recommended if using cache # layout tweak row_hspace=0.60, col_wspace=0.35, ): self.dPol = dPol self.chrom_indices = [int(i) for i in chrom_indices] self.chrom_indices = self._validate_chrom_indices(chrom_indices) self.ChrNickNames = [Chr_Nickname(name) for name in dPol.chrNames] self.current_DI = float(DIthreshold) # ---------- DI span / grid ---------- DI_span = get_DI_span(self.dPol) self.di_min, self.di_max = float(DI_span[0]), float(DI_span[1]) if prefill_step is None: span = self.di_max - self.di_min prefill_step = span / 200.0 if span > 0 else 1.0 if cache_tol is None: cache_tol = float(prefill_step) / 2.0 self._cache_tol = float(cache_tol) self.di_grid = self._build_di_grid(self.di_min, self.di_max, float(prefill_step)) # ---------- optional caches ---------- self._inc_statewise = None self._inc_pairwise_by_chr = {} # chr_idx -> PairwiseDIIncrementalCache if prefill_cache: for chr_idx in self.chrom_indices: self._inc_pairwise_by_chr[chr_idx] = PairwiseDIIncrementalCache( self.dPol, di_grid=self.di_grid, chrom_indices=[chr_idx], progress=("text" if progress == "text" else None), label=f"MultiPairsPlot chr{chr_idx} matrix cache prefill", snapshot_dtype=np.float32, ).prefill() if cache_statewise_for_HI: self._inc_statewise = StatewiseDIIncrementalCache( self.dPol, di_grid=self.di_grid, progress=("text" if progress == "text" else None), label="MultiPairsPlot HI cache prefill", ).prefill() # ---------- initial ordering by HI at current DI ---------- self.ind_order = None self._compute_order_for_DI(self.current_DI) self.indNames = np.array(self.dPol.indNames)[self.ind_order] self.n = len(self.indNames) # ---------- initial matrices at current DI ---------- self.Ms = [] self._compute_matrices_for_DI(self.current_DI, keep_order=True) # ---------- figure / grid layout ---------- n_plots = len(self.chrom_indices) n_cols = min(max_cols, n_plots) n_rows = int(np.ceil(n_plots / n_cols)) self.fig = plt.figure(figsize=figsize) gs = self.fig.add_gridspec( n_rows, n_cols + 1, width_ratios=[20] * n_cols + [1.6], hspace=float(row_hspace), wspace=float(col_wspace), ) self.axes = [] for r in range(n_rows): for c in range(n_cols): idx = r * n_cols + c if idx < n_plots: self.axes.append(self.fig.add_subplot(gs[r, c])) self.cax = self.fig.add_subplot(gs[0, -1]) # colour key # room at bottom for widgets self.fig.subplots_adjust(bottom=0.20) # ---------- colormap ---------- self.cmap = LinearSegmentedColormap.from_list( "soft_coolwarm", ["#1e90ff", "white", "#fff266", "#ff1a1a"], ) # ---------- style state ---------- self.fontsize = 4 # ---------- draw once (persistent patches) ---------- self._setup_axes_all() self._init_bricks_all() self._init_colour_key() # ---------- widgets ---------- self._init_widgets() # ---------- hover ---------- self._install_format_coord() plt.show() # ================================================= # DI grid helpers # ================================================= @staticmethod def _build_di_grid(di_min, di_max, step): if di_max > di_min: n_steps = int(np.floor((di_max - di_min) / step)) vals = di_min + step * np.arange(n_steps + 1, dtype=float) vals = np.clip(vals, di_min, di_max) vals[-1] = di_max else: vals = np.array([di_min], dtype=float) return np.asarray(vals, dtype=float) def _nearest_grid_value(self, di): di = float(di) j = int(np.argmin(np.abs(self.di_grid - di))) return float(self.di_grid[j]) # ================================================= # Computation # ================================================= def _get_HI_and_retained_for_DI(self, DIthreshold): if self._inc_statewise is not None: di_key = self._nearest_grid_value(DIthreshold) chrom_counts, chrom_retained = self._inc_statewise._snapshots[di_key] HI = summaries_from_statewise_counts(chrom_counts)[0] else: chrom_counts, chrom_retained = statewise_genomes_summary_given_DI(self.dPol, float(DIthreshold)) HI = summaries_from_statewise_counts(chrom_counts)[0] DInumer = sum(n for (n, _) in chrom_retained) DIdenom = sum(d for (_, d) in chrom_retained) return HI, int(DInumer), int(DIdenom) def _compute_order_for_DI(self, DIthreshold): HI, DInumer, DIdenom = self._get_HI_and_retained_for_DI(DIthreshold) self.ind_order = np.argsort(HI) self.DInumer = DInumer self.DIdenom = DIdenom def _get_M_chr_for_DI(self, chr_idx, DIthreshold): chr_idx = int(chr_idx) if chr_idx in self._inc_pairwise_by_chr: return self._inc_pairwise_by_chr[chr_idx].get(DIthreshold) return PARApwmatrixFromDiemType( self.dPol, DIthreshold=float(DIthreshold), chrom_indices=[chr_idx], ) def _compute_matrices_for_DI(self, DIthreshold, *, keep_order=True): """ Recompute per-chromosome matrices at DIthreshold. If keep_order=True, use existing self.ind_order. """ self.current_DI = float(DIthreshold) if not keep_order or self.ind_order is None: self._compute_order_for_DI(self.current_DI) self.indNames = np.array(self.dPol.indNames)[self.ind_order] self.n = len(self.indNames) Ms = [] vmins, vmaxs = [], [] for chr_idx in self.chrom_indices: M = self._get_M_chr_for_DI(chr_idx, self.current_DI) M = M[self.ind_order][:, self.ind_order] Ms.append(M) finite = np.isfinite(M) if np.any(finite): vmins.append(np.nanmin(M)) vmaxs.append(np.nanmax(M)) self.Ms = Ms self.vmin = min(vmins) if vmins else 0.0 self.vmax = max(vmaxs) if vmaxs else 1.0 # ================================================= # Drawing setup # ================================================= def _setup_axes_all(self): centers = np.arange(self.n) + 0.5 for ax, chr_idx in zip(self.axes, self.chrom_indices): ax.set_xlim(0, self.n) ax.set_ylim(0, self.n) ax.set_aspect("equal") ax.set_xticks(centers) ax.set_yticks(centers) ax.set_xticklabels(self.indNames, rotation=90, fontsize=self.fontsize) ax.set_yticklabels(self.indNames, fontsize=self.fontsize) ax.set_title(self.ChrNickNames[int(chr_idx)], fontsize=10, pad=8) prop_sites = self.DInumer / self.DIdenom if self.DIdenom > 0 else 0.0 self.fig.suptitle( "Pairwise distances DI ≥ {:.2f} {} SNVs ({:.1f}% divergent across barrier)" .format(self.current_DI, self.DInumer, 100 * prop_sites), y=0.98 ) def _color_for_val(self, val, norm): if not np.isfinite(val): return "black" return self.cmap(norm(val)) # ================================================= # Bricks: create once, update facecolors later # ================================================= def _init_bricks_all(self): self._bricks_by_ax = [] # list aligned with self.axes / self.Ms norm = plt.Normalize(self.vmin, self.vmax) self._norm = norm for ax, M in zip(self.axes, self.Ms): bricks = [] for i in range(self.n): for j in range(self.n): val = M[j, i] rect = Rectangle( (i, j), 1, 1, facecolor=self._color_for_val(val, norm), edgecolor="none", ) ax.add_patch(rect) bricks.append(rect) self._bricks_by_ax.append(bricks) self.fig.canvas.draw_idle() def _update_bricks_all(self): norm = plt.Normalize(self.vmin, self.vmax) self._norm = norm for bricks, M, ax in zip(self._bricks_by_ax, self.Ms, self.axes): k = 0 for i in range(self.n): for j in range(self.n): val = M[j, i] bricks[k].set_facecolor(self._color_for_val(val, norm)) k += 1 prop_sites = self.DInumer / self.DIdenom if self.DIdenom > 0 else 0.0 self.fig.suptitle( "Pairwise distances DI ≥ {:.2f} {} SNVs ({:.1f}% divergent across barrier)" .format(self.current_DI, self.DInumer, 100 * prop_sites), y=0.98 ) self.fig.canvas.draw_idle() # ================================================= # Colour key # ================================================= def _init_colour_key(self): self.cax.clear() self._key_bins = 256 gradient = np.linspace(self.vmin, self.vmax, self._key_bins).reshape(-1, 1) self._key_im = self.cax.imshow( gradient, aspect="auto", cmap=self.cmap, origin="lower", ) self.cax.set_xticks([]) self.cax.set_yticks([0, self._key_bins - 1]) self.cax.set_yticklabels([f"{self.vmin:.2f}", f"{self.vmax:.2f}"], fontsize=8) self.cax.set_title("Distance", fontsize=9) self.fig.canvas.draw_idle() def _update_colour_key(self): gradient = np.linspace(self.vmin, self.vmax, self._key_bins).reshape(-1, 1) self._key_im.set_data(gradient) self._key_im.set_clim(self.vmin, self.vmax) self.cax.set_yticklabels([f"{self.vmin:.2f}", f"{self.vmax:.2f}"], fontsize=8) self.fig.canvas.draw_idle() # ================================================= # Widgets # ================================================= def _init_widgets(self): # DI slider ax_DI = self.fig.add_axes([0.15, 0.10, 0.60, 0.03]) self.DI_slider = Slider(ax_DI, "DI", self.di_min, self.di_max, valinit=self.di_min) self.DI_slider.on_changed(self._on_DI_change) # Reorder button ax_RE = self.fig.add_axes([0.88, 0.10, 0.08, 0.035]) self.reorder_button = Button(ax_RE, "Reorder by HI", hovercolor="0.95", color="red") self.reorder_button.on_clicked(self._on_reorder) # Font slider under colour key (your original location) pos = self.cax.get_position() ax_FS = self.fig.add_axes([pos.x0, 0.04, pos.width*2, 0.03]) self.font_slider = Slider(ax_FS, "Labels", 0, 8, valinit=self.fontsize, valstep=1) self.font_slider.on_changed(self._on_font_change) # ================================================= # Callbacks # ================================================= def _on_DI_change(self, val): # If cached, snap to nearest grid DI if self._inc_pairwise_by_chr or (self._inc_statewise is not None): di_eff = self._nearest_grid_value(val) else: di_eff = float(val) # update retained/total counts for title (no reorder) _, self.DInumer, self.DIdenom = self._get_HI_and_retained_for_DI(di_eff) # keep current ordering during slider motion self._compute_matrices_for_DI(di_eff, keep_order=True) # update bricks + key self._update_bricks_all() self._update_colour_key() def _on_reorder(self, event=None): # recompute ordering at current DI, and recompute matrices with new order self._compute_matrices_for_DI(self.current_DI, keep_order=False) # update tick labels everywhere centers = np.arange(self.n) + 0.5 for ax, chr_idx in zip(self.axes, self.chrom_indices): ax.set_xticks(centers) ax.set_yticks(centers) ax.set_xticklabels(self.indNames, rotation=90, fontsize=self.fontsize) ax.set_yticklabels(self.indNames, fontsize=self.fontsize) ax.set_title(self.ChrNickNames[int(chr_idx)], fontsize=10, pad=8) self._update_bricks_all() self._update_colour_key() def _on_font_change(self, val): self.fontsize = int(val) for ax in self.axes: ax.set_xticklabels(self.indNames, rotation=90, fontsize=self.fontsize) ax.set_yticklabels(self.indNames, fontsize=self.fontsize) self.fig.canvas.draw_idle() # ================================================== # Validation # ================================================== def _validate_chrom_indices(self, chrom_indices): max_idx = len(self.dPol.chrLengths) - 1 valid, rejected = [], [] for idx in chrom_indices: if isinstance(idx, (int, np.integer)) and 0 <= int(idx) <= max_idx: valid.append(int(idx)) else: rejected.append(idx) if rejected: print("diemMultiPairsPlot: rejected chromosome indices:", rejected) if not valid: raise ValueError("diemMultiPairsPlot: no valid chromosome indices.") return valid # ================================================= # Hover logic # ================================================= def _install_format_coord(self): # Formatter that always reads current matrices (no stale closures) n = self.n ax_to_idx = {ax: k for k, ax in enumerate(self.axes)} def make_formatter(ax): def format_coord(x, y): fallback = " " * 40 i = int(np.floor(x)) j = int(np.floor(y)) if 0 <= i < self.n and 0 <= j < self.n: k = ax_to_idx.get(ax, None) if k is None: return fallback M = self.Ms[k] a = self.indNames[j] b = self.indNames[i] d = M[j, i] if np.isfinite(d): return f"{a} × {b} : {d:.3f}" return f"{a} × {b} : NA" return fallback return format_coord for ax in self.axes: ax.format_coord = make_formatter(ax)
"""________________________________________ END DiemMultiPairsPlot __________________""" """________________________________________ START DiemPlotPrep__________________""" class DiemPlotPrep: """ Prepares data for DI-based plotting, including filtering, smoothing, dithering, and label generation. Args: plot_theme: Theme for plotting. ind_ids: List of individual IDs. chrRefLengths: Dictionary of chromosome reference lengths. Perhaps polarised_data: DataFrame containing polarised genomic data. di_threshold: DI threshold for filtering. di_column: Column name for DI values. diemStringPyCol: Column name for Diem genotype strings. genome_pixels: Number of genome pixels for dithering. ticks: Optional ticks for plotting. smooth: Optional smoothing parameter. """ def __init__(self, plot_theme, ind_ids, chrRefLengths, polarised_data, di_threshold, di_column, diemStringPyCol, genome_pixels, ticks=None, chrRelativeRecRates=None, smooth=None): self.polarised_data = polarised_data self.di_threshold = di_threshold self.di_column = di_column self.diemStringPyCol = diemStringPyCol self.genome_pixels = genome_pixels self.plot_theme = plot_theme self.ind_ids = ind_ids self.chrRefLengths = chrRefLengths self.chrRelativeRecRates=chrRelativeRecRates self.ticks = ticks self.smooth = smooth self.diemPlotLabel = None self.DIfilteredDATA = None self.DIfilteredGenomes = None self.DIfilteredHIs = None self.DIfilteredBED = None self.DIpercent = None self.DIfilteredScafRLEs = None self.diemDITgenomes = None self.DIfilteredGenomes_unsmoothed = None self.DIfilteredBED_formatted = None self.IndIDs_ordered = None self.unit_plot_prep = [] self.plot_ordered = None self.length_of_chromosomes = {} self.iris_plot_prep = {} self.diemDITgenomes_ordered = None self.nBasesDithered = None self.chrom_keys = None self.MapBC = None # primitive map positions of each SNV self.diem_plot_prep() def diem_plot_prep(self): """ Perform DI filtering, dithering, and label generation """ self.filter_data() if self.smooth: self.initMapcoords() # before smoothing self.kernel_smooth(self.smooth) self.diem_dithering() self.generate_plot_label(self.plot_theme) self.format_bed_data() def initMapcoords(self): """Initialize map coordinates based on filtered BED data""" if self.DIfilteredBED_formatted is None: self.MapBC = None return self.MapBC = [] for i in range(len(self.chrom_keys)): bed_positions = self.DIfilteredBED_formatted[i] ref_len = self.chrRefLengths[i] # normalize to chromosome-relative coordinate [0, 1] self.MapBC.append(bed_positions / ref_len) # format_bed_data reworked by ChatGPT 5.2 for speed, but mostly clarity def format_bed_data(self): # ------------------------------------------------- # 1. Compute ordering by HI (vectorised + clearer) # ------------------------------------------------- HI_values = np.array([ float(b[0]) if b[0] is not None else np.nan for b in self.DIfilteredHIs ]) # stable sort: NaNs last sorted_indices = np.argsort( np.isnan(HI_values), kind="stable" ) sorted_indices = sorted_indices[np.argsort(HI_values[sorted_indices], kind="stable")] self.plot_ordered = list(zip(HI_values[sorted_indices], sorted_indices + 1)) self.IndIDs_ordered = [self.ind_ids[i] for i in sorted_indices] self.diemDITgenomes_ordered = [self.diemDITgenomes[i] for i in sorted_indices] # ------------------------------------------------- # 2. Prepare unit_plot_prep (slice once, reuse) # ------------------------------------------------- self.unit_plot_prep = [] start = 0 for bed_data in self.DIfilteredBED_formatted: end = start + len(bed_data) sublist = [genome[start:end] for genome in self.DIfilteredGenomes] self.unit_plot_prep.append([sublist[idx] for idx in sorted_indices]) start = end def filter_data(self): """ Apply DI threshold filtering on the data """ if isinstance(self.di_threshold, str): # No filtering if threshold is a string self.DIfilteredDATA = self.polarised_data elif isinstance(self.di_threshold, int) or isinstance(self.di_threshold, float): # Filter above if threshold is just one number self.DIfilteredDATA = self.polarised_data[self.polarised_data.DI >= self.di_threshold] else: # Filter within an interval if threshold is a tuple or list self.DIfilteredDATA = self.polarised_data[(self.di_threshold[0] <= self.polarised_data.DI) & (self.polarised_data.DI <= self.di_threshold[1])] # Extract relevant data after filtering self.DIfilteredGenomes = StringTranspose(self.DIfilteredDATA['diem_genotype'])[1:] # slice off the 'S' column self.DIfilteredHIs = [pHetErrOnString(genome) for genome in self.DIfilteredGenomes] self.DIfilteredBED = self.DIfilteredDATA[['chrom','start']].values.tolist() self.DIpercent = round(100 * len(self.DIfilteredDATA) / len(self.polarised_data)) self.DIfilteredScafRLEs = RichRLE(self.DIfilteredDATA['chrom'].values.tolist()) # the following WAS the first operation of dithering section! # now here, so map of refpos can be set before smoothing... # ------------------------------------------------- # 1. Group BED entries by chromosome (single pass) # ------------------------------------------------- grouped = defaultdict(list) for key, value in self.DIfilteredBED: grouped[key].append(value) self.chrom_keys = list(grouped.keys()) self.DIfilteredBED_formatted = [ np.asarray(grouped[k]) for k in self.chrom_keys ] def kernel_smooth(self, scale): # ChatGPT drop-in # -------------------------------------------------- # 1. Precompute scaffold → indices # -------------------------------------------------- scaffold_indices = defaultdict(list) for idx, (scaffold, _) in enumerate(self.DIfilteredBED): scaffold_indices[scaffold].append(idx) # -------------------------------------------------- # 2. Precompute scaffold → positions array # Use MapBC if available (chromosome-relative metric) # -------------------------------------------------- if self.MapBC is not None: scaffold_arrays = { chrom: self.MapBC[i] for i, chrom in enumerate(self.chrom_keys) } else: scaffold_positions = defaultdict(list) for scaffold, pos in self.DIfilteredBED: scaffold_positions[scaffold].append(pos) scaffold_arrays = { scaffold: np.asarray(positions) for scaffold, positions in scaffold_positions.items() } # -------------------------------------------------- # 3. Split genomes by scaffold (numeric form) # -------------------------------------------------- # scaffold_haplotypes[scaffold] = list of np.arrays (one per individual) scaffold_haplotypes = { scaffold: [] for scaffold in scaffold_indices } for genome in self.DIfilteredGenomes: for scaffold, indices in scaffold_indices.items(): # extract once, convert once s = ''.join(genome[i] for i in indices) s = s.replace("U", "3") # DMBCtoucher {U,0,1,2} → {0,1,2,3} s = s.replace("_", "3") # DMBCtoucher {U,0,1,2} → {0,1,2,3} scaffold_haplotypes[scaffold].append( np.fromiter((ord(c) - 48 for c in s), dtype=np.int8) ) # -------------------------------------------------- # 4. Smooth ALL haplotypes per scaffold (key speedup) # -------------------------------------------------- smoothed_scaffold_haplotypes = {} for chr_i, scaffold in enumerate(self.chrom_keys): haplos = scaffold_haplotypes[scaffold] haplo_matrix = np.vstack(haplos) rec_rate = self.chrRelativeRecRates[chr_i] if rec_rate <= 0: raise ValueError( f"Invalid recombination rate for {scaffold}: {rec_rate}" ) effective_scale = scale / rec_rate smoothed = smooth.laplace_smooth_multiple_haplotypes( scaffold_arrays[scaffold], haplo_matrix, effective_scale ) smoothed_scaffold_haplotypes[scaffold] = smoothed # -------------------------------------------------- # 5. Reassemble genomes (string form) # -------------------------------------------------- n_individuals = len(self.DIfilteredGenomes) smoothed_split_genomes = [ {} for _ in range(n_individuals) ] for scaffold, smoothed_matrix in smoothed_scaffold_haplotypes.items(): for i in range(n_individuals): arr = smoothed_matrix[i] chars = np.where(arr == 3, "_", arr.astype(str)) # # DMBCtoucher {U,0,1,2} → {0,1,2,3} smoothed_split_genomes[i][scaffold] = ''.join(chars.tolist()) # -------------------------------------------------- # 6. Finalise # -------------------------------------------------- self.DIfilteredGenomes_unsmoothed = self.DIfilteredGenomes self.DIfilteredGenomes = self._reconstruct_genomes( smoothed_split_genomes, scaffold_indices ) def _reconstruct_genomes(self, smoothed_split_genomes, scaffold_indices): reconstructed_genomes = [] for individual in smoothed_split_genomes: full_genome = ['0'] * len(self.DIfilteredBED) for scaffold, indices in scaffold_indices.items(): scaffold_str = individual[scaffold] for i, idx in enumerate(indices): full_genome[idx] = scaffold_str[i] reconstructed_genome = ''.join(full_genome) reconstructed_genomes.append(reconstructed_genome) return reconstructed_genomes def diem_dithering(self): # ------------------------------------------------- # 2. Precompute chromosome spans # ------------------------------------------------- self.length_of_chromosomes = {} start = 0 for key, bed_data in zip(self.chrom_keys, self.DIfilteredBED_formatted): end = start + len(bed_data) self.length_of_chromosomes[key] = (start, end, len(bed_data)) start = end # ------------------------------------------------- # 3. Prepare iris_plot_prep ticks (vectorised shift) # ------------------------------------------------- for idx, (key, bed) in enumerate(zip(self.chrom_keys, self.DIfilteredBED_formatted), start=1): x_ticks = fractional_positions_of_multiples(bed, self.ticks) offset = self.length_of_chromosomes[key][0] x_ticks[:, 1] += offset self.iris_plot_prep[idx] = x_ticks # ------------------------------------------------- # 4. Calculate nBasesDithered SJEB 24 Jan 2026 # ------------------------------------------------- ringSpanInBases = 0; for chrRefPoses in self.DIfilteredBED_formatted: ringSpanInBases = ringSpanInBases + chrRefPoses[-1] - chrRefPoses[0] + 1 # Input argument 'genome_pixels' is number of dithering 'pixels' along genome (pixels may be curved for iris plots) # Here, GappedQuotientSplitLengths takes the number of bases that should be dithered together. self.nBasesDithered = max(1,round(ringSpanInBases/self.genome_pixels)) # ------------------------------------------------- # 5. Perform dithering on the filtered data give nBasesDithered # ------------------------------------------------- diem_dit_genomes_bed = [list(group) for _, group in groupby(self.DIfilteredBED, key=lambda x: x[0])] processed_diemDITgenomes = [] for chr in diem_dit_genomes_bed: length_data = [row[1] for row in chr] split_lengths = self.GappedQuotientSplitLengths(length_data, self.nBasesDithered)# nBasesDithered was self.genome_pixels SJEB 24 Jan 2026 processed_diemDITgenomes.append(split_lengths) #processed_diemDITgenomes = Flatten(processed_diemDITgenomes) # IMPORTANT: EURG # processed_diemDITgenomes is now: # List[chromosome][segment_length] diemDITgenomes = [] # IMPORTANT: need chromosome order consistent with how processed_diemDITgenomes was built potential BUG chrom_keys = list(self.length_of_chromosomes.keys()) for genome in self.DIfilteredGenomes: per_chr = [] # iterate chromosomes in the SAME order as processed_diemDITgenomes for chr_i, chr_lengths in enumerate(processed_diemDITgenomes): chrom = chrom_keys[chr_i] g0, g1, _ = self.length_of_chromosomes[chrom] # global [start,end) slice in concatenated genome genome_chr = genome[g0:g1] # chromosome-local genome string # now take segments from the chromosome-local string string_take_result = StringTakeList(genome_chr, chr_lengths) state_count = Map(sStateCount, string_take_result) combined = list(zip(state_count, chr_lengths)) compressed = self.DITcompress(combined) per_chr.append(self.Lengths2StartEnds(compressed)) diemDITgenomes.append(per_chr) self.diemDITgenomes = diemDITgenomes def generate_plot_label(self, plot_theme): """ Generate the label for the plot """ self.diemPlotLabel = f"{plot_theme} @ DI = {self.di_threshold}: {len(self.DIfilteredDATA)} sites ({self.DIpercent}%) {self.nBasesDithered} bases dithered." @staticmethod def GappedQuotientSplit(lst, Q): """ Splits the list `lst` into sublists where consecutive elements share the same quotient when divided by `Q`. """ quotients = [x // Q for x in lst] groups = [] current_group = [lst[0]] for i in range(1, len(lst)): if quotients[i] == quotients[i - 1]: current_group.append(lst[i]) else: groups.append(current_group) current_group = [lst[i]] groups.append(current_group) return groups def GappedQuotientSplitLengths(self, lst, Q): """ Returns the lengths of the sublists produced by `gapped_quotient_split`. """ return Map(len, self.GappedQuotientSplit(lst, Q)) @staticmethod def normalize_4list(lst): """ Normalizes a 4list by converting each element to its ratio of the total sum. Uses Fraction for precise comparison without floating-point errors. """ total = sum(lst) if total == 0: return tuple(0 for _ in lst) # Handle case where total is 0 return tuple(Fraction(x, total) for x in lst) def DITcompress(self, DITl): """ Compresses the list of {4list, length} tuples. """ grouped_data = [list(group) for _, group in groupby(DITl, key=lambda x: self.normalize_4list(x[0]))] final_data = [] for group in grouped_data: summed_states = [sum(x) for x in zip(*(item[0] for item in group))] summed_value = sum(item[1] for item in group) result = (summed_states, summed_value) final_data.append(result) return final_data @staticmethod def Lengths2StartEnds(stateNlen): lengths = [x[1] for x in stateNlen] ends = np.cumsum(lengths) # Calculate the start positions (end positions minus length plus 1) starts = ends - np.array(lengths) + 1 # Combine states, starts, and ends into a list of triplets result = [(state, int(start), int(end)) for (state, start, end) in zip([x[0] for x in stateNlen], starts, ends)] return result def flatten_ring_with_offsets(per_chr_ring, length_of_chromosomes): """ Convert per-chromosome ring representation into a single global-coordinate ring suitable for IrisPlot / LongPlot. per_chr_ring: [ [(w,s,e), ...], # chromosome 0 (local coords) [(w,s,e), ...], # chromosome 1 ... ] length_of_chromosomes: dict preserving chromosome order: chrom -> (start, end, length) """ flat = [] chrom_keys = list(length_of_chromosomes.keys()) for chr_idx, chr_segments in enumerate(per_chr_ring): chrom = chrom_keys[chr_idx] chrom_start = length_of_chromosomes[chrom][0] for weights, s, e in chr_segments: flat.append(( weights, chrom_start + s - 1, # preserve your 1→0 convention chrom_start + e )) return flat def prefill_slider_cache( *, cache, # PlotCache instance namespace: str, basekey, di_values: Iterable[float], compute_fn: Callable[[float], object], # returns payload to cache tol: Optional[float] = None, # optional: skip if already cached "nearby" progress: str = "text", # "text", "tqdm", or "none" label: str = "Prefill cache", ): di_values = list(di_values) n = len(di_values) if n == 0: return use_tqdm = (progress == "tqdm") pbar = None if use_tqdm: try: from tqdm.auto import tqdm pbar = tqdm(total=n, desc=label) except Exception: use_tqdm = False # fall back silently t0 = time.time() done = 0 for i, di in enumerate(di_values, start=1): # optional "near" skip if tol is not None: hit = cache.get_nearest_float_key( namespace=namespace, basekey=basekey, x=float(di), tol=float(tol), ) if hit is not None: done += 1 if use_tqdm: pbar.update(1) continue payload = compute_fn(float(di)) cache.set_float_key( namespace=namespace, basekey=basekey, x=float(di), value=payload, ) done += 1 if use_tqdm: pbar.update(1) elif progress == "text": # lightweight text progress every ~5% if i == 1 or i == n or (i % max(1, n // 20) == 0): dt = time.time() - t0 print(f"{label}: {i}/{n} ({100*i/n:.0f}%) elapsed {dt:.1f}s") if use_tqdm and pbar is not None: pbar.close() def diemPlotPrepFromBedMeta(plot_theme, bed_file_path, meta_file_path,di_threshold,genome_pixels,ticks, smooth = None): pzbed, bmIndIDs, chrRefLengths, chrRelativeRecRates = read_diem_bed_4_plots(bed_file_path, meta_file_path) prep = DiemPlotPrep( plot_theme=plot_theme, ind_ids=bmIndIDs, chrRefLengths=chrRefLengths, polarised_data=pzbed, di_threshold=di_threshold, diemStringPyCol=10, di_column=13, genome_pixels=genome_pixels, ticks=ticks, chrRelativeRecRates=chrRelativeRecRates, smooth=smooth ) return prep """________________________________________ END DiemPlotPrep ___________________""" """________________________________________ START DiemIris ___________________""" class WheelDiagram: """ Utility class for creating wheel diagrams (iris plots). Args: subplot: Matplotlib subplot to draw on. center: Center coordinates of the wheel. radius: Outer radius of the wheel. number_of_rings: Number of concentric rings in the wheel. cutout_angle: Angle of the cutout section (default is 13 degrees). """ def __init__(self, subplot, center, radius, number_of_rings, cutout_angle=13): self.subplot = subplot self.center = center self.radius = radius self.center_radius = radius / 2 self.number_of_rings = number_of_rings self.cutout_angle = cutout_angle self.rings_added = 0 def add_wedge(self, radius, from_angle, to_angle, color): self.subplot.add_artist( Wedge(self.center, radius, from_angle, to_angle, color=color, clip_on=False) # SJEB this last avoids a world of pain ) def add_ring(self, list_of_thingies): # print(f'Adding ring: {self.rings_added + 1}') available_angle = 360 - self.cutout_angle angle_scale = available_angle / list_of_thingies[-1][-1] colors = np.array(Map(mcolors.to_rgb,diemColours)) ring_radius = self.radius - self.rings_added * (self.radius - self.center_radius) / self.number_of_rings start_angle_offset = 90 for index, thing in enumerate(list_of_thingies): weights = np.array(thing[0]) total_weight = np.sum(weights) if total_weight == 0: blended_rgb = (0, 0, 0) else: blended_rgb = np.sum(colors.T * weights, axis=1) / total_weight blended_hex = mcolors.to_hex(blended_rgb) from_angle = start_angle_offset + 360 - (angle_scale * (thing[1] - 1)) to_angle = start_angle_offset + 360 - (angle_scale * thing[2]) self.add_wedge(ring_radius, to_angle,from_angle, blended_hex) self.rings_added += 1 def add_heatmap_ring(self, heatmap): # needs work. This version is specific to Honza's MolEcol figures. available_angle = 360 - self.cutout_angle angle_scale = available_angle / int(heatmap[-1][-1]) keys = ["barr", "int", "ovm"] values = ["Red", "Blue", "Yellow"] color_map = dict(zip(keys, values)) ring_radius = self.radius + 2 * (self.radius - self.center_radius) / self.number_of_rings start_angle_offset = 90 for index, thing in enumerate(heatmap): from_angle = start_angle_offset + 360 - (angle_scale * (int(thing[1]) - 1)) to_angle = start_angle_offset + 360 - (angle_scale * int(thing[2])) self.add_wedge(ring_radius, to_angle, from_angle, color_map[thing[0]]) def clear_center(self): self.add_wedge(self.center_radius, 0, 360, "white") """________________________________________ END WheelDiagram ___________________""" """________________________________________ START diemIrisPlot ___________________""" """________________________________________ START Iris and Long helper function ___________________""" def _restrict_chromosomes( *, input_data, refposes, length_of_chromosomes, bed_info=None, chrom_indices=None, ): """ Restrict per-chromosome genome data to a subset of chromosomes and, if requested, pack them contiguously into a genome-global coordinate system. Canonical assumptions: - input_data[ind][chr] = list of (weights, start, end), chromosome-local - start/end are 1-based and inclusive (as produced by PlotPrep) - NO flattening occurs here """ # ------------------------------------------------- # Trivial case: no restriction # ------------------------------------------------- if chrom_indices is None: return input_data, refposes, length_of_chromosomes, bed_info if length_of_chromosomes is None: raise ValueError("chrom_indices requires length_of_chromosomes") chrom_keys = list(length_of_chromosomes.keys()) n_chr = len(chrom_keys) # ------------------------------------------------- # Validate chromosome indices # ------------------------------------------------- kept = [] rejected = [] for ci in chrom_indices: if isinstance(ci, (int, np.integer)) and 0 <= int(ci) < n_chr: kept.append(int(ci)) else: rejected.append(ci) if rejected: print("restrict_chromosomes: rejected chromosome indices:", rejected) if not kept: raise ValueError("restrict_chromosomes: no valid chromosome indices") kept = sorted(kept) # ------------------------------------------------- # Build packed chromosome offsets # ------------------------------------------------- chrom_offsets = {} packed_cursor = 0.0 for chr_idx in kept: chrom = chrom_keys[chr_idx] _, _, L = length_of_chromosomes[chrom] chrom_offsets[chr_idx] = packed_cursor packed_cursor += L # ------------------------------------------------- # Remap input_data (per individual) # ------------------------------------------------- new_input_data = [] for indiv in input_data: new_ring = [] for chr_idx in kept: offset = chrom_offsets[chr_idx] chr_segments = indiv[chr_idx] for weights, s, e in chr_segments: # s,e are chromosome-local (1-based) new_ring.append(( weights, offset + s, offset + e )) new_input_data.append(new_ring) # ------------------------------------------------- # Restrict refposes (order preserved) # ------------------------------------------------- new_refposes = [refposes[i] for i in kept] # ------------------------------------------------- # Remap chromosome lengths (geometry only) # ------------------------------------------------- length_of_chromosomes_remapped = {} for chr_idx in kept: chrom = chrom_keys[chr_idx] _, _, L = length_of_chromosomes[chrom] start = chrom_offsets[chr_idx] end = start + L length_of_chromosomes_remapped[chrom] = (start, end, L) # ------------------------------------------------- # Remap bed_info (outer ticks) — GLOBAL → PACKED # ------------------------------------------------- new_bed = None if bed_info is not None: new_bed = {} for chr_idx in kept: bed_key = chr_idx + 1 # iris-style 1-based indexing if bed_key not in bed_info: continue chrom = chrom_keys[chr_idx] chrom_start, _, _ = length_of_chromosomes[chrom] offset = chrom_offsets[chr_idx] positions = [] for label, pos in bed_info[bed_key]: # pos is GLOBAL; convert to chromosome-local then pack pos2 = offset + (pos - chrom_start) positions.append((label, pos2)) if positions: new_bed[chrom] = positions return ( new_input_data, new_refposes, length_of_chromosomes_remapped, new_bed, ) """________________________________________ END Iris and Long helper function ___________________"""
[docs] def diemIrisFromPlotPrep(prepped, chrom_indices=None): """ Uses per-chromosome diemDITgenomes_ordered as the canonical form. - If chrom_indices is None: flatten to whole-genome rings for WheelDiagram - If chrom_indices is provided: pass per-chromosome structure through unchanged (_restrict_chromosomes will select + pack) """ if chrom_indices is None: # Whole-genome plot → flatten for WheelDiagram input_data = [ flatten_ring_with_offsets( ring, prepped.length_of_chromosomes ) for ring in prepped.diemDITgenomes_ordered ] else: # Chromosome-restricted plot → MUST stay per-chromosome input_data = prepped.diemDITgenomes_ordered diemIrisPlot( title=prepped.diemPlotLabel, input_data=input_data, refposes=prepped.DIfilteredBED_formatted, names=prepped.IndIDs_ordered, bed_info=prepped.iris_plot_prep, length_of_chromosomes=prepped.length_of_chromosomes, chrom_indices=chrom_indices, )
"""________________________________________ END DiemIris ___________________""" """________________________________________ START diemLongPlot ___________________""" class BrickDiagram: """ Utility class for creating linear (brick) genome diagrams. Each ring is a horizontal band. Each brick spans a genomic interval [start, end). """ def __init__(self, subplot, x_min, x_max, y_min, y_max, number_of_rings): self.subplot = subplot self.x_min = x_min self.x_max = x_max self.y_min = y_min self.y_max = y_max self.number_of_rings = number_of_rings self.rings_added = 0 self.ring_height = (y_max - y_min) / number_of_rings def add_brick(self, x0, x1, ring_idx, color): y0 = self.y_max - (ring_idx + 1) * self.ring_height width = x1 - x0 rect = Rectangle( (x0, y0), width, self.ring_height, facecolor=color, edgecolor=None, linewidth=0, clip_on=False ) self.subplot.add_patch(rect) def add_ring(self, list_of_thingies, colors): """ list_of_thingies: [(weights, start_pos, end_pos), ...] colors: RGB base colours (same as iris) """ ring_idx = self.rings_added for thing in list_of_thingies: weights, start, end = thing weights = np.asarray(weights) total = weights.sum() if total == 0: blended_rgb = (0, 0, 0) else: blended_rgb = (colors.T * weights).sum(axis=1) / total self.add_brick(start, end, ring_idx, mcolors.to_hex(blended_rgb)) self.rings_added += 1 class BrickDiagram: """ Draws horizontal rings made of rectangles spanning [x0,x1) in data coords. """ def __init__(self, ax, n_rings, y_min=0.0, y_max=1.0): self.ax = ax self.n_rings = n_rings self.y_min = y_min self.y_max = y_max self.ring_h = (y_max - y_min) / n_rings def add_brick(self, x0, x1, ring_idx, color): y0 = self.y_max - (ring_idx + 1) * self.ring_h self.ax.add_patch( Rectangle( (x0, y0), x1 - x0, self.ring_h, facecolor=color, edgecolor="none", linewidth=0, clip_on=False ) )
[docs] def diemLongFromPlotPrep(prepped, chrom_indices=None): """ Uses per-chromosome diemDITgenomes_ordered as the canonical form. - If chrom_indices is None: flatten to whole-genome rings for WheelDiagram - If chrom_indices is provided: pass per-chromosome structure through unchanged (_restrict_chromosomes will select + pack) """ if chrom_indices is None: # Whole-genome plot → flatten for WheelDiagram input_data = [ flatten_ring_with_offsets( ring, prepped.length_of_chromosomes ) for ring in prepped.diemDITgenomes_ordered ] else: # Chromosome-restricted plot → MUST stay per-chromosome input_data = prepped.diemDITgenomes_ordered diemLongPlot( title=prepped.diemPlotLabel, input_data=input_data, refposes=prepped.DIfilteredBED_formatted, names=prepped.IndIDs_ordered, bed_info=prepped.iris_plot_prep, length_of_chromosomes=prepped.length_of_chromosomes, chrom_indices=chrom_indices, )
"""________________________________________ END diemLongPlot ___________________""" def diemIrisPlot( input_data, refposes, title=None, names=None, bed_info=None, length_of_chromosomes=None, heatmap=None, chrom_indices=None, # optional (same as long) show_outer_ticks=True, ): # ------------------------------------------------- # Figure & axes # ------------------------------------------------- fig, ax = plt.subplots(figsize=(10, 10)) fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.98) if title is not None: ax.set_title(title, pad=20) # Move axes DOWN to make room for title (no resizing distortion) pos = ax.get_position() ax.set_position([pos.x0, pos.y0, pos.width, pos.height * 0.96]) ax.set_xlim(0, 1) ax.set_ylim(0, 1) # keep wheel circular regardless of margins ax.set_box_aspect(1) ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) # ------------------------------------------------- # Chromosome restriction (shared helper) # ------------------------------------------------- ( input_data, refposes, length_of_chromosomes_remapped, bed_info_remapped, ) = _restrict_chromosomes( input_data=input_data, refposes=refposes, length_of_chromosomes=length_of_chromosomes, bed_info=bed_info, chrom_indices=chrom_indices, ) # ------------------------------------------------- # Wheel geometry # ------------------------------------------------- center = np.array((0.5, 0.5)) radius = 0.48 cutout_angle = 20 number_of_rings = len(input_data) wd = WheelDiagram( ax, center, radius, number_of_rings + (1 if heatmap is not None else 0), cutout_angle=cutout_angle, ) if heatmap is not None: wd.add_heatmap_ring(heatmap) for ring in input_data: wd.add_ring(ring) wd.clear_center() # ------------------------------------------------- # Geometry helpers # ------------------------------------------------- available_angle = 360 - cutout_angle start_angle_offset = 90 if length_of_chromosomes_remapped is not None: max_position = max(end for (_, end, _) in length_of_chromosomes_remapped.values()) else: max_position = input_data[0][-1][2] ring_width = (radius - wd.center_radius) / number_of_rings # ------------------------------------------------- # Inner chromosome wedges + labels # ------------------------------------------------- chrom_ranges = [] if length_of_chromosomes_remapped is not None: for chrom, (start, end, _) in length_of_chromosomes_remapped.items(): chrom_ranges.append((chrom, float(start), float(end))) if chrom_ranges: for idx, (chrom, start, end) in enumerate(chrom_ranges): start_angle = start_angle_offset + 360 - available_angle * start / max_position end_angle = start_angle_offset + 360 - available_angle * end / max_position if idx % 2 == 1: ax.add_artist( Wedge(center, radius / 2, end_angle, start_angle, color="lightgrey", alpha=0.3) ) midpoint = 0.5 * (start + end) mid_angle = start_angle_offset + 360 - available_angle * midpoint / max_position mid_rad = np.deg2rad(mid_angle) label_xy = center + (radius - 0.28) * np.array([np.cos(mid_rad), np.sin(mid_rad)]) ax.text( label_xy[0], label_xy[1], Chr_Nickname(chrom), ha="center", va="center", fontsize=8, rotation=mid_angle, rotation_mode="anchor", ) ax.add_artist(Wedge(center, 0.18, 0, 360, color="white")) # ------------------------------------------------- # Outer ticks (iris) # ------------------------------------------------- if show_outer_ticks and bed_info_remapped is not None: outer_radius = radius + (0.035 if heatmap is not None else 0.015) for positions in bed_info_remapped.values(): for label, position in positions: position = float(position) angle = start_angle_offset + 360 - available_angle * position / max_position ang_rad = np.deg2rad(angle) base = center + outer_radius * np.array([np.cos(ang_rad), np.sin(ang_rad)]) text_pos = base + 0.006 * np.array([np.cos(ang_rad), np.sin(ang_rad)]) ax.text( text_pos[0], text_pos[1], str(int(label)), ha="left", va="center", fontsize=6, rotation=angle, rotation_mode="anchor", ) line_start = base - 0.01 * np.array([np.cos(ang_rad), np.sin(ang_rad)]) line_end = line_start + 0.01 * np.array([np.cos(ang_rad), np.sin(ang_rad)]) ax.plot([line_start[0], line_end[0]], [line_start[1], line_end[1]], color="black", linewidth=0.5) # ------------------------------------------------- # Ring labels (optional) # ------------------------------------------------- if names is not None and len(names) == number_of_rings: for i, name in enumerate(names): ring_radius = radius - (i + 0.5) * ring_width label_xy = center + ring_radius * np.array([0, 1]) ax.text(label_xy[0], label_xy[1], name, ha="right", va="center", fontsize=4) # ------------------------------------------------- # Hover (make it actually visible) # ------------------------------------------------- # Key change: do NOT return all-spaces fallback (looks like “nothing” in some backends). # Use a quiet but visible fallback. fallback = "" def iris_format_coord(x, y): dx = x - center[0] dy = y - center[1] r = np.hypot(dx, dy) # outside wheel or inside hole if r < wd.center_radius or r > radius: return fallback angle = (np.degrees(np.arctan2(dy, dx)) + 360) % 360 rel_angle = (start_angle_offset + 360 - angle) % 360 if rel_angle > available_angle: return fallback raw_pos = rel_angle / available_angle * max_position # chromosome lookup chrom_label = None bp = None for chrom_idx, (chrom, start, end) in enumerate(chrom_ranges): if start <= raw_pos < end: frac = (raw_pos - start) / (end - start) ref = refposes[chrom_idx] if len(ref) == 0: return fallback ref_idx = int(np.clip(round(frac * (len(ref) - 1)), 0, len(ref) - 1)) bp = ref[ref_idx] chrom_label = Chr_Nickname(chrom) break if chrom_label is None or bp is None: return fallback # ring lookup ring_idx = int((radius - r) / ring_width) if ring_idx < 0 or ring_idx >= number_of_rings: return fallback sample = names[ring_idx] if (names is not None and len(names) == number_of_rings) else f"ring {ring_idx}" return f"{chrom_label} bp={int(bp):,} sample={sample}" ax.format_coord = iris_format_coord plt.show() def diemLongPlot( input_data, refposes, chrom_indices=None, # NOW OPTIONAL (matches iris) title=None, names=None, bed_info=None, length_of_chromosomes=None, show_outer_ticks=True, ): """ Linear analogue of diemIrisPlot, using BrickDiagram. Same semantics as iris: - chrom_indices is optional. - If chrom_indices is provided, chromosomes are remapped into a packed coordinate system. - Hover reports chrom + bp + sample. - Outer ticks are supported (bed_info) and drawn above the rings. """ if length_of_chromosomes is None: raise ValueError("diemLongPlot: length_of_chromosomes is required.") # ------------------------------------------------- # Chromosome restriction (shared helper) # ------------------------------------------------- ( input_data, refposes, length_of_chromosomes_remapped, bed_info_remapped, ) = _restrict_chromosomes( input_data=input_data, refposes=refposes, length_of_chromosomes=length_of_chromosomes, bed_info=bed_info, chrom_indices=chrom_indices, ) # ------------------------------------------------- # Normalize bed_info keys for unrestricted Long plot # ------------------------------------------------- if chrom_indices is None and bed_info_remapped is not None: # Convert 1-based BED keys → chromosome-name keys chrom_keys = list(length_of_chromosomes.keys()) bed_info_remapped = { chrom_keys[k - 1]: v for k, v in bed_info_remapped.items() if 1 <= k <= len(chrom_keys) } # ------------------------------------------------- # Packed chromosome ranges # ------------------------------------------------- chrom_ranges = [] for chrom, v in (length_of_chromosomes_remapped or length_of_chromosomes).items(): start, end = float(v[0]), float(v[1]) chrom_ranges.append((chrom, start, end)) if not chrom_ranges: raise ValueError("diemLongPlot: no chromosomes available after restriction.") packed_len = max(end for (_, _, end) in chrom_ranges) # ------------------------------------------------- # Figure & axes # ------------------------------------------------- n_rings = len(input_data) fig, ax = plt.subplots(figsize=(11, 4)) fig.subplots_adjust(left=0.06, right=0.98, bottom=0.18, top=0.86) if title is not None: ax.set_title(title, pad=16) # Move axes DOWN to make room for outer ticks (no geometry distortion) pos = ax.get_position() ax.set_position([pos.x0, pos.y0, pos.width, pos.height * 0.94]) ax.set_xlim(0, packed_len) # Allocate room above rings for outer ticks ax.set_ylim(-0.9, n_rings + 0.7) ax.set_aspect("auto") ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) # ------------------------------------------------- # Draw rings as bricks (already packed if chrom_indices was used) # ------------------------------------------------- colors = np.array(list(map(mcolors.to_rgb, diemColours))) bd = BrickDiagram(ax, n_rings, y_min=0, y_max=n_rings) for ring_idx, ring in enumerate(input_data): for weights, x0, x1 in ring: # conditionally convert 1-based to 0-based if True or chrom_indices is None: # global genome: 1-based → 0-based x0 = float(x0 - 1) x1 = float(x1) else: # already packed by _restrict_chromosomes x0 = float(x0) x1 = float(x1) x1 = float(x1) if x1 <= x0: continue w = np.asarray(weights) tot = float(np.sum(w)) if tot == 0: blended_rgb = (0, 0, 0) else: blended_rgb = (colors.T * w).sum(axis=1) / tot bd.add_brick(x0, x1, ring_idx, mcolors.to_hex(blended_rgb)) # ------------------------------------------------- # Chromosome bricks + labels at base # ------------------------------------------------- base_y = -0.62 for i, (chrom, p0, p1) in enumerate(chrom_ranges): if i % 2 == 1: ax.axvspan(p0, p1, color="grey", alpha=0.35, ymin=-0.18, ymax=0.01, clip_on=False) ax.text( 0.5 * (p0 + p1), base_y - 0.12, Chr_Nickname(chrom), ha="center", va="top", fontsize=8, rotation=90 ) # ------------------------------------------------- # Outer ticks (linear analogue of iris outer ticks) # ------------------------------------------------- if show_outer_ticks and bed_info_remapped is not None: tick_y0 = n_rings + 0 + 0.05 tick_y1 = n_rings + 1 + 0.18 text_y = n_rings + 2 + 0.22 # NOTE: positions are already packed coords for chrom, p0, p1 in chrom_ranges: if chrom not in bed_info_remapped: continue for label, pos in bed_info_remapped[chrom]: #for positions in bed_info_remapped.values(): # for label, pos in positions: x = float(pos) ax.plot([x, x], [tick_y0, tick_y1], color="black", linewidth=0.5, clip_on=False) ax.text( x, text_y, str(int(label)), ha="center", va="bottom", fontsize=6, rotation=90, clip_on=False ) # ------------------------------------------------- # Ring labels (optional) # ------------------------------------------------- if names is not None and len(names) == n_rings: for i, name in enumerate(names): y = n_rings - i - 0.5 ax.text( -0.01 * packed_len, y, name, ha="right", va="center", fontsize=6, clip_on=False ) # ------------------------------------------------- # Hover (same semantics as iris) # ------------------------------------------------- fallback = "" def long_format_coord(x, y): # ring index from y ring_idx = int(np.floor(n_rings - y)) if ring_idx < 0 or ring_idx >= n_rings: return fallback # find chromosome interval in packed coords chrom_hit = None for chrom_i, (chrom, p0, p1) in enumerate(chrom_ranges): if p0 <= x < p1: chrom_hit = (chrom_i, chrom, p0, p1) break if chrom_hit is None: return fallback chrom_i, chrom, p0, p1 = chrom_hit if p1 <= p0: return fallback frac = (x - p0) / (p1 - p0) ref = refposes[chrom_i] if len(ref) == 0: return fallback ref_i = int(np.clip(round(frac * (len(ref) - 1)), 0, len(ref) - 1)) bp = ref[ref_i] sample = names[ring_idx] if (names is not None and len(names) == n_rings) else f"ring {ring_idx}" return f"{Chr_Nickname(chrom)} bp={int(bp):,} sample={sample}" ax.format_coord = long_format_coord plt.show()