Skip to content

Commit cb6b1d1

Browse files
committed
Updated data example notebook and added better backend switching
1 parent 6e9f87c commit cb6b1d1

File tree

6 files changed

+329
-430
lines changed

6 files changed

+329
-430
lines changed

agml/data/managers/transform_helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,17 @@ def apply(self, mask):
145145
return out[..., 1:].astype(np.int32)
146146

147147

148+
class ToTensorSeg(TransformApplierBase):
149+
def apply(self, image, mask):
150+
# This is essentially the same as `torchvision.transforms.ToTensor`,
151+
# except it doesn't scale the values of the segmentation mask from
152+
# 0-255 -> 0-1, as this would break the segmentaiotn pipeline.
153+
if image.ndim == 2:
154+
image = image[:, :, None]
155+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous()
156+
if isinstance(image, torch.ByteTensor):
157+
image = image.to(dtype = torch.float32).div(255)
158+
mask = torch.from_numpy(mask)
159+
return mask
160+
161+

agml/data/managers/transforms.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
ScaleTransform,
3131
NormalizationTransform,
3232
OneHotLabelTransform,
33-
MaskToChannelBasisTransform
33+
MaskToChannelBasisTransform,
34+
ToTensorSeg
3435
)
3536
from agml.utils.logging import log
3637

@@ -67,7 +68,8 @@ class TransformManager(AgMLSerializable):
6768
there is no `dual_transform` for image classification.
6869
"""
6970
serializable = frozenset(
70-
('task', 'transforms', 'time_inserted_transforms'))
71+
('task', 'transforms', 'time_inserted_transforms',
72+
'contains_tf_transforms', 'contains_torch_transforms'))
7173

7274
def __init__(self, task):
7375
self._task = task
@@ -87,6 +89,12 @@ def __init__(self, task):
8789
# in this list and applies them as required.
8890
self._time_inserted_transforms = []
8991

92+
# Check if the loader already contains TensorFlow/PyTorch transforms.
93+
# This is to track whether there are issues when a new transform is added
94+
# or if there already exists one, so we whether to apply transforms.
95+
self._contains_tf_transforms = False
96+
self._contains_torch_transforms = False
97+
9098
def get_transform_states(self):
9199
"""Returns a copy of the existing transforms."""
92100
transform_dict = {}
@@ -319,13 +327,34 @@ def _build_normalization_transform(self, transform):
319327
self._time_inserted_transforms.pop(norm_transform_index_time)
320328
return None
321329

330+
def _transform_update_and_check(self, tfm_type):
331+
"""Checks whether a TensorFlow/PyTorch transform has been added."""
332+
# Check for transform conflicts.
333+
if tfm_type == 'torch' and self._contains_tf_transforms:
334+
raise TypeError("Received a PyTorch-type transform, yet the loader "
335+
"already contains TensorFlow/Keras transforms. This "
336+
"will cause an error, please only pass one format. If "
337+
"you want to remove a transform, pass a value of `None` "
338+
"to reset all of the transforms for a certain type.")
339+
if tfm_type == 'tf' and self._contains_torch_transforms:
340+
raise TypeError("Received a TensorFlow/Keras-type transform, yet the "
341+
"loader already contains PyTorch transforms. This "
342+
"will cause an error, please only pass one format. If "
343+
"you want to remove a transform, pass a value of `None` "
344+
"to reset all of the transforms for a certain type.")
345+
346+
# Update the transform type.
347+
if tfm_type == 'torch':
348+
self._contains_torch_transforms = True
349+
if tfm_type == 'tf':
350+
self._contains_tf_transforms = True
351+
322352
# The following methods implement different checks which validate
323353
# as well as process input transformations, and manage the backend.
324354
# The transforms here will be also checked to match a specific
325355
# backend. Alternatively, the backend will dynamically be switched.
326356

327-
@staticmethod
328-
def _construct_single_image_transform(transform):
357+
def _construct_single_image_transform(self, transform):
329358
"""Validates a transform which is applied to a single image.
330359
331360
This is used for image classification transforms, which only
@@ -361,6 +390,7 @@ def _construct_single_image_transform(transform):
361390
if user_changed_backend():
362391
raise StrictBackendError(change = 'tf', obj = transform)
363392
set_backend('torch')
393+
self._transform_update_and_check('torch')
364394
return transform
365395

366396
# A `tf.keras.Sequential` preprocessing model or an individual
@@ -370,15 +400,15 @@ def _construct_single_image_transform(transform):
370400
if user_changed_backend():
371401
raise StrictBackendError(change = 'torch', obj = transform)
372402
set_backend('tf')
403+
self._transform_update_and_check('tf')
373404
return transform
374405

375406
# Otherwise, it may be a transform from a (lesser-known) third-party
376407
# library, in which case we just return it as a callable. Transforms
377408
# which are used in a more complex manner should be passed as decorators.
378409
return transform
379410

