"""$ rio stack""" from collections.abc import Iterable from itertools import zip_longest import logging import click import rasterio from rasterio.rio import options from rasterio.rio.helpers import resolve_inout @click.command(short_help="Stack a number of bands into a multiband dataset.") @options.files_inout_arg @options.output_opt @options.format_opt @options.bidx_magic_opt @options.rgb_opt @options.overwrite_opt @options.creation_options @click.pass_context def stack(ctx, files, output, driver, bidx, photometric, overwrite, creation_options): """Stack a number of bands from one or more input files into a multiband dataset. Input datasets must be of a kind: same data type, dimensions, etc. The output is cloned from the first input. By default, rio-stack will take all bands from each input and write them in same order to the output. Optionally, bands for each input may be specified using a simple syntax: --bidx N takes the Nth band from the input (first band is 1). --bidx M,N,0 takes bands M, N, and O. --bidx M..O takes bands M-O, inclusive. --bidx ..N takes all bands up to and including N. --bidx N.. takes all bands from N to the end. Examples, using the Rasterio testing dataset, which produce a copy. rio stack RGB.byte.tif -o stacked.tif rio stack RGB.byte.tif --bidx 1,2,3 -o stacked.tif rio stack RGB.byte.tif --bidx 1..3 -o stacked.tif rio stack RGB.byte.tif --bidx ..2 RGB.byte.tif --bidx 3.. -o stacked.tif """ logger = logging.getLogger(__name__) try: with ctx.obj['env']: output, files = resolve_inout(files=files, output=output, overwrite=overwrite) output_count = 0 indexes = [] for path, item in zip_longest(files, bidx, fillvalue=None): with rasterio.open(path) as src: src_indexes = src.indexes if item is None: indexes.append(src_indexes) output_count += len(src_indexes) elif '..' in item: start, stop = map( lambda x: int(x) if x else None, item.split('..')) if start is None: start = 1 indexes.append(src_indexes[slice(start - 1, stop)]) output_count += len(src_indexes[slice(start - 1, stop)]) else: parts = list(map(int, item.split(','))) if len(parts) == 1: indexes.append(parts[0]) output_count += 1 else: parts = list(parts) indexes.append(parts) output_count += len(parts) with rasterio.open(files[0]) as first: kwargs = first.meta kwargs.update(**creation_options) if driver: kwargs["driver"] = driver kwargs.update(count=output_count) if photometric: kwargs['photometric'] = photometric with rasterio.open(output, 'w', **kwargs) as dst: dst_idx = 1 for path, index in zip(files, indexes): with rasterio.open(path) as src: if isinstance(index, int): data = src.read(index) dst.write(data, dst_idx) dst_idx += 1 elif isinstance(index, Iterable): data = src.read(index) dst.write(data, range(dst_idx, dst_idx + len(index))) dst_idx += len(index) except Exception: logger.exception("Exception caught during processing") raise click.Abort()