#!/usr/bin/env python

"""
QT MOV health checker

Copyright 2015 Codex Digital Ltd.
Author: John Beard

Simple script to rip through a MOV file looking for problems

Inspired by VP-3055 - corrupt frame offset tables!
"""

import struct
import argparse
import os
import sys

try:
    from colorama import init
    from colorama import Fore, Back, Style

    init(autoreset=True)

except ImportError:
    # no colour is OK
    pass

def print_error(s):
    try:
        col = Fore.RED
    except NameError:
        col = ''

    print(col + s)


def get4cc(b):
    return ''.join(struct.unpack('cccc', b))

class QtHealthError(object):

    def __init__(self, desc):
        self.desc = desc

    def __str__(self):
        return self.desc

class QtDocLogger(object):

    def __init__(self):
        self.level = 0
        self.print_atoms = False

        # print these table entries
        self.print_table_entries = []

        # list of errors encountered
        self.errors = []

    def decend(self):
        self.level += 1

    def ascend(self):
        self.level -= 1

    def format_atom(self, a, off, size):
        return "%s @ 0x%8x for 0x%x" % (a, off, size)

    def print_atom(self, a, off, size):
        if self.print_atoms:
            indentation = "    " * self.level
            print("%s%s" % (indentation, self.format_atom(a, off, size)))

    def log_error(self, err):
        self.errors.append(err)
        print_error(str(err))


class SampleTable:

    def __init__(self):
        pass

class QtCheckError(Exception):
    pass

class OutOfAtom(QtCheckError):

    def __init__(self):
        pass

class LimitedFileReader:

    def __init__(self, f, max_size):
        self.f = f
        self.start = f.tell()
        self.max_off = self.start + max_size

    def read(self, size):

        if self.f.tell() + size > self.max_off:
            raise OutOfAtom

        return self.f.read(size)

    def tell(self):
        return self.f.tell()

    def seek(self, to):
        return self.f.seek(to)

    def max(self):
        return self.max_off

class QtAtom(object):

    def __init__(self, typ, doc):
        self.doc = doc
        self.typ = typ

    def __str__(self):
        return self.typ

    def create_atom(self, f):
        """Reads and atom at the current file offset
        """

        start_off = f.tell()
        size = struct.unpack('>I', f.read(4))[0]

        if size == 0:
            # terminator?
            return None

        atomtype = struct.unpack('cccc', f.read(4))
        atomtype = ''.join(atomtype)

        if size == 1:
            size = struct.unpack('>Q', f.read(8))[0]

        # read all sub atom types
        if atomtype in CONTAINERS:
            constr = QtAtomContainer
        elif atomtype in ATOMS:
            constr = ATOMS[atomtype]
        else:
            constr = QtAtom

        atom = constr(atomtype, self.doc)

        self.doc.print_atom(atom, start_off, size)

        try:
            remaining = size - (f.tell() - start_off)
            lfr = LimitedFileReader(f.f, remaining)
            atom.read(lfr)
        except OutOfAtom:
            if self.doc.print_atom_size_errors:
                atom_str = self.doc.format_atom(atom, start_off, size)

                err = QtHealthError("Atom too small - overflow averted: %s" % (atom_str))
                self.doc.log_error(err)


        # force seek to declared next atom in case the contects are messed up!
        if f.tell() != start_off + size:
            print("  Mismatch 0x%x, should be 0x%x!" % (f.tell(), size))
            f.seek(start_off + size)

        return atom

    def read(self, f):
        # use up the file..
        f.seek(f.max())

class QtAtomContainer(QtAtom):

    def __init__(self, *args):
        super(QtAtomContainer, self).__init__(*args)
        self.subatoms = []

    def get_subatom_of_type(self, t):
        try:
            return [x for x in self.subatoms if isinstance(x, t)][0]
        except IndexError:
            return None

    def read(self, f):

        self.doc.decend()
        while (f.tell() < f.max() - 8):
            # read all the subatoms
            sub = self.create_atom(f)

            if sub:
                self.subatoms.append(sub)
            else:
                break

        self.doc.ascend()

class SampleOffsetTableAtom(QtAtom):

    def __init__(self, *args):
        super(SampleOffsetTableAtom, self).__init__(*args)
        self.offsets = []
        self.expected_elems = 0;

    def read_preamble(self, f):

        f.read(1)
        f.read(3)
        self.expected_elems = struct.unpack('>I', f.read(4))[0]

        if self.typ in self.doc.print_table_entries:
            print "%s: %d entries" % (self.typ, self.expected_elems)


class StcoAtom(SampleOffsetTableAtom):

    def read(self, f):

        self.read_preamble(f)
        for i in range(self.expected_elems):
            soffset = struct.unpack('>I', f.read(4))[0]
            self.offsets.append(soffset)

class Co64Atom(SampleOffsetTableAtom):

    def read(self, f):

        self.read_preamble(f)
        for i in range(self.expected_elems):
            soffset = struct.unpack('>Q', f.read(8))[0]
            self.offsets.append(soffset)

