Source code for src.infer.motif_production_rate_constants_extension_matrix
import jax.numpy as jnp
from jax.experimental import sparse
from itertools import product as iterprod
from .motif_production_transition_kernel import _hybridization_site_categories
[docs]
def motif_production_rate_constants_extension_matrix(
number_of_letters : int = 2,
motiflength : int = 4,
hybridization_length_max : int = 4
) -> sparse.BCOO:
"""
generates a matrix as jax.experimental.sparse.BCOO matrix
that transforms hybridization site formated vectors into collision format.
Parameters:
-----------
number_of_letters : int
default : 2
motiflength : int
default : 4
hybridization_length_max : int
default : 4
"""
raise NotImplementedError("motif_production_rate_constants_extension_matrix is under construction")
hybridization_site_caterories, hybridization_configuration_indices = _hybridization_site_categories(number_of_letters, hybridization_length_max)
motif_indices = jnp.concatenate([jnp.zeros(1),jnp.cumsum(number_of_letters**jnp.arange(1,motiflength+1)),jnp.array([jnp.sum(number_of_letters**jnp.arange(1,motiflength+1))+number_of_letters**(motiflength-1)])],dtype=int)
template_indices = (motif_indices-motif_indices[1]).at[0].set(0)
number_of_left_reactants = number_of_right_reactants = int(motif_indices[-3])
number_of_templates = int(template_indices[-1])
mprcem_shape = (number_of_left_reactants,number_of_right_reactants, number_of_templates, hybridization_configuration_indices[-1])
motif_production_rate_constants_extension_matrx = sparse.BCOO((jnp.array([0]),jnp.array([[0]*len(mprcem_shape)])),shape=mprcem_shape)
for a_index in range(len(hybridization_site_caterories)):
a = hybridization_site_caterories[a_index]
left_shift = a[-1]
left_ligant_length = a[1]
right_ligant_length = a[2]
template_length = a[3]
right_shift = left_shift + right_ligant_length + left_ligant_length - template_length
if left_shift==-1:
left_ligant_lengths = jnp.arange(a[1],motiflength)
# blunt end gets extended iff hybridization_length=hybridization_length_max and the ligation spot is at the center or further apart.
# for the other part of the complex, the same rules apply as for every hybridization_length: for a dangling end, the dangling strand/motif can be extended
# for a blunt ent that is close to the ligation spot, the complex is supposed to end there
# This way, we only extend blunt ends on both sides, if the ligation_spot is at the center
# If hybridization_length_max is uneven, the right central ligation spot is treated as the center,
# so the center is at hybridization_length_max-hybridization_length_max//2 from the left side and hybridization_lengh_max//2 from the right side apart
elif (a[0]==hybridization_length_max) and (left_shift==0) and (a[1]>=hybridization_length_max-hybridization_length_max//2):
# continue left ligant
left_ligant_lengths = jnp.arange(a[1],motiflength)
else:
left_ligant_lengths = [a[1]]
if right_shift==1:
right_ligant_lengths = jnp.arange(a[2],motiflength)
elif (a[0]==hybridization_length_max) and (right_shift==0) and (a[2]>=hybridization_length_max//2):
right_ligant_lengths = jnp.arange(a[2],motiflength)
else:
right_ligant_lengths = [a[2]]
template_might_continue_forwards = (left_shift==1) or ((a[0]==hybridization_length_max) and (left_shift==0))
template_might_continue_backwards = (right_shift==-1) or ((a[0]==hybridization_length_max) and (right_shift==0))
template_might_continue_both_ways = (template_might_continue_forwards and template_might_continue_backwards)
if template_might_continue_both_ways:
template_overhangs = jnp.arange(motiflength-a[3])
elif template_might_continue_forwards: #(but not backwards)
template_overhangs = jnp.arange(motiflength-a[3])
else:
template_overhangs = jnp.arange(1)
for left_ligant_length, right_ligant_length, template_overhang in iterprod(left_ligant_lengths, right_ligant_lengths, template_overhangs):
if template_overhang+a[3] == motiflength:
template_lengths = jnp.array([motiflength])
elif template_might_continue_backwards or template_might_continue_forwards:
template_lengths = jnp.arange(a[3]+template_overhang, motiflength)
else:
template_lengths = jnp.array([a[3]+template_overhang])
for template_length in template_lengths:
if (template_length == (motiflength-1)) and template_might_continue_backwards and not template_might_continue_forwards:
template_index_start = template_indices[-1-1]
else:
template_index_start = template_indices[template_length-1]
matrix_indices = [
(il*number_of_letters**(a[1])+i1+motif_indices[left_ligant_length-1],
i2*number_of_letters**(right_ligant_length-a[2])+ir+motif_indices[right_ligant_length-1],
int(itc*number_of_letters**(template_overhang+a[3])+i3*number_of_letters**(template_overhang)+io+template_index_start),
int(i1*number_of_letters**(a[2]+a[3])+i2*number_of_letters**a[3]+i3+hybridization_configuration_indices[a_index]))
for i1,i2,i3 in iterprod(
range(number_of_letters**a[1]),
range(number_of_letters**a[2]),
range(number_of_letters**a[3])
)
for il,ir,io,itc in iterprod(
range(number_of_letters**int(left_ligant_length-a[1])),
range(number_of_letters**int(right_ligant_length-a[2])),
range(number_of_letters**int(template_overhang)),
range(number_of_letters**int(template_length-template_overhang-a[3])))
]
motif_production_rate_constants_extension_matrx = motif_production_rate_constants_extension_matrx + sparse.BCOO((jnp.array([1]*len(matrix_indices)),matrix_indices),shape=(mprcem_shape))
return motif_production_rate_constants_extension_matrx