380-
@staticmethod
381-
def _construct_image_and_mask_transform(transform):
411+
def _construct_image_and_mask_transform(self, transform):
382412
"""Validates a transform for an image and annotation mask.
383413
384414
This is used for a semantic segmentation transform. Such
@@ -426,18 +456,41 @@ def _construct_image_and_mask_transform(transform):
426456
if user_changed_backend():
427457
raise StrictBackendError(
428458
change = 'tf', obj = transform)
429-
set_backend('tf')
459+
set_backend('torch')
460+
461+
# Update `torchvision.transforms.ToTensor` to a custom
462+
# updated class as this will modify the mask incorrectly.
463+
import torchvision
464+
if isinstance(transform, torchvision.transforms.ToTensor):
465+
transform = ToTensorSeg(None)
466+
log("Updated `ToTensor` transform in the provided pipeline "
467+
f"{transform} to an updated transform which does not "
468+
f"modify the mask. If you want to change this behaviour, "
469+
f"please raise an error with the AgML team.")
470+
elif isinstance(transform, torchvision.transforms.Compose):
471+
tfm_list = transform.transforms.copy()
472+
for i, compose_tfm in enumerate(transform.transforms):
473+
if isinstance(compose_tfm, torchvision.transforms.ToTensor):
474+
tfm_list[i] = ToTensorSeg(None)
475+
transform = torchvision.transforms.Compose(tfm_list)
476+
log("Updated `ToTensor` transform in the provided pipeline "
477+
f"{transform} to an updated transform which does not "
478+
f"modify the mask. If you want to change this behaviour, "
479+
f"please raise an error with the AgML team.")
480+
481+
self._transform_update_and_check('torch')
430482
elif 'keras.layers' in transform.__module__:
431483
if get_backend() != 'tf':
432484
if user_changed_backend():
433485
raise StrictBackendError(
434486
change = 'torch', obj = transform)
435-
set_backend('torch')
487+
set_backend('tf')
436488
log('Got a Keras transformation for a dual image and '
437489
'mask transform. If you are passing preprocessing '
438490
'layers to this method, then use `agml.data.experimental'
439491
'.generate_keras_segmentation_dual_transform` in order '
440492
'for the random state to be applied properly.', 'warning')
493+
self._transform_update_and_check('tf')
441494
return SameStateImageMaskTransform(transform)
442495

443496
# Another type of transform, most likely some form of transform

agml/viz/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
convert_mask_to_colored_image,
2525
annotate_semantic_segmentation,
2626
show_image_and_mask,
27-
show_image_with_overlaid_mask,
27+
show_image_and_overlaid_mask,
2828
show_semantic_segmentation_truth_and_prediction,
2929
)
3030
from .boxes import (

agml/viz/general.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from agml.backend.tftorch import is_array_like
1818
from agml.viz.boxes import show_image_and_boxes
19-
from agml.viz.masks import show_image_with_overlaid_mask
19+
from agml.viz.masks import show_image_and_overlaid_mask
2020
from agml.viz.labels import show_images_and_labels
2121
from agml.viz.tools import format_image, _inference_best_shape, convert_figure_to_image
2222
from agml.viz.display import display_image
@@ -52,7 +52,7 @@ def show_sample(loader, image_only = False, **kwargs):
5252
return show_image_and_boxes(
5353
sample, info = loader.info, no_show = kwargs.get('no_show', False))
5454
elif loader.task == 'semantic_segmentation':
55-
return show_image_with_overlaid_mask(
55+
return show_image_and_overlaid_mask(
5656
sample, no_show = kwargs.get('no_show', False))
5757
elif loader.task == 'image_classification':
5858
return show_images_and_labels(

agml/viz/masks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def binary_to_channel_by_channel(mask, num_classes = None):
2828
if mask.ndim == 3 and not np.all(mask[:, :, 0] == mask[:, :, 1]):
2929
return mask
3030

31-
# For binary classification tasks, return the mask as-is.
31+
# For binary classification tasks, return the mask with an additional channel.
3232
if len(np.unique(mask)) == 2 and mask.max() == 1:
33-
return mask
33+
return np.expand_dims(mask, -1)
3434

3535
# Otherwise, convert the mask to channel-by-channel format.
3636
input_shape = mask.shape
@@ -137,11 +137,11 @@ def annotate_semantic_segmentation(image,
137137
return image
138138

139139

140-
def show_image_with_overlaid_mask(image,
141-
mask = None,
142-
alpha = 0.3,
143-
border = True,
144-
**kwargs):
140+
def show_image_and_overlaid_mask(image,
141+
mask = None,
142+
alpha = 0.3,
143+
border = True,
144+
**kwargs):
145145
"""Displays an image with an annotated segmentation mask.
146146
147147
This method overlays a segmentation mask over an image. It uses contours

examples/AgML-Data.ipynb

Lines changed: 244 additions & 412 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)