# GromacsWrapper: top.py
# Copyright (c) 2012 Oliver Beckstein <orbeckst@gmail.com>
# Copyright (c) 2010 Tsjerk Wassenaar <tsjerkw@gmail.com>
# Released under the GNU Public License 3 (or higher, your choice)
# See the file COPYING for details.

"""
Gromacs TOP file format
=======================

Classes
-------
.. autoclass:: TOP
   :members:

.. autoclass:: SystemToGroTop
   :members:


History
-------

Sources adapted from code by Reza Salari https://github.com/resal81/PyTopol


Example: Read a processed.top file and scale charges
----------------------------------------------------

Run ``grompp -pp`` to produce a processed.top from conf.gro, grompp.mdp and topol.top files::

  $ grompp -pp

This file now containts all the force-field information::

  from gromacs.fileformats import TOP
  top = TOP("processed.top")

Scale the LJ epsilon by an arbitrary number, here 0.9 ::

  scaling = 0.9
  for at in top.atomtypes:
    at.gromacs['param']['lje'] *= scaling

Write out the scaled down topology::

  top.write("output.top")

.. Note::

   You can use this to prepare a series of top files for Hamiltonian Replica
   Exchange (HREX) simulations. See ``scripts/gw-partial_tempering.py`` for an example.

"""

import textwrap
import logging
from collections import OrderedDict as odict

from . import blocks


class TOP(blocks.System):
    """Class to make a TOP object from a GROMACS processed.top file

    The force-field and molecules data is exposed as python object.

    .. Note::

       Only processed.top files generated by GROMACS 'grompp -pp'
       are supported - the usual topol.top files are not supported (yet!)

    """

    default_extension = "top"
    logger = logging.getLogger("gromacs.fileformats.TOP")

    def __init__(self, fname):
        """Initialize the TOP structure.

        :Arguments:
          *fname*
              name of the processed.top file
        """
        super(TOP, self).__init__()

        self.fname = fname

        self.defaults = {
            "nbfunc": None,
            "comb-rule": None,
            "gen-pairs": None,
            "fudgeLJ": None,
            "fudgeQQ": None,
        }

        self.dict_molname_mol = odict()  # contains molname:mol
        self.found_sections = []
        self.forcefield = "gromacs"

        self.molecules = []
        self._parse(fname)
        self.molecules = tuple(self.molecules)

    def write(self, filename):
        """Write the TOP object to a file"""
        SystemToGroTop(self, filename)

    def __repr__(self):
        """Represent the TOP object as a string"""
        moltypenames = list(self.dict_molname_mol.keys())
        moltypenames.sort()

        data = []
        data.append("\n")

        main_items = set(
            ["atomtypes", "pairtypes", "bondtypes", "angletypes", "dihedraltypes"]
        )
        other_items = [
            "{0:s} ({1:d})".format(m, len(self.information[m]))
            for m in list(self.information.keys())
            if m not in main_items
        ]
        other_items = " ".join(other_items)
        nattype = len(self.atomtypes)
        nprtype = len(self.pairtypes)
        nbndtype = len(self.bondtypes)
        nangtype = len(self.angletypes)
        ndihtype = len(self.dihedraltypes)
        nimptype = len(self.impropertypes)
        data.append(
            "{0:>20s}  {1:>7s} {2:>7s} {3:>7s} {4:>7s} {5:>7s} {6:>7s}".format(
                "Param types:", "atom", "pair", "bond", "ang", "dih", "imp"
            )
        )
        msg = "{0:20s}  {1:7d} {2:7d} {3:7d} {4:7d} {5:7d} {6:7d}    {7:s}".format(
            "", nattype, nprtype, nbndtype, nangtype, ndihtype, nimptype, other_items
        )
        data.append("=" * 69)
        data.append(msg)
        data.append("\n")

        main_items = set(["atoms", "pairs", "bonds", "angles", "dihedrals"])
        data.append(
            "{0:>20s}  {1:>7s} {2:>7s} {3:>7s} {4:>7s} {5:>7s} {6:>7s}".format(
                "Params:", "atom", "pair", "bond", "ang", "dih", "imp"
            )
        )
        data.append("=" * 69)
        for mname in moltypenames:
            mol = self.dict_molname_mol[mname]
            other_items = [
                "{0:s} ({1:d})".format(m, len(mol.information[m]))
                for m in list(mol.information.keys())
                if m not in main_items
            ]
            other_items = " ".join(other_items)

            natoms = len(mol.atoms)
            npairs = len(mol.pairs)
            nbonds = len(mol.bonds)
            nangles = len(mol.angles)
            ndih = len(mol.dihedrals)
            nimp = len(mol.impropers)
            msg = "{0:20s}  {1:7d} {2:7d} {3:7d} {4:7d} {5:7d} {6:7d}    {7:s}".format(
                mol.name, natoms, npairs, nbonds, nangles, ndih, nimp, other_items
            )
            data.append(msg)

        return "\n".join(data)

    def _parse(self, fname):
        """Parse a processed.top GROMACS topology file

        The function reads in the file line-by-line, and it's a bunch of 'elif' statements,
        writing parameter/atom line to current section/molecule.

        ParamTypes are added to self.xyztypes (AtomType goes to self.atomtypes).
        Params are added to current molecule (Atom goes to mol.atoms.append(atom))

        MoleculeTypes and Molecules are odd, and are added to
            * MoleculeType to :attr:`self.dict_molname_mol[mol.name] = mol`
            * Molecule to :attr:`self.molecules.append(self.dict_molname_mol[mname])`

        :obj:`curr_sec` variable stores to current section being read-in
        :obj:`mol` variable stores the current molecule being read-in
        :obj:`cmap_lines` are a little odd, since CMAP parameters are stored on multiple lines

        :Arguments:
          *fname*
              name of the processed.top file

        :Returns: None
        """

        def _find_section(line):
            return line.strip("[").strip("]").strip()

        def _add_info(sys_or_mol, section, container):
            # like (mol, 'atomtypes', mol.atomtypes)
            if sys_or_mol.information.get(section, False) is False:
                sys_or_mol.information[section] = container

        mol = None  # to hold the current mol
        curr_sec = None
        cmap_lines = []

        with open(fname) as f:
            for i_line, line in enumerate(f):
                # trimming
                if ";" in line:
                    line = line[0 : line.index(";")]
                line = line.strip()

                if line == "":
                    continue

                if line[0] == "*":
                    continue

                # the topology must be stand-alone (i.e. no includes)
                if line.startswith("#include"):
                    msg = 'The topology file has "#include" statements.'
                    msg += " You must provide a processed topology file that grompp creates."
                    raise ValueError(msg)

                # find sections
                if line[0] == "[":
                    curr_sec = _find_section(line)
                    self.found_sections.append(curr_sec)
                    continue

                fields = line.split()

                if curr_sec == "defaults":
                    """
                    # ; nbfunc        comb-rule       gen-pairs       fudgeLJ fudgeQQ
                    #1               2               yes             0.5     0.8333
                    """
                    assert len(fields) in [2, 5]
                    self.defaults["nbfunc"] = int(fields[0])
                    self.defaults["comb-rule"] = int(fields[1])
                    if len(fields) == 5:
                        self.defaults["gen-pairs"] = fields[2]
                        self.defaults["fudgeLJ"] = float(fields[3])
                        self.defaults["fudgeQQ"] = float(fields[4])

                elif curr_sec == "atomtypes":
                    """
                    # ;name               at.num    mass         charge    ptype  sigma   epsilon
                    # ;name   bond_type   at.num    mass         charge    ptype  sigma   epsilon
                    # ;name                         mass         charge    ptype  c6      c12

                    """
                    if len(fields) not in (6, 7, 8):
                        self.logger.warning(
                            "skipping atomtype line with neither 7 or 8 fields: \n {0:s}".format(
                                line
                            )
                        )
                        continue

                    # shift = 0 if len(fields) == 7 else 1
                    shift = len(fields) - 7
                    at = blocks.AtomType("gromacs")
                    at.atype = fields[0]
                    if shift == 1:
                        at.bond_type = fields[1]

                    at.mass = float(fields[2 + shift])
                    at.charge = float(fields[3 + shift])

                    particletype = fields[4 + shift]
                    assert particletype in ("A", "S", "V", "D")
                    if particletype not in ("A",):
                        self.logger.warning(
                            'warning: non-atom particletype: "{0:s}"'.format(line)
                        )

                    sig = float(fields[5 + shift])
                    eps = float(fields[6 + shift])

                    at.gromacs = {
                        "param": {"lje": eps, "ljl": sig, "lje14": None, "ljl14": None}
                    }

                    self.atomtypes.append(at)

                    _add_info(self, curr_sec, self.atomtypes)

                # extend system.molecules
                elif curr_sec == "moleculetype":
                    assert len(fields) == 2

                    mol = blocks.Molecule()

                    mol.name = fields[0]
                    mol.exclusion_numb = int(fields[1])

                    self.dict_molname_mol[mol.name] = mol

                elif curr_sec == "atoms":
                    """
                    #id    at_type     res_nr  residu_name at_name  cg_nr  charge   mass  typeB    chargeB      massB
                    # 1       OC          1       OH          O1       1      -1.32

                    OR

                    [ atoms ]
                    ; id   at type  res nr  residu name at name     cg nr   charge
                    1       OT      1       SOL              OW             1       -0.834

                    """

                    aserial = int(fields[0])
                    atype = fields[1]
                    resnumb = int(fields[2])
                    resname = fields[3]
                    aname = fields[4]
                    cgnr = int(fields[5])
                    charge = float(fields[6])
                    rest = fields[7:]

                    atom = blocks.Atom()
                    atom.name = aname
                    atom.atomtype = atype
                    atom.number = aserial
                    atom.resname = resname
                    atom.resnumb = resnumb
                    atom.charge = charge

                    if rest:
                        mass = float(rest[0])
                        atom.mass = mass

                    mol.atoms.append(atom)

                    _add_info(mol, curr_sec, mol.atoms)

                elif curr_sec in ("pairtypes", "pairs", "pairs_nb"):
                    """
                    section     #at     fu      #param
                    ---------------------------------
                    pairs       2       1       V,W
                    pairs       2       2       fudgeQQ, qi, qj, V, W
                    pairs_nb    2       1       qi, qj, V, W

                    """

                    ai, aj = fields[:2]
                    fu = int(fields[2])
                    assert fu in (1, 2)

                    pair = blocks.InteractionType("gromacs")
                    if fu == 1:
                        if curr_sec == "pairtypes":
                            pair.atype1 = ai
                            pair.atype2 = aj
                            v, w = list(map(float, fields[3:5]))
                            pair.gromacs = {
                                "param": {
                                    "lje": None,
                                    "ljl": None,
                                    "lje14": w,
                                    "ljl14": v,
                                },
                                "func": fu,
                            }

                            self.pairtypes.append(pair)
                            _add_info(self, curr_sec, self.pairtypes)

                        elif curr_sec == "pairs":
                            ai, aj = list(map(int, [ai, aj]))
                            pair.atom1 = mol.atoms[ai - 1]
                            pair.atom2 = mol.atoms[aj - 1]
                            pair.gromacs["func"] = fu

                            mol.pairs.append(pair)
                            _add_info(mol, curr_sec, mol.pairs)

                        else:
                            raise ValueError

                    else:
                        raise NotImplementedError(
                            "{0:s} with functiontype {1:d} is not supported".format(
                                curr_sec, fu
                            )
                        )

                elif curr_sec == "nonbond_params":
                    """
                    ; typei typej  f.type sigma   epsilon
                    ; f.type=1 means LJ (not buckingham)
                    ; sigma&eps since mixing-rule = 2
                    """

                    assert len(fields) == 5
                    ai, aj = fields[:2]
                    fu = int(fields[2])

                    assert fu == 1
                    sig = float(fields[3])
                    eps = float(fields[4])

                    nonbond_param = blocks.NonbondedParamType("gromacs")
                    nonbond_param.atype1 = ai
                    nonbond_param.atype2 = aj
                    nonbond_param.gromacs["func"] = fu
                    nonbond_param.gromacs["param"] = {"eps": eps, "sig": sig}

                    self.nonbond_params.append(nonbond_param)
                    _add_info(self, curr_sec, self.nonbond_params)

                elif curr_sec in ("bondtypes", "bonds"):
                    """
                    section     #at     fu      #param
                    ----------------------------------
                    bonds       2       1       2
                    bonds       2       2       2
                    bonds       2       3       3
                    bonds       2       4       2
                    bonds       2       5       ??
                    bonds       2       6       2
                    bonds       2       7       2
                    bonds       2       8       ??
                    bonds       2       9       ??
                    bonds       2       10      4
                    """

                    ai, aj = fields[:2]
                    fu = int(fields[2])
                    assert fu in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

                    if fu != 1:
                        raise NotImplementedError(
                            "function {0:d} is not yet supported".format(fu)
                        )

                    bond = blocks.BondType("gromacs")

                    if fu == 1:
                        if curr_sec == "bondtypes":
                            bond.atype1 = ai
                            bond.atype2 = aj

                            b0, kb = list(map(float, fields[3:5]))
                            bond.gromacs = {"param": {"kb": kb, "b0": b0}, "func": fu}

                            self.bondtypes.append(bond)
                            _add_info(self, curr_sec, self.bondtypes)

                        elif curr_sec == "bonds":
                            ai, aj = list(map(int, [ai, aj]))
                            bond.atom1 = mol.atoms[ai - 1]
                            bond.atom2 = mol.atoms[aj - 1]
                            bond.gromacs["func"] = fu

                            if len(fields) > 3:
                                b0, kb = list(map(float, fields[3:5]))
                                bond.gromacs = {
                                    "param": {"kb": kb, "b0": b0},
                                    "func": fu,
                                }

                            mol.bonds.append(bond)
                            _add_info(mol, curr_sec, mol.bonds)

                    else:
                        raise NotImplementedError

                elif curr_sec in {"angletypes", "angles"}:
                    """
                    section     #at     fu      #param
                    ----------------------------------
                    angles      3       1       2
                    angles      3       2       2
                    angles      3       3       3
                    angles      3       4       4
                    angles      3       5       4
                    angles      3       6       6
                    angles      3       8       ??
                    """
                    ai, aj, ak = fields[:3]

                    fu = blocks.AngleFunctionType(int(fields[3]))

                    if len(fields[4:]) != 0 and len(fields[4:]) != fu.num_params:
                        raise ValueError(
                            "Expected {num_params} parameters for function type {fu}, got {len(fields[4:])}".format(
                                num_params=fu.num_params, fu=fu
                            )
                        )

                    ang = blocks.AngleType("gromacs")
                    ang.gromacs = {"func": fu}

                    if curr_sec == "angletypes":
                        ang.atype1 = ai
                        ang.atype2 = aj
                        ang.atype3 = ak
                    elif curr_sec == "angles":
                        ai, aj, ak = list(map(int, [ai, aj, ak]))
                        ang.atom1 = mol.atoms[ai - 1]
                        ang.atom2 = mol.atoms[aj - 1]
                        ang.atom3 = mol.atoms[ak - 1]

                    # Parse parameters based on the function type
                    params = list(map(float, fields[4:]))

                    # Handle parameters based on function types
                    if fu.num_params == len(params):
                        if fu in {
                            blocks.AngleFunctionType.HARMONIC,
                            blocks.AngleFunctionType.G96_ANGLE,
                        }:
                            ang.gromacs["param"] = {
                                "tetha0": params[0],
                                "ktetha": params[1],
                            }
                        elif fu == blocks.AngleFunctionType.CROSS_BOND_BOND:
                            ang.gromacs["param"] = {
                                "r1e": params[0],
                                "r2e": params[1],
                                "krrprime": params[2],
                            }
                        elif fu == blocks.AngleFunctionType.CROSS_BOND_ANGLE:
                            ang.gromacs["param"] = {
                                "r1e": params[0],
                                "r2eprime": params[1],
                                "r3e": params[2],
                                "krtheta": params[3],
                            }
                        elif fu == blocks.AngleFunctionType.UREY_BRADLEY:
                            ang.gromacs["param"] = {
                                "tetha0": params[0],
                                "ktetha": params[1],
                                "s0": params[2],
                                "kub": params[3],
                            }
                        elif fu == blocks.AngleFunctionType.QUARTIC_ANGLE:
                            ang.gromacs["param"] = {
                                "tetha0": params[0],
                                "C1": params[1],
                                "C2": params[2],
                                "C3": params[3],
                                "C4": params[4],
                                "C5": params[5],
                            }
                        elif fu == blocks.AngleFunctionType.TABULATED_ANGLE:
                            ang.gromacs["param"] = {
                                "table_number": params[0],
                                "k": params[1],
                            }  # Assuming 'table number' is a parameter here
                        elif fu == blocks.AngleFunctionType.LINEAR_ANGLE:
                            ang.gromacs["param"] = {
                                "a0": params[0],
                                "klin": params[1],
                            }
                        elif fu == blocks.AngleFunctionType.RESTRICTED_BENDING:
                            ang.gromacs["param"] = {
                                "tetha0": params[0],
                                "ktheta": params[1],
                            }
                        else:
                            raise NotImplementedError(
                                "Function type {fu} is not implemented".forma(fu=fu)
                            )

                    # Add the angle to the appropriate list and call _add_info
                    if curr_sec == "angletypes":
                        self.angletypes.append(ang)
                        _add_info(self, curr_sec, self.angletypes)
                    elif curr_sec == "angles":
                        mol.angles.append(ang)
                        _add_info(mol, curr_sec, mol.angles)
                    else:
                        raise ValueError("Unknown section while parsing angles")

                elif curr_sec in ("dihedraltypes", "dihedrals"):
                    """
                    section     #at     fu      #param
                    ----------------------------------
                    dihedrals   4       1       3
                    dihedrals   4       2       2
                    dihedrals   4       3       6
                    dihedrals   4       4       3
                    dihedrals   4       5       4
                    dihedrals   4       8       ??
                    dihedrals   4       9       3
                    """

                    if curr_sec == "dihedraltypes" and len(fields) == 6:
                        # in oplsaa - quartz parameters
                        fields.insert(2, "X")
                        fields.insert(0, "X")

                    ai, aj, ak, am = fields[:4]
                    fu = int(fields[4])
                    assert fu in (1, 2, 3, 4, 5, 8, 9)

                    if fu not in (1, 2, 3, 4, 9):
                        raise NotImplementedError(
                            "dihedral function {0:d} is not yet supported".format(fu)
                        )

                    dih = blocks.DihedralType("gromacs")
                    imp = blocks.ImproperType("gromacs")
                    # proper dihedrals
                    if fu in (1, 3, 9):
                        if curr_sec == "dihedraltypes":
                            dih.atype1 = ai
                            dih.atype2 = aj
                            dih.atype3 = ak
                            dih.atype4 = am

                            dih.line = i_line + 1

                            if fu == 1:
                                delta, kchi, n = list(map(float, fields[5:8]))
                                dih.gromacs["param"].append(
                                    {"kchi": kchi, "n": n, "delta": delta}
                                )
                            elif fu == 3:
                                c0, c1, c2, c3, c4, c5 = list(map(float, fields[5:11]))
                                m = dict(c0=c0, c1=c1, c2=c2, c3=c3, c4=c4, c5=c5)
                                dih.gromacs["param"].append(m)
                            elif fu == 4:
                                delta, kchi, n = list(map(float, fields[5:8]))
                                dih.gromacs["param"].append(
                                    {"kchi": kchi, "n": int(n), "delta": delta}
                                )
                            elif fu == 9:
                                delta, kchi, n = list(map(float, fields[5:8]))
                                dih.gromacs["param"].append(
                                    {"kchi": kchi, "n": int(n), "delta": delta}
                                )
                            else:
                                raise ValueError

                            dih.gromacs["func"] = fu
                            self.dihedraltypes.append(dih)
                            _add_info(self, curr_sec, self.dihedraltypes)

                        elif curr_sec == "dihedrals":
                            ai, aj, ak, am = list(map(int, fields[:4]))
                            dih.atom1 = mol.atoms[ai - 1]
                            dih.atom2 = mol.atoms[aj - 1]
                            dih.atom3 = mol.atoms[ak - 1]
                            dih.atom4 = mol.atoms[am - 1]
                            dih.gromacs["func"] = fu

                            dih.line = i_line + 1

                            if fu == 1:
                                delta, kchi, n = list(map(float, fields[5:8]))
                                dih.gromacs["param"].append(
                                    {"kchi": kchi, "n": int(n), "delta": delta}
                                )
                            elif fu == 3:
                                pass
                            elif fu == 4:
                                pass
                            elif fu == 9:
                                if len(fields[5:8]) == 3:
                                    delta, kchi, n = list(map(float, fields[5:8]))
                                    dih.gromacs["param"].append(
                                        {"kchi": kchi, "n": int(n), "delta": delta}
                                    )
                            else:
                                raise ValueError

                            mol.dihedrals.append(dih)
                            _add_info(mol, curr_sec, mol.dihedrals)

                        else:
                            raise ValueError
                    # impropers
                    elif fu in (2, 4):
                        if curr_sec == "dihedraltypes":
                            imp.atype1 = ai
                            imp.atype2 = aj
                            imp.atype3 = ak
                            imp.atype4 = am

                            imp.line = i_line + 1

                            if fu == 2:
                                psi0, kpsi = list(map(float, fields[5:7]))
                                imp.gromacs["param"].append(
                                    {"kpsi": kpsi, "psi0": psi0}
                                )
                            elif fu == 4:
                                psi0, kpsi, n = list(map(float, fields[5:8]))
                                imp.gromacs["param"].append(
                                    {"kpsi": kpsi, "psi0": psi0, "n": int(n)}
                                )
                            else:
                                raise ValueError

                            imp.gromacs["func"] = fu
                            self.impropertypes.append(imp)
                            _add_info(self, curr_sec, self.impropertypes)

                        elif curr_sec == "dihedrals":
                            ai, aj, ak, am = list(map(int, fields[:4]))
                            imp.atom1 = mol.atoms[ai - 1]
                            imp.atom2 = mol.atoms[aj - 1]
                            imp.atom3 = mol.atoms[ak - 1]
                            imp.atom4 = mol.atoms[am - 1]
                            imp.gromacs["func"] = fu

                            imp.line = i_line + 1

                            if fu == 2:
                                pass
                            elif fu == 4:
                                # in-line override of dihedral parameters
                                if len(fields[5:8]) == 3:
                                    psi0, kpsi, n = list(map(float, fields[5:8]))
                                    imp.gromacs["param"].append(
                                        {"kpsi": kpsi, "psi0": psi0, "n": int(n)}
                                    )
                            else:
                                raise ValueError

                            mol.impropers.append(imp)
                            _add_info(mol, curr_sec, mol.impropers)

                        else:
                            raise ValueError

                    else:
                        raise NotImplementedError

                elif curr_sec in ("cmaptypes", "cmap"):
                    cmap = blocks.CMapType("gromacs")
                    if curr_sec == "cmaptypes":
                        cmap_lines.append(line)
                        _add_info(self, curr_sec, self.cmaptypes)
                    else:
                        ai, aj, ak, am, an = list(map(int, fields[:5]))
                        fu = int(fields[5])
                        assert fu == 1
                        cmap.atom1 = mol.atoms[ai - 1]
                        cmap.atom2 = mol.atoms[aj - 1]
                        cmap.atom3 = mol.atoms[ak - 1]
                        cmap.atom4 = mol.atoms[am - 1]
                        cmap.atom8 = mol.atoms[an - 1]
                        cmap.gromacs["func"] = fu

                        mol.cmaps.append(cmap)
                        _add_info(mol, curr_sec, mol.cmaps)

                elif curr_sec == "settles":
                    """
                    section     #at     fu      #param
                    ----------------------------------
                    """

                    assert len(fields) == 4
                    ai = int(fields[0])
                    fu = int(fields[1])
                    assert fu == 1

                    settle = blocks.SettleType("gromacs")
                    settle.atom = mol.atoms[ai - 1]
                    settle.dOH = float(fields[2])
                    settle.dHH = float(fields[3])

                    mol.settles.append(settle)
                    _add_info(mol, curr_sec, mol.settles)

                elif curr_sec == "virtual_sites3":
                    """
                    ; Dummy from            funct   a       b
                    4   1   2   3   1   0.131937768 0.131937768
                    """
                    assert len(fields) == 7
                    ai = int(fields[0])
                    aj = int(fields[1])
                    ak = int(fields[2])
                    al = int(fields[3])
                    fu = int(fields[4])
                    assert fu == 1
                    a = float(fields[5])
                    b = float(fields[6])

                    vs3 = blocks.VirtualSites3Type("gromacs")
                    vs3.atom1 = ai
                    vs3.atom2 = aj
                    vs3.atom3 = ak
                    vs3.atom4 = al
                    vs3.gromacs["func"] = fu
                    vs3.gromacs["param"] = {"a": a, "b": b}
                    mol.virtual_sites3.append(vs3)
                    _add_info(mol, curr_sec, mol.virtual_sites3)

                elif curr_sec in ("exclusions",):
                    ai = int(fields[0])
                    other = list(map(int, fields[1:]))

                    exc = blocks.Exclusion()
                    exc.main_atom = mol.atoms[ai - 1]
                    exc.other_atoms = [mol.atoms[k - 1] for k in other]

                    mol.exclusions.append(exc)
                    _add_info(mol, curr_sec, mol.exclusions)

                elif curr_sec in ("constrainttypes", "constraints"):
                    """
                    section     #at     fu      #param
                    ----------------------------------
                    constraints 2       1       1
                    constraints 2       2       1
                    """

                    ai, aj = fields[:2]
                    fu = int(fields[2])
                    assert fu in (1, 2)

                    cons = blocks.ConstraintType("gromacs")

                    # TODO: what's different between 1 and 2
                    if fu in [1, 2]:
                        if curr_sec == "constrainttypes":
                            cons.atype1 = ai
                            cons.atype2 = aj
                            b0 = float(fields[3])
                            cons.gromacs = {"param": {"b0": b0}, "func": fu}

                            self.constrainttypes.append(cons)
                            _add_info(self, curr_sec, self.constrainttypes)

                        elif curr_sec == "constraints":
                            ai, aj = list(map(int, fields[:2]))
                            cons.atom1 = mol.atoms[ai - 1]
                            cons.atom2 = mol.atoms[aj - 1]
                            cons.gromacs["func"] = fu

                            mol.constraints.append(cons)
                            _add_info(mol, curr_sec, mol.constraints)

                        else:
                            raise ValueError
                    else:
                        raise ValueError

                elif curr_sec in (
                    "position_restraints",
                    "distance_restraints",
                    "dihedral_restraints",
                    "orientation_restraints",
                    "angle_restraints",
                    "angle_restraints_z",
                ):
                    pass

                elif curr_sec in ("implicit_genborn_params",):
                    """
                    attype   sar     st      pi      gbr      hct
                    """
                    pass

                elif curr_sec == "system":
                    # assert len(fields) == 1
                    self.name = fields[0]

                elif curr_sec == "molecules":
                    assert len(fields) == 2
                    mname, nmol = fields[0], int(fields[1])

                    # if the number of a molecule is more than 1, add copies to system.molecules
                    for i in range(nmol):
                        self.molecules.append(self.dict_molname_mol[mname])

                else:
                    raise NotImplementedError(
                        "Unknown section in topology: {0}".format(curr_sec)
                    )

        # process cmap_lines
        curr_cons = None
        for line in cmap_lines:
            # cmaptype opening line
            if len(line.split()) == 8:
                cons = blocks.CMapType("gromacs")

                (
                    atype1,
                    atype2,
                    atype3,
                    atype4,
                    atype8,
                    func,
                    sizeX,
                    sizeY,
                ) = line.replace("\\", "").split()
                func, sizeX, sizeY = int(func), int(sizeX), int(sizeY)
                cons.atype1 = atype1
                cons.atype2 = atype2
                cons.atype3 = atype3
                cons.atype4 = atype4
                cons.atype8 = atype8
                cons.gromacs = {"param": [], "func": func}

                curr_cons = cons

            # cmap body
            elif len(line.split()) == 10:
                cmap_param = map(float, line.replace("\\", "").split())
                cons.gromacs["param"] += cmap_param

            # cmaptype cloning line
            elif len(line.split()) == 6:
                cmap_param = map(float, line.replace("\\", "").split())
                cons.gromacs["param"] += cmap_param
                self.cmaptypes.append(curr_cons)
            else:
                raise ValueError


class SystemToGroTop(object):
    """Converter class - represent TOP objects as GROMACS topology file."""

    formats = {
        "atomtypes": "{:<7s} {:3s} {:>7} {} {:3s} {} {}\n",
        "atoms": "{:6d} {:>10s} {:6d} {:6s} {:6s} {:6d} {} {}\n",
        "atoms_nomass": "{:6d} {:>10s} {:6d} {:6s} {:6s} {:6d} {}\n",
        "nonbond_params": "{:20s}  {:20s}  {:1d}  {}  {}\n",
        "bondtypes": "{:5s}  {:5s}  {:1d}  {}  {}\n",
        "bonds": "{:3d}  {:3d}   {:1d}\n",
        "bonds_ext": "{:3d}  {:3d}   {:1d} {} {}\n",
        "settles": "{:3d}  {:3d}  {} {}\n",
        "virtual_sites3": "{:3d}  {:3d}  {:3d}  {:3d}   {:1d}  {}  {}\n",
        "exclusions": "{:3d}  {}\n",
        "pairtypes": "{:6s} {:6s}   {:d}    {:.13g}     {:.13g}\n",
        "pairs": "{:3d} {:3d}   {:1d}\n",
        "angletypes_1": "{:>8s} {:>8s} {:>8s} {:1d}    {}    {}\n",
        "angletypes_5": "{:>8s} {:>8s} {:>8s} {:1d}    {}    {}    {}    {}\n",
        "constrainttypes": "{:6s} {:6s} {:1d}    {}\n",
        "angles": "{:3d} {:3d} {:3d}   {:1d}\n",
        "angles_ext": "{:3d} {:3d} {:3d}   {:1d} {} {}\n",
        "dihedraltypes": "{:6s} {:6s} {:6s} {:6s}   {:1d}    {}    {}    {:1d}\n",
        "dihedrals": "{:3d} {:3d} {:3d} {:3d}   {:1d}\n",
        "dihedrals_ext": "{:3d} {:3d} {:3d} {:3d}   {:1d}    {}    {}    {:1d}\n",
        "impropertypes_2": "{:6s} {:6s} {:6s} {:6s}   {:1d} {} {} \n",
        "impropertypes_4": "{:6s} {:6s} {:6s} {:6s}   {:1d} {} {} {:2d}\n",
        "impropers": "{:3d} {:3d} {:3d} {:3d}   {:1d}\n",
        "impropers_2": "{:3d} {:3d} {:3d} {:3d}   {:1d} {} {} \n",
        "impropers_4": "{:3d} {:3d} {:3d} {:3d}   {:1d} {} {} {:2d}\n",
    }

    toptemplate = """
            [ defaults ]
            *DEFAULTS*
            [ atomtypes ]
            *ATOMTYPES*
            [ nonbond_params ]
            *NONBOND_PARAM*
            [ pairtypes ]
            *PAIRTYPES*
            [ bondtypes ]
            *BONDTYPES*
            [ angletypes ]
            *ANGLETYPES*
            [ constrainttypes ]
            *CONSTRAINTTYPES*
            [ dihedraltypes ]
            *DIHEDRALTYPES*
            [ dihedraltypes ]
            *IMPROPERTYPES*
            [ cmaptypes ]
            *CMAPTYPES*
            """
    toptemplate = textwrap.dedent(toptemplate)

    itptemplate = """
            [ moleculetype ]
            *MOLECULETYPE*
            [ atoms ]
            *ATOMS*
            [ bonds ]
            *BONDS*
            [ pairs ]
            *PAIRS*
            [ settles ]
            *SETTLES*
            [ virtual_sites3 ]
            *VIRTUAL_SITES3*
            [ exclusions ]
            *EXCLUSIONS*
            [ angles ]
            *ANGLES*
            [ dihedrals ]
            *DIHEDRALS*
            [ dihedrals ]
            *IMPROPERS*
            [ cmap ]
            *CMAPS*
            """
    itptemplate = textwrap.dedent(itptemplate)

    def __init__(self, system, outfile="output.top", multiple_output=False):
        """Initialize GROMACS topology writer.

        :Arguments:
          *system*
              :class:`blocks.System` object, containing the topology
          *outfile*
              name of the file to write to
          *multiple_output*
              if True, write moleculetypes to separate files, named mol_MOLNAME.itp (default: False)
        """
        self.logger = logging.getLogger("gromacs.fileformats.SystemToGroTop")
        self.logger.debug(">> entering SystemToGroTop")

        self.system = system
        self.outfile = outfile
        self.multiple_output = multiple_output
        self.assemble_topology()

        self.logger.debug("<< leaving SystemToGroTop")

    @staticmethod
    def _redefine_atomtypes(mol):
        for i, atom in enumerate(mol.atoms):
            atom.atomtype = "at{0:03d}".format(i + 1)

    def assemble_topology(self):
        """Call the various member self._make_* functions to convert the topology object into a string"""
        self.logger.debug("starting to assemble topology...")

        top = ""

        self.logger.debug("making atom/pair/bond/angle/dihedral/improper types")
        top += self.toptemplate
        top = top.replace("*DEFAULTS*", "".join(self._make_defaults(self.system)))
        top = top.replace("*ATOMTYPES*", "".join(self._make_atomtypes(self.system)))
        top = top.replace(
            "*NONBOND_PARAM*", "".join(self._make_nonbond_param(self.system))
        )
        top = top.replace("*PAIRTYPES*", "".join(self._make_pairtypes(self.system)))
        top = top.replace("*BONDTYPES*", "".join(self._make_bondtypes(self.system)))
        top = top.replace(
            "*CONSTRAINTTYPES*", "".join(self._make_constrainttypes(self.system))
        )
        top = top.replace("*ANGLETYPES*", "".join(self._make_angletypes(self.system)))
        top = top.replace(
            "*DIHEDRALTYPES*", "".join(self._make_dihedraltypes(self.system))
        )
        top = top.replace(
            "*IMPROPERTYPES*", "".join(self._make_impropertypes(self.system))
        )
        top = top.replace("*CMAPTYPES*", "".join(self._make_cmaptypes(self.system)))

        for i, (molname, m) in enumerate(self.system.dict_molname_mol.items()):
            itp = self.itptemplate
            itp = itp.replace(
                "*MOLECULETYPE*",
                "".join(self._make_moleculetype(m, molname, m.exclusion_numb)),
            )
            itp = itp.replace("*ATOMS*", "".join(self._make_atoms(m)))
            itp = itp.replace("*BONDS*", "".join(self._make_bonds(m)))
            itp = itp.replace("*PAIRS*", "".join(self._make_pairs(m)))
            itp = itp.replace("*SETTLES*", "".join(self._make_settles(m)))
            itp = itp.replace("*VIRTUAL_SITES3*", "".join(self._make_virtual_sites3(m)))
            itp = itp.replace("*EXCLUSIONS*", "".join(self._make_exclusions(m)))
            itp = itp.replace("*ANGLES*", "".join(self._make_angles(m)))
            itp = itp.replace("*DIHEDRALS*", "".join(self._make_dihedrals(m)))
            itp = itp.replace("*IMPROPERS*", "".join(self._make_impropers(m)))
            itp = itp.replace("*CMAPS*", "".join(self._make_cmaps(m)))
            if not self.multiple_output:
                top += itp
            else:
                outfile = "mol_{0}.itp".format(molname)
                top += '#include "mol_{0}.itp" \n'.format(molname)
                with open(outfile, "w") as f:
                    f.writelines([itp])

        top += "\n[system]  \nConvertedSystem\n\n"
        top += "[molecules] \n"
        molecules = [("", 0)]

        for m in self.system.molecules:
            if molecules[-1][0] != m.name:
                molecules.append([m.name, 0])
            if molecules[-1][0] == m.name:
                molecules[-1][1] += 1

        for molname, n in molecules[1:]:
            top += "{0:s}     {1:d}\n".format(molname, n)
        top += "\n"

        with open(self.outfile, "w") as f:
            f.writelines([top])

    def _make_defaults(self, m):
        if m.defaults["gen-pairs"] and m.defaults["fudgeLJ"] and m.defaults["fudgeQQ"]:
            line = [
                "{0:d}          {1:d}           {2}          {3}       {4} \n".format(
                    m.defaults["nbfunc"],
                    m.defaults["comb-rule"],
                    m.defaults["gen-pairs"],
                    m.defaults["fudgeLJ"],
                    m.defaults["fudgeQQ"],
                )
            ]
        else:
            line = [
                "{0:d}          {1:d}\n".format(
                    m.defaults["nbfunc"],
                    m.defaults["comb-rule"],
                )
            ]
        return line

    def _make_atomtypes(self, m):
        def get_prot(at):
            # TODO improve this
            _protons = {"C": 6, "H": 1, "N": 7, "O": 8, "S": 16, "P": 15}
            if at[0] in list(_protons.keys()):
                return _protons[at[0]]
            else:
                return 0

        result = []
        for at in m.atomtypes:
            at.convert("gromacs")
            prot = get_prot(at.atype)
            ljl = at.gromacs["param"]["ljl"]
            lje = at.gromacs["param"]["lje"]
            line = self.formats["atomtypes"].format(
                at.atype,
                at.bond_type if at.bond_type else "",
                at.mass,
                at.charge,
                "A",
                ljl,
                lje,
            )
            if at.comment:
                line += at.comment
            result.append(line)

        return result

    def _make_nonbond_param(self, m):
        result = []
        for pr in m.nonbond_params:
            at1 = pr.atype1
            at2 = pr.atype2

            # pr.convert('gromacs')
            eps = pr.gromacs["param"]["eps"]
            sig = pr.gromacs["param"]["sig"]

            fu = 1  # TODO
            line = self.formats["nonbond_params"].format(at1, at2, fu, sig, eps)
            if pr.comment:
                line = line[:-1] + pr.comment + line[-1:]
            result.append(line)

        return result

    def _make_pairtypes(self, m):
        result = []
        for pt in m.pairtypes:
            at1, at2 = pt.atype1, pt.atype2
            fu, l14, e14 = (
                pt.gromacs["func"],
                pt.gromacs["param"]["ljl14"],
                pt.gromacs["param"]["lje14"],
            )
            line = self.formats["pairtypes"].format(at1, at2, fu, l14, e14)
            if pt.comment:
                line = line[:-1] + pt.comment
            result.append(line)

        return result

    def _make_bondtypes(self, m):
        result = []
        for bond in m.bondtypes:
            at1 = bond.atype1
            at2 = bond.atype2
            bond.convert("gromacs")

            kb = bond.gromacs["param"]["kb"]
            b0 = bond.gromacs["param"]["b0"]
            fu = bond.gromacs["func"]

            line = self.formats["bondtypes"].format(at1, at2, fu, b0, kb)
            result.append(line)

        return result

    def _make_constrainttypes(self, m):
        result = []

        for con in m.constrainttypes:
            at1 = con.atype1
            at2 = con.atype2

            fu = con.gromacs["func"]
            b0 = con.gromacs["param"]["b0"]

            line = self.formats["constrainttypes"].format(at1, at2, fu, b0)
            result.append(line)

        return result

    def _make_angletypes(self, m):
        result = []
        for ang in m.angletypes:
            at1 = ang.atype1
            at2 = ang.atype2
            at3 = ang.atype3
            ang.convert("gromacs")

            ktetha = ang.gromacs["param"]["ktetha"]
            tetha0 = ang.gromacs["param"]["tetha0"]
            kub = ang.gromacs["param"].get("kub")
            s0 = ang.gromacs["param"].get("s0")

            fu = ang.gromacs["func"]

            angletypes = "angletypes_{0:d}".format(fu)
            line = self.formats[angletypes].format(
                at1, at2, at3, fu, tetha0, ktetha, s0, kub
            )
            result.append(line)

        return result

    def _make_dihedraltypes(self, m):
        result = []
        for dih in m.dihedraltypes:
            at1 = dih.atype1
            at2 = dih.atype2
            at3 = dih.atype3
            at4 = dih.atype4

            dih.convert("gromacs")
            fu = dih.gromacs["func"]

            for dpar in dih.gromacs["param"]:
                kchi = dpar["kchi"]
                n = dpar["n"]
                delta = dpar["delta"]

                if not dih.disabled:
                    line = self.formats["dihedraltypes"].format(
                        at1, at2, at3, at4, fu, delta, kchi, n
                    )
                else:
                    line = self.formats["dihedraltypes"].format(
                        at1, at2, at3, at4, fu, delta, kchi, n
                    )
                    line = dih.comment + line
                result.append(line)

        return result

    def _make_impropertypes(self, m):
        result = []
        for imp in m.impropertypes:
            at1 = imp.atype1
            at2 = imp.atype2
            at3 = imp.atype3
            at4 = imp.atype4

            imp.convert("gromacs")
            fu = imp.gromacs["func"]

            for ipar in imp.gromacs["param"]:
                kpsi = ipar["kpsi"]
                psi0 = ipar["psi0"]

                if fu == 2:
                    line = self.formats["impropertypes_2"].format(
                        at1, at2, at3, at4, fu, psi0, kpsi
                    )
                if fu == 4:
                    n = ipar["n"]
                    line = self.formats["impropertypes_4"].format(
                        at1, at2, at3, at4, fu, psi0, kpsi, n
                    )

                if imp.disabled:
                    line = imp.comment + line
                result.append(line)

        return result

    def _make_cmaptypes(self, m):
        result = []
        for cmap in m.cmaptypes:
            at1 = cmap.atype1
            at2 = cmap.atype2
            at3 = cmap.atype3
            at4 = cmap.atype4
            # at5 = cmap.atype5
            # at6 = cmap.atype6
            # at7 = cmap.atype7
            at8 = cmap.atype8

            cmap.convert("gromacs")

            fu = cmap.gromacs["func"]
            line = "{0:s} {1:s} {2:s} {3:s} {4:s} {5:d} 24 24".format(
                at1, at2, at3, at4, at8, fu
            )
            for i, c in enumerate(cmap.gromacs["param"]):
                if i % 10 == 0:
                    line += "\\\n"
                else:
                    line += " "
                line += "{0:12.8f}".format(c)

            line += "\n\n"
            result.append(line)

        return result

    def _make_moleculetype(self, m, molname, nrexcl):
        return ["; Name \t\t  nrexcl \n {0}    {1} \n".format(molname, nrexcl)]

    def _make_atoms(self, m):
        result = []
        # i = 1
        for atom in m.atoms:
            numb = cgnr = atom.number
            atype = atom.get_atomtype()

            assert atype != False
            assert hasattr(atom, "charge")  # and hasattr(atom, 'mass')

            if hasattr(atom, "mass"):
                line = self.formats["atoms"].format(
                    numb,
                    atype,
                    atom.resnumb,
                    atom.resname,
                    atom.name,
                    cgnr,
                    atom.charge,
                    atom.mass,
                )
            else:
                line = self.formats["atoms_nomass"].format(
                    numb,
                    atype,
                    atom.resnumb,
                    atom.resname,
                    atom.name,
                    cgnr,
                    atom.charge,
                )
            result.append(line)

        result.insert(0, "; {0:5d} atoms\n".format(len(result)))
        return result

    def _make_pairs(self, m):
        result = []
        for pr in m.pairs:
            fu = 1
            p1 = pr.atom1.number
            p4 = pr.atom2.number

            line = self.formats["pairs"].format(p1, p4, fu)
            result.append(line)

        result.insert(0, "; {0:5d} pairs\n".format(len(result)))
        return result

    def _make_bonds(self, m):
        result = []
        for bond in m.bonds:
            fu = bond.gromacs["func"]

            if bond.gromacs["param"]["kb"] and bond.gromacs["param"]["b0"]:
                kb, b0 = bond.gromacs["param"]["kb"], bond.gromacs["param"]["b0"]
                line = self.formats["bonds_ext"].format(
                    bond.atom1.number, bond.atom2.number, fu, b0, kb
                )
            else:
                line = self.formats["bonds"].format(
                    bond.atom1.number, bond.atom2.number, fu
                )

            result.append(line)

        result.insert(0, "; {0:5d} bonds\n".format(len(result)))
        return result

    def _make_angles(self, m):
        result = []
        for ang in m.angles:
            fu = ang.gromacs["func"]
            has_params = (
                "param" in ang.gromacs
                and ang.gromacs["param"]["ktetha"]
                and ang.gromacs["param"]["tetha0"]
            )
            if has_params:
                ktetha, tetha0 = (
                    ang.gromacs["param"]["ktetha"],
                    ang.gromacs["param"]["tetha0"],
                )
                line = self.formats["angles_ext"].format(
                    ang.atom1.number,
                    ang.atom2.number,
                    ang.atom3.number,
                    fu,
                    tetha0,
                    ktetha,
                )
            else:
                line = self.formats["angles"].format(
                    ang.atom1.number, ang.atom2.number, ang.atom3.number, fu
                )
            result.append(line)

        result.insert(0, "; {0:5d} angles\n".format(len(result)))
        return result

    def _make_settles(self, m):
        result = []
        for st in m.settles:
            line = self.formats["settles"].format(st.atom.number, 1, st.dOH, st.dHH)
            result.append(line)

        result.insert(0, "; {0:5d} settles\n".format(len(result)))
        return result

    def _make_virtual_sites3(self, m):
        result = []
        for vs in m.virtual_sites3:
            fu = 1
            line = self.formats["virtual_sites3"].format(
                vs.atom1,
                vs.atom2,
                vs.atom3,
                vs.atom4,
                fu,
                vs.gromacs["param"]["a"],
                vs.gromacs["param"]["b"],
            )
            result.append(line)

        result.insert(0, "; {0:5d} virtual_sites3\n".format(len(result)))
        return result

    def _make_exclusions(self, m):
        result = []
        for excl in m.exclusions:
            other_atoms = ["  {:3d}".format(at.number) for at in excl.other_atoms]
            line = self.formats["exclusions"].format(
                excl.main_atom.number, "".join(other_atoms)
            )
            result.append(line)

        result.insert(0, "; {0:5d} exclusions\n".format(len(result)))
        return result

    def _make_dihedrals(self, m):
        result = []
        for dih in m.dihedrals:
            fu = dih.gromacs["func"]

            if not dih.gromacs["param"]:
                line = self.formats["dihedrals"].format(
                    dih.atom1.number,
                    dih.atom2.number,
                    dih.atom3.number,
                    dih.atom4.number,
                    fu,
                )
                result.append(line)

            for dpar in dih.gromacs["param"]:
                kchi = dpar["kchi"]
                n = dpar["n"]
                delta = dpar["delta"]

                line = self.formats["dihedrals_ext"].format(
                    dih.atom1.number,
                    dih.atom2.number,
                    dih.atom3.number,
                    dih.atom4.number,
                    fu,
                    delta,
                    kchi,
                    n,
                )
                if dih.comment:
                    line = dih.comment + line
                result.append(line)

        result.insert(0, "; {0:5d} dihedrals\n".format(len(result)))
        return result

    def _make_impropers(self, m):
        result = []
        for imp in m.impropers:
            fu = imp.gromacs["func"]

            if not imp.gromacs["param"]:
                line = self.formats["impropers"].format(
                    imp.atom1.number,
                    imp.atom2.number,
                    imp.atom3.number,
                    imp.atom4.number,
                    fu,
                )
                result.append(line)

            for ipar in imp.gromacs["param"]:
                kpsi = ipar["kpsi"]
                psi0 = ipar["psi0"]

                if fu == 2:
                    line = self.formats["impropers_2"].format(
                        imp.atom1.number,
                        imp.atom2.number,
                        imp.atom3.number,
                        imp.atom4.number,
                        fu,
                        psi0,
                        kpsi,
                    )
                if fu == 4:
                    n = ipar["n"]
                    line = self.formats["impropers_4"].format(
                        imp.atom1.number,
                        imp.atom2.number,
                        imp.atom3.number,
                        imp.atom4.number,
                        fu,
                        psi0,
                        kpsi,
                        n,
                    )

                if imp.comment:
                    line = imp.comment + line
                result.append(line)

        result.insert(0, "; {0:5d} impropers\n".format(len(result)))
        return result

    def _make_cmaps(self, m):
        result = []

        for cmap in m.cmaps:
            fu = 1
            line = "{0:5d} {1:5d} {2:5d} {3:5d} {4:5d}   {5:d}\n".format(
                cmap.atom1.number,
                cmap.atom2.number,
                cmap.atom3.number,
                cmap.atom4.number,
                cmap.atom8.number,
                fu,
            )
            result.append(line)

        result.insert(0, "; {0:5d} cmaps\n".format(len(result)))
        return result
