#!/usr/bin/env python3
import argparse
import collections
import decimal
import itertools
import json
import re
import os
import shlex
import subprocess
import sys

PREFIX = 'smartmon_'

GUARD_FILE = '/etc/smartmon.disable'

smart_attributes_whitelist = {
    'airflow_temperature_cel',
    'command_timeout',
    'current_pending_sector',
    'end_to_end_error',
    'erase_fail_count_total',
    'g_sense_error_rate',
    'hardware_ecc_recovered',
    'host_reads_mib',
    'host_reads_32mib',
    'host_writes_mib',
    'host_writes_32mib',
    'load_cycle_count',
    'media_wearout_indicator',
    'wear_leveling_count',
    'nand_writes_1gib',
    'offline_uncorrectable',
    'power_cycle_count',
    'power_on_hours',
    'program_fail_count',
    'raw_read_error_rate',
    'reallocated_event_count',
    'reallocated_sector_ct',
    'reported_uncorrect',
    'sata_downshift_count',
    'seek_error_rate',
    'spin_retry_count',
    'spin_up_time',
    'start_stop_count',
    'temperature_case',
    'temperature_celsius',
    'temperature_internal',
    'total_lbas_read',
    'total_lbas_written',
    'udma_crc_error_count',
    'unsafe_shutdown_count',
    'workld_host_reads_perc',
    'workld_media_wear_indic',
    'workload_minutes',
}


def to_label_value(s):
    """Return a value suitable for a label."""
    return s.lower().replace(' ', '_')


def quote(s):
    """Quote a value with double quotes."""
    return '"%s"' % s.replace('"', '\\"')


_smart_value_rx = re.compile(r'^(\d+)')

def smart_value(s):
    """Parse a SMART attribute string representation.

    We need to use string representations because it's the only
    "processed" (transformed) value in the smartctl JSON output.

    """
    m = _smart_value_rx.match(s)
    if m:
        return int(m[1])
    return 0


class _Metric(collections.namedtuple('Metric', ['name', 'labels', 'value', 'type'])):
    """Individual metric, with labels."""

    def print(self):
        labels = ','.join(
            f'{k}={quote(v)}' for k, v in self.labels.items())
        # Decimal also correctly converts bool values to 0/1.
        value = decimal.Decimal(self.value)
        print(f'{PREFIX}{self.name}{{{labels}}} {value}')

    def print_meta(self):
        print(f'# HELP {PREFIX}{self.name} SMART metric {self.name}')
        print(f'# TYPE {PREFIX}{self.name} {self.type}')


def Gauge(name, labels, value):
    """Gauge-type metric."""
    return _Metric(name, labels, value, 'gauge')


class Collection():
    """Metric collection."""

    def __init__(self):
        self.metrics = {}

    def add(self, metric):
        if metric.name in self.metrics:
            self.metrics[metric.name].append(metric)
        else:
            self.metrics[metric.name] = [metric]

    def print(self):
        for metrics in self.metrics.values():
            metrics[0].print_meta()
            for m in metrics:
                m.print()


class Device(collections.namedtuple('DeviceBase', ['path', 'type'])):
    """Representation of a device as found by smartctl --scan output."""

    @property
    def labels(self):
        return {'disk': self.path, 'type': self.type}

    @property
    def smartctl_select(self):
        return ['--device', self.type, self.path]

    @staticmethod
    def from_string(string):
        parser = argparse.ArgumentParser()
        parser.add_argument('-d', '--device', dest='type')

        tokens = shlex.split(string, comments=True)
        if not tokens:
            return None

        args, _ = parser.parse_known_args(tokens[1:])
        return Device(tokens[0], args.type)

    def __hash__(self):
        return hash((self.path, self.type))

    def __eq__(self, other):
        return ((self.path, self.type) == (other.path, other.type))


def smart_ctl(*args, check=True):
    """Wrapper around invoking the smartctl binary.

    Returns:
        (str) Data piped to stdout by the smartctl subprocess.
    """
    try:
        return subprocess.run(
            ['/usr/sbin/smartctl', *args], stdout=subprocess.PIPE, check=check
        ).stdout.decode('utf-8')
    except subprocess.CalledProcessError as e:
        return e.output.decode('utf-8')


