#!/usr/bin/env python
"""
Extract and store data entries from a set of reStructuredText documents.

  rst-fields DATABASE FILE1 FILE2 ...

Entries look like this (all names are generic)::

  [book]
    :title: Probability Theory
    :subtitle: The Logic of Science
    :authors: E.T. Jaynes, G. Larry Bretthorst
    :isbn: 978-0521592710

You can have the data stored in an existing database, or in a file. This program
can also automatically infer a database schema from the data (and create the
tables). If you only store the data in a database, only the fields defined in
the schema will be stored (the others will be dropped). This allows you to
create your database tables ahead of time, as you like, and to have only valid
fields filled with the data later.

This script provides an ultra-simple way to store data that can be expressed as
name-value pairs embedded in a text file (with other text). The typical kind of
data that you would use this for would be for PIM data (e.g., books, addresses,
links, etc.).

A special source key is used in the outgoing data to identify the source file
that each data entry came from. By default, the source is the filename, but you
can set this source id in the file itself, by creating an 'Id' docinfo entry, at
the top of the document, like this::

  ========
  My Title
  ========
  :Id: <unique-id>
  ...

About normalization:

  By default the field names are sanitized by attempting to remove plurals
  automatically, and the value strings are merged into a single line each. This
  can be optionally disabled.

Notes:

  The ideas for this project originate from project Nabu
  (http://furius.ca/nabu/). It is an attempt at providing its basic
  functionality without all the complications, setup and customization that Nabu
  requires.

"""
#__author__ = 'Martin Blais <blais@furius.ca>'

## FIXME: add an option to force an additional field for all the files processed.
## FIXME: fields with * at the end (e.g. comments*) should not get joined/sanitized.
## FIXME: How do we deal with subtypes?  e.g. :p/cell:


from __future__ import with_statement
try:
    import psyco
except ImportError:
    pass

# stdlib imports
import sys, os, re, getpass, logging
from types import ClassType
from StringIO import StringIO
from collections import defaultdict
from operator import itemgetter, attrgetter
from os.path import *
from pprint import pprint

# psyco imports
import psycopg2 as dbapi

# docutils imports
from docutils import core, nodes, io, utils


col_source = '__source__'


#-------------------------------------------------------------------------------
# Database helpers.

class committing(object):
    """
    A DBAPI-compatible committing context manager.
    """
    def __init__(self, conn):
        assert conn is not None and not conn.closed
        self.conn = conn
        self.curs = None

    def __enter__(self):
        c = self.curs = self.conn.cursor()
        return c

    def __exit__(self, _1, exc_value, _2):
        if exc_value is None:
            self.conn.commit()
        else:
            self.conn.rollback()
        self.curs.close()

class dbclosing(object):
    """ Context to automatically close cursor."""

    def __init__(self, conn):
        self.conn = conn
        self.curs = None

    def __enter__(self):
        self.curs = self.conn.cursor()
        return self.curs

    def __exit__(self, *exc_info):
        self.curs.close()
        self.curs = None



#-------------------------------------------------------------------------------
# Operations on association lists.

def alist_memq(alist, key):
    "Return true if 'key' is in the assoc-list 'alist'."
    for k, v in alist:
        if k == key:
            return True

def alist_replace(alist, from_, to_):
    "Replace keys with 'from_' in alist with keys of 'to_', destructively."
    replace = [(i, v) for i, (k, v) in enumerate(alist) if k == from_]
    for i, v in replace:
        alist[i] = (to_, v)

def alist_remove(alist, key):
    "Remove in-place the elements with key 'key'."
    remove = [i for i, (k, v) in enumerate(alist) if k == key]
    for i in reversed(remove):
        del alist[i]


#-------------------------------------------------------------------------------
# Parsing and values extraction code

