"""
``sfftk.core.prep``
==========================
This module consists of preparation utilities to condition segmentation files prior to conversion.
"""
import asyncio
import json
import pathlib
import re
import sys
from typing import List
import mrcfile
import numpy
import numpy.lib.mixins
from sfftkrw.core import _str
from sfftkrw.core.print_tools import print_date
from stl import Mesh
from ..readers.starreader import RelionStarReader
def _label_generator():
yield from (*range(1, 128), *range(-128, 0))
[docs]
class MergedMask:
"""This class describes a special mask used to perform mask merging. It automatically handles
complex cases involving mask overlaps by constructing a label tree showing the relations
between masks. The trivial case of non-overlapping overlaps will have all labels children of
the root label (0).
There are only three ways that an overlap can happen.
1. no overlap is the trivial case - no elements are shared between masks;
2. complete overlap: one set of elements is completely contained in another set;
3. partial overlap: some elements are shared.
For this functionality to work we need several functions:
- vectorised addition of masks to the merged mask;
- a way to decide the next label to use, which is not necessary the current label plus one;
- a way to capture the relationship between labels
Consider the simple exercise of merging the following non-trivial (overlapping) masks:
.. code:: python
mask1 = [0, 1, 0, 0]
mask2 = [0, 1, 0, 0]
mask3 = [0, 0, 1, 0]
mask4 = [0, 1, 1, 1]
mask5 = [1, 0, 0, 0]
mask6 = [1, 0, 1, 0]
We will build our merged mask by successively adding each mask to the empty mask: ``[0, 0, 0, 0]``.
We assume that all masks are positive binary with values ``0`` (background) and ``1`` (elements of interest).
At each iteration, will set a new label to be used. This label will identify the particular mask. Therefore,
we multiply the mask by the label.
Because elements can overlap, we need a way to keep track of labels so that we can record when we have to
assign labels that indicate either complete or partial overlap. We, therefore, examine the resulting labels and
from this infer the relationships between labels. To do this, we have a set of admitted labels as well as a set of
new labels. By comparing these sets and taking into account the current label, we can determine the label for
elements resulting from overlap and which labels they relate to.
.. code-block:: python
merged_mask = [0, 0, 0, 0] # the internal value of MergedMask's array
label = 1
label_set = {}
label_tree = dict()
# mask 1
merged_mask = merged_mask + [0, 1, 0, 0] * 1 # => [0, 1, 0, 0]
label_set = {1}
label_tree[1] = 0 # 1 is a child of the root (0) => {1: 0}
new_labels = {}
label = numpy.amax(merged_mask) + 1 = 2
# mask 2
merged_mask = [0, 1, 0, 0] + [0, 1, 0, 0] * 2 = [0, 3, 0, 0]
label_set = {1, 2}
label_tree[2] = 0 # => {1: 0, 2: 0}
new_labels = {3}
label_tree[3] = [1, 2] # 3 is a child of 1 and 2 (overlap) => {1: 0, 2: 0, 3: [1, 2]}
label_set = {1, 2, 3}
label = numpy.amax(merged_mask) + 1 = 4
# mask 3
merged_mask = [0, 3, 0, 0] + [0, 0, 1, 0] * 4 = [0, 3, 4, 0]
label_set = {1, 2, 3, 4}
label_tree[4] = 0 # => {1: 0, 2: 0, 3: [1, 2], 4: 0}
new_labels = {}
label = numpy.amax(merge_mask) + 1 = 5
# mask 4
merged_mask = [0, 3, 4, 0] + [0, 1, 1, 1] * 5 = [0, 8, 9, 5]
label_set = {1, 2, 3, 4, 5}
label_tree[5] = 0 # => {1: 0, 2: 0, 3: [1, 2], 4: 0, 5: 0}
new_labels = {8, 9}
label_tree[8] = [3, 5]
label_tree[9] = [4, 5] # => {1: 0, 2: 0, 3: [1, 2], 4: 0, 5: 0, 8: [3, 5], 9: [4, 5]}
label = numpy.amax(merge_mask) + 1 = 10
# mask 5
merged_mask = [0, 8, 9, 5] + [0, 1, 1, 1] * 10 = [10, 18, 19, 15]
label_set = {1, 2, 3, 4, 5, 10}
label_tree[10] = 0 # => {1: 0, 2: 0, 3: [1, 2], 4: 0, 5: 0, 8: [3, 5], 9: [4, 5], 10: 0}
new_labels = {15, 18, 19}
label_tree[15] = [5, 10]
label_tree[18] = [8, 10]
label_tree[19] = [9, 10] # => {1: 0, 2: 0, 3: [1, 2], 4: 0, 5: 0, 8: [3, 5], 9: [4, 5], 10: 0, 15: [5, 10],
18: [8, 10], 19: [9, 10]}
label_set = {1, 2, 3, 4, 5, 10, 15, 18, 19}
label = numpy.amax(merge_mask) + 1 = 20
# mask 6
merged_mask = [10, 18, 19, 15] + [1, 0, 1, 0] * 20 = [30, 18, 39, 15]
label_set = {1, 2, 3, 4, 5, 10, 15, 18, 19, 20}
label_tree[20] = 0 # => {1: 0, 2: 0, 3: [1, 2], 4: 0, 5: 0, 8: [3, 5], 9: [4, 5], 10: 0, 15: [5, 10],
18: [8, 10], 19: [9, 10], 20: 0}
new_labels = {30, 39}
label_tree[30] = [10, 20]
label_tree[39] = [19, 20] # => {1: 0, 2: 0, 3: [1, 2], 4: 0, 5: 0, 8: [3, 5], 9: [4, 5], 10: 0, 15: [5, 10],
18: [8, 10], 19: [9, 10], 20: 0, 30: [10, 20], 39: [19, 20]}
label_set = {1, 2, 3, 4, 5, 10, 15, 18, 19, 20, 30, 39}
label = numpy.amax(merge_mask) + 1 = 40
Objects of this class have a number of important properties germane to working with collation of masks:
- they know what the next label value is implicitly;
- they handle iterative addition of masks to construct the merged mask;
- they keep track of the label tree;
The internal array instantiation is lazy---it is only created once we know the size of the
masks to be merged.
Using a ``MergedMask`` object converts the complexity of the above into the following:
.. code-block:: python
merged_mask = MergedMask()
for mask in masks: # masks is a list of n-dimensional binary-valued arrays
merged_mask.merge(mask)
Internally, merging is a vectorised addition of arrays by overloading the ``__add__``, ``__radd__``
and ``__iadd__`` protocols. However, it is safest to use the :py:func:`MergeMask.merge()` method because
``numpy`` arrays also implement the addition protocols meaning that ``__radd__`` fails.
Once the masks have been merged, we can now interrogate the merged mask for some attributes:
.. code-block:: python
merged_mask.label # the next label to be used; autoincremented appropriately
merged_mask.label_tree # the hiearchy of labels (complex tree of labels)
merged_mask.mask_to_label # the relations between masks and labels
"""
def __init__(self, data=None, dtype=numpy.dtype('int16'), mask_name_prefix="mask_", zfill=4):
# we could use int8 but the overflow leads to negative numbers which break the flow
# using int16 gives us a positive upper ceiling of 32k, much higher than 127 for int8
self._label = 1 # initial label value
self._label_tree = dict()
self._label_set = set()
self._dtype = dtype
self._data = data
self._mask_to_label = dict()
self._mask_id = 1
self._mask_name_prefix = mask_name_prefix
self._zfill = zfill
self._mask_name = None
def _init_data(self, mask: numpy.ndarray):
"""Private method to initialise MergedMask based on a provide mask"""
# validate mask
unique_values = numpy.unique(mask)
try:
assert len(unique_values) == 2 and 0 in unique_values and 1 in unique_values
except AssertionError:
raise ValueError(f"non-binary mask with values: {unique_values}")
if not isinstance(mask, numpy.ndarray):
raise TypeError("mask must be a numpy.ndarray object")
# instantiate self._data as zeros of the right dimension
if self._data is None:
self._data = numpy.zeros(mask.shape, self._dtype)
def __repr__(self):
"""Representation"""
return f"{self.__class__.__qualname__}(data={self.data}, dtype={self._dtype})"
def __array__(self):
"""Numpy array interface"""
return self._data
@property
def shape(self):
if self._data is not None:
return self._data.shape
return
@property
def data(self):
return self._data
@property
def dtype(self):
return self._dtype
@property
def label(self):
return self._label
@property
def label_tree(self):
return self._label_tree
@property
def label_set(self):
return self._label_set
@property
def mask_to_label(self):
return self._mask_to_label
@property
def mask_name(self):
if self._mask_name is not None:
return self._mask_name
return f"{self._mask_name_prefix}{self._mask_id:0>{self._zfill}}"
[docs]
def merge(self, mask: numpy.ndarray, mask_name=None):
"""Merge the sequence of masks in the specified order"""
# temporarily set self._mask_name
self._mask_name = mask_name
self += mask
# reset _mask_name
self._mask_name = None
def _update_label(self):
"""Update the label to the next value to use"""
# first, add the current label to the label set and the label tree
self._label_set.add(self._label)
self._label_tree[str(self._label)] = 0 # this is a direct child of the root (0, repr. background)
self._mask_to_label[self.mask_name] = int(self._label)
# get the new resulting labels: all those not already in the label set
new_labels = set(numpy.unique(self._data)).difference(self._label_set.union([0]))
# determine the parentage for each new label
for new_label in new_labels:
for _label in self._label_set:
# since we added the content of the merged mask to the new mask then any new labels are sum of
# current label and the label for the current mask i.e. new_label = previous_label + label;
# we are only interested in associating the pair to the new label; the new_label now becomes
# a leaf with parent nodes being the previous_label and the label for the last mask
# we store them sorted
if new_label == _label + self._label:
self._label_tree[str(new_label)] = sorted([int(new_label - _label), int(new_label - self._label)])
# finally, we should not forget to now include the new labels into the label set
self._label_set |= new_labels
self._label = numpy.amax(self._data) + 1
self._mask_id += 1
def __add__(self, mask) -> 'MergedMask':
self._init_data(mask)
self._data += mask * self._label # merge the current mask to the merged mask and label it uniquely
self._update_label()
return self
def __radd__(self, mask) -> 'MergedMask':
self._init_data(mask)
self._data += mask * self._label # merge the current mask to the merged mask and label it uniquely
self._update_label()
return self
def __iadd__(self, mask) -> 'MergedMask':
self._init_data(mask)
self._data += mask * self._label # merge the current mask to the merged mask and label it uniquely
self._update_label()
return self
def __eq__(self, other):
return numpy.array_equal(self.data, other.data) and self.shape == other.shape and self.dtype == other.dtype
[docs]
def bin_map(args, configs):
"""Bin the CCP4 map
:param args: parsed arguments
:type args: :py:class:`argparse.Namespace`
:param configs: configurations object
:type configs: :py:class:`sfftk.core.configs.Configs`
:return: exit status
:rtype: int
"""
if args.verbose:
print_date('Reading in data from {}...'.format(args.from_file))
with mrcfile.mmap(args.from_file) as mrc:
if args.bytes_per_voxel == 1:
out_type = numpy.int8
elif args.bytes_per_voxel == 2:
out_type = numpy.int16
elif args.bytes_per_voxel == 4:
out_type = numpy.int32
elif args.bytes_per_voxel == 8:
out_type = numpy.int64
elif args.bytes_per_voxel == 16:
out_type = numpy.int128
if args.verbose:
print_date('Voxels will be of type {}'.format(out_type))
print_date('Binarising to {} about contour-level of {}'.format(args.mask_value, args.contour_level))
if args.negate:
print_date('Negating...')
data = ((mrc.data < args.contour_level) * args.mask_value).astype(
out_type) # only need a byte per voxel
else:
data = ((mrc.data > args.contour_level) * args.mask_value).astype(
out_type) # only need a byte per voxel
if args.verbose:
print_date('Creating output file...')
try:
mrc2 = mrcfile.new(args.output, data, overwrite=args.overwrite)
except ValueError:
print_date("Binarising preparation failed")
print_date("Attempting to overwrite without explicit --overwrite argument")
return 65
if args.verbose:
print_date('Writing header data...')
mrc2.header.cella = mrc.header.cella
mrc2.flush()
mrc2.close()
if args.verbose:
print_date('Binarising complete!')
return 0
async def _mask_is_binary(mask, verbose=False):
"""Coroutine to check whether individual masks are binary"""
from ..readers.mapreader import Map
this_map = Map(mask)
if verbose:
print_date(f"info: assessing {mask}...")
# if a mask is binary but not with non-zero value of 1 fix this
if 1 not in this_map._voxel_values:
if verbose:
print_date(f"info: fixing {mask} with voxel values {this_map._voxel_values}...")
this_map.fix_mask(mask_value=1)
return this_map.is_mask
async def _check_masks_binary(args, configs):
"""Corourite to run the event loop for all masks"""
awaitables = list()
for mask in args.masks:
awaitables.append(_mask_is_binary(mask, verbose=args.verbose))
return await asyncio.gather(*awaitables)
[docs]
def check_mask_is_binary(fn, verbose=False):
"""Check whether a mask is binary or not
:param str fn: map filename
:param bool verbose: verbosity flag
:return: boolean, True if binary mask
:rtype: bool
"""
if sys.version_info.minor > 6:
is_binary = asyncio.run(_mask_is_binary(fn, verbose=verbose))
else:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
is_binary = loop.run_until_complete(_mask_is_binary(fn, verbose=verbose))
loop.close()
return is_binary
def _masks_all_binary(args, configs):
"""Check whether all masks are binary"""
"""Validate that all masks are binary masks"""
# todo: for small files read all data
# todo: for large files only read the first X bytes
# todo: give the user the option to read full files for large files
if sys.version_info.minor > 6:
all_binary = asyncio.run(_check_masks_binary(args, configs))
else:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
all_binary = loop.run_until_complete(_check_masks_binary(args, configs))
loop.close()
return all(all_binary)
def _masks_no_overlap(args, configs):
"""Checks that all segments do not overlap"""
# make all binary
# current_data = None
from ..readers.mapreader import Map
previous_mask = None
for mask in args.masks:
this_map = Map(mask)
if 'current_data' not in locals():
current_data = numpy.zeros(this_map.voxels.shape)
# if current_data is None:
# current_data = this_map.voxels
# continue
# add all volumes
current_data += this_map.voxels
if numpy.amax(current_data) > 1:
print_date(f"warning: segment overlap between mask {mask} and {previous_mask}")
break
previous_mask = mask
# the max should be 1
max_voxel_value = numpy.amax(current_data)
return max_voxel_value == 1
def _mergemask(masks: List[str]) -> 'MergedMask':
"""The mergemask workhorse which does the actual merging"""
from ..readers.mapreader import Map
import pathlib
merged_mask = MergedMask() # everything is initialised from the first mask since masks are homogeneous
for mask in masks:
this_map = Map(mask)
merged_mask.merge(this_map.voxels, mask_name=pathlib.Path(mask).name)
return merged_mask # that's it!
[docs]
def mergemask(args, configs):
"""Merge two or more (max 255) masks into one with a distinct label for each mask
:param args: parsed arguments
:type args: :py:class:`argparse.Namespace`
:param configs: configurations object
:type configs: :py:class:`sfftk.core.configs.Configs`
:return: exit status
:rtype: int
"""
# some sanity checks
# fail fast: ensure the output does not exist
outfile = pathlib.Path(f"{args.output_prefix}.{args.mask_extension}")
if not args.overwrite and outfile.exists():
print_date("error: the file already exists; use --overwrite to overwrite the existing merged_mask or set a "
"new output prefix using --output-prefix")
return 64
# ensure that the files are binary
if args.skip_assessment:
print_date("info: skipping mask assessment; assuming all masks are binary...")
elif not _masks_all_binary(args, configs) and not args.skip:
print_date("error: one or more masks are non-binary; use --verbose to view details")
return 65
# todo: allow cases where one or more files are non-binary
# ensure that they don't overlap each other
if not _masks_no_overlap(args, configs) and not args.allow_overlap:
print_date("error: one or more masks overlap; use --verbose to view details")
print_date("info: if overlapping segments are expected re-run with the --allow-overlap argument; "
"see 'sff prep mergemask' for more information")
return 65
# now we can merge masks
if args.verbose:
print_date("info: proceeding to merge masks...")
merged_mask = _mergemask(args.masks)
if args.verbose:
print_date("info: merge complete...")
if args.verbose:
print_date(f"info: attempting to write output to '{args.output_prefix}.{args.mask_extension}'...")
with mrcfile.new(f"{args.output_prefix}.{args.mask_extension}", overwrite=args.overwrite) as mrc:
with mrcfile.open(args.masks[0]) as one_mask:
mrc.set_data(merged_mask.data)
mrc.voxel_size = one_mask.voxel_size
if args.verbose:
print_date(f"info: attempting to write mask metadata below to '{args.output_prefix}.json'...")
# create the mask metadata
mask_metadata = dict()
mask_metadata['mask_to_label'] = merged_mask.mask_to_label
mask_metadata['label_tree'] = merged_mask.label_tree
data = json.dumps(mask_metadata, indent=4)
with open(f"{args.output_prefix}.json", 'w') as label_file:
if args.verbose:
print_date(f"info: mask metadata:\n{data}")
print(data, file=label_file)
if args.verbose:
print_date("info: merge complete!")
return 0
[docs]
class RelionCompositeStarReader(RelionStarReader):
"""Relion composite star file reader"""
maximum_tomograms = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def starsplit(args, configs):
"""Split a star file into multiple star files based on the given column
:param args: parsed arguments
:type args: :py:class:`argparse.Namespace`
:param configs: configurations object
:type configs: :py:class:`sfftk.core.configs.Configs`
:return: exit status
:rtype: int
"""
composite_star_reader = RelionCompositeStarReader(image_name_field=args.image_name_field)
if args.verbose:
print_date(f"info: parsing {args.star_file}...", newline=False)
composite_star_reader.parse(args.star_file)
if args.verbose:
print_date(f"done.", incl_date=False)
file_handlers = dict()
for row in composite_star_reader.tables['_rln']:
image_name = pathlib.Path(row.ImageName).stem
if args.image_name_prefix:
# use a regex to match the prefix
image_name_match = re.match(args.image_name_prefix, image_name)
if not image_name_match:
print_date(f"warning: image name '{image_name}' does not match prefix '{args.image_name_prefix}'")
continue
image_name = image_name_match.group()
if image_name not in file_handlers:
if args.verbose:
print_date(f"info: creating file handler for '{image_name}'...")
file_handlers[image_name] = open(f"{args.output_prefix}{image_name}.star", 'w')
# write the header
if args.verbose:
print_date(f"info: writing header for '{args.output_prefix}{image_name}.star'...")
file_handlers[image_name].write("data_\n\n")
file_handlers[image_name].write(composite_star_reader.tables['_rln'].header)
file_handlers[image_name].write('\n')
# write to the appropriate file
# fix the tomogram path
# by default we strip the path and only retain the name
row.setattr('ImageName', pathlib.Path(args.image_path) / f"{image_name}.{args.image_extension}")
file_handlers[image_name].write(str(row.raw_data()))
file_handlers[image_name].write('\n')
# close all files
if args.verbose:
print_date(f"info: closing all file handlers...", newline=False)
for file in file_handlers.values():
file.close()
print_date(f"done.", incl_date=False)
return 0
[docs]
def starcrop(args, configs):
"""Crop a star file to have at most the given number of rows
:param args: parsed arguments
:type args: :py:class:`argparse.Namespace`
:param configs: configurations object
:type configs: :py:class:`sfftk.core.configs.Configs`
:return: exit status
:rtype: int
"""
composite_star_reader = RelionCompositeStarReader(image_name_field=args.image_name_field)
if args.verbose:
print_date(f"info: parsing {args.star_file}...", newline=False)
composite_star_reader.parse(args.star_file)
if args.verbose:
print_date(f"done.", incl_date=False)
# only print out the required number of lines
if args.verbose:
print_date(f"info: writing output to {args.output}...")
with open(args.output, 'w') as outfile:
# write the header
if args.verbose:
print_date(f"info: writing header for '{args.output}'...")
outfile.write("data_\n\n")
outfile.write(composite_star_reader.tables['_rln'].header)
outfile.write('\n')
rows_written = 0
# write the data
if args.verbose:
print_date(f"info: writing {args.rows} of data for '{args.output}'...")
row_iterator = iter(composite_star_reader.tables['_rln'])
while rows_written < args.rows:
row = next(row_iterator)
# fix the tomogram path
# by default we strip the path and only retain the name
outfile.write(str(row.raw_data()))
outfile.write('\n')
rows_written += 1
if args.verbose:
print_date(f"info: done.")
return 0