#!/usr/bin/env python3

# Copyright 2021 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.


import argparse
import collections
import logging
import os
import re
import shlex
import subprocess
import sys

LOG = logging.getLogger(__name__)

UNIT_BYTES = {
    'MB': 1000000,
    'MiB': 1048576,
    'GB': 1000000000,
    'GiB': 1073741824
}
UNITS = ['%']
UNITS.extend(UNIT_BYTES.keys())
AMOUNT_UNIT_RE = re.compile('^([0-9]+)(%s)$' % '|'.join(UNITS))

UNIT_BYTES_FORMAT = {
    'B': 1,
    'KiB': 1024,
    'MiB': 1048576,
    'GiB': 1073741824
}

# Regular expression to get digits at the end of a string
DIGITS_EOF_STR_RE = re.compile(r'\d+$')

# Only create growth partition if there is at least 1GiB available
MIN_DISK_SPACE_BYTES = UNIT_BYTES['GiB']

# Default LVM physical extent size is 4MiB
PHYSICAL_EXTENT_BYTES = 4 * UNIT_BYTES['MiB']

# Grow the thin pool metadata size to 1GiB
POOL_METADATA_SIZE = UNIT_BYTES['GiB']


class Command(object):
    """ An object to represent a command to run with associated comment """

    cmd = None

    comment = None

    def __init__(self, cmd, comment=None):
        self.cmd = cmd
        self.comment = comment

    def __repr__(self):
        if self.comment:
            return "\n# %s\n%s" % (self.comment, printable_cmd(self.cmd))
        return printable_cmd(self.cmd)


def parse_opts(argv):
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description='''
Grow the logical volumes in an LVM group to take available device space.

The positional arguments specify how available space is allocated. They
have the format <volume>=<amount><unit> where:

<volume> is the label or the mountpoint of the logical volume
<amount> is an integer growth amount in the specified unit
<unit> is one of the supported units

Supported units are:

% percentage of available device space before any changes are made
MiB mebibyte (1048576 bytes)
GiB gibibyte (1073741824 bytes)
MB megabyte (1000000 bytes)
GB gigabyte (1000000000 bytes)

Each argument is processed in order and the requested amount is allocated
to each volume until the disk is full. This means that if space is
overallocated, the last volumes may only grow by the remaining space, or
not grow at all, and a warning will be printed. When space is underallocated
the remaining space will be given to the root volume (mounted at /).

The currently supported partition layout is:
- Exactly one of the partitions containing an LVM group
- The disk having unpartitioned space to grow with
- The LVM logical volumes being formatted with XFS filesystems

Example usage:

growvols /var=80% /home=20GB

growvols --device sda --group vg img-rootfs=20% fs_home=20% fs_var=60%

''')
    parser.add_argument('grow_vols',
                        nargs="*",
                        metavar="<volume>=<amount><unit>",
                        default=os.environ.get('GROWVOLS_ARGS', '').split(),
                        help='A label or mountpoint, and the proportion to '
                             'grow it by. Defaults to $GROWVOLS_ARGS. '
                             'For example: /home=80%% img-rootfs=20%%')
    parser.add_argument('--group', metavar='GROUP',
                        default=os.environ.get('GROWVOLS_GROUP'),
                        help='The name of the LVM group to extend. Defaults '
                             'to $GROWVOLS_GROUP or the discovered group '
                             'if only one exists.')
    parser.add_argument('--device', metavar='DEVICE',
                        default=os.environ.get('GROWVOLS_DEVICE'),
                        help='The name of the disk block device to grow the '
                             'volumes in (such as "sda"). Defaults to the '
                             'disk containing the root mount.')

    parser.add_argument(
        '--exit-on-no-grow',
        action='store_true',
        help="Exit with error code 2 if no volume could be grown. ",
        default=False)

    parser.add_argument(
        '-d', '--debug',
        dest="debug",
        action='store_true',
        help="Print debugging output.",
        required=False)
    parser.add_argument(
        '-v', '--verbose',
        dest="verbose",
        action='store_true',
        help="Print verbose output.",
        required=False)
    parser.add_argument(
        '--noop',
        dest="noop",
        action='store_true',
        help="Return the configuration commands, without applying them.",
        required=False)
    parser.add_argument(
        '-y', '--yes',
        help='Skip yes/no prompt (assume yes).',
        default=False,
        action="store_true")

    opts = parser.parse_args(argv[1:])

    return opts