# Note: this is generic, perhaps should be part of docutils.nodes.
def find_first(nodetype, node, document):
    """ Find the first node under 'node' that is of the same type as 'nodetype'."""

    assert isinstance(nodetype, type)
    found = []

    class FindFirst(nodes.SparseNodeVisitor):
        def visit(self, node):
            found.append(node)
            return True # stop

    setattr(FindFirst, 'visit_%s' % nodetype.__name__,
            FindFirst.visit)

    vis = FindFirst(document)
    node.walk(vis)

    return found[0] if found else None

def get_file_entries(fn):
    """ Parse a file and extract entries from it. """
    text = open(fn).read()
    document = core.publish_doctree(
        open(fn),
        source_class=io.FileInput,
        # source_path=fn,
        reader_name='standalone',
        parser_name='restructuredtext',
        settings_overrides={'report_level': 'error'},
        )

    # Extract the unique document id.
    docid = None
    docinfo = find_first(nodes.docinfo, document, document)
    if docinfo:
        fields = extract_fields(docinfo, document)
        dfields = dict((k.lower(), v) for k, v in fields)
        docid = dfields.get(u'id', None)
    if docid is None:
        docid = basename(fn)

    # Obtain all the data from the document.
    v = FindData(fn, docid, document)
    document.walk(v)
    return docid, v.entries


class Entry(object):
    """ A data entry read from a file. """

    def __init__(self, source, table, values, locator):
        self.source = source
        self.table = table
        self.values = values
        self.locator = locator

    def __str__(self):
        s = StringIO()
        s.write('[%s]  (%s)\n' % (self.table, self.source))
        for x in self.values:
            s.write('  :%s: %s\n' % x)
        return s.getvalue()



class FindData(nodes.SparseNodeVisitor):
    """ A visitor that finds all the definition_list_item which match our
    desired tagging for format."""

    # Regexp for the definition tag.
    tagre = re.compile('[\[\{\(]([a-zA-Z0-9_]+)[\]\}\)]\s*$')

    def __init__(self, filename, docid, *args):
        nodes.SparseNodeVisitor.__init__(self, *args)
        self.filename = filename
        self.docid = docid
        self.entries = []

    def visit_term(self, node):
        if len(node.children) != 1:
            return
        mo = self.tagre.match(node.astext())
        if not mo:
            return

        table = str2table(mo.group(1))
        dlitem = node.parent
        if len(dlitem.children) != 2:
            return

        defn = dlitem.children[1]
        if not isinstance(defn, nodes.definition) or len(defn.children) != 1:
            return
        flist = defn.children[0]
        if not isinstance(flist, nodes.field_list):
            return

        fields = extract_fields(flist, self.document)

        if opts.enclosing_section:
            assert 'section' not in fields, fields
            section = ''
            node_sec = find_enclosing_section(node)
            if node_sec is not None:
                i = node_sec.first_child_matching_class(nodes.title)
                title = node_sec.children[i]
                if title is not None:
                    section = title.astext()
            fields.append( ('section', section) )

        source, line = utils.get_source_line(node)
        e = Entry(self.docid, table, fields,
                  '%s:%s' % (abspath(self.filename), line))
        self.entries.append(e)

        raise nodes.SkipNode()

def find_enclosing_section(node):
    """ Go up the tree until a section node is found. """
    while node:
        if isinstance(node, nodes.section):
            return node
        node = node.parent

def str2table(s):
    "Convert a string to a valid table name."
    return s.lower().replace(' ', '_')

def extract_fields(node, document):
    "Return a list of (key, value) pairs from all underlying field_list's."
    v = ExtractFields(document)
    node.walk(v)
    return list(v)

class ExtractFields(nodes.SparseNodeVisitor, list):
    """ A visitor for a field_list that extracts all the name/value pairs. """

    def visit_field_name(self, node):
        key = node.astext()
        if not opts.raw:
            key = key.split()[0] # Use just the first word of the field.
        self.key = key if opts.raw else sanitize_column(key)

    def visit_field_body(self, node):
        value = node.astext()
        value = value if opts.raw else sanitize_value(value)
        self.append( (self.key, value) )
        self.key = None


