Source code for src.infer.fourmer_production_rates

import jax.numpy as jnp
from warnings import warn

from jax import config
config.update("jax_enable_x64", True)

from ._rates_utils import _set_invalid_logc_to_log0
from ._rates_utils import _set_invalid_logc_diff_to_zero
from ._rates_utils import _cut_low_concentrations
from ._smooth_rate_clipping import _clip_smoothly

def _initialize_empty_fourmer_production_rates(
        number_of_letters : int = 4,
        motiflength : int = 4
        ) -> jnp.ndarray :
    if motiflength != 4:
        raise NotImplementedError("motiflength!=4")
    shape = (number_of_letters+1,)*(motiflength-1)
    shape = shape*2+(number_of_letters+1,)*motiflength
    return jnp.zeros(shape)

def _set_invalid_production_rates_to_zero(
        reaction_rates : jnp.ndarray,
        ) -> jnp.ndarray:
    rr = reaction_rates.at[:,0].set(0)
    rr = rr.at[:,:,0].set(0)
    rr = rr.at[:,:,:,:,:,0].set(0)
    return rr.at[:,:,:,:,:,:,0].set(0)

def _set_invalid_log_production_rates_to_logzero(
        log_reaction_rates : jnp.ndarray,
        pseudo_count_concentration : float = 1.e-12
        ) -> jnp.ndarray:
    log0 = jnp.log(pseudo_count_concentration)
    rr = log_reaction_rates.at[:,:,0].set(log0)
    rr = rr.at[:,0].set(log0)
    rr = rr.at[:,:,0].set(log0)
    rr = rr.at[:,:,:,:,:,0].set(log0)
    return rr.at[:,:,:,:,:,:,0].set(log0)

def _set_invalid_rates_to_zero(
        reaction_rates : jnp.ndarray,
        ) -> jnp.ndarray:
    """
    set invalid rates to zeros.
    Format of the rates: [l0,l1,l2,r1,r2,r3,t0,t1,t2,t3]
    """
    rr = reaction_rates.at[:,:,0].set(0)
    rr = rr.at[1:,0].set(0)
    rr = rr.at[:,:,:,0].set(0)
    rr = rr.at[:,:,:,:,0,1:].set(0)
    rr = rr.at[:,:,:,:,:,:,:,0,:,:].set(0)
    return rr.at[:,:,:,:,:,:,:,:,0,:].set(0)

def _set_invalid_log_rates_to_logzero(
        log_reaction_rates : jnp.ndarray,
        pseudo_count_concentration : float = 1.e-12
        ) -> jnp.ndarray:
    log0 = jnp.log(pseudo_count_concentration)
    lrr = log_reaction_rates.at[:,:,0].set(log0)
    lrr = lrr.at[1:,0].set(log0)
    lrr = lrr.at[:,:,:,0].set(log0)
    lrr = lrr.at[:,:,:,:,0,1:].set(log0)
    lrr = lrr.at[:,:,:,:,:,:,:,0,:,:].set(log0)
    return lrr.at[:,:,:,:,:,:,:,:,0,:].set(log0)

