plot module¶
The visualization module
DiscreteColors(colors=None, ncolors=3, seed=None)
¶
Generates a discrete colormap with a specified number of colors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ncolors |
int |
Number of colors to include in the colormap. Default is 3. |
3 |
seed |
int |
Seed for the random number generator to ensure reproducibility. Default is None. |
None |
Returns:
| Type | Description |
|---|---|
ListedColormap |
A matplotlib ListedColormap object with the specified number of colors. |
Source code in geonate/plot.py
def DiscreteColors(colors=None, ncolors=3, seed=None):
"""
Generates a discrete colormap with a specified number of colors.
Args:
ncolors (int): Number of colors to include in the colormap. Default is 3.
seed (int, optional): Seed for the random number generator to ensure reproducibility. Default is None.
Returns:
ListedColormap: A matplotlib ListedColormap object with the specified number of colors.
"""
import random
from matplotlib.colors import ListedColormap
# initial vector colors, currently 32 colors
if colors is None:
colors = [
"#000000", # Black
"#FFFFFF", # White
"#FF0000", # Red
"#00FF00", # Green
"#0000FF", # Blue
"#FFFF00", # Yellow
"#00FFFF", # Cyan
"#FF00FF", # Magenta
"#808080", # Gray
"#C0C0C0", # Silver
"#800000", # Maroon
"#808000", # Olive
"#800080", # Purple
"#008080", # Teal
"#000080", # Navy
"#FFA500", # Orange
"#FFC0CB", # Pink
"#A52A2A", # Brown
"#00FF00", # Lime
"#4B0082", # Indigo
"#EE82EE", # Violet
"#F5F5DC", # Beige
"#FF7F50", # Coral
"#40E0D0", # Turquoise
"#E6E6FA", # Lavender
"#FFDAB9", # Peach
"#98FF98", # Mint
"#F5DEB3", # Wheat
"#F0E68C", # Khaki
"#DDA0DD", # Plum
"#D3D3D3", # Light Grey
"#A9A9A9" # Dark Grey
]
ncolors = ncolors
# Initialize the custom colors
custom_colors = None
# Randomly pick out colors by numbers of colors
if custom_colors is None:
random.seed(seed)
custom_colors = random.sample(colors, ncolors)
elif colors is not None:
ncolors = len(colors)
custom_colors = colors
# Create colormap from custom colors picked
cmap_custom = ListedColormap(custom_colors[ : ncolors])
return cmap_custom
colormaps()
¶
Display all available colormaps in Matplotlib.
This function generates a plot that shows all the colormaps available in Matplotlib. Each colormap is displayed as a horizontal gradient bar.
Source code in geonate/plot.py
def colormaps():
"""
Display all available colormaps in Matplotlib.
This function generates a plot that shows all the colormaps available in Matplotlib.
Each colormap is displayed as a horizontal gradient bar.
"""
import numpy as np
import matplotlib.pyplot as plt
# Get all colormaps available in Matplotlib
colormaps = plt.colormaps()
# Generate a gradient to display colormaps
gradient = np.linspace(0, 1, 256).reshape(1, -1)
# Set figure size
fig, ax = plt.subplots(figsize=(10, len(colormaps) * 0.25))
# Loop through colormaps and display them
for i, cmap in enumerate(colormaps):
ax.imshow(np.vstack([gradient] * 5), aspect='auto', cmap=cmap, extent=[0, 10, i, i + 1])
# Formatting
ax.set_yticks(np.arange(len(colormaps)) + 0.5)
ax.set_yticklabels(colormaps)
ax.set_xticks([])
ax.set_title("Matplotlib Colormaps", fontsize=12, fontweight="bold")
ax.set_ylim(0, len(colormaps))
plt.show()
plotMap(image, cmap=None, figsize=(6, 6), axis_off=False, colorbar=True, cbar_shrink=0.5, colorbar_name='Cluster', mapTitle=None, fontFamily='Arial', imgPath=None, resolution=300)
¶
Plots an image with discrete values
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image |
np.ndarray |
The data array in image format (Height x Width x Bands). |
required |
cmap |
str or Colormap |
Colormap to use for the image. Default is None. |
None |
figsize |
tuple |
Size of the figure in inches. Default is (6, 6). |
(6, 6) |
axis_off |
bool |
Remove axis number (stick). Default is False. |
False |
colorbar |
bool |
Whether to display a colorbar. Default is True. |
True |
cbar_shrink |
float |
Fraction by which to multiply the size of the colorbar. Default is 0.5. |
0.5 |
colorbar_name |
str |
Label for the colorbar. Default is 'Cluster'. |
'Cluster' |
mapTitle |
str |
Title of the map. Default is None. |
None |
fontFamily |
str |
Font family for the plot. Default is 'Arial'. |
'Arial' |
imgPath |
str |
Path to save the figure with extension (eg., *.jpg). Default is None. |
None |
resolution |
int |
Resolution of the saved figure in DPI. Default is 300. |
300 |
Source code in geonate/plot.py
def plotMap(image, cmap=None, figsize=(6,6), axis_off=False, colorbar=True, cbar_shrink=0.5, colorbar_name='Cluster', mapTitle=None, fontFamily='Arial', imgPath=None, resolution=300):
"""
Plots an image with discrete values
Args:
image (np.ndarray): The data array in image format (Height x Width x Bands).
cmap (str or Colormap, optional): Colormap to use for the image. Default is None.
figsize (tuple, optional): Size of the figure in inches. Default is (6, 6).
axis_off (bool, optional): Remove axis number (stick). Default is False.
colorbar (bool, optional): Whether to display a colorbar. Default is True.
cbar_shrink (float, optional): Fraction by which to multiply the size of the colorbar. Default is 0.5.
colorbar_name (str, optional): Label for the colorbar. Default is 'Cluster'.
mapTitle (str, optional): Title of the map. Default is None.
fontFamily (str, optional): Font family for the plot. Default is 'Arial'.
imgPath (str, optional): Path to save the figure with extension (eg., *.jpg). Default is None.
resolution (int, optional): Resolution of the saved figure in DPI. Default is 300.
"""
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = fontFamily
# Check input data
if not isinstance(image, np.ndarray):
raise ValueError('Input image must data array in image format (Height x Width x Bands)')
else:
plt.figure(figsize= figsize)
plt.imshow(image, cmap= cmap)
# Add color bar
if colorbar is True:
plt.colorbar(label= colorbar_name, shrink= cbar_shrink)
# Add map title
if mapTitle is not None:
plt.title(mapTitle)
# Remove axis number
if axis_off is True:
plt.axis('off')
# Save plot
if imgPath is not None:
plt.savefig(imgPath, dpi= resolution)
plt.tight_layout()
plt.show()
plotRGB(input, rgb=(0, 1, 2), stretch=True, str_clip=2, figsize=(10, 10), **kwargs)
¶
Plot a 3-band RGB image using earthpy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input |
rasterio.DatasetReader | np.ndarray |
Rasterio image or data array. |
required |
rgb |
tuple |
Indices of the RGB bands. Defaults to (0, 1, 2). |
(0, 1, 2) |
stretch |
bool |
Apply contrast stretching. Defaults to True. |
True |
str_clip |
int |
The percentage of clip to apply to the stretch. Default = 2 (2 and 98). |
2 |
figsize |
numeric tuple |
Width and Height. Defaults to (10, 10) inches. |
(10, 10) |
**kwargs |
Additional optional parameters for earthpy.plot.plot_rgb(), such as stretch=True for contrast stretching. |
{} |
Source code in geonate/plot.py
def plotRGB(input, rgb=(0, 1, 2), stretch=True, str_clip: int = 2, figsize=(10,10), **kwargs):
"""
Plot a 3-band RGB image using earthpy.
Args:
input (rasterio.DatasetReader | np.ndarray): Rasterio image or data array.
rgb (tuple, optional): Indices of the RGB bands. Defaults to (0, 1, 2).
stretch (bool, optional): Apply contrast stretching. Defaults to True.
str_clip (int): The percentage of clip to apply to the stretch. Default = 2 (2 and 98).
figsize (numeric tuple): Width and Height. Defaults to (10, 10) inches.
**kwargs: Additional optional parameters for earthpy.plot.plot_rgb(), such as stretch=True for contrast stretching.
"""
import numpy as np
import rasterio
import earthpy.plot as ep
### Check input data
if isinstance(input, rasterio.DatasetReader):
dataset = input.read()
elif isinstance(input, np.ndarray):
dataset = input
else:
raise ValueError('Input data is not supported')
# Check data dimension to make sure it is a multiple band image
if len(dataset) <= 2:
raise ValueError('Image has only one band, please provide at least 3-band image')
# Visualize the input dataset
ep.plot_rgb(dataset, rgb= rgb, stretch=stretch, str_clip=str_clip, figsize=figsize, **kwargs)
plot_bands(input, cmap='Greys_r', cols=3, figsize=(10, 10), cbar=True, **kwargs)
¶
Plot a raster image or data array using earthpy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input |
DatasetReader | np.ndarray |
Rasterio image or data array |
required |
cmap |
str |
Colormap for the plot. Defaults to 'Greys_r'. |
'Greys_r' |
cols |
int |
Numbers of column on the plot. Defaults to cols = 3. |
3 |
figsize |
numeric tuple |
Width and Height. Defaults to (10, 10) inches. |
(10, 10) |
cbar |
bool |
Show color cbar. Defaults to True. |
True |
**kwargs |
AnyStr |
All optional parameters taken from earthpy.plot.plot_bands(), such as cmap='Spectral' for color shade |
{} |
Source code in geonate/plot.py
def plot_bands(input, cmap='Greys_r', cols=3, figsize=(10,10), cbar=True, **kwargs):
"""Plot a raster image or data array using earthpy.
Args:
input (DatasetReader | np.ndarray): Rasterio image or data array
cmap (str, optional): Colormap for the plot. Defaults to 'Greys_r'.
cols (int): Numbers of column on the plot. Defaults to cols = 3.
figsize (numeric tuple): Width and Height. Defaults to (10, 10) inches.
cbar (bool): Show color cbar. Defaults to True.
**kwargs (AnyStr, optional): All optional parameters taken from earthpy.plot.plot_bands(), such as cmap='Spectral' for color shade
"""
import numpy as np
import rasterio
import earthpy.plot as ep
### Check input data
if isinstance(input, rasterio.DatasetReader):
dataset = input.read()
elif isinstance(input, np.ndarray):
dataset = input
else:
raise ValueError('Input data is not supported')
# Visualize the input dataset
ep.plot_bands(dataset, cmap=cmap, cols=cols, figsize=figsize, cbar=cbar,**kwargs)
plot_raster(input, layername=None, rgb=None, stretch='linear', brightness=None, contrast=None, opacity=1, zoom=5, basemap='OSM', output=None)
¶
Plots a basemap with an overlay of raster data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input |
DatasetReader |
The input raster dataset. |
required |
layername |
Anstr |
Layer name of image. |
None |
rgb |
list |
List of RGB bands to visualize. Defaults to None. |
None |
stretch |
AnyStr |
Stretch method for the image ('linear', 'hist', 'custom'). Defaults to 'linear'. |
'linear' |
brightness |
float |
Brightness value for custom stretch. Defaults to None. |
None |
contrast |
float |
Contrast value for custom stretch. Defaults to None. |
None |
opacity |
float |
Opacity of the image overlay. Defaults to 1. |
1 |
zoom |
float |
Initial zoom level of the map. Defaults to 5. |
5 |
basemap |
AnyStr |
Basemap type ('OSM', 'CartoDB Positron', 'CartoDB Dark Matter', 'OpenTopoMap', 'Esri Satellite', 'Esri Street Map', 'Esri Topo', 'Esri Canvas'). Defaults to 'OSM'. |
'OSM' |
output |
AnyStr |
File path to write out html file to local directory. Defaults to None. |
None |
Returns:
| Type | Description |
|---|---|
folium.Map |
A folium map object with the raster data overlay. |
Source code in geonate/plot.py
def plot_raster(input, layername: Optional[AnyStr]=None, rgb: Optional[list]=None, stretch: Optional[AnyStr]='linear', brightness: Optional[float]=None, contrast: Optional[float]=None, opacity: Optional[float]=1, zoom: Optional[float]=5, basemap: Optional[AnyStr]='OSM', output: Optional[AnyStr]= None):
"""
Plots a basemap with an overlay of raster data.
Args:
input (DatasetReader): The input raster dataset.
layername (Anstr, optional): Layer name of image.
rgb (list, optional): List of RGB bands to visualize. Defaults to None.
stretch (AnyStr, optional): Stretch method for the image ('linear', 'hist', 'custom'). Defaults to 'linear'.
brightness (float, optional): Brightness value for custom stretch. Defaults to None.
contrast (float, optional): Contrast value for custom stretch. Defaults to None.
opacity (float, optional): Opacity of the image overlay. Defaults to 1.
zoom (float, optional): Initial zoom level of the map. Defaults to 5.
basemap (AnyStr, optional): Basemap type ('OSM', 'CartoDB Positron', 'CartoDB Dark Matter', 'OpenTopoMap', 'Esri Satellite', 'Esri Street Map', 'Esri Topo', 'Esri Canvas'). Defaults to 'OSM'.
output (AnyStr, optional): File path to write out html file to local directory. Defaults to None.
Returns:
folium.Map: A folium map object with the raster data overlay.
"""
import folium
from folium.raster_layers import ImageOverlay
import rasterio
import numpy as np
from .common import meter2degree, get_extent_local
from .processor import reproject
### Check input data is raster or not and extract information
if isinstance(input, rasterio.DatasetReader):
# Convert image to lat/long if input is not in lat/long system
crs = input.crs.to_string()
if crs == "EPSG:4326":
input_converted = input
else:
resolution_degree = meter2degree(input.res[0])
input_converted = reproject(input, reference='EPSG:4326', res=resolution_degree)
# Extract data from image to visualize
if (input.count <= 2):
print('Input image/data has less than 2 bands, it will load the first band only')
dataset = input.read(1)
imgData = dataset[:, :, np.newaxis]
elif (input.count >= 3):
if rgb is None:
raise ValueError('Input is multiple band image, please provide rgb bands to visualize [3,2,1]')
else:
dataset = input.read(rgb)
imgData = np.transpose(dataset, (1, 2, 0)) # Transpose from raster dims (bands, width, height) to image dims (width, height, bands)
else:
raise ValueError("Input data is not supported. It must be raster image")
# Check stretch method
if stretch is None:
data = imgData
elif stretch.lower() == 'linear':
data = np.clip((imgData - imgData.min()) / (imgData.max() - imgData.min()) * 255, 0, 255).astype(np.uint8) # linear stretching based on min max values
elif stretch.lower() == 'hist' or stretch.lower() == 'histogram':
from skimage import exposure
data = exposure.equalize_hist(imgData) # This returns a floating point image with values between 0 and 1
data = (data * 255).astype(np.uint8) # Convert back to 8-bit image for display
elif stretch.lower() == 'custom':
if (contrast is None) or (brightness is None):
raise ValueError("contrast and brightness must be given for custom stretching method")
else:
data = np.clip((imgData * contrast + brightness), 0, 255).astype(np.uint8)
else:
raise ValueError("Stretch method is not supported ('linear', 'hist', 'custom')")
# Get Bounds values
left, bottom, right, top = get_extent_local(input_converted)[0]
lat_center = (top + bottom) / 2
lon_center = (left + right)/ 2
bounds = [[bottom, left], [top, right]]
# Create overlay image
if layername is not None:
image_overlay = ImageOverlay(image= data, bounds= bounds, opacity= opacity, name=layername)
else:
image_overlay = ImageOverlay(image= data, bounds= bounds, opacity= opacity, name='Layer')
# Add the image overlay to the map
# Take basemap
if basemap.lower() == 'openstreetmap' or basemap.lower() == 'osm' or basemap.lower() == 'open street map':
basemap_name = 'OpenStreetMap'
m = folium.Map(location=[lat_center, lon_center], zoom_start= zoom, tiles=basemap_name)
elif basemap.lower() == 'cartodbpositron' or basemap.lower() == 'cartodb positron' or basemap.lower() == 'light' :
basemap_name = 'Cartodb Positron'
m = folium.Map(location=[lat_center, lon_center], zoom_start= zoom, tiles=basemap_name)
elif basemap.lower() == 'cartodbdarkmatter' or basemap.lower() == 'cartodb dark matter' or basemap.lower() == 'cartodb dark' or basemap.lower() == 'dark':
basemap_name = 'Cartodb dark_matter'
m = folium.Map(location=[lat_center, lon_center], zoom_start= zoom, tiles=basemap_name)
elif basemap.lower() == 'opentopomap' or basemap.lower() == 'opentopo' or basemap.lower() == 'topo':
m = folium.Map(location=[lat_center, lon_center], zoom_start= zoom)
folium.TileLayer(
tiles='https://{s}.tile.opentopomap.org/{z}/{x}/{y}.png',
attr='© Topo Map',
name='Open Topo Map'
).add_to(m)
elif basemap.lower() == 'esri satellite' or basemap.lower() == 'esrisatellite' or basemap.lower() == 'satellite':
m = folium.Map(location=[lat_center, lon_center], zoom_start= zoom)
folium.TileLayer(
tiles= 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
attr='© Esri',
name='Esri Satellite'
).add_to(m)
elif basemap.lower() == 'esri street' or basemap.lower() == 'esristreet' or basemap.lower() == 'streetmap' or basemap.lower() == 'street map':
m = folium.Map(location=[lat_center, lon_center], zoom_start= zoom)
folium.TileLayer(
tiles= 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Street_Map/MapServer/tile/{z}/{y}/{x}',
attr='© Esri',
name='Esri Street Map'
).add_to(m)
elif basemap.lower() == 'esri topo' or basemap.lower() == 'esritopo':
m = folium.Map(location=[lat_center, lon_center], zoom_start= zoom)
folium.TileLayer(
tiles= 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Topo_Map/MapServer/tile/{z}/{y}/{x}',
attr='© Esri',
name='Esri Topo Map'
).add_to(m)
elif basemap.lower() == 'esri canvas' or basemap.lower() == 'esricanvas' or basemap.lower() == 'canvas':
m = folium.Map(location=[lat_center, lon_center], zoom_start= zoom)
folium.TileLayer(
tiles= 'https://server.arcgisonline.com/ArcGIS/rest/services/Canvas/World_Light_Gray_Base/MapServer/tile/{z}/{y}/{x}',
attr='© Esri',
name='Esri Canvas Gray'
).add_to(m)
else:
raise ValueError("Basemap is not supported, please select one of these maps ('OSM', 'CartoDB Positron', 'CartoDB Dark Matter', 'OpenTopoMap', 'Esri Satellite', 'Esri Street Map', 'Esri Topo', 'Esri Canvas')")
# Add image to basemap
image_overlay.add_to(m)
folium.LayerControl().add_to(m)
# Save map
if output is not None:
m.save(output)
else:
pass
return m