def sanitize_column(colname):
    "Sanitize columns names for SQL."
    return re.sub('[^a-z]', '_', colname.strip().lower())

def sanitize_value(value):
    "Sanitize value content for SQL."
    lines = [x for x in map(lambda x: x.strip(), value.splitlines()) if x]
    return ' '.join(lines)

def uniquify_keys(entries):
    """ Given a list of entries, modify their values so that the field keys are
    unique."""
    for e in entries:
        ecols = set()
        replace = []
        for i, (key, value) in enumerate(e.values):
            nkey = key
            if nkey in ecols:
                for j in xrange(2, 1000):
                    nkey = '%s_%d' % (key, j)
                    if nkey not in ecols:
                        break
                else:
                    raise RuntimeError("Internal error in uniquify_keys.")
                replace.append( (i, (nkey, value)) )
            ecols.add(nkey)
        for i, assoc in replace:
            e.values[i] = assoc

#-------------------------------------------------------------------------------
# Table definition inference code.

# Note: this is generic utils code.
def seq2dict(seq, classify_fun):
    """ Given a sequence of objects and a function to classify them, return a
    dict of (key, sublist of objects) whereby 'key' is computed by calling
    'classify_fun' on objects."""
    assert isinstance(seq, (list, tuple)), seq
    r = defaultdict(list)
    for e in seq:
        try:
            r[classify_fun(e)].append(e)
        except Exception:
            pass
    return r

def infer_tables(entries):
    """ Given a list of entries, infer some database models from it. """
    table_entries = seq2dict(entries, attrgetter('table'))
    return dict((table, infer_table(entries))
                for table, entries in table_entries.iteritems())

intre = re.compile('[0-9]+$')
floatre = re.compile('[0-9\.]+$')

min_nb_values_for_type = 5

def infer_table(entries):
    """ Given a list of entries from the same table, infer a table description.
    This returns a list of (column-name, column-type).
    """

    coldata = defaultdict(list)
    sortorder = defaultdict(int)
    for e in entries:
        for i, (key, value) in enumerate(e.values):
            coldata[key].append(value)
            sortorder[key] += i

    coldefs = {}
    for colname, values in coldata.iteritems():
        ctype = 'text'
        if len(values) > min_nb_values_for_type:
            if all(intre.match(x) for x in values):
                ctype = 'integer'
            elif all(floatre.match(x) for x in values):
                ctype = 'float'
        coldefs[colname] = ctype

    return [(col_source, 'varchar(256)')] + \
           sorted(coldefs.items(), key=lambda x: sortorder[x[0]])

def table2sql(table, tabledef):
    """Generate SQL table definition code given the table name and columns
    definition."""
    lines = ['CREATE TABLE "%s" (' % table]
    for c in tabledef:
        lines.append('  "%s" %s,' % c)
    lines[-1] = lines[-1][:-1]

    lines.append(');')
    lines.append('')
    return os.linesep.join(lines)


def sanitize_model(infmodel, dbmodel):
    """ Sanitize the model 'infmodel', to better fit 'dbmodel' (if specified,
    may be None). Returns a modified inferred model and some data to be used
    later to sanitize the data entries."""

    # Figure out which columns need to be renamed because of plurals.
    column_renamings = []
    for table, columns in infmodel.iteritems():
        try:
            dbcolumns = dict(dbmodel[table])
        except (KeyError, TypeError):
            dbcolumns = {}

        coldict = dict(columns)
        for mo in filter(None, map(re.compile('(.*)s\s*$').match, coldict)):
            sing = mo.group(1)
            if sing in coldict:
                # We have detected a plural, select the one we need.
                plur = '%ss' % sing
                if plur in dbcolumns and sing not in dbcolumns:
                    renaming = table, sing, plur
                else:
                    renaming = table, plur, sing

                logging.warning("In table '%s', renaming %s to %s" % renaming)
                column_renamings.append(renaming)

                # Remove the column in the model definition.
                alist_remove(columns, renaming[1])

    return infmodel, (column_renamings,)