def configure_logger(verbose=False, debug=False):
    LOG_FORMAT = '[%(levelname)s] %(message)s'
    log_level = logging.WARN

    if debug:
        log_level = logging.DEBUG
    elif verbose:
        log_level = logging.INFO

    logging.basicConfig(format=LOG_FORMAT,
                        level=log_level)


def printable_cmd(cmd):
    """Convert a command list to a log printable string"""
    cmd_quoted = [shlex.quote(c) for c in cmd]
    return ' '.join(cmd_quoted)


def convert_bytes(num):
    """Format a bytes amount with units GiB, MiB etc"""
    for x in ['GiB', 'MiB', 'KiB', 'B']:
        unit = UNIT_BYTES_FORMAT[x]
        unit_num = num // unit
        if unit_num > 0:
            break
    return "%d%s" % (unit_num, x)


def shrink_bytes_for_alignment(num):
    """Shrink the amount of bytes to line up with LVM.

    The purpose of this method is to gently shrink the requested volume
    size of bytes to ensure that the size of the bytes range is in alignment
    with the LVM extent size, and that the bytes value used is modeled in
    the concept of extents where *two* extents have also been removed from
    the value to account for any possible structural alignment differences
    due to varying block sizes. In other words, reduce the size slightly to
    account for the differences.

    :param: num as in the number of bytes desired.
    :returns: The closest value accounting for extent sizing *minus* two LVM
              extents.
    """
    result = ((num // PHYSICAL_EXTENT_BYTES) - 2) * PHYSICAL_EXTENT_BYTES
    if result <= 0:
        raise Exception(
            'Not enough space available to shrink to alignment. '
            'Requires more than %s, requested: %s' % (
                convert_bytes(PHYSICAL_EXTENT_BYTES * 2),
                convert_bytes(num)))
    return result


def execute(cmd):
    """Run a command"""
    LOG.info('Running: %s', printable_cmd(cmd))

    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        universal_newlines=True
    )
    out, err = process.communicate()
    if process.returncode != 0:
        error_msg = (
            'Running command failed: cmd "{}", stdout "{}",'
            ' stderr "{}"'.format(
                ' '.join(cmd),
                out,
                err
            )
        )
        LOG.error(error_msg)
        raise Exception(error_msg)
    LOG.debug('Result: %s', out)
    return out


def parse_shell_vars(output):
    """Parse tags from the lsblk/blkid output.

    Parses format KEY="VALUE" KEY2="VALUE2".

    :return: a generator yielding dicts with information from each line.
    """
    for line in output.strip().split('\n'):
        if line.strip():
            yield {key: value for key, value in
                   (v.split('=', 1) for v in shlex.split(line))}


def find_device(devices, keys, value):
    """Search devices list for one device that has a key matching the value"""
    if isinstance(keys, str):
        keys = [keys]
    for device in devices:
        for key in keys:
            if device.get(key) == value:
                return device


def find_disk(opts, devices):
    """Find the disk device given a device option, or searching from mount /"""
    if opts.device:
        device = find_device(devices, ['KNAME', 'NAME'], opts.device)
        if not device:
            raise Exception('Could not find specified --device: %s'
                            % opts.device)
    else:
        device = find_device(devices, 'MOUNTPOINT', '/')
        if not device:
            raise Exception('Could not find mountpoint for /')

        while True:
            device = find_device(devices, 'KNAME', device['PKNAME'])
            if not device:
                break
            if device['TYPE'] == 'disk':
                break
            if device['TYPE'] == 'mpath':
                break
            if not device['PKNAME']:
                break
        if not device:
            raise Exception('Could not detect disk device')

    if device['TYPE'] not in ('disk', 'mpath'):
        raise Exception('Expected a device with TYPE="disk" or TYPE="mpath", '
                        'got: %s' % device['TYPE'])
    return device


def grow_gpt(disk_name):
    device = '/dev/%s' % disk_name
    output = execute(['sgdisk', '-v', device])

    search_str = "it doesn't reside\nat the end of the disk"
    if search_str in output:
        LOG.info('Fixing GPT structure, moving to end of device %s', device)
        execute(['sgdisk', '-e', device])
        execute(['partprobe'])


def find_space(disk_name):
    LOG.info('Finding spare space to grow into')
    dev_path = '/dev/%s' % disk_name
    largest = execute([
        'sgdisk',
        '--first-aligned-in-largest',
        '--end-of-largest',
        dev_path])
    start_end = largest.strip().split('\n')
    sector_start = int(start_end[0])
    sector_end = int(start_end[1])
    size_sectors = sector_end - sector_start
    return sector_start, sector_end, size_sectors


def find_devices():
    LOG.info('Finding all block devices')
    lsblk = execute([
        'lsblk',
        '-Po',
        'kname,pkname,name,label,type,fstype,mountpoint'])
    return list(parse_shell_vars(lsblk))


def find_group(opts):
    LOG.info('Finding LVM volume group')
    vgs = execute(['vgs', '--noheadings', '--options', 'vg_name'])
    vg_names = set([vg.strip() for vg in vgs.split('\n') if vg.strip()])
    if opts.group:
        if opts.group not in vg_names:
            raise Exception('Could not find specified --group: %s'
                            % opts.group)
        return opts.group
    if len(vg_names) > 1:
        raise Exception('More than one volume group, specify one to '
                        'use with --group: %s' % ', '.join(sorted(vg_names)))
    if len(vg_names) < 1:
        raise Exception('No volume groups found')
    return vg_names.pop()


def find_next_partnum(devices, disk):
    if disk['TYPE'] == 'disk':
        disk_name = disk['KNAME']
        return len([d for d in devices if d['PKNAME'] == disk_name]) + 1
    if disk['TYPE'] == 'mpath':
        max_partnum = 0
        disk_name_length = len(disk['NAME'])
        seen_devs = set()
        for d in devices:
            if d['PKNAME'] != disk['KNAME']:
                continue
            dev_name = d['NAME']
            if dev_name in seen_devs:
                continue
            seen_devs.add(dev_name)
            # Multipath device names may have a partition_delimiter between
            # the base device name and the partition number. By default, if
            # partition_delimiter is unset in the multipath configuration and
            # the base name ends in a digit, a 'p' is inserted.
            # A regular expression is used to reliably extract trailing
            # digits, regardless the presence of a partition delimiter or not.
            max_partnum = max(max_partnum,
                              int(DIGITS_EOF_STR_RE.search(
                                  dev_name[disk_name_length:]).group(0)))
        return max_partnum + 1


def find_next_device_name(devices, disk, partnum):
    disk_name = disk['NAME']
    existing_partnum = partnum - 1

    # try partition scheme for SATA etc, then NVMe
    for part_template in '%s%d', '%sp%d':
        part = part_template % (disk_name, existing_partnum)
        LOG.debug('Looking for device %s' % part)
        if find_device(devices, 'NAME', part):
            return part_template % (disk_name, partnum)

    raise Exception('Could not find partition naming scheme for %s'
                    % disk_name)


def prompt_user_for_confirmation(message):
    """Prompt user for a y/N confirmation

    Use this function to prompt the user for a y/N confirmation
    with the provided message. The [y/N] should be included in
    the provided message to this function to indicate the expected
    input for confirmation. You can customize the positive response if
    y/N is not a desired input.

    :param message: Confirmation string prompt
    :param positive_response: Beginning character for a positive user input
    :return: boolean true for valid confirmation, false for all others
    """
    try:
        if not sys.stdin.isatty():
            LOG.error('User interaction required, cannot confirm.')
            return False
        else:
            sys.stdout.write(message)
            sys.stdout.flush()
            prompt_response = sys.stdin.readline().lower()
            if not prompt_response.startswith('y'):
                print('Taking no action.')
                return False
            LOG.info('User confirmed action.')
            return True
    except KeyboardInterrupt:  # ctrl-c
        print('User did not confirm action (ctrl-c) so taking no action.')
    except EOFError:  # ctrl-d
        print('User did not confirm action (ctrl-d) so taking no action.')
    return False


def amount_unit_to_extent(amount_unit, total_size_bytes, remaining_bytes):
    m = AMOUNT_UNIT_RE.match(amount_unit)
    if not m:
        raise Exception('Value for <amount><unit> not valid: %s'
                        % amount_unit)

    amount = int(m.group(1))
    unit = m.group(2)
    if unit in UNIT_BYTES:
        bytes = amount * UNIT_BYTES[unit]
        LOG.debug('%s is %s' % (amount_unit, convert_bytes(bytes)))
    else:
        bytes = int((amount / 100.0) * total_size_bytes)
        LOG.debug('%s of %s is %s'
                  % (amount_unit,
                     convert_bytes(total_size_bytes),
                     convert_bytes(bytes)))

    if remaining_bytes <= 0:
        LOG.debug('No remaining space available to allocate %s' % amount_unit)
        bytes = 0
    elif bytes > remaining_bytes:
        LOG.debug('Cannot allocate all of %s, only %s remaining space'
                  % (convert_bytes(bytes), convert_bytes(remaining_bytes)))
        bytes = remaining_bytes

    # reduce to align on physical extent size
    bytes = bytes - bytes % PHYSICAL_EXTENT_BYTES

    return bytes, remaining_bytes - bytes


def find_grow_vols(opts, devices, group, total_size_bytes):
    grow_vols = collections.OrderedDict()
    remaining_bytes = total_size_bytes

    arg_error = ('<grow_vols> must be of the format '
                 '<volume>=<amount><unit>, '
                 'for example: /home=100% or fs_home=100%')
    grow_vols_opt = list(gv for gv in opts.grow_vols if gv)
    # append an implicit /=100% to grow root by any underallocation
    grow_vols_opt.append('/=100%')

    for gv in grow_vols_opt:
        if '=' not in gv:
            raise Exception(arg_error)
        devname, amount_unit = gv.split('=', 1)
        device = find_device(devices, ['LABEL', 'MOUNTPOINT'], devname)
        if not device:
            raise Exception('Could not find device %s for argument: %s'
                            % (devname, gv))

        if device['TYPE'] != 'lvm':
            raise Exception('Device %(NAME)s is not of type lvm: %(TYPE)s'
                            % device)

        if device['FSTYPE'] != 'xfs':
            raise Exception('Device %(NAME)s is not of fstype xfs: %(FSTYPE)s'
                            % device)

        size_bytes, remaining_bytes = amount_unit_to_extent(
            amount_unit, total_size_bytes, remaining_bytes)

        if size_bytes == 0:
            continue

        volume_path = '/dev/mapper/%s' % device['NAME']

        grow_vols[volume_path] = size_bytes

    return grow_vols


def find_sector_size(disk_name):
    sector_path = '/sys/block/%s/queue/logical_block_size' % disk_name

    LOG.info('Reading sector size from %s' % sector_path)
    with open(sector_path) as f:
        size = int(f.read().strip())
        LOG.info('Sector size of %s is %s'
                 % (disk_name, convert_bytes(size)))
        return size


def find_thin_pool(devices, group):
    LOG.info('Finding LVM thin pool')
    lvs = execute(['lvs', '--noheadings', '--options',
                   'lv_name,lv_dm_path,lv_attr,pool_lv'])
    lvs_devices = [lvsd.strip().split() for lvsd in lvs.split('\n')
                   if lvsd.strip()]
    thin_pool_device = None
    thin_pool_name = None

    # find any thin pool
    for d in lvs_devices:
        lv_name, lv_dm_path, lv_attr = d[:3]
        if lv_attr.startswith('t'):
            # this device is a thin pool
            thin_pool_device = lv_dm_path
            thin_pool_name = lv_name
            break

    if not thin_pool_device:
        return None, None

    # ensure every volume uses the pool
    for d in lvs_devices:
        lv_name, lv_dm_path, lv_attr = d[:3]
        if not lv_attr.startswith('V'):
            # Ignore device not a volume
            continue
        pool_lv = None
        if len(d) == 4:
            pool_lv = d[3]
        if pool_lv != thin_pool_name:
            raise Exception('All volumes need to be in pool %s. '
                            '%s is in pool %s' %
                            (thin_pool_name, lv_name, pool_lv))
    return thin_pool_device, thin_pool_name


def main(argv):
    opts = parse_opts(argv)
    configure_logger(opts.verbose, opts.debug)

    devices = find_devices()
    disk = find_disk(opts, devices)
    disk_name = disk['KNAME']

    grow_gpt(disk_name)

    sector_bytes = find_sector_size(disk_name)
    sector_start, sector_end, size_sectors = find_space(disk_name)

    min_required_sectors = MIN_DISK_SPACE_BYTES // sector_bytes

    size_bytes = size_sectors * sector_bytes
    if size_sectors < min_required_sectors:
        msg = ('Need at least %s sectors to expand into, there '
               'is only: %s' % (min_required_sectors, size_sectors))
        if opts.exit_on_no_grow:
            LOG.error(msg)
            return 2

        LOG.info(msg)
        return 0

    group = find_group(opts)
    partnum = find_next_partnum(devices, disk)
    devname = find_next_device_name(devices, disk, partnum)
    thin_pool, thin_pool_name = find_thin_pool(devices, group)
    if thin_pool:
        # reserve for the size of the metadata volume, which
        # is *allocated* from the underlying storage as well.
        size_bytes -= POOL_METADATA_SIZE
        # reserve for the size of the spare metadata volume,
        # used for metadata check and repair
        size_bytes -= POOL_METADATA_SIZE
        # round down to a whole extent
        size_bytes -= size_bytes % PHYSICAL_EXTENT_BYTES
        # reduce for metadata overhead
        size_bytes -= PHYSICAL_EXTENT_BYTES
    if disk['TYPE'] == 'mpath':
        dev_path = '/dev/mapper/%s' % devname
    else:
        dev_path = '/dev/%s' % devname
    grow_vols = find_grow_vols(opts, devices, group, size_bytes)

    commands = []

    commands.append(Command([
        'sgdisk',
        '--new=%d:%d:%d' % (partnum, sector_start, sector_end),
        '--change-name=%d:growvols' % partnum,
        '/dev/%s' % disk_name
    ], 'Create new partition %s' % devname))

    if disk['TYPE'] == 'mpath':
        commands.append(Command(
            ['multipath', '-r'],
            'Force a reload of all existing multipath maps'))

    commands.append(Command(
        ['partprobe'],
        'Inform the OS of partition table changes'))

    commands.append(Command([
        'pvcreate',
        '-ff',
        '--yes',
        dev_path
    ], 'Initialize %s for use by LVM' % devname))

    commands.append(Command([
        'vgextend',
        group,
        dev_path
    ], 'Add physical volume %s to group %s' % (devname, group)))

    if thin_pool:
        commands.append(Command([
            'lvextend',
            '--poolmetadatasize',
            '+%sB' % POOL_METADATA_SIZE,
            '%s/%s' % (group, thin_pool_name)
        ], 'Add %s to thin pool metadata %s' % (
            convert_bytes(POOL_METADATA_SIZE), thin_pool)))
        commands.append(Command([
            'lvextend',
            '-L+%sB' % shrink_bytes_for_alignment(size_bytes),
            thin_pool,
            dev_path
        ], 'Add %s to thin pool %s' % (convert_bytes(size_bytes),
                                       thin_pool)))

    for volume_path, size_bytes in grow_vols.items():
        if size_bytes > 0:
            extend_args = [
                'lvextend',
                '--size',
                '+%sB' % shrink_bytes_for_alignment(size_bytes),
                volume_path
            ]
            if not thin_pool:
                extend_args.append(dev_path)
            commands.append(Command(
                extend_args, 'Add %s to logical volume %s' % (
                    convert_bytes(size_bytes), volume_path)))

    for volume_path, size_bytes in grow_vols.items():
        if size_bytes > 0:
            commands.append(Command([
                'xfs_growfs',
                volume_path
            ], 'Grow XFS filesystem for %s' % volume_path))

    if opts.noop:
        print('The following commands would have run:')
    else:
        print('The following commands will be run:')

    for cmd in commands:
        print(cmd)
    sys.stdout.flush()

    if opts.noop:
        return 0

    yes = opts.yes
    if not yes:
        yes = prompt_user_for_confirmation(
            '\nAre you sure you want to run these commands? [y/N] ')
    if yes:
        for cmd in commands:
            execute(cmd.cmd)

    return 0


if __name__ == '__main__':
    sys.exit(main(sys.argv))
