3030 ScaleTransform ,
3131 NormalizationTransform ,
3232 OneHotLabelTransform ,
33- MaskToChannelBasisTransform
33+ MaskToChannelBasisTransform ,
34+ ToTensorSeg
3435)
3536from agml .utils .logging import log
3637
@@ -67,7 +68,8 @@ class TransformManager(AgMLSerializable):
6768 there is no `dual_transform` for image classification.
6869 """
6970 serializable = frozenset (
70- ('task' , 'transforms' , 'time_inserted_transforms' ))
71+ ('task' , 'transforms' , 'time_inserted_transforms' ,
72+ 'contains_tf_transforms' , 'contains_torch_transforms' ))
7173
7274 def __init__ (self , task ):
7375 self ._task = task
@@ -87,6 +89,12 @@ def __init__(self, task):
8789 # in this list and applies them as required.
8890 self ._time_inserted_transforms = []
8991
92+ # Check if the loader already contains TensorFlow/PyTorch transforms.
93+ # This is to track whether there are issues when a new transform is added
94+ # or if there already exists one, so we whether to apply transforms.
95+ self ._contains_tf_transforms = False
96+ self ._contains_torch_transforms = False
97+
9098 def get_transform_states (self ):
9199 """Returns a copy of the existing transforms."""
92100 transform_dict = {}
@@ -319,13 +327,34 @@ def _build_normalization_transform(self, transform):
319327 self ._time_inserted_transforms .pop (norm_transform_index_time )
320328 return None
321329
330+ def _transform_update_and_check (self , tfm_type ):
331+ """Checks whether a TensorFlow/PyTorch transform has been added."""
332+ # Check for transform conflicts.
333+ if tfm_type == 'torch' and self ._contains_tf_transforms :
334+ raise TypeError ("Received a PyTorch-type transform, yet the loader "
335+ "already contains TensorFlow/Keras transforms. This "
336+ "will cause an error, please only pass one format. If "
337+ "you want to remove a transform, pass a value of `None` "
338+ "to reset all of the transforms for a certain type." )
339+ if tfm_type == 'tf' and self ._contains_torch_transforms :
340+ raise TypeError ("Received a TensorFlow/Keras-type transform, yet the "
341+ "loader already contains PyTorch transforms. This "
342+ "will cause an error, please only pass one format. If "
343+ "you want to remove a transform, pass a value of `None` "
344+ "to reset all of the transforms for a certain type." )
345+
346+ # Update the transform type.
347+ if tfm_type == 'torch' :
348+ self ._contains_torch_transforms = True
349+ if tfm_type == 'tf' :
350+ self ._contains_tf_transforms = True
351+
322352 # The following methods implement different checks which validate
323353 # as well as process input transformations, and manage the backend.
324354 # The transforms here will be also checked to match a specific
325355 # backend. Alternatively, the backend will dynamically be switched.
326356
327- @staticmethod
328- def _construct_single_image_transform (transform ):
357+ def _construct_single_image_transform (self , transform ):
329358 """Validates a transform which is applied to a single image.
330359
331360 This is used for image classification transforms, which only
@@ -361,6 +390,7 @@ def _construct_single_image_transform(transform):
361390 if user_changed_backend ():
362391 raise StrictBackendError (change = 'tf' , obj = transform )
363392 set_backend ('torch' )
393+ self ._transform_update_and_check ('torch' )
364394 return transform
365395
366396 # A `tf.keras.Sequential` preprocessing model or an individual
@@ -370,15 +400,15 @@ def _construct_single_image_transform(transform):
370400 if user_changed_backend ():
371401 raise StrictBackendError (change = 'torch' , obj = transform )
372402 set_backend ('tf' )
403+ self ._transform_update_and_check ('tf' )
373404 return transform
374405
375406 # Otherwise, it may be a transform from a (lesser-known) third-party
376407 # library, in which case we just return it as a callable. Transforms
377408 # which are used in a more complex manner should be passed as decorators.
378409 return transform
379410
380- @staticmethod
381- def _construct_image_and_mask_transform (transform ):
411+ def _construct_image_and_mask_transform (self , transform ):
382412 """Validates a transform for an image and annotation mask.
383413
384414 This is used for a semantic segmentation transform. Such
@@ -426,18 +456,41 @@ def _construct_image_and_mask_transform(transform):
426456 if user_changed_backend ():
427457 raise StrictBackendError (
428458 change = 'tf' , obj = transform )
429- set_backend ('tf' )
459+ set_backend ('torch' )
460+
461+ # Update `torchvision.transforms.ToTensor` to a custom
462+ # updated class as this will modify the mask incorrectly.
463+ import torchvision
464+ if isinstance (transform , torchvision .transforms .ToTensor ):
465+ transform = ToTensorSeg (None )
466+ log ("Updated `ToTensor` transform in the provided pipeline "
467+ f"{ transform } to an updated transform which does not "
468+ f"modify the mask. If you want to change this behaviour, "
469+ f"please raise an error with the AgML team." )
470+ elif isinstance (transform , torchvision .transforms .Compose ):
471+ tfm_list = transform .transforms .copy ()
472+ for i , compose_tfm in enumerate (transform .transforms ):
473+ if isinstance (compose_tfm , torchvision .transforms .ToTensor ):
474+ tfm_list [i ] = ToTensorSeg (None )
475+ transform = torchvision .transforms .Compose (tfm_list )
476+ log ("Updated `ToTensor` transform in the provided pipeline "
477+ f"{ transform } to an updated transform which does not "
478+ f"modify the mask. If you want to change this behaviour, "
479+ f"please raise an error with the AgML team." )
480+
481+ self ._transform_update_and_check ('torch' )
430482 elif 'keras.layers' in transform .__module__ :
431483 if get_backend () != 'tf' :
432484 if user_changed_backend ():
433485 raise StrictBackendError (
434486 change = 'torch' , obj = transform )
435- set_backend ('torch ' )
487+ set_backend ('tf ' )
436488 log ('Got a Keras transformation for a dual image and '
437489 'mask transform. If you are passing preprocessing '
438490 'layers to this method, then use `agml.data.experimental'
439491 '.generate_keras_segmentation_dual_transform` in order '
440492 'for the random state to be applied properly.' , 'warning' )
493+ self ._transform_update_and_check ('tf' )
441494 return SameStateImageMaskTransform (transform )
442495
443496 # Another type of transform, most likely some form of transform
0 commit comments