def sanitize_data(entries_list, sani_info):
    """ Apply the sanitization changes required to make the data fit the model
    changes. This modifies the entries in-place and does not return anything."""
    (column_renamings,) = sani_info

    # Rename all the columns we need to.
    for table, from_, to_ in column_renamings:
        for e in entries_list:
            if alist_memq(e.values, from_):
                alist_replace(e.values, from_, to_)


#-------------------------------------------------------------------------------
# Database introspection.

def db_get_model(conn):
    "Obtain the definition of database tables and columns."
    dbmodel = {}
    for table in db_get_tables(conn):
        dbmodel[table] = db_get_table_columns(conn, table)
    return dbmodel

def db_get_tables(conn):
    "List all the tables of the database."
    with dbclosing(conn) as curs:
        curs.execute("""
          SELECT table_name FROM information_schema.tables
            WHERE table_schema = 'public';
            """)
        return [x[0] for x in curs]

def db_get_table_columns(conn, table):
    "List all the columns of a table in the database."
    try:
        with dbclosing(conn) as curs:
            curs.execute("""
              SELECT column_name, data_type FROM information_schema.columns
                WHERE table_schema = 'public' AND
                      table_name = %s
                """, (table,))
            return list(curs)
    except dbapi.Error:
        return []



#-------------------------------------------------------------------------------
# Filling up the database.

progress_count = 100

def store_entries(entries_list, dbmodel, curs):
    """ Given a list of entries to be stored, try to store as much data as
    possible in the given database model."""

    logging.info("  ... dropping old data")

    # Drop the old data from this document from the previous tables.
    tables = set(x.table for x in entries_list)
    sources = set(x.source for x in entries_list)
    for table in tables:
        if table not in dbmodel:
            continue
        for source in sources:
            curs.execute('DELETE FROM "%s" WHERE %s = %%s' % (table, col_source),
                         (source,))

    dbmodel = dict((k, dict(v)) for (k,v) in dbmodel.iteritems())
    for i, e in enumerate(entries_list):
        if i > 0 and (i % progress_count) == 0:
            logging.info("  ... progress: %d/%d" % (i, len(entries_list)))

        try:
            dbcols = dbmodel[e.table]
        except KeyError:
            continue # Table not available.

        outcols = [col_source]
        outvalues = [e.source]
        colseen = set()
        for cname, cvalue in e.values:
            cname = cname

            # If the columns cannot be stored in the current model, don't.
            if cname not in dbcols:
                continue

            if cname in colseen:
                logging.warning("Duplicate field name at %s" % e.locator)
                continue
            else:
                colseen.add(cname)

            dbtype = dbcols[cname].lower()
            if dbtype in ('text', 'varchar', 'char'):
                value = unicode(cvalue)
            elif dbtype in ('integer',):
                value = int(cvalue)
            elif dbtype in ('numeric', 'float'):
                value = float(cvalue)
            else:
                raise NotImplementedError("Unsupported type: '%s'" % dbtype)

            outcols.append(cname)
            outvalues.append(value)

        if outvalues:
            curs.execute("""
              INSERT INTO "%s" (%s) VALUES (%s)
              """ % (e.table,
                     ', '.join('"%s"' % x for x in outcols),
                     ', '.join(['%s'] * len(outvalues))),
                         outvalues)


#-------------------------------------------------------------------------------
# Main program.

def parse_dburi(dburi):
    """ Parse the database connection URI and return a dict of the values that
    allow us to connect to a database."""

    user, passwd, host, dbname = [None] * 4
    mo = re.match('(db|postgres|postgresql|psql)://(?:([^:@]+)'
                  '(?::([^:@]+))?@)?([a-z0-9]+)/([a-z0-9]+)/?$', dburi)
    if mo:
        user, passwd, host, dbname = mo.group(2, 3, 4, 5)
    elif re.match('[a-z]+', dburi):
        dbname = dburi
    else:
        parser.error("Invalid database connection string.")

    if host is None:
        host = 'localhost'
    if user is None:
        user = getpass.getuser()
    if passwd is None:
        passwd = 'pg' ## getpass.getpass()

    r = {'user': user,
         'host': host,
         'password': passwd,
         'database': dbname}

    assert None not in r, r
    return r

def isdb(deststr):
    "Return true if the destination string is for a database connection."
    return deststr and bool(re.match('[a-z]+://', deststr))


def main():
    import optparse
    parser = optparse.OptionParser(__doc__.strip())

    parser.add_option('-s', '--schema', action='store_true',
                      help="Infer the schema definition from the data and "
                      "recreate it, dropping the existing tables in the process.")

    parser.add_option('--schema-sql', action='store_true',
                      help="Output DDL commands.")

    parser.add_option('-r', '--raw', action='store_true',
                      help="Do not sanitize the values, column names, etc.")

    parser.add_option('-v', '--verbose', action='store_true',
                      help="Turn on verbose output.")

    parser.add_option('-e', '--enclosing-section', action='store_true',
                      help="Add a field for enclosing sections to each entry.")

    global opts; opts, args = parser.parse_args()
    if len(args) < 2:
        parser.error("You must provide a database URI and a list of filenames "
                     "to process.")

    logging.basicConfig(level=logging.INFO if opts.verbose else logging.WARNING,
                        format='%(levelname)-8s: %(message)s')

    dburi, args = args[0], args[1:]
    if not isdb(dburi):
        if not re.match('[a-zA-Z0-9]+', dburi):
            parser.error("Invalid database name: %s" % dburi)
        else:
            dburi = 'postgres://localhost/%s' % dburi

    #---------------------------------------------------------------------------
    logging.info("Opening database '%s'." % dburi)

    conn = dbapi.connect(**parse_dburi(dburi))

    #---------------------------------------------------------------------------
    logging.info("Processing input files, extract all the data entries.")

    # Disable the conversion of system messages into text.
    nodes.system_message.astext = lambda *args: u''

    entries_by_document = {}
    entries_list = []
    for fn in args:
        logging.info("  %s" % fn)

        docid, entries = get_file_entries(fn)

        entries_by_document[docid] = entries
        entries_list.extend(entries)

    if len(entries_list) == 0:
        logging.info("No entries, exiting.")
        sys.exit(1)

    # Make sure that the values have unique key names.
    uniquify_keys(entries_list)

    #---------------------------------------------------------------------------
    logging.info("Inferring model (and read target model if necessary).")

    infmodel = infer_tables(entries_list)

    # If relevant, we read the schema of the target database in order to skew
    # the sanitization process for a better fit to it.
    dbmodel = db_get_model(conn)

    # Sanitize the model.
    if not opts.raw:
        infmodel, sani_info = sanitize_model(
            infmodel, dbmodel if not opts.schema else None)

    #---------------------------------------------------------------------------
    logging.info("Storing the schema.")

    if opts.schema:

        oss = StringIO()
        for table, tabledef in infmodel.iteritems():
            oss.write(table2sql(table, tabledef))
            oss.write('\n')
        sqlcreate = oss.getvalue()

        if opts.schema_sql:
            # Write the table creation to a file.
            logging.warning("Outputting SQL schema creation commands.")
            sys.stdout.write(sqlcreate)

        # Drop the tables and apply to the given database.
        for table, _ in infmodel.iteritems():
            try:
                with committing(conn) as curs:
                    curs.execute('DROP TABLE "%s" CASCADE' % table)
            except dbapi.Error:
                pass # Ignore "table does not exist" errors.

        with committing(conn) as curs:
            curs.execute(sqlcreate)

    #---------------------------------------------------------------------------
    logging.info("Storing the data.")

    if not opts.raw:
        logging.info("  ... sanitizing the data")
        sanitize_data(entries_list, sani_info)

    logging.info("  ... writing to database")

    # Open a connection to the database.
    with committing(conn) as curs:
        store_entries(entries_list, dbmodel, curs)


if __name__ == '__main__':
    main()

