import jax
import jax.numpy as jnp
import numpy as np
from .fourmer_production_rates import compute_motif_extensions, _set_invalid_logc_to_log0
[docs]
def collisions_from_motif_concentration_trajectory_and_collision_rate_constants(
motif_concentration_trajectory,
collision_rate_constants = 1,
flatten : bool = True
):
'''
The collisions give how often a certain ligation (channel)
can be observed.
It measures, how often the reactants (end motif of the rear strand,
beginning motif of the fore strand and template motif)
that perform the reaction, meet (are exposed to each other).
Technically, it integrates over time the three
concentrations reaction wise,
according to effective ligation dynamics.
Parameters:
-----------
motif_concentration_trajectory : concentration
collision_rate_constants: array, optional (default: None)
weights of the integral. If None, they are set to one.
flatten : boolean, optional (default: True)
if True, output is flat
Returns:
--------
field : NIFTy-field
field with each entry being the integral over the corresponding
concentrations of the reactants that are needed for the templated
ligation.
'''
#'length{}strand','beginning','continuation','end'
endings_trajectory = motif_concentration_trajectory.motifs['ending']
beginnings_trajectory = motif_concentration_trajectory.motifs['beginning']
templates_trajectory = motif_concentration_trajectory.motifs
return collisions
[docs]
def motifs_collisions_array_from_motifs_array(
motifs_array : jax.Array,
collision_order : int = 3
) -> jax.Array:
"""
takes motifs_array and calculates motifs_collisions_array, i.e.
the $m$th outer product of the motifs array with itself,
where $m$ is the collision_order.
Parameters
----------
motifs_array : jax.Array
collision_order : int, optional
number of collisions, by default 3
Returns
-------
jax.Array
motifs_collisions_array
"""
motifs_collisions_array = motifs_array
for ii in range(1,collision_order):
motifs_collisions_array = motifs_collisions_array.reshape(-1,1)*motifs_array.reshape(-1,1).T
return motifs_collisions_array.flatten()
[docs]
def collisions_from_motif_concentration_trajectory_array_and_collision_rate_constants_array(
motif_concentration_trajectory_array : jax.Array,
motif_concentration_trajectory_times_array : jax.Array = None,
collision_rate_constants_array : jax.Array = 1.,
motiflength : int = 4,
alphabet : list = ['A', 'T'],
complements : list = [1,0],
concentrations_are_logarithmised : bool = True,
pseudo_count_concentration : float = 1.e-12
):
"""
Jax version of collions_from_motif_concentration_trajectory_and_collision_rate_constants
"""
if motiflength != 4:
raise NotImplementedError("Function only implemented for motiflength 4.")
nol = len(alphabet)
# for arbitrary motiflength:
# left_reactants = motif_concentration_trajectory_array.at[(slice(None),)+(slice(None),)*(motiflength-1)+(0,)]
# right_reactants = motif_concentration_trajectory_array.at[(slice(None),)+(0,)]
# templates = motif_concentration_trajectory_array.at[(slice(None),)+(slice(None),)*(motiflength//2-2)+(slice(None),slice(1,None))]
# collisions = jnp.zeros((nol+1,)*(motiflength-2)+(nol,nol)+(nol+1,)*(motiflength-2)+(nol+1,)+(nol))
# for strandlength in range(1,motiflength-1):
# ...
logc = jnp.zeros((len(jnp.asarray(motif_concentration_trajectory_times_array)),)+(nol+1,)*motiflength)
if not concentrations_are_logarithmised:
motif_concentration_trajectory_array = motif_concentration_trajectory_array.at[motif_concentration_trajectory_array<pseudo_count_concentration].set(pseudo_count_concentration)
logc = logc.at[:,:,1:,:,:].set(jnp.log(motif_concentration_trajectory_array))
concentrations_are_logarithmised = True
else:
logc = logc.at[:,:,1:,:,:].set(motif_concentration_trajectory_array)
if not isinstance(collision_rate_constants_array,(jax.Array,list)):
collision_rate_constants_array = collision_rate_constants_array*jnp.ones((nol+1,)*(2*motiflength))
collisions = jnp.array([[
compute_motif_extensions(
jnp.asarray(logc)[ii],#logc or c
basic_rate_constants = jnp.asarray(collision_rate_constants_array),
motiflength = motiflength,
number_of_letters = len(alphabet)
)[:,1:,1:,:,:,1:,1:,:].reshape(-1)
] for ii in range(len(motif_concentration_trajectory_times_array))])
collisions = collisions.reshape(collisions.shape[0],-1)
collisions = jnp.trapezoid(
collisions,
x=motif_concentration_trajectory_times_array,
axis=0
)
# collisions_= collisions.reshape()
return collisions