class StszAtom(QtAtom):

    def __init__(self, *args):
        super(StszAtom, self).__init__(*args)
        self.sizes = []

    def read(self, f):

        start = f.tell()
        f.read(1)
        f.read(3)
        samplesize =  struct.unpack('>I', f.read(4))[0]
        elems = struct.unpack('>I', f.read(4))[0]

        if self.typ in self.doc.print_table_entries:
            print "%s: %d entries, size %d" % (self.typ, elems, samplesize)

        for i in range(elems):
            size_entry = struct.unpack('>I', f.read(4))[0]
            #print " entry %4i: %s" % (i, size_entry)
            self.sizes.append(size_entry)

    def get_sample_size(self, i):
        return self.sizes[i]
        #... or the constant one?

class SttsAtom(QtAtom):

    def __init__(self, *args):
        super(SttsAtom, self).__init__(*args)
        self.time_to_samps = []

    def read(self, f):

        f.read(1)
        f.read(3)

        elems = struct.unpack('>I', f.read(4))[0]

        if self.typ in self.doc.print_table_entries:
            print "%s: %d entries" % (self.typ, elems)

        for i in range(elems):
            (sam_cnt, sam_dur) = struct.unpack('>II', f.read(8))
            self.time_to_samps.append((sam_cnt, sam_dur))

            if self.typ in self.doc.print_table_entries:
                print ("   count %d: duration %d" % (sam_cnt, sam_dur))

class StscAtom(QtAtom):

    def __init__(self, *args):
        super(StscAtom, self).__init__(*args)
        self.chunks = []

    def read(self, f):

        f.read(1)
        f.read(3)
        elems = struct.unpack('>I', f.read(4))[0]

        if self.typ in self.doc.print_table_entries:
            print "%s: %d entries" % (self.typ, elems)

        for i in range(elems):
            (first_chunk, sample_per_chunk, desc) = struct.unpack('>III', f.read(12))
            #print " entry %4i: %s" % (i, size_entry)

            self.chunks.append((first_chunk - 1, sample_per_chunk))

            if self.typ in self.doc.print_table_entries:
                print "   chunk %4d: %d samples/chunk" % (first_chunk, sample_per_chunk)

    def get_chunk_size(self, sample):

        entry = 0

        while len(self.chunks) > entry + 1 and self.chunks[entry][0] < sample:
            entry += 1

        # print ("Chunk entry %d" % entry)
        return self.chunks[entry][1]

class StblAtom(QtAtomContainer):

    FRAME_TYPES = ['icpf']

    class SampleTableError(QtCheckError):
        def __init__(self, reason):
            self.reason = reason

        def __str__(self):
            return self.reason

    def read(self, f):
        QtAtomContainer.read(self, f)

    def scan_for_signature(self, file_obj, sig, offset, scan_range, expected_offset):
        # store so we can return
        old_off = file_obj.tell()

        # go to the new offset
        file_obj.seek(offset)

        scan_off = offset

        found = False

        while scan_off < offset + scan_range:

            candidate = file_obj.read((len(sig)))

            if (candidate == sig):
                diff = scan_off - expected_offset
                print("    Found signature at 0x%x, which is 0x%x past the expected offset 0x%x" % (scan_off, diff, expected_offset))
                found = True
                break

            # advance the scan
            scan_off += 1
            file_obj.seek(scan_off)

        if not found:
            print("    No signature found within 0x%x of offset 0x%x" % (scan_range, offset))

        # put it back where we found it
        file_obj.seek(old_off)

    def check_for_sample_sanity(self, fileObj):

        try:
            samples = self.get_absolute_sample_table()
        except self.SampleTableError as e:
            print("Error: the sample table doesn't seem to make sense:")
            print(e)
            return False

        saved = fileObj.tell()

        for s in range(len(samples)):
            sample_off = samples[s][0]

            fileObj.seek(sample_off)

            frame_size = struct.unpack('>I', fileObj.read(4))
            frame_type = get4cc(fileObj.read(4))

            if frame_type not in self.FRAME_TYPES:
                try:
                    ft_str = frame_type.decode('ascii')
                except UnicodeDecodeError:
                    ft_str = "0x%s" % (frame_type.encode("hex"))

                err = QtHealthError("  Unknown frame type: frame %d: '%s' @ 0x%x" % (s, ft_str, sample_off))
                self.doc.log_error(err)

                # look for signature within 50kb on each side
                scan_range = 1024 * 50
                self.scan_for_signature(fileObj, 'icpf', sample_off - scan_range, scan_range * 2, sample_off)
            else:
                if self.doc.print_known_frame_types:
                    print ("  Frame %d of known type %s @ 0x%x" % (s, frame_type, sample_off))

        # leave the file how we found it
        fileObj.seek(saved)

    def get_absolute_sample_table(self):

        sample_tab = []

        try:
            stco = [x for x in self.subatoms if isinstance(x, SampleOffsetTableAtom)][0]
            stsc = [x for x in self.subatoms if isinstance(x, StscAtom)][0]
            stsz = [x for x in self.subatoms if isinstance(x, StszAtom)][0]
        except IndexError:
            raise self.SampleTableError(
                "stbl doesn't sem to have the right atoms in it to chek the sample table"
            )

        max_sample = len(stsz.sizes) # work it out based on track data?

        sample = 0
        chunk_num = 0
        for c in range(len(stsc.chunks)):

            chunk_start = stsc.chunks[c][0]
            chunk_size = stsc.chunks[c][1] #in samples per chunk

            try:
                # chnks covered by this desc
                num_chunks_in_desc = stsc.chunks[c + 1][0] - chunk_start

                # which is how many samples?
                last_sample_in_chunk_set = sample + num_chunks_in_desc * chunk_size
            except IndexError:
                last_sample_in_chunk_set = max_sample

            if stsc.typ in self.doc.print_table_entries:
                print("Chunk descriptor: from sample %d, chunk size is %d, last sample is %d" % (chunk_start, chunk_size, last_sample_in_chunk_set))

            while sample < last_sample_in_chunk_set:

                index_in_chk = 0
                #print("Chunk %d, size %d, sample %d" % (chunk_num, chunk_size, sample))
                # offset of start of chunk
                try:
                    offset = stco.offsets[chunk_num]
                except IndexError:
                    print("Couldn't get offset for chunk %d from the offset table (sample %d)" % (chunk_num, sample))
                    # return what we have...
                    return sample_tab

                while (index_in_chk < chunk_size):
                    sam_sz = stsz.sizes[sample]

                    sample_tab.append((offset, sam_sz))
                    #print ("   Sample %3d (%d/%d in c%d): 0x%8x for 0x%x" % (sample, index_in_chk, chunk_size, chunk_num, offset, sam_sz))

                    index_in_chk += 1
                    sample += 1
                    offset += sam_sz

                chunk_num += 1

        return sample_tab