def _add_log_rate_constants(
        log_rate_constants : jnp.ndarray,
        log_reaction_rates : jnp.ndarray,
        pseudo_count_concentration : float = 1.e-12

        ) -> jnp.ndarray:
    return log_reaction_rates.at[:,:,1:,1:,:,:,:,1:,1:,:].add(log_rate_constants.reshape((1,)+log_rate_constants.shape[:len(log_rate_constants.shape)//2]+(1,)+log_rate_constants.shape[len(log_rate_constants.shape)//2:]))

def _multiply_rr_with_rc(rr,rc):
    return rr.at[:,:,1:,1:,:,:,:,1:,1:,:].multiply(rc[None,:,:,:,:,None])

def _add_template_logc(
        fourmer_logc : jnp.ndarray,
        log_reaction_rates : jnp.ndarray,
        number_of_letters : int = 4,
        pseudo_count_concentration : float = 1.e-12
        ) -> jnp.ndarray:
    motiflength = 4
    fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.)
    return log_reaction_rates.at[:,:,:,:,:,:,:,1:,1:,:].add(fourmer_logc[:,1:,1:,:].reshape((1,)*(2*(motiflength-1))+(number_of_letters+1,number_of_letters,number_of_letters,number_of_letters+1)))

def _add_extended_end_logc(
        fourmer_logc : jnp.ndarray,
        log_reaction_rates : jnp.ndarray,
        number_of_letters : int = 4,
        pseudo_count_concentration : float = 1.e-12
        ) -> jnp.ndarray:
    motiflength = 4
    fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.)
    # add monomers
    lrr = log_reaction_rates.at[0,0,1:].add(fourmer_logc[0,1:,0,0].reshape((number_of_letters,)+(1,)*(motiflength-1)+(1,)*motiflength))
    # add ending motifs
    return lrr.at[:,1:,1:].add(
            fourmer_logc[:,1:,1:,0].reshape(
                (number_of_letters+1,number_of_letters,number_of_letters)+(1,)*(motiflength-1)+(1,)*motiflength
                )
            )

def _add_extending_beginning_logc(
        fourmer_logc : jnp.ndarray,
        log_reaction_rates : jnp.ndarray,
        motiflength : int = 4,
        pseudo_count_concentration : float = 1.e-12
        ) -> jnp.ndarray:
    if motiflength != 4:
        raise NotImplementedError("motiflength!=4")
    fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.)
    # monomers (0100) and dimers (0110)
    lrr = log_reaction_rates.at[:,:,:,1:,:,0].add(
            fourmer_logc[0,1:,:,0].reshape(
                (1,)*(motiflength-1)+fourmer_logc[0,1:,:,0].shape+(1,)*(motiflength)
            )
            )
    # beginnings 0111
    lrr = lrr.at[:,:,:,1:,1:,1:].add(
            fourmer_logc[0,1:,1:,1:].reshape(
                (1,)*(motiflength-1)+fourmer_logc[0,1:,1:,1:].shape+(1,)*(motiflength)
            )
            )
    return lrr

def _subtract_produced_fourmers(
        fourmer_logc : jnp.ndarray,
        log_reaction_rates : jnp.ndarray,
        product_index : int = 0,
        motiflength : int = 4,
        pseudo_count_concentration : float  = 1.e-12
        ) -> jnp.ndarray:
    if motiflength != 4:
        raise NotImplementedError("motiflength != 4")
    # only subtract product, where term is nonzero (so greater log0)
    log0 = jnp.log(pseudo_count_concentration)
    if product_index==0:
        lrr = log_reaction_rates.at[:,0,:].set(log0-1)
    elif product_index==2:
        lrr = log_reaction_rates.at[:,:,:,:,0].set(log0-1)
    else:
        lrr = log_reaction_rates
    fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.)
    nonzero = lrr > log0
    produced_fourmer_logc = _set_invalid_logc_diff_to_zero(fourmer_logc, print_warning = False)
    produced_fourmer_logc = produced_fourmer_logc.reshape(
                (1,)*product_index + produced_fourmer_logc.shape + (1,)*(2*(motiflength-1)-product_index)
                )
    produced_fourmer_logc = nonzero*produced_fourmer_logc
    lrr = lrr.at[:].subtract(
            produced_fourmer_logc
            )
    return _set_invalid_log_rates_to_logzero(lrr, pseudo_count_concentration/10.)

def _exponentiate_log_reaction_rates(
        log_reaction_rates : jnp.ndarray
        ) -> jnp.ndarray:
    return jnp.exp(log_reaction_rates)

def _compute_extended_end_motif_reaction_logc_rates(
        fourmer_logc : jnp.ndarray,
        log_rate_constants : jnp.ndarray = 0.,
        rate_constants : jnp.ndarray = 1.,
        motiflength : int = 4,
        number_of_letters : int = 4,
        pseudo_count_concentration : float = 1.e-12,
        soft_reactant_threshold : float = None,
        hard_reactant_threshold : float = None,
        ):
    if soft_reactant_threshold is None:
        soft_reactant_threshold = pseudo_count_concentration
    if hard_reactant_threshold is None:
        hard_reactant_threshold = soft_reactant_threshold/2.
    fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.)
    log_rate_constants, rate_constants = _split_and_shape_fprc(
            (log_rate_constants, rate_constants),
            number_of_letters=number_of_letters,
            motiflength=motiflength,
            pseudo_count_concentration=pseudo_count_concentration
            )
    lrr = _initialize_empty_fourmer_production_rates(number_of_letters)
    lrr = _add_log_rate_constants(log_rate_constants,lrr)
    lrr = _set_invalid_log_rates_to_logzero(lrr, pseudo_count_concentration/10.)
    lrr = _add_template_logc(fourmer_logc,lrr, number_of_letters, pseudo_count_concentration)
    lrr = _add_extending_beginning_logc(fourmer_logc,lrr)
    rr = _exponentiate_log_reaction_rates(lrr)
    # clip smoothly
    rr = _clip(fourmer_logc, rr, soft_reactant_threshold, hard_reactant_threshold)
    rr = _multiply_rr_with_rc(rr,rate_constants)
    return -_set_invalid_rates_to_zero(rr)

