Sunday, November 13, 2022

Fix your matplotlib colorbars!

Matplotlib colorbars are a mess, especially if using subplots I often get tall bars like in this case

  

but instead I wanted

 


This great article by Joseph Long describes this issue in detail and provides some solutions. 
The code for generating the first figure is

import matplotlib.pyplot as plt
from skimage import img_as_float, data, color, transform
from scipy import ndimage, signal, fft, io
import numpy as np
 
img = color.rgb2gray(img_as_float(data.astronaut()))
I1 = np.abs(fft.fft2(img, norm='ortho')**2)
 
#plt.figure(figsize=(7.5,2.5))
 
subplotnum=1
 
plt.subplot(1,2,subplotnum); subplotnum+=1
plt.imshow(img)
plt.colorbar()
plt.title('image')
 
plt.subplot(1,2,subplotnum); subplotnum+=1
plt.imshow(np.log(fft.fftshift(I1)+1e-6))
plt.colorbar()
plt.yticks([])           # desable yticks in the second image
plt.title('log spectrum')
 
plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
plt.show()   

A solution provided in the article consists in using a custom colorbar, the important changes are highlighted

import matplotlib.pyplot as plt
from skimage import img_as_float, data, color, transform
from scipy import ndimage, signal, fft, io
import numpy as np
 
def colorbar(mappable):
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    import matplotlib.pyplot as plt
    last_axes = plt.gca()
    ax = mappable.axes
    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(mappable, cax=cax)
    plt.sca(last_axes)
    return cbar
 
img = color.rgb2gray(img_as_float(data.astronaut()))
I1 = np.abs(fft.fft2(img, norm='ortho')**2)
  
subplotnum=1
 
aa = plt.subplot(1,2,subplotnum); subplotnum+=1
ax = plt.imshow(img)
colorbar(ax)
aa.set_title('image', size=14)
 
aa = plt.subplot(1,2,subplotnum); subplotnum+=1
ax = plt.imshow(np.log(fft.fftshift(I1)+1e-6))
colorbar(ax)
aa.get_yaxis().set_ticks([])  # desable yticks in the second image
aa.set_title('log spectrum', size=14)
 
plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
plt.show()    

No comments:

Post a Comment