def scan_devices():
    """Find SMART devices by scanning.

    Yields:
        (Device) Single device found by smartctl --scan-open
    """

    devices = smart_ctl('--scan-open')

    for line in devices.split('\n'):
        device = Device.from_string(line.strip())
        if device:
            yield device


def smartd_devices(config='/etc/smartd.conf'):
    """Find SMART devices by reading smartd configuration.

    Yields:
        (Device) Single device configured
    """
    with open(config) as f:
        for line in f:
            if line.startswith('/dev/'):
                yield Device.from_string(line.strip())


def _wrap_hours(power_on_hours, n):
    # SMART self-test lifetime_hours parameter is a 16-bit field that
    # wraps around. Try to put it back in the range of the current
    # value of power_on_hours, although now we're just making up data
    # and this will fail spectacularly in certain circumstances.
    if power_on_hours < 65536:
        return n
    if n < (power_on_hours & 0xffff):
        return (power_on_hours & 0xffff0000) + n
    return n


def collect_self_test_status(device, data):
    """Extract SMART self-test status from logs."""
    if 'ata_smart_self_test_log' not in data or \
       'table' not in data['ata_smart_self_test_log']['standard']:
        return

    power_on_hours = data['power_on_time']['hours']

    # Attempt to extract the most recent self test status by type.
    most_recent_test_by_type = {}
    for test in data['ata_smart_self_test_log']['standard']['table']:
        key = test['type']['value']
        if (key not in most_recent_test_by_type) or \
           (test['lifetime_hours'] > most_recent_test_by_type[key]['lifetime_hours']):
            most_recent_test_by_type[key] = test
    for test in most_recent_test_by_type.values():
        labels = {'test': to_label_value(test['type']['string'])}
        labels.update(device.labels)
        yield Gauge('self_test_status', labels, test['status']['passed'])
        yield Gauge('self_test_hours', labels,
                    _wrap_hours(power_on_hours, test['lifetime_hours']))


def collect_ata_attributes(device, data):
    """Parse SMART ATA attributes."""

    if 'ata_smart_attributes' not in data:
        return

    for attr in data['ata_smart_attributes']['table']:
        name = to_label_value(attr['name'])
        if name not in smart_attributes_whitelist:
            continue
        labels = {'attr': name}
        labels.update(device.labels)
        value = smart_value(attr['raw']['string'])
        yield Gauge('attribute', labels, value)


def collect_nvme_attributes(device, data):
    """Parse SMART NVME attributes."""

    if 'nvme_smart_health_information_log' not in data:
        return

    for key, value in data['nvme_smart_health_information_log'].items():
        if not isinstance(value, int):
            continue
        labels = {'attr': key}
        labels.update(device.labels)
        yield Gauge('attribute', labels, value)


def collect_device_metrics(device):
    """Collect all SMART metrics for a single device."""
    data = json.loads(
        smart_ctl('-a', '--json', *device.smartctl_select))

    is_available = data['smart_support']['available']
    yield Gauge('device_smart_available', device.labels, is_available)
    yield Gauge('device_smart_enabled', device.labels, data['smart_support']['enabled'])
    if not is_available:
        return

    yield Gauge('device_smart_healthy', device.labels, data['smart_status']['passed'])
    yield Gauge('power_on_hours', device.labels, data['power_on_time']['hours'])
    yield Gauge('power_cycle_count', device.labels, data['power_cycle_count'])

    device_info_labels = {}
    device_info_labels.update(device.labels)
    for key in ['model_name', 'model_family', 'serial_number', 'firmware_version']:
        if key in data:
            device_info_labels[key] = data[key]
    yield Gauge('device_info', device_info_labels, 1)

    for metric in itertools.chain(
            collect_ata_attributes(device, data),
            collect_nvme_attributes(device, data),
            collect_self_test_status(device, data),
    ):
        yield metric


def collect_metrics(devices):
    """Collect all SMART metrics for all known devices."""
    for device in devices:
        for metric in collect_device_metrics(device):
            yield metric


def main():
    # Guard file to stop this automation in an emergency.
    if os.path.exists(GUARD_FILE):
        sys.exit(0)

    # Get the list of devices from scanning and/or configuration.
    devices = set(itertools.chain(scan_devices(), smartd_devices()))

    collection = Collection()
    for metric in collect_metrics(devices):
        collection.add(metric)

    collection.print()


if __name__ == '__main__':
    main()