def _clip(fourmer_logc, rr, soft_reactant_threshold, hard_reactant_threshold):
    # check for extensive or excessive format
    if rr.shape[2]<fourmer_logc.shape[1]:
        fourmer_logc = fourmer_logc[:,1:]
    elif rr.shape[2]>fourmer_logc.shape[1]:
        fourmer_logc = (hard_reactant_threshold*jnp.ones(
                (fourmer_logc.shape[0],)*len(fourmer_logc.shape)
                )).at[:,1:].set(
                        fourmer_logc
                )
    weight = _clip_smoothly(
            jnp.exp(fourmer_logc),
            soft_reactant_threshold=soft_reactant_threshold,
            hard_reactant_threshold=hard_reactant_threshold)
    return weight*rr

def _compute_extending_beginning_motif_reaction_logc_rates(
        fourmer_logc : jnp.ndarray,
        log_rate_constants : jnp.ndarray = 0.,
        rate_constants : jnp.ndarray = 1.,
        motiflength : int = 4,
        number_of_letters : int = 4,
        pseudo_count_concentration : float = 1.e-12,
        soft_reactant_threshold : float = None,
        hard_reactant_threshold : float = None,
        ) -> jnp.ndarray :
    if soft_reactant_threshold is None:
        soft_reactant_threshold = pseudo_count_concentration
    if hard_reactant_threshold is None:
        hard_reactant_threshold = soft_reactant_threshold/2.
    fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.)
    log_rate_constants, rate_constants = _split_and_shape_fprc(
            (log_rate_constants, rate_constants),
            number_of_letters=number_of_letters,
            motiflength=motiflength,
            pseudo_count_concentration=pseudo_count_concentration
            )
    lrr = _initialize_empty_fourmer_production_rates(number_of_letters)
    lrr = _add_log_rate_constants(log_rate_constants,lrr)
    lrr = _set_invalid_log_rates_to_logzero(lrr, pseudo_count_concentration/10.)
    lrr = _add_template_logc(fourmer_logc, lrr, number_of_letters, pseudo_count_concentration)
    lrr = _add_extended_end_logc(fourmer_logc, lrr, number_of_letters)
    rr = _exponentiate_log_reaction_rates(lrr)
    # clip smoothly
    rr = _clip(fourmer_logc, rr, soft_reactant_threshold, hard_reactant_threshold)
    rr = _multiply_rr_with_rc(rr,rate_constants)
    return -_set_invalid_rates_to_zero(rr)

def _compute_produced_motif_reaction_logc_rates(
        fourmer_logc : jnp.ndarray,
        log_rate_constants : jnp.ndarray = 0.,
        rate_constants : jnp.ndarray = 1.,
        product_index : int = 0,
        number_of_letters : int = 4,
        motiflength : int = 4,
        pseudo_count_concentration : float = 1.e-12,
        soft_reactant_threshold : float = None,
        hard_reactant_threshold : float = None,
        ) -> jnp.ndarray:
    if soft_reactant_threshold is None:
        soft_reactant_threshold = pseudo_count_concentration
    if hard_reactant_threshold is None:
        hard_reactant_threshold = soft_reactant_threshold/2.
    fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.)
    log_rate_constants, rate_constants = _split_and_shape_fprc(
            (log_rate_constants, rate_constants),
            number_of_letters=number_of_letters,
            motiflength=motiflength,
            pseudo_count_concentration=pseudo_count_concentration
            )
    lrr = _initialize_empty_fourmer_production_rates(number_of_letters)
    lrr = _add_log_rate_constants(log_rate_constants, lrr)
    lrr = _set_invalid_log_rates_to_logzero(lrr, pseudo_count_concentration/10.)
    lrr = _add_template_logc(fourmer_logc,lrr,number_of_letters, pseudo_count_concentration)
    lrr = _add_extended_end_logc(fourmer_logc,lrr, number_of_letters)
    lrr = _add_extending_beginning_logc(fourmer_logc,lrr)
    lrr = _subtract_produced_fourmers(
            fourmer_logc,
            lrr,
            product_index=product_index
            )
    rr = _exponentiate_log_reaction_rates(lrr)
    # clip smoothly
    rr = _clip(fourmer_logc, rr, soft_reactant_threshold, hard_reactant_threshold)
    rr = _multiply_rr_with_rc(rr, rate_constants)
    return +_set_invalid_rates_to_zero(rr)

