Source code for src.infer.fourmer_mass_correction

import jax.numpy as jnp
from warnings import warn

from jax import config
config.update("jax_enable_x64", True)

def _nucleotide_mass(fourmer_log_concentration_vector):
    return jnp.sum(jnp.exp(fourmer_log_concentration_vector[:,1:]),axis=(0,2,3))+jnp.sum(jnp.exp(fourmer_log_concentration_vector[:,1:,1:,0]),axis=(0,1))

def _nucl_mass_counter(number_of_letters : int, motiflength : int):
    nucl_mass_counter = jnp.zeros((number_of_letters,)+(number_of_letters+1,)*motiflength)
    for ii in range(number_of_letters):
        nucl_mass_counter = nucl_mass_counter.at[ii,:,ii+1,:,:].add(1)
        nucl_mass_counter = nucl_mass_counter.at[ii,:,1:,ii+1,0].add(1)
    return nucl_mass_counter

def _nucleotide_mass_correction_rate(
        initial_log_concentration_array,
        current_log_concentration_array
        ):
    number_of_letters = initial_log_concentration_array.shape[0]-1
    motiflength = len(initial_log_concentration_array.shape)

    jnp.zeros((number_of_letters,)+initial_log_concentration_array.shape)
    nucl_mass_correction_rate = _nucleotide_mass(initial_log_concentration_array)-_nucleotide_mass(current_log_concentration_array)
    nucl_mass_correction_rate = nucl_mass_correction_rate.reshape((-1,)+(1,)*motiflength)

    nucl_mass_counter = _nucl_mass_counter(number_of_letters, motiflength)
    nucl_mass_correction_rate = nucl_mass_correction_rate * nucl_mass_counter * jnp.exp(current_log_concentration_array)[None]
    #return jnp.matmul(nucl_mass_counter,jnp.exp(fourmer_log_concentration_vector))

    return jnp.sum(nucl_mass_correction_rate, axis=0)

def _nonending_strand_concentration(fourmer_log_concentration_vector):
    beginning_concentration = jnp.sum(jnp.exp(fourmer_log_concentration_vector[0,1:,1:,1:]))
    ending_concentration = jnp.sum(jnp.exp(fourmer_log_concentration_vector[1:,1:,1:,0]))
    return beginning_concentration-ending_concentration

def _nonending_strand_correction_rate(
        fourmer_log_concentration_vector
        ):
    number_of_letters = fourmer_log_concentration_vector.shape[0]-1
    motiflength = len(fourmer_log_concentration_vector.shape)
    nonending_strand_concentration = jnp.zeros((number_of_letters+1,)*motiflength)
    nonending_strand_concentration = nonending_strand_concentration.at[0,1:,1:,1:].subtract(1)
    nonending_strand_concentration = nonending_strand_concentration.at[1:,1:,1:,0].add(1)
    nonending_strand_concentration = nonending_strand_concentration*jnp.exp(fourmer_log_concentration_vector)
    return _nonending_strand_concentration(fourmer_log_concentration_vector)*nonending_strand_concentration

[docs] def mass_correction_rates( initial_log_concentration_array : jnp.array, current_log_concentration_array : jnp.array, weight : float = 1., ) -> jnp.array: return weight * _nonending_strand_correction_rate(current_log_concentration_array) + weight * _nucleotide_mass_correction_rate(initial_log_concentration_array, current_log_concentration_array)