Source code for src.obj.motif_production_vector

import numpy as np
import nifty8 as ift
from jax import numpy as jnp
from collections import namedtuple
from typing import Tuple
import itertools
from scipy.sparse import coo_matrix, save_npz, load_npz

import yaml

from ..domains.motif_space import _return_motif_categories
from ..domains.motif_production_space import (MotifProductionSpace,
                                              make_motif_production_dct,
                                              _determine_product_and_template_categories_and_ligation_spots,
                                              _production_channel_id,
                                              _valid_production_channel
                                              )

from .units import (Unit, make_unit,
                    transform_unit_to_dict, transform_dict_to_unit)

from ..utils.save import create_directory_path_if_not_already_existing

def _create_empty_motif_production_dict(motiflength : int,
        alphabet : list,
        maximum_ligation_window_length : int) -> dict:
    number_of_letters = len(alphabet)

    empty_motif_production_vector = {}
    mpd = make_motif_production_dct(
            alphabet,
            motiflength,
            maximum_ligation_window_length
            )
    for key in mpd.keys():
        empty_motif_production_vector[key] = np.zeros(mpd[key].shape)
    return empty_motif_production_vector

[docs] def MotifProductionVector(motiflength : int, alphabet : list, unit : Unit, maximum_ligation_window_length : int ) -> Tuple[object]: unit = make_unit(unit) if maximum_ligation_window_length is None: maximum_ligation_window_length = motiflength motif_production_vector_properties = {'motiflength' : motiflength, 'alphabet' : alphabet, 'number_of_letters' : len(alphabet), 'unit' : unit, 'maximum_ligation_window_length' : maximum_ligation_window_length } def makeMotifProductionVector(motif_production_vector_dct : dict): motif_production_vector = namedtuple('MotifProductionVector', ('productions',) + tuple(motif_production_vector_properties.keys())) productions = ift.MultiField.from_raw( MotifProductionSpace.make(alphabet,motiflength,maximum_ligation_window_length), motif_production_vector_dct, ) return motif_production_vector(**{**{'productions' : productions}, **motif_production_vector_properties}) return makeMotifProductionVector
[docs] def isinstance_motifproductionvector(obj) -> bool: is_motif_production_vector = True keys = ['motiflength','alphabet','number_of_letters','unit','maximum_ligation_window_length','productions'] for key in obj._asdict().keys(): if key not in keys: print('Not a MotifProductionVector, missing key: {}.'.format(key)) return False keys = MotifProductionSpace( obj.alphabet, obj.motiflength, obj.maximum_ligation_window_length ).keys() for key in obj.motifs.keys(): if key not in keys: print('Not a MotifVector, missing key in motifs field: {}.'.format(key)) return False is_motif_vector *= isinstance(obj, tuple) return bool(is_motif_production_vector)
def _motif_production_array_shape(number_of_letters : int, maximum_ligation_window_length : int ) -> tuple: motif_production_array_shape = (number_of_letters+1,)*(maximum_ligation_window_length-maximum_ligation_window_length//2-1) motif_production_array_shape += (number_of_letters,)*2 motif_production_array_shape += (number_of_letters+1,)*(maximum_ligation_window_length//2-1) motif_production_array_shape += motif_production_array_shape[::-1] return motif_production_array_shape def _motif_production_vector_as_array( motif_production_vector : MotifProductionVector, ) -> np.ndarray: """ transforms a motif vector into a numpy-array Parameters: ----------- motif_production_vector : MotifProductionVector Returns: -------- motif_production_array : np.ndarray """ motiflength = motif_production_vector.motiflength number_of_letters = motif_production_vector.number_of_letters maximum_ligation_window_length = motif_production_vector.maximum_ligation_window_length motif_categories = _return_motif_categories(motiflength) motif_production_array = np.zeros(_motif_production_array_shape(number_of_letters,maximum_ligation_window_length)) if maximum_ligation_window_length < 4: ligation_window_lengths = np.array([maximum_ligation_window_length]) else: ligation_window_lengths = np.arange(4,maximum_ligation_window_length+1) for ligation_window_length in ligation_window_lengths: product_categories, template_categories, ligation_spots = _determine_product_and_template_categories_and_ligation_spots(motiflength, maximum_ligation_window_length, ligation_window_length ) for product_category, template_category, ligation_spot in itertools.product(product_categories, template_categories, ligation_spots): if not _valid_production_channel(product_category, template_category, ligation_window_length, ligation_spot, maximum_ligation_window_length): continue reaction_key = _production_channel_id(product_category, template_category, ligation_window_length, ligation_spot) destination_axes, source_axes = _moved_axes(ligation_window_length,ligation_spot, maximum_ligation_window_length) mpa_indices = _reaction_indices(product_category, template_category, ligation_window_length, ligation_spot, maximum_ligation_window_length, axes_moved=False) np.moveaxis(motif_production_array,source_axes,destination_axes)[mpa_indices] = motif_production_vector.productions[reaction_key].val return motif_production_array
[docs] def save_motif_production_vector( archive_path : str, motif_production_vector : MotifProductionVector, file_sparse : bool = True ): create_directory_path_if_not_already_existing(archive_path) if file_sparse: save_npz(archive_path+'motif_productions', coo_matrix(_motif_production_vector_as_array(motif_production_vector).reshape(1,-1))) else: np.save(archive_path+'motif_productions', _motif_production_vector_as_array(motif_production_vector) ) with open(archive_path+'properties.yaml','w') as yaml_file: yaml.dump({'motiflength':motif_production_vector.motiflength, 'alphabet':motif_production_vector.alphabet, 'unit':transform_unit_to_dict(motif_production_vector.unit), 'maximum_ligation_window_length' : motif_production_vector.maximum_ligation_window_length }, yaml_file, indent=4)
[docs] def load_motif_production_vector(archive_path : str, file_sparse : bool = True ) -> MotifProductionVector: dct_filename = archive_path+'properties.yaml' productions_filename = archive_path+'motif_productions' productions_filename += '.np'+'z'*file_sparse+'y'*(1-file_sparse) with open(dct_filename, 'r') as yaml_file: properties = yaml.safe_load(yaml_file) if file_sparse: motif_production_array = np.asarray(coo_matrix.todense(load_npz(productions_filename))) motif_production_array = motif_production_array.reshape( _motif_production_array_shape(len(properties["alphabet"]), properties["maximum_ligation_window_length"]) ) else: motif_production_array = np.load(productions_filename) properties['unit'] = transform_dict_to_unit(properties['unit']) motif_production_vec = _array_to_motif_production_vector(motif_production_array, **properties ) return motif_production_vec
def _array_to_motif_production_vector(motif_production_array: np.ndarray, motiflength : int, alphabet : list, unit : Unit, maximum_ligation_window_length : int ) -> MotifProductionVector: unit = make_unit(unit) makeMotifProductionVector = MotifProductionVector(motiflength, alphabet, unit, maximum_ligation_window_length) motif_production_vector_dct = _motif_production_array_to_dct(motif_production_array, motiflength, alphabet, maximum_ligation_window_length ) return makeMotifProductionVector(motif_production_vector_dct) def _motif_production_array_to_dct(motif_production_array: np.ndarray, motiflength : int, alphabet : list, maximum_ligation_window_length : int ) -> dict: """ transforms a motif array into a motif vector Parameters: ----------- motif_production_array : np.ndarray Returns: -------- motif_production_vector : MotifProductionVector """ motif_production_vector_dct = {} number_of_letters = len(alphabet) motif_categories = _return_motif_categories(motiflength) if maximum_ligation_window_length < 4: ligation_window_lengths = np.array([maximum_ligation_window_length]) else: ligation_window_lengths = np.arange(4,maximum_ligation_window_length+1) mpd = make_motif_production_dct( alphabet, motiflength, maximum_ligation_window_length ) for ligation_window_length in ligation_window_lengths: product_categories, template_categories, ligation_spots = _determine_product_and_template_categories_and_ligation_spots(motiflength, maximum_ligation_window_length, ligation_window_length ) for product_category, template_category, ligation_spot in itertools.product(product_categories, template_categories, ligation_spots): if not _valid_production_channel(product_category, template_category, ligation_window_length, ligation_spot, maximum_ligation_window_length): continue reaction_key = _production_channel_id(product_category, template_category, ligation_window_length, ligation_spot) mpa_indices = _reaction_indices(product_category, template_category, ligation_window_length, ligation_spot, maximum_ligation_window_length, axes_moved=False) destination_axes , source_axes= _moved_axes(ligation_window_length,ligation_spot, maximum_ligation_window_length) motif_production_vector_dct[reaction_key] = np.moveaxis( motif_production_array, source_axes, destination_axes)[mpa_indices] return motif_production_vector_dct def _moved_axes(ligation_window_length : int, ligation_spot : int, maximum_ligation_window_length : int): """ returns the axes_indixes of the overlap in the vector (source) and in the array (destination) such that the motif is not interrupted in the vector and the ligation spot is in the center of the ligation window for the array. For the array, periodic boundary conditions treat longer arrays, where the end of the motifs is indicated by a 0 either in the motif itself or its hybridized partner. """ ligation_spot_relative_to_center = ligation_spot-(ligation_window_length-ligation_window_length//2-1) source = np.arange(min(0,ligation_spot_relative_to_center),max(0,ligation_spot_relative_to_center)) product_source = (maximum_ligation_window_length+source)%maximum_ligation_window_length product_destination = (maximum_ligation_window_length-source[::-1]-1)%maximum_ligation_window_length template_source = maximum_ligation_window_length+product_destination template_destination = maximum_ligation_window_length+product_source source = list(product_source) + list(template_source) destination = list(product_destination) + list(template_destination) return source, destination def _reaction_indices(product_category : str, template_category : str, ligation_window_length : int, ligation_spot : int, maximum_ligation_window_length : int, axes_moved : bool = True ) -> tuple: motif_categories = _return_motif_categories(maximum_ligation_window_length)#FIXME: ignore monomers product_length = (ligation_window_length -int(product_category not in motif_categories[-2:]) -int(product_category not in motif_categories[-3:-1]) ) # product left_reactant_length = ligation_spot+int(product_category in motif_categories[-2:]) right_reactant_length = product_length-left_reactant_length product_first_part_overlap_length = max(0, left_reactant_length-maximum_ligation_window_length+maximum_ligation_window_length//2) product_second_part_overlap_length = max(0, right_reactant_length - maximum_ligation_window_length//2) length_from_first_product_part = min(left_reactant_length, maximum_ligation_window_length-maximum_ligation_window_length//2) length_from_second_product_part = min(right_reactant_length, maximum_ligation_window_length//2) left_ligation_window_length = left_reactant_length + int(product_category not in motif_categories[-2:]) right_ligation_window_length = ligation_window_length-left_ligation_window_length ligation_window_shift = ligation_spot-ligation_window_length+ligation_window_length//2+1 #assert(left_ligation_window_length+product_second_part_overlap_length) #length_from_first_product_part-product_second_part_overlap_length-int(product_category in motif_categories[-1:]) if axes_moved: mpa_indices = (slice(1,None),)*product_second_part_overlap_length mpa_indices += (0,)*(maximum_ligation_window_length-maximum_ligation_window_length//2-length_from_first_product_part-product_second_part_overlap_length) mpa_indices += (slice(1,None),)*(length_from_first_product_part-1) mpa_indices += (slice(None),)*2 mpa_indices += (slice(1,None),)*(length_from_second_product_part-1) mpa_indices += (0,)*(maximum_ligation_window_length-len(mpa_indices)-product_first_part_overlap_length) mpa_indices += (slice(1,None),)*product_first_part_overlap_length else: mpa_indices = (0,)*max(0,maximum_ligation_window_length-maximum_ligation_window_length//2-left_ligation_window_length+ligation_window_shift) mpa_indices += (0,)*int(product_category not in motif_categories[-2:]) mpa_indices += (slice(1,None),)*(left_reactant_length-1) mpa_indices += (slice(None),)*2 mpa_indices += (slice(1,None),)*(right_reactant_length-1) mpa_indices += (0,)*(max(0,maximum_ligation_window_length-len(mpa_indices))) #template template_length = (ligation_window_length -int(template_category not in motif_categories[-2:]) -int(template_category not in motif_categories[-3:-1]) ) template_second_part_length = ligation_spot+int(template_category in motif_categories[-3:-1]) template_first_part_length = template_length-template_second_part_length length_from_first_template_part = min(template_first_part_length, maximum_ligation_window_length//2) length_from_second_template_part = min(template_second_part_length, maximum_ligation_window_length-maximum_ligation_window_length//2) template_first_part_overlap_length = template_first_part_length-length_from_first_template_part template_second_part_overlap_length = template_second_part_length-length_from_second_template_part if ((product_second_part_overlap_length +length_from_first_product_part) >=(maximum_ligation_window_length -maximum_ligation_window_length//2 +int((product_second_part_overlap_length==0) or (template_second_part_length<length_from_first_product_part))) ): raise ValueError('Expected: ' + str(product_second_part_overlap_length)+'+' + str(length_from_first_product_part)+'<' + str(maximum_ligation_window_length)+'-' + str(maximum_ligation_window_length//2)+'+int(' + str(product_second_part_overlap_length)+'==0 or ' + str(template_second_part_length) + '<' + str(length_from_first_product_part) +')' + _production_channel_id(product_category, template_category, ligation_window_length, ligation_spot) ) if (product_first_part_overlap_length +length_from_second_product_part >= maximum_ligation_window_length//2 +int(product_first_part_overlap_length==0 or length_from_first_template_part<length_from_second_product_part)): raise ValueError(str(product_first_part_overlap_length)+'+' +str(length_from_second_product_part) +'>=' +str(maximum_ligation_window_length) + '//2 + int(' +str(product_first_part_overlap_length) + '==0 or ' +str(length_from_first_template_part) +'<' +str(length_from_second_product_part)+'))') if axes_moved: mpa_indices += (slice(1,None),)*template_second_part_overlap_length mpa_indices += (0,)*(maximum_ligation_window_length//2-length_from_first_template_part-template_second_part_overlap_length) mpa_indices += (slice(1,None),)*(length_from_first_template_part-1) mpa_indices += (slice(None),)*2 mpa_indices += (slice(1,None),)*(length_from_second_template_part-1) mpa_indices += (0,)*(2*maximum_ligation_window_length-len(mpa_indices)-template_first_part_overlap_length) mpa_indices += (slice(1,None),)*template_first_part_overlap_length else: mpa_indices += (0,)*max(0,maximum_ligation_window_length//2-right_ligation_window_length-ligation_window_shift) mpa_indices += (0,)*int(template_category not in motif_categories[-2:]) mpa_indices += (slice(1,None),)*(template_first_part_length-1) mpa_indices += (slice(None),)*2 mpa_indices += (slice(1,None),)*(template_second_part_length-1) mpa_indices += (0,)*(2*maximum_ligation_window_length-len(mpa_indices)) return mpa_indices