[docs] def compute_logc_diff_extended_end_motifs( fourmer_logc : jnp.ndarray, log_rate_constants : jnp.ndarray = 0., rate_constants : jnp.ndarray = 1., motiflength : int = 4, number_of_letters : int = 4, pseudo_count_concentration : float = 1.e-12, soft_reactant_threshold : float = None, hard_reactant_threshold : float = None ) -> jnp.ndarray: log_rate_constants, rate_constants = _split_and_shape_fprc( (log_rate_constants, rate_constants), number_of_letters=number_of_letters, motiflength=motiflength, pseudo_count_concentration=pseudo_count_concentration ) fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.) if motiflength!=4: raise NotImplementedError("motiflength!=4") logc_diff = jnp.zeros((number_of_letters+1,)*motiflength) # monomers extended_end_motif_reaction_rates = _compute_extended_end_motif_reaction_logc_rates( fourmer_logc, log_rate_constants, rate_constants, motiflength, number_of_letters, pseudo_count_concentration, soft_reactant_threshold=soft_reactant_threshold, hard_reactant_threshold=hard_reactant_threshold ) logc_diff = logc_diff.at[0,1:,0,0].add( jnp.sum( extended_end_motif_reaction_rates[0,0,1:], axis = tuple([ii for ii in range(1,len(extended_end_motif_reaction_rates[0,0,1:].shape))] ), ) ) # motifs logc_diff = logc_diff.at[:,1:,1:,0].add( jnp.sum(extended_end_motif_reaction_rates[:,1:,1:], axis = tuple([ii for ii in range((motiflength-1),len(extended_end_motif_reaction_rates[:,1:,1:].shape))]) ) ) return _set_invalid_logc_diff_to_zero(logc_diff, print_warning = False)
[docs] def compute_logc_diff_extending_beginning_motifs( fourmer_logc : jnp.ndarray, log_rate_constants : jnp.ndarray = 0., rate_constants : jnp.ndarray = 1., motiflength : int = 4, number_of_letters : int = 4, pseudo_count_concentration : float = 1.e-12, soft_reactant_threshold : float = None, hard_reactant_threshold : float = None ) -> jnp.ndarray: fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.) logc_diff = jnp.zeros((number_of_letters+1,)*motiflength) log_rate_constants, rate_constants = _split_and_shape_fprc( (log_rate_constants, rate_constants), number_of_letters=number_of_letters, motiflength=motiflength, pseudo_count_concentration=pseudo_count_concentration ) extending_beginning_motif_reaction_rates = _compute_extending_beginning_motif_reaction_logc_rates( fourmer_logc, log_rate_constants, rate_constants, motiflength, number_of_letters, pseudo_count_concentration, soft_reactant_threshold=soft_reactant_threshold, hard_reactant_threshold=hard_reactant_threshold ) # monomers and dimers logc_diff = logc_diff.at[0,1:,:,0].add( jnp.sum(extending_beginning_motif_reaction_rates[:,:,:,1:,:,0], axis = tuple([ii for ii in range(motiflength-1)])+tuple( [ii for ii in range(2*motiflength-2-1,len(extending_beginning_motif_reaction_rates.shape)-1)]) ) ) # beginnings logc_diff = logc_diff.at[0,1:,1:,1:].add( jnp.sum(extending_beginning_motif_reaction_rates[:,:,:,1:,1:,1:], axis = tuple([ii for ii in range(motiflength-1)])+tuple([ii for ii in range(2*motiflength-2,len(extending_beginning_motif_reaction_rates.shape))]) ) ) return _set_invalid_logc_diff_to_zero(logc_diff, print_warning = False)
[docs] def compute_logc_diff_produced_motifs( fourmer_logc : jnp.ndarray, log_rate_constants : jnp.ndarray = 0., rate_constants : jnp.ndarray = 1., motiflength : int = 4, number_of_letters : int = 4, pseudo_count_concentration : float = 1.e-12, soft_reactant_threshold : float = None, hard_reactant_threshold : float = None ) -> jnp.ndarray: fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.) logc_diff = jnp.zeros((number_of_letters+1,)*motiflength) log_rate_constants, basic_rate_constants = _split_and_shape_fprc( (log_rate_constants, rate_constants), number_of_letters=number_of_letters, motiflength=motiflength, pseudo_count_concentration=pseudo_count_concentration ) for product_index in range(motiflength-1): produced_motif_reaction_rates = _compute_produced_motif_reaction_logc_rates( fourmer_logc, log_rate_constants, basic_rate_constants, product_index, number_of_letters, motiflength=motiflength, pseudo_count_concentration=pseudo_count_concentration, soft_reactant_threshold=soft_reactant_threshold, hard_reactant_threshold=hard_reactant_threshold ) pmrr = jnp.sum( produced_motif_reaction_rates, axis = tuple( [ii for ii in range(product_index)]) + tuple( [ii for ii in range(motiflength+product_index,3*motiflength-2)]) ) # dimers, beginning, ends and continuations logc_diff = logc_diff.at[:,1:,1:,:].add(pmrr[:,1:,1:,:]) return _set_invalid_logc_diff_to_zero(logc_diff, print_warning = False)
[docs] def compute_total_extension_rates( fourmer_logc : jnp.ndarray, log_rate_constants : jnp.ndarray = 0., rate_constants : jnp.ndarray = 1., motiflength : int = 4, number_of_letters : int = 4, pseudo_count_concentration : float = 1.e-12, soft_reactant_threshold : float = None, hard_reactant_threshold : float = None, conserve_mass : bool = False, fourmer_logc_slope : float = 0., logc_diff_slope : float = 0. ) -> jnp.ndarray: """ Parameters: ----------- conserve_mass : bool whether concentration of monomers shall compensate variation in the total mass (deprecated! Please use mass_correction_rates for this from now on, since this feature leads to inconsistencies in the zebraness and strand number) default False Returns: -------- total_extension_rates : jnp.ndarray """ fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.) log_rate_constants, basic_rate_constants = _split_and_shape_fprc( (log_rate_constants, rate_constants), number_of_letters=number_of_letters, motiflength=motiflength, pseudo_count_concentration=pseudo_count_concentration ) # initialize logc_diff logc_diff = jnp.zeros((number_of_letters+1,)*motiflength) # consume_left_reactant(self): logc_diff = jnp.add(logc_diff, compute_logc_diff_extended_end_motifs( fourmer_logc, log_rate_constants, rate_constants, motiflength, number_of_letters, pseudo_count_concentration, soft_reactant_threshold=soft_reactant_threshold, hard_reactant_threshold=hard_reactant_threshold ) ) # consume_right_reactant(self): logc_diff = jnp.add(logc_diff, compute_logc_diff_extending_beginning_motifs( fourmer_logc, log_rate_constants, rate_constants, motiflength, number_of_letters, pseudo_count_concentration, soft_reactant_threshold=soft_reactant_threshold, hard_reactant_threshold=hard_reactant_threshold ) ) # create_products logc_diff = jnp.add(logc_diff, compute_logc_diff_produced_motifs( fourmer_logc, log_rate_constants, rate_constants, motiflength, number_of_letters, pseudo_count_concentration, soft_reactant_threshold=soft_reactant_threshold, hard_reactant_threshold=hard_reactant_threshold ) ) if conserve_mass: c_diff = jnp.exp(fourmer_logc)*logc_diff c_diff = c_diff.at[0,1:,0,0].set(-jnp.sum(c_diff[:,1:,1:,:],axis=(0,2,3))-jnp.sum(c_diff[:,1:,1:,0],axis=(0,1))) logc_diff = jnp.exp(-fourmer_logc)*c_diff return _set_invalid_logc_diff_to_zero(logc_diff, print_warning = False)
[docs] def compute_motif_extensions( fourmer_logc : jnp.ndarray, log_rate_constants : jnp.ndarray = 0., basic_rate_constants : jnp.ndarray = 1., motiflength : int = 4, number_of_letters : int = 4, pseudo_count_concentration : float = 1.e-12, soft_reactant_threshold : float = None, hard_reactant_threshold : float = None ) -> jnp.ndarray: if soft_reactant_threshold is None: soft_reactant_threshold = pseudo_count_concentration if hard_reactant_threshold is None: hard_reactant_threshold = soft_reactant_threshold/2. fourmer_logc = _set_invalid_logc_to_log0(fourmer_logc, pseudo_count_concentration/10.) log_rate_constants, basic_rate_constants = _split_and_shape_fprc( (log_rate_constants, basic_rate_constants), number_of_letters=number_of_letters, motiflength=motiflength, pseudo_count_concentration=pseudo_count_concentration ) lrr = _initialize_empty_fourmer_production_rates(number_of_letters) lrr = _add_log_rate_constants(log_rate_constants, lrr) lrr = _set_invalid_log_rates_to_logzero(lrr, pseudo_count_concentration/10.) lrr = _add_extended_end_logc(fourmer_logc,lrr, number_of_letters) lrr = _add_extending_beginning_logc(fourmer_logc,lrr) lrr = _add_template_logc(fourmer_logc,lrr, number_of_letters, pseudo_count_concentration) rr = _exponentiate_log_reaction_rates(lrr) #rr = rr.at[rr<=pseudo_count_concentration].set(0.) rr = _clip(fourmer_logc, rr, soft_reactant_threshold, hard_reactant_threshold) rr = _multiply_rr_with_rc(rr,basic_rate_constants) return jnp.sum(rr, axis=(0,2*(motiflength-1)-1))
def _split_and_shape_fprc( lrc_brc, number_of_letters = 4, motiflength = 4, pseudo_count_concentration = 1.e-12, ): lrc, brc = lrc_brc lrc = _shape_fprc( lrc, number_of_letters = number_of_letters, motiflength = motiflength, pseudo_count_concentration = pseudo_count_concentration, fprc_are_logarithmized=True, shape_is_thin=True ) brc = _shape_fprc( brc, number_of_letters = number_of_letters, motiflength = motiflength, pseudo_count_concentration = pseudo_count_concentration, fprc_are_logarithmized=False, shape_is_thin=True ) return lrc,brc def _shape_fprc( fourmer_production_rate_constants, number_of_letters : int = 4, motiflength : int = 4, pseudo_count_concentration : float = 1.e-12, fprc_are_logarithmized : bool = False, shape_is_thin : bool = False ) -> jnp.ndarray: motif_production_rate_constants_shape = (number_of_letters+1,)*(2*motiflength) if isinstance(fourmer_production_rate_constants,(float,int)): eerc = float(fprc_are_logarithmized)*jnp.ones(motif_production_rate_constants_shape) fourmer_production_rate_constants = eerc.at[:,1:,1:,:,:,1:,1:,:].set(fourmer_production_rate_constants) elif fourmer_production_rate_constants.size == (number_of_letters+1)**motiflength*number_of_letters**motiflength: eerc_new = float(fprc_are_logarithmized)*jnp.ones(motif_production_rate_constants_shape) fourmer_production_rate_constants = eerc_new.at[:,1:,1:,:,:,1:,1:,:].set(fourmer_production_rate_constants) if fourmer_production_rate_constants.shape != motif_production_rate_constants_shape: raise ValueError(f"{fourmer_production_rate_constants.shape=} != {motif_production_rate_constants_shape=}") if fprc_are_logarithmized: if shape_is_thin: #thin_shape = (number_of_letters+1,number_of_letters) #thin_shape += shape[::-1] #thin_shape += shape return fourmer_production_rate_constants[:,1:,1:,:,:,1:,1:,:] else: return _set_invalid_log_production_rates_to_logzero( fourmer_production_rate_constants, pseudo_count_concentration=pseudo_count_concentration ) else: if shape_is_thin: return fourmer_production_rate_constants[:,1:,1:,:,:,1:,1:,:] else: return _set_invalid_production_rates_to_zero( fourmer_production_rate_constants, )