import jax.numpy as jnp
from warnings import warn
import timeit
from functools import partial
from ..obj.times_vector import TimesVector
from ..obj.motif_vector import (MotifVector,
_array_to_motif_vector_dct, _motif_vector_as_array)
from ..obj.motif_trajectory import MotifTrajectory
from ..obj.motif_production_vector import (MotifProductionVector,
_motif_production_vector_as_array)
from ..obj.motif_breakage_vector import (MotifBreakageVector,
_motif_breakage_vector_as_array)
from .fourmer_production_rates import _shape_fprc
from .fourmer_breakage_rates import _shape_brc
from scipy.integrate import solve_ivp
from jax.experimental.ode import odeint
from diffrax import diffeqsolve
import diffrax
import timeit
from .fourmer_production_rates import compute_total_extension_rates
from .fourmer_production_rates import _set_invalid_log_rates_to_logzero
from .fourmer_production_rates import _set_invalid_production_rates_to_zero
from .fourmer_breakage_rates import fourmer_breakage_rates
from .fourmer_mass_correction import mass_correction_rates
[docs]
def fourmer_trajectory_from_rate_constants(
motif_production_rate_constants : MotifProductionVector,
motif_production_log_rate_constants : MotifProductionVector,
breakage_rate_constants : MotifBreakageVector,
initial_motif_concentrations_vector : MotifVector,
times : TimesVector,
complements : list,
mass_correction_rate_constant : float = 0.,
concentrations_are_logarithmized : bool = True,
ode_integration_method : str = 'DOP853',
execution_time_path : str = None,
pseudo_count_concentration : float = 1.e-12,
soft_reactant_threshold : float = None,
hard_reactant_threshold : float = None,
first_step : float = None,
ivp_atol : float = 1.e-6,
ivp_rtol : float = 1.e-3,
) -> MotifTrajectory:
"""
performs a fourmer reactor with the stated rate constants and initial
conditions,
i.e. integrates the motif ode with stated parameters.
Parameters:
-----------
motif_production_rate_constants : MotifProductionVector,
motif_production_log_rate_constants : MotifProductionVector,
breakage_rate_constants : MotifBreakageVector,
initial_motif_concentrations_vector : MotifVector,
times : TimesVector,
complements : list,
mass_correction_rate_constant : float
Rate constant for compensation of numerical mass fluctuations,
default 0.
ode_integration_method : str
default 'DOP853',
execution_time_path : str
if not None, the execution time of the ode integration is saved here
default None,
pseudo_count_concentration : float
default 1.e-12,
first_step : float
default None
ivp_atol : float
atol for scipy.integrate.solve_ivp
default 1.e-6
ivp_rtol : float
rtol for scipy.integrate.solve_ivp
default 1.e-3
Returns:
--------
motif_trajectory : MotifTrajectory
"""
alphabet = initial_motif_concentrations_vector.alphabet
motiflength = initial_motif_concentrations_vector.motiflength
times_vector = times
times = times.val
if soft_reactant_threshold is None:
soft_reactant_threshold = pseudo_count_concentration
if motiflength!=4:
raise ValueError("Motiflength ({}!=4) does not fit fourmer length.".format(motiflength))
unit = initial_motif_concentrations_vector.unit
if motiflength != 4:
raise NotImplementedError("Motiflength needs to be four.")
if breakage_rate_constants != 0.:
breakage_rate_constants_array = _motif_breakage_vector_as_array(breakage_rate_constants)
else:
breakage_rate_constants_array = breakage_rate_constants
#TODO assert breakage_rate_constants and motif_production_rate_constants
# have fitting units with motif_trajectory
initial_motif_concentrations_array = _motif_vector_as_array(initial_motif_concentrations_vector)
motif_trajectory_field, times_array, execution_time = _integrate_motif_rate_equations(
initial_motif_concentrations_array,
number_of_letters = len(alphabet),
motiflength=motiflength,
complements=jnp.asarray(complements),
concentrations_are_logarithmized = concentrations_are_logarithmized,
fourmer_production_log_rate_constants = _motif_production_vector_as_array(motif_production_log_rate_constants),
fourmer_production_rate_constants = _motif_production_vector_as_array(motif_production_rate_constants),
breakage_rate_constants = breakage_rate_constants_array,
mass_correction_rate_constant = mass_correction_rate_constant,
t_eval = times,
ode_integration_method = ode_integration_method,
first_step = first_step,
pseudo_count_concentration=pseudo_count_concentration,
ivp_atol =ivp_atol,
ivp_rtol = ivp_rtol,
soft_reactant_threshold = soft_reactant_threshold,
hard_reactant_threshold = hard_reactant_threshold
)
times_vector = TimesVector(times_array, times_vector.domain[0].units)
motif_trajectory_array = motif_trajectory_field.reshape(times_array.shape+(len(alphabet)+1,)*motiflength)[:,:,1:]
if execution_time_path is not None:
with open(execution_time_path,'a') as f:
f.write('\n'+str(execution_time))
mv = MotifVector(motiflength, alphabet, unit)
motif_vectors = [mv(_array_to_motif_vector_dct(motif_vector_array, motiflength, alphabet)) for motif_vector_array in motif_trajectory_array]
return MotifTrajectory(motif_vectors, times_vector)
def _integrate_motif_rate_equations(
initial_concentration_vector : jnp.ndarray,
number_of_letters : int = 4,
motiflength : int = 4,
complements : list = jnp.array([1,0,3,2]),
concentrations_are_logarithmized : bool = True,
influx_rate_constants : jnp.ndarray = 0.,
fourmer_production_log_rate_constants : jnp.ndarray = 0.,
fourmer_production_rate_constants : jnp.ndarray = 1.,
breakage_rate_constants : jnp.ndarray = 0.,
mass_correction_rate_constant : float = 0.,
t_eval : jnp.ndarray = jnp.arange(0, 50000, 1),
ode_integration_method : str = 'RK45',
pseudo_count_concentration : float = 1.e-12,
first_step = None,
ivp_atol : float = 1.e-3,
ivp_rtol : float = 1.e-6,
soft_reactant_threshold : float = None,
hard_reactant_threshold : float = None
):
"""
Parameters:
-----------
number_of_letters : int, optional
number of letters
default : 4
motiflength : int, optional
length of the tracked motifs
default : 4
complements : jnp.array, optional
array of which letters are complementary to each other
default : jnp.array([1,0,3,2])
concentrations_are_logarithmized : boolean, optional (default True)
specifies whether to use logconcentrations
initial_concentration_vector : array
array of initial concentrations of motifs
default : None
fourmer_production_log_rate_constants : array
array of effective ligation rates
default : None
breakage_rate_constants : float or d-array,
breakage_rate_constants, if float is given, all breakage_rate_constants
are assumed to be the same.
default: 0.
ode_integration_method : string, optional
solve_ivp method
# explicit: # 'RK45' # 'RK23' # 'DOP853'
# implicit: # 'Radau' # 'BDF' # 'LSODA'
default : 'RK45'
ivp_atol : float
atol for scipy.integrate.solve_ivp
default 1.e-6
ivp_rtol : float
rtol for scipy.integrate.solve_ivp
default 1.e-3
Return:
-------
x_0 : NIFTy field
solution of the integrated ODE
"""
if motiflength != 4:
raise NotImplementedError(f"{motiflength=}!=4")
if soft_reactant_threshold is None:
soft_reactant_threshold = pseudo_count_concentration
motif_concentration_vector_shape = (number_of_letters+1,)*(motiflength)
if initial_concentration_vector.size == number_of_letters*(number_of_letters+1)**(motiflength-1):
n0 = jnp.zeros(motif_concentration_vector_shape)
n0 = n0.at[:,1:].add(
initial_concentration_vector.reshape((number_of_letters+1,number_of_letters)+(number_of_letters+1,)*(motiflength-2))
)
else:
n0 = jnp.asarray(initial_concentration_vector).copy()
fourmer_production_log_rate_constants = _shape_fprc(
fourmer_production_log_rate_constants,
number_of_letters=number_of_letters,
motiflength=motiflength,
fprc_are_logarithmized = True
)
fourmer_production_rate_constants = _shape_fprc(
fourmer_production_rate_constants,
number_of_letters=number_of_letters,
motiflength=motiflength,
fprc_are_logarithmized = False
)
breakage_rate_constants = _shape_brc(
breakage_rate_constants,
number_of_letters=number_of_letters,
motiflength=motiflength
)
if not concentrations_are_logarithmized:
warn(f"motif rate equation only implemented for log concentrations to ensure positivity, will transform them to log concentrations with pseudo_count {pseudo_count_concentration} for the sake of the ode integration.")
n0 = n0.at[n0<pseudo_count_concentration].set(pseudo_count_concentration)
n0 = jnp.log(n0)
n0 = n0.flatten()
t0, t1, t_eval = _shape_t_eval(t_eval)
rate_equations, args = _build_rate_equations(
soft_reactant_threshold = soft_reactant_threshold,
hard_reactant_threshold = hard_reactant_threshold,
influx_rate_constants = influx_rate_constants,
fourmer_production_log_rate_constants = fourmer_production_log_rate_constants,
fourmer_production_rate_constants = fourmer_production_rate_constants,
breakage_rate_constants = breakage_rate_constants,
complements = complements,
motiflength = motiflength,
mass_correction_rate_constant = mass_correction_rate_constant,
initial_log_concentration_array = n0.reshape(motif_concentration_vector_shape)
)
print(f"{args.keys() = }")
timing_dct = {"execution_time" : None}
timing_dct["execution_time"] = - timeit.default_timer()
if ode_integration_method in ['RK45','RK23','DOP853','Radau','LSODA','BDF',]:
# integrate ode
r = solve_ivp(
rate_equations,
[t0,t1],
n0,
t_eval = t_eval,
args = (args,),
method = ode_integration_method,
first_step=first_step,
atol=ivp_atol,
rtol=ivp_rtol,
dense_output=True,
)
if not r.success:
print(f"solve_ivp unsuccesful (probably uncomplete solution)")
print(f"got the following message: {r.message}")
motif_concentration_trajectory = jnp.asarray(r.y)
times_array = jnp.asarray(r.t)
elif ode_integration_method == 'Dopri':
sol = odeint(
rate_equations, #func
n0, #y0
t_eval, #t
(args,), #args
rtol=ivp_rtol,
atol=ivp_atol,
)
times_array = jnp.asarray(t_eval)
motif_concentration_trajectory = jnp.asarray(sol)
else:
rate_equations = diffrax.ODETerm(rate_equations)
# Dopri5, Dopri8
solver = diffrax.Dopri5()
dt0 = None
stepsize_controller = diffrax.PIDController(rtol=ivp_rtol,atol=ivp_atol)
save_at = diffrax.SaveAt(ts=t_eval)
r = diffeqsolve(
rate_equations,
solver,
t0=t0,
t1=t1,
dt0=dt0,
y0=n0,
args = args,
saveat = save_at,
stepsize_controller=stepsize_controller
)
times_array = jnp.asarray(r.ts)
motif_concentration_trajectory = jnp.asarray(r.ys)
timing_dct["execution_time"] += timeit.default_timer()
if not concentrations_are_logarithmized:
motif_concentration_trajectory = jnp.exp(motif_concentration_trajectory)
return jnp.asarray(motif_concentration_trajectory).T, jnp.asarray(times_array), timing_dct
def _shape_t_eval(t_eval):
if len(t_eval)==1:
t0=0
t1=t_eval
t_eval = None
else:
t0 = t_eval[0]
t1 = t_eval[-1]
return t0, t1, t_eval
def _build_rate_equations(
soft_reactant_threshold : float = None,
hard_reactant_threshold : float = None,
**kwargs,
):
"""
influx_rate_constants : jnp.ndarray = 0.,
breakage_rate_constants : jnp.ndarray,
fourmer_production_log_rate_constants : jnp.ndarray = 0.,
fourmer_production_rate_constants : jnp.ndarray = 0.,
complements : list = jnp.array([1,0,3,2]),
motiflength : int = 4,
"""
print(f"{kwargs.keys() = }")
kwargs.setdefault('influx_rate_constants', 0.)
kwargs.setdefault('fourmer_production_log_rate_constants', 0.)
kwargs.setdefault(' fourmer_production_rate_constants', 0.)
kwargs.setdefault('breakage_rate_constants',0.)
kwargs.setdefault('complements', jnp.array([1,0,3,2]))
kwargs.setdefault('motiflength', 4)
kwargs.setdefault('mass_correction_rate_constant', 0.)
print(f"{kwargs.keys() = }")
if kwargs['mass_correction_rate_constant'] != 0.:
if 'initial_log_concentration_array' not in kwargs.keys():
raise TypeError("For mass_correction_rate_constant != 0, initial_log_concentration_array needs to be specified")
rate_equations = _fourmer_rate_equations(kwargs,
soft_reactant_threshold = soft_reactant_threshold,
hard_reactant_threshold = hard_reactant_threshold)
return rate_equations, kwargs
def _fourmer_production_equations(
t, y, arg1,
soft_reactant_threshold : float = 0.,
hard_reactant_threshold : float = None
):
"""
arg1 = ([influx_rate_constants,
(fourmer_production_log_rate_constants,fourmer_production_rate_constants),
complements,
motiflength],)
"""
leerc = arg1['fourmer_production_log_rate_constants']
eerc=arg1['fourmer_production_rate_constants']
complements = arg1['complements']
motiflength = arg1['motiflength']
if eerc.shape[0] == eerc.shape[1]:
eerc = eerc[:,1:,1:,:,:,1:,1:,:]
if leerc.shape[0] == leerc.shape[1]:
leerc = leerc[:,1:,1:,:,:,1:,1:,:]
return jnp.asarray(compute_total_extension_rates(
jnp.asarray(y.reshape((len(complements)+1,)*motiflength)),
log_rate_constants = leerc,
rate_constants = eerc,
motiflength = 4,
number_of_letters = eerc.shape[1],
soft_reactant_threshold = soft_reactant_threshold,
hard_reactant_threshold = hard_reactant_threshold
)).reshape(-1)
def _fourmer_influx_equations(t, y, arg):
return arg['influx_rate_constants'].reshape(-1)*jnp.exp(-y).reshape(-1)
def _fourmer_breakage_equations(
t, y, arg2,
soft_reactant_threshold : float = None,
hard_reactant_threshold : float = None
):
"""
arg2 : dict
specify 'breakage_rate_constants', 'complements', 'motiflength'
soft_reactant_threshold : float
minimal concentration of a reactant to contribute fully to a reaction,
for smaller concentrations the reaction rates get damped or clipped
default : 0., i.e. clipping deactivated
hard_reactant_threshold : float
concentration of a reactant at which all its reactions are set to zero.
Between 'soft_reactant_threshold' and 'hard_reactant_threshold' the reaction rates smoothly
transition between fully contribution of the reactant and clipping
via a cos-funtion.
Default : None, setting hard_reactant_threshold to soft_reactant_threshold/2
"""
if not isinstance(arg2,dict):
raise TypeError("arg2 supposed to be dictionary")
arg2.setdefault('breakage_rate_constants',0.)
arg2.setdefault('complements', jnp.array([1,0,3,2]))
arg2.setdefault('motiflength', 4)
breakage_rate_constants = arg2['breakage_rate_constants']
complements = arg2['complements']
motiflength = arg2['motiflength']
breakage_logc_diff = fourmer_breakage_rates(
y.reshape((complements.size+1,)*motiflength),
effective_breakage_rate_constants = breakage_rate_constants,
soft_reactant_threshold = soft_reactant_threshold,
hard_reactant_threshold = hard_reactant_threshold
)
return breakage_logc_diff.reshape(-1)
def _fourmer_mass_correction_rates(t,y, arg3):
initial_log_concentration_array = arg3['initial_log_concentration_array']
mass_correction_rate_constant = arg3['mass_correction_rate_constant']
complements = arg3['complements']
motiflength = arg3['motiflength']
return mass_correction_rates(
initial_log_concentration_array,
y.reshape((complements.size+1,)*motiflength),
weight = mass_correction_rate_constant).reshape(-1)
def _fourmer_rate_equations(
kwargs,
soft_reactant_threshold : float = None,
hard_reactant_threshold : float = None
):
"""
returns a function: the rate equation with arguments
t, y, args : dict;
args = [influx_rate_constants,
fourmer_production_log_rate_constants,
fourmer_production_rate_constants,
breakage_rate_constants,
complements,
motiflength,
mass_correction_rate_constant,
initial_log_concentration_array
],
"""
exclude_influx = jnp.all(jnp.asarray(kwargs['influx_rate_constants']) == 0.)
include_influx = 1-exclude_influx
if exclude_influx:
print("Influx rate constants equal 0, thus, influx turned off.")
exclude_breakage = jnp.all(kwargs['breakage_rate_constants'] == 0.)
include_breakage = 1-exclude_breakage
if exclude_breakage:
print("Breakage rate constant equals 0, thus, breakage turned off.")
exclude_extension = jnp.all(kwargs['fourmer_production_rate_constants']==0.)
include_extension = 1-exclude_extension
if exclude_extension:
print("Extension rate constants equal 0, thus, fourmer extension turned off.")
exclude_mass_correction = bool(kwargs['mass_correction_rate_constant']==0.)
include_mass_correction = 1-exclude_mass_correction
if exclude_mass_correction:
print("Mass correction rate constant equals 0, thus, mass correction turned off.")
def fre(t,y,args):
freturn = jnp.zeros(y.shape)
if include_influx:
freturn = freturn + _fourmer_influx_equations(t, y, args)
if include_breakage:
freturn = freturn + _fourmer_breakage_equations(t, y, args,
soft_reactant_threshold=soft_reactant_threshold,
hard_reactant_threshold=hard_reactant_threshold)
if include_extension:
freturn = freturn + _fourmer_production_equations(t, y, args,
soft_reactant_threshold=soft_reactant_threshold,
hard_reactant_threshold=hard_reactant_threshold)
if include_mass_correction:
freturn = freturn + _fourmer_mass_correction_rates(t, y, args)
return freturn
return fre