Module imageapply.flexible_model
Expand source code
#TODO: Improve Package Name
#TODO: Add support for single image calls (i.e. without batch dimension)
#TODO: Add support for non-batched models (models that can only handle one image at a time)
#TODO: Add support for pytorch tensors (and others)
#TODO: Limit image dims to 3
#TODO: Implement a Data class (with Pytorch, Numpy, and Tensorflow subclasses) to handle the different data types
from .revtransform import PadCrop, DivideCombine, CombinedModel, BasicTTA
from .tools import apply_model
class FlexibleModel:
"""
This model is designed to be used with models that can only handle a certain input size.
"""
def __init__(self, model, input_size, max_batch_size=None, basic_tta=False):
"""
Creates a new FlexibleModel object.
Args:
model (function): The model to apply to the data
input_size (tuple): The size of the input to the model
max_batch_size (int): The maximum batch size to use when applying the model
basic_tta (bool): Whether to use basic test time augmentation
Returns:
FlexibleModel: The new FlexibleModel object
"""
assert model is not None and callable(model), "Model must be callable"
assert input_size[0] is None, "First dimension of input size must be None"
self.model = model
self.input_size = input_size[1:]
self.max_batch_size = max_batch_size
self.tta = basic_tta
self.combined = CombinedModel([
BasicTTA() if self.tta else None,
PadCrop(self.input_size, pad_mode="zeros", pad_position="end"),
DivideCombine(self.input_size),
lambda x: apply_model(self.model, x, batch_size=self.max_batch_size)
])
def __call__(self, batch):
"""
Runs the model on the batch of data.
Args:
batch (T): The batch of data to run the model on
Returns:
T: The output of the model
"""
# Assume for now, data is a batch numpy array
return self.combined(batch)
Classes
class FlexibleModel (model, input_size, max_batch_size=None, basic_tta=False)
-
This model is designed to be used with models that can only handle a certain input size.
Creates a new FlexibleModel object.
Args
model
:function
- The model to apply to the data
input_size
:tuple
- The size of the input to the model
max_batch_size
:int
- The maximum batch size to use when applying the model
basic_tta
:bool
- Whether to use basic test time augmentation
Returns
FlexibleModel
- The new FlexibleModel object
Expand source code
class FlexibleModel: """ This model is designed to be used with models that can only handle a certain input size. """ def __init__(self, model, input_size, max_batch_size=None, basic_tta=False): """ Creates a new FlexibleModel object. Args: model (function): The model to apply to the data input_size (tuple): The size of the input to the model max_batch_size (int): The maximum batch size to use when applying the model basic_tta (bool): Whether to use basic test time augmentation Returns: FlexibleModel: The new FlexibleModel object """ assert model is not None and callable(model), "Model must be callable" assert input_size[0] is None, "First dimension of input size must be None" self.model = model self.input_size = input_size[1:] self.max_batch_size = max_batch_size self.tta = basic_tta self.combined = CombinedModel([ BasicTTA() if self.tta else None, PadCrop(self.input_size, pad_mode="zeros", pad_position="end"), DivideCombine(self.input_size), lambda x: apply_model(self.model, x, batch_size=self.max_batch_size) ]) def __call__(self, batch): """ Runs the model on the batch of data. Args: batch (T): The batch of data to run the model on Returns: T: The output of the model """ # Assume for now, data is a batch numpy array return self.combined(batch)