class VmhdAtom(QtAtom):

    pass

class MinfAtom(QtAtomContainer):

    def read(self, f):
        QtAtomContainer.read(self, f)

        if self.get_subatom_of_type(VmhdAtom):

            stbl = self.get_subatom_of_type(StblAtom)
            stbl.check_for_sample_sanity(f.f)

class QtRootAtom(QtAtomContainer):

    def __init__(self, *args):
        super(QtRootAtom, self).__init__("", *args)

    def read(self, f):
        QtAtomContainer.read(self, f)

CONTAINERS = [
    "moov",
    "trak",
    "mdia",
]

ATOMS = {
    'stco' : StcoAtom,
    'stsz' : StszAtom,
    'stsc' : StscAtom,
    'stbl' : StblAtom,
    'stts' : SttsAtom,
    'co64' : Co64Atom,
    'vmhd' : VmhdAtom,
    'minf' : MinfAtom
}

class QtChecker():

    def __init__(self, f):
        self.mov = f

    def parse_qt(self, logger):
        self.mov.seek(0, os.SEEK_END)
        size = self.mov.tell()
        self.mov.seek(0)

        lfr = LimitedFileReader(self.mov, size)
        root = QtRootAtom(logger)
        root.read(lfr)

RET_OK = 0
RET_QTBAD = 1
RET_INTERNAL = 2

def main():

    parser = argparse.ArgumentParser(description='Health checks for QT files')
    parser.add_argument('-f', '--file', metavar='FILE', type=argparse.FileType('r'), required=True,
                        help='the file to read')
    parser.add_argument('-t', '--tables', metavar='TABLE', type=str, nargs='+', default = [],
                        help='table atoms to print (eg co64)')
    parser.add_argument('-a', '--print-all-atoms', action='store_true',
                        help='print outline of all atoms')
    parser.add_argument('-z', '--print-atom-size-errors', action='store_true',
                        help='print notice if atoms don\'t seem to match their declared sizes')
    parser.add_argument('-k', '--print-known-frame-types', action='store_true',
                        help='print frame types and offsets even if we know the type')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='print verbose general program state')
    args = parser.parse_args()

    #fn = "/media/shared/R&D_image_dump/VP-3055/T011_C001_1207J8_001.mov"
    qtc = QtChecker(args.file)

    if args.verbose:
        print("Checking file: %s" % args.file.name)

    logger = QtDocLogger()
    logger.print_atoms = args.print_all_atoms
    logger.print_table_entries = args.tables
    logger.print_atom_size_errors = args.print_atom_size_errors
    logger.print_known_frame_types = args.print_known_frame_types

    qtc.parse_qt(logger)

    if len(logger.errors):
        return RET_QTBAD

    return RET_OK

if __name__ == "__main__":
    try:
        ret = main()
    except KeyboardInterrupt:
        # just quit
        sys.exit(RET_OK)

    sys.exit(ret)
