Source code for src.infer.total_mass

from ..obj.motif_vector import MotifVector
from ..obj.motif_trajectory import MotifTrajectory

from ..domains.motif_space import _motif_categories

import jax.numpy as jnp

[docs] def total_mass( motif_vector : MotifVector ) -> jnp.ndarray: """ Parameters: ----------- motif_vector : MotifVector Returns: -------- total_mass : jnp.ndarray """ motif_categories = _motif_categories() total_mass = 0 for strandlength in range(1,motif_vector.motiflength-1): total_mass += strandlength*jnp.sum(motif_vector.motifs.val[motif_categories[0].format(strandlength)]) total_mass += jnp.sum(motif_vector.motifs.val[motif_categories[-3]].flatten()) total_mass += jnp.sum(motif_vector.motifs.val[motif_categories[-1]].flatten())*(motif_vector.motiflength-2) total_mass += jnp.sum(motif_vector.motifs.val[motif_categories[-2]].flatten()) return total_mass
[docs] def total_mass_of_motif_concentration_trajectory_array( motif_concentration_trajectory_array : jnp.ndarray, number_of_letters : int = 4, motiflength : int = 4 ) -> jnp.ndarray: strandmasses = jnp.concatenate([jnp.array([strandlength,]*(number_of_letters**strandlength)) for strandlength in range(1,motiflength-1)] + [jnp.array([1,]*(number_of_letters**(motiflength-1)))] + [jnp.array([1,]*(number_of_letters**(motiflength)))] + [jnp.array([motiflength-2,]*(number_of_letters**(motiflength-1)))]) return jnp.vecdot(strandmasses[None], motif_concentration_trajectory_array, axis=1)
[docs] def total_mass_trajectory( motif_trajectory : MotifTrajectory ) -> jnp.ndarray: motif_categories = _motif_categories() total_mass_trajectory = jnp.zeros(motif_trajectory.times.size) for strandlength in range(1,motif_trajectory.motiflength-1): strand_mass_trajectory = strandlength*jnp.sum( motif_trajectory.motifs.val[motif_categories[0].format(strandlength)].reshape(motif_trajectory.times.size,-1), axis=-1) total_mass_trajectory = total_mass_trajectory.at[:].add(strand_mass_trajectory) # beginnings strand_mass_trajectory = jnp.sum( motif_trajectory.motifs.val[motif_categories[-3]].reshape(motif_trajectory.times.size,-1), axis=-1 ) total_mass_trajectory = total_mass_trajectory.at[:].add(strand_mass_trajectory) # continuations strand_mass_trajectory = jnp.sum( motif_trajectory.motifs.val[motif_categories[-2]].reshape(motif_trajectory.times.size,-1), axis=-1 ) total_mass_trajectory = total_mass_trajectory.at[:].add(strand_mass_trajectory) # endings strand_mass_trajectory = jnp.sum( motif_trajectory.motifs.val[motif_categories[-1]].reshape(motif_trajectory.times.size,-1), axis=-1 ) total_mass_trajectory = total_mass_trajectory.at[:].add((motif_trajectory.motiflength-2)*strand_mass_trajectory) return total_mass_trajectory