Skip to content

Commit 6e9f87c

Browse files
committed
Fixed display issues for Jupyter notebooks in viz module
1 parent 35236e0 commit 6e9f87c

File tree

5 files changed

+40
-21
lines changed

5 files changed

+40
-21
lines changed

agml/viz/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .masks import (
2424
convert_mask_to_colored_image,
2525
annotate_semantic_segmentation,
26+
show_image_and_mask,
2627
show_image_with_overlaid_mask,
2728
show_semantic_segmentation_truth_and_prediction,
2829
)

agml/viz/boxes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ def show_image_and_boxes(image,
193193
**kwargs)
194194

195195
# Display the image.
196-
display_image(image)
196+
_ = display_image(image, matplotlib_figure = False)
197+
return image
197198

198199

199200
def show_object_detection_truth_and_prediction(image,
@@ -312,7 +313,7 @@ def show_object_detection_truth_and_prediction(image,
312313
# Display and return the image.
313314
image = convert_figure_to_image()
314315
if not kwargs.get('no_show', False):
315-
display_image(image)
316+
_ = display_image(image)
316317
return image
317318

318319

agml/viz/display.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,33 @@ def display_image(image, **kwargs):
3434
cv2_imshow(image)
3535
return
3636

37-
if kwargs.get('read_raw', False):
38-
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # convert back to BGR
39-
cv2.imshow('image', image)
40-
cv2.waitKey(0)
41-
cv2.destroyWindow('image')
42-
43-
if get_viz_backend() == 'matplotlib':
44-
plt.figure(figsize = (10, 10))
45-
plt.imshow(image)
46-
plt.gca().axis('off')
47-
plt.gca().set_aspect('equal')
48-
plt.show()
37+
# If running in a Jupyter notebook, then for some weird reason it automatically
38+
# displays images in the background, so don't actually do anything here.
39+
notebook = False
40+
try:
41+
shell = eval("get_ipython().__class__.__name__")
42+
if shell == 'ZMQInteractiveShell':
43+
notebook = True
44+
except NameError:
45+
pass
46+
if notebook:
47+
# If the input content is not a figure, then we can display it.
48+
if kwargs.get('matplotlib_figure', True):
49+
return
50+
51+
else:
52+
if kwargs.get('read_raw', False):
53+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # convert back to BGR
54+
cv2.imshow('image', image)
55+
cv2.waitKey(0)
56+
cv2.destroyWindow('image')
57+
return
58+
59+
# Default case is matplotlib, since it is the most modular.
60+
fig = plt.figure(figsize = (10, 10))
61+
plt.imshow(image)
62+
plt.gca().axis('off')
63+
plt.gca().set_aspect('equal')
64+
plt.show()
4965

5066

agml/viz/labels.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def show_images_and_labels(images,
5757
The matplotlib figure with the plotted info.
5858
"""
5959
if images is not None and labels is None:
60-
if is_array_like(images[0]):
60+
if is_array_like(images[0], no_list = True):
6161
if images[0].ndim >= 3:
6262
images, labels = images[0], images[1]
6363
else:
@@ -82,8 +82,9 @@ def show_images_and_labels(images,
8282

8383
# Check if the labels are converted to one-hot, and re-convert them back.
8484
if is_array_like(labels):
85-
if labels.ndim == 2: # noqa
86-
labels = np.argmax(labels, axis = -1)
85+
if not isinstance(labels, (list, tuple)):
86+
if labels.ndim == 2: # noqa
87+
labels = np.argmax(labels, axis = -1)
8788

8889
# If a prime number is passed, e.g. 23, then the `_inference_best_shape`
8990
# method will return the shape of (23, 1). Likely, the user is expecting
@@ -118,7 +119,7 @@ def show_images_and_labels(images,
118119
# Display and return the image.
119120
image = convert_figure_to_image()
120121
if not kwargs.get('no_show', False):
121-
display_image(image)
122+
_ = display_image(image)
122123
return image
123124

124125

agml/viz/masks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def show_image_with_overlaid_mask(image,
173173

174174
# Display the annotated image.
175175
if not kwargs.get('no_show', False):
176-
display_image(image)
176+
_ = display_image(image, matplotlib_figure = False)
177177
return image
178178

179179

@@ -221,7 +221,7 @@ def show_image_and_mask(image,
221221
# Display and return the image.
222222
image = convert_figure_to_image()
223223
if not kwargs.get('no_show', False):
224-
display_image(image)
224+
_ = display_image(image)
225225
return image
226226

227227

@@ -262,6 +262,6 @@ def show_semantic_segmentation_truth_and_prediction(image,
262262
# Display and return the image.
263263
image = convert_figure_to_image()
264264
if not kwargs.get('no_show', False):
265-
display_image(image)
265+
_ = display_image(image)
266266
return image
267267

0 commit comments

Comments
 (0)