#!/usr/bin/env python3

# Copyright (c) 2017-2023 California Institute of Technology ("Caltech"). U.S.
# Government sponsorship acknowledged. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0

r'''Visualize calibration residuals for one or more observations of a board

SYNOPSIS

  $ mrcal-show-residuals-board-observation
      --from-worst --explore left.cameramodel 0-2

  ... a plot pops up showing the 3 worst-fitting chessboard observations in this
  ... solve. And a REPL opens up to allow further study

The residuals come from the optimization inputs stored in the cameramodel file.
A cameramodel that's missing this data cannot be used with this tool.

To plot the residuals on top of the image used in the calibration the image
paths are loaded from the optimization inputs. The paths are used directly,
relative to the current directory. If the paths are unavailable or if the image
cannot be read, the plot is made without the underlying image.

'''


import sys
import argparse
import re
import os

def parse_args():

    parser = \
        argparse.ArgumentParser(description = __doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('--title',
                        type=str,
                        default = None,
                        help='''Title string for the plot. Overrides the default
                        title. Exclusive with --extratitle''')
    parser.add_argument('--extratitle',
                        type=str,
                        default = None,
                        help='''Additional string for the plot to append to the
                        default title. Exclusive with --title''')

    parser.add_argument('--vectorscale',
                        type = float,
                        default = 1.0,
                        help='''Scale all the vectors by this factor. Useful to improve legibility if the
                        vectors are too small to see''')

    parser.add_argument('--circlescale',
                        type = float,
                        default = 1.0,
                        help='''Scale all the plotted circles by this factor. Useful to improve legibility if
                        the vectors are too big or too small''')

    parser.add_argument('--cbmax',
                        type=float,
                        help='''Maximum range of the colorbar. If omitted, we
                        autoscale''')

    parser.add_argument('--from-worst',
                        action = 'store_true',
                        help='''If given, the requested observations index from
                        the worst-fitting observations to the best-fitting
                        (observation 0 is the worst-fitting observation).
                        Exclusive with --from-glob. By default we index the
                        observations in the order they appear in the solve''')

    parser.add_argument('--from-glob',
                        action = 'store_true',
                        help='''If given, the requested observations are
                        specified as glob(s) matching the image filenames.
                        Exclusive with --from-worst''')

    parser.add_argument('--image-path-prefix',
                        help='''If given, we prepend the given prefix to the image paths. Exclusive with
                        --image-directory''')

    parser.add_argument('--image-directory',
                        help='''If given, we extract the filenames from the image paths in the solve, and use
                        the given directory to find those filenames. Exclusive
                        with --image-path-prefix''')

    parser.add_argument('--hardcopy',
                        type=str,
                        help='''Write the output to disk, instead of making an
                        interactive plot. If given, only ONE observation may be
                        specified''')
    parser.add_argument('--terminal',
                        type=str,
                        help=r'''gnuplotlib terminal. The default is good almost always, so most people don't
                        need this option''')
    parser.add_argument('--set',
                        type=str,
                        action='append',
                        help='''Extra 'set' directives to gnuplotlib. Can be given multiple times''')
    parser.add_argument('--unset',
                        type=str,
                        action='append',
                        help='''Extra 'unset' directives to gnuplotlib. Can be given multiple times''')
    parser.add_argument('--explore',
                        action='store_true',
                        help='''If given, the tool drops into a REPL before exiting, to allow the user to
                        follow-up with more diagnostics''')

    parser.add_argument('model',
                        type=str,
                        help=r'''Camera model that contains the optimization_inputs that describe the solve.
                        The displayed observations may come from ANY of the
                        cameras in the solve, not necessarily the one given by
                        this model''')

    parser.add_argument('observations',
                        type=str,
                        nargs='+',
                        help=r'''The observations we're looking at. Unless
                        --from-glob, these are a list of integers A and/or A-B
                        ranges. By default these index the board observations in
                        the order they appear in the solve. If --from-worst,
                        then we index them from the worst-fitting to the
                        best-fitting instead. If --from-glob then we treat the
                        strings as globs to match against the image paths''')

    args = parser.parse_args()

    if args.image_path_prefix is not None and \
       args.image_directory is not None:
        print("--image-path-prefix and --image-directory are mutually exclusive",
              file=sys.stderr)
        sys.exit(1)

    if args.from_worst and args.from_glob:
        print("--from-worst and --from-glob are mutually exclusive",
              file=sys.stderr)
        sys.exit(1)

    def list_ints_from_range_string(s):
        try:
            i = int(s)
            if i < 0:
                print(f"Observations should be a list of non-negative integers and/or A-B ranges. Invalid observation given: '{s}'",
                      file=sys.stderr)
                sys.exit(1)
            return [i]
        except Exception as e:
            pass

        m = re.match("^([0-9]+)-([0-9]+)$", s)
        if m is None:
            print(f"Observations should be a list of non-negative integers and/or A-B ranges. Invalid observation given: '{s}'",
                  file=sys.stderr)
            sys.exit(1)
        try:
            i0 = int(m.group(1))
            i1 = int(m.group(2))
        except Exception as e:
            print(f"Observations should be a list of non-negative integers and/or A-B ranges. Invalid observation given: '{s}'",
                  file=sys.stderr)
            sys.exit(1)
        return list(range(i0,i1+1))

    if not args.from_glob:
        args.observations = [o for obs in args.observations for o in list_ints_from_range_string(obs)]

    if args.hardcopy is not None:
        if len(args.observations) != 1:
            print(f"--hardcopy given, so exactly one observation should have been given. Instead got these observations: {args.observations}",
                  file=sys.stderr)
            sys.exit(1)

    if args.title      is not None and \
       args.extratitle is not None:
        print("--title and --extratitle are exclusive", file=sys.stderr)
        sys.exit(1)

    return args

args = parse_args()

# arg-parsing is done before the imports so that --help works without building
# stuff, so that I can generate the manpages and README


import mrcal
import numpy as np
import numpysane as nps
import pprint


plotkwargs_extra = {}
if args.set is not None:
    plotkwargs_extra['set'] = args.set
if args.unset is not None:
    plotkwargs_extra['unset'] = args.unset

if args.title is not None:
    plotkwargs_extra['title'] = args.title
if args.extratitle is not None:
    plotkwargs_extra['extratitle'] = args.extratitle

try:
    model = mrcal.cameramodel(args.model)
except Exception as e:
    print(f"Couldn't load camera model '{args.model}': {e}", file=sys.stderr)
    sys.exit(1)

optimization_inputs = model.optimization_inputs()
if optimization_inputs is None:
    print(f"Camera model '{args.model}' does not contain optimization inputs. Residuals aren't available", file=sys.stderr)
    sys.exit(1)

try:
    paths = optimization_inputs['imagepaths']
except:
    # older models don't have this data
    paths = None


if args.from_glob:
    if paths is None:
        print("--from-glob requires image paths, but the given model doesn't have them",
              file=sys.stderr)
        sys.exit(1)

    import fnmatch

    def matching_path_indices(obs):
        try:
            r = re.compile( fnmatch.translate(obs) )
        except:
            print(f"Error translating, compiling glob '{obs}'",
                  file=sys.stderr)
            sys.exit(1)

        m = [i for i in range(len(paths)) if r.match(paths[i])]
        if not m:
            print(f"Glob '{obs}' did not match any images",
                  file=sys.stderr)
            sys.exit(1)
        return m

    args.observations = [i for obs in args.observations for i in matching_path_indices(obs)]



residuals = mrcal.optimizer_callback(**optimization_inputs)[1]

# for --explore
indices_frame_camintrinsics_camextrinsics = optimization_inputs['indices_frame_camintrinsics_camextrinsics']

# i_observations_sorted_from_worst is available even if I don't need it.
# --explore might want to look at it


# shape (Nobservations, object_height_n, object_width_n, 3)
observations = optimization_inputs['observations_board']
residuals_shape = observations.shape[:-1] + (2,)
# shape (Nobservations, object_height_n, object_width_n, 2)
residuals_reshaped = residuals[:np.product(residuals_shape)].reshape(*residuals_shape)
# shape (Nobservations,)
err_per_observation = nps.norm2(nps.clump(residuals_reshaped, n=-3))
i_observations_sorted_from_worst = \
    list(reversed(np.argsort(err_per_observation)))





def show(i_observation):
    try:
        mrcal.show_residuals_board_observation(optimization_inputs,
                                               i_observation,
                                               from_worst                       = args.from_worst,
                                               i_observations_sorted_from_worst = i_observations_sorted_from_worst,
                                               residuals                        = residuals,
                                               paths                            = paths,
                                               cbmax                            = args.cbmax,
                                               image_path_prefix                = args.image_path_prefix,
                                               image_directory                  = args.image_directory,
                                               circlescale                      = args.circlescale,
                                               vectorscale                      = args.vectorscale,
                                               hardcopy                         = args.hardcopy,
                                               terminal                         = args.terminal,
                                               **plotkwargs_extra)
    except Exception as e:
        print(f"Couldn't show_residuals_board_observation(): {e}",
              file=sys.stderr)
        sys.exit(1)


Nplots = len(args.observations)

if args.hardcopy:
    for i in range(Nplots):
        show(args.observations[i])
    print(f"Wrote {args.hardcopy}")
    sys.exit(0)

plotkwargs_extra['wait'] = True
pids = [0] * Nplots
for i in range(Nplots):
    pid = os.fork()
    if pid == 0:
        # child
        # make the plot, and wait for it to be closed by the user
        show(args.observations[i])
        sys.exit()

    # parent
    pids[i] = pid


if args.explore:

    explore_message = \
f'''We're exploring. The first plot being shown can be re-created with:

mrcal.show_residuals_board_observation(optimization_inputs,
                                       {args.observations[0]},
                                       from_worst                       = args.from_worst,
                                       i_observations_sorted_from_worst = i_observations_sorted_from_worst,
                                       residuals                        = residuals,
                                       paths                            = paths,
                                       circlescale                      = args.circlescale,
                                       vectorscale                      = args.vectorscale,
                                       hardcopy                         = args.hardcopy,
                                       terminal                         = args.terminal,
                                       **{pprint.pformat(plotkwargs_extra)})

This is available via a shorthand function call show({args.observations[0]})


The first 10 worst-fitting observations (i_observations_sorted_from_worst[:10])

{i_observations_sorted_from_worst[:10]}


The corresponding 10 (iframe, icamintrinsics, icamextrinsics) tuples
(indices_frame_camintrinsics_camextrinsics[i_observations_sorted_from_worst[:10] ]):

{indices_frame_camintrinsics_camextrinsics[ i_observations_sorted_from_worst[:10] ]}

'''
    if paths is not None:
        explore_message += \
f'''The corresponding image paths (paths[i_observations_sorted_from_worst[:10]]):

{paths[i_observations_sorted_from_worst[:10]]}
'''

    explore_message += \
        r"""
The optimization inputs are available in the optimization_inputs dict
"""

    print(explore_message)
    import IPython
    IPython.embed()


for i in range(Nplots):
    os.waitpid(pids[i], 0)
sys.exit()
