#!/usr/bin/python2.7
#
#    Copyright (C) 2013
#    Free Software Foundation, Inc.
#
# This file is part of GCC.
#
# GCC is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3, or (at your option)
# any later version.
#
# GCC is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GCC; see the file COPYING3.  If not see
# <http://www.gnu.org/licenses/>.
#


"""Merge two or more gcda profile.
"""

__author__ = 'Seongbae Park, Rong Xu'
__author_email__ = 'spark@google.com, xur@google.com'

import array
from optparse import OptionParser
import os
import struct
import zipfile

new_histogram = None


class Error(Exception):
  """Exception class for profile module."""


def ReadAllAndClose(path):
  """Return the entire byte content of the specified file.

  Args:
    path: The path to the file to be opened and read.

  Returns:
    The byte sequence of the content of the file.
  """
  with open(path, 'rb') as data_file:
    data = data_file.read()
  return data


def ReturnMergedCounters(objs, index, multipliers):
  """Accumulate the counter at "index" from all counters objs."""
  val = 0
  for j in xrange(len(objs)):
    val += multipliers[j] * objs[j].counters[index]
  return val


class DataObject(object):
  """Base class for various datum in GCDA/GCNO file."""

  def __init__(self, tag):
    self.tag = tag


class Function(DataObject):
  """Function and its counters.

  Attributes:
    length: Length of the data on the disk.
    ident: Ident field.
    line_checksum: Checksum of the line number.
    cfg_checksum: Checksum of the control flow graph.
    counters: All counters associated with the function.
    file: The name of the file the function is defined in. Optional.
    line: The line number the function is defined at. Optional.

  Function object contains other counter objects and block/arc/line objects.
  """

  def __init__(self, reader, tag, n_words):
    """Read function record information from a gcda/gcno file.

    Args:
      reader: gcda/gcno file.
      tag: Funtion tag.
      n_words: Length of function record in unit of 4-byte.
    """
    DataObject.__init__(self, tag)
    self.length = n_words
    self.counters = []

    if reader:
      pos = reader.pos
      self.ident = reader.ReadWord()
      self.line_checksum = reader.ReadWord()
      self.cfg_checksum = reader.ReadWord()

      # Function name string is in gcno files, but not
      # in gcda files. Here we make string reading optional.
      if (reader.pos - pos) < n_words:
        reader.ReadStr()

      if (reader.pos - pos) < n_words:
        self.file = reader.ReadStr()
        self.line_number = reader.ReadWord()
      else:
        self.file = ''
        self.line_number = 0
    else:
      self.ident = 0
      self.line_checksum = 0
      self.cfg_checksum = 0
      self.file = None
      self.line_number = 0

  def Write(self, writer):
    """Write out the function."""

    writer.WriteWord(self.tag)
    writer.WriteWord(self.length)
    writer.WriteWord(self.ident)
    writer.WriteWord(self.line_checksum)
    writer.WriteWord(self.cfg_checksum)
    for c in self.counters:
      c.Write(writer)

  def EntryCount(self):
    """Return the number of times the function called."""
    return self.ArcCounters().counters[0]

  def Merge(self, others, multipliers):
    """Merge all functions in "others" into self.

    Args:
      others: A sequence of Function objects
      multipliers: A sequence of integers to be multiplied during merging.
    """
    for o in others:
      assert self.ident == o.ident
      assert self.line_checksum == o.line_checksum
      assert self.cfg_checksum == o.cfg_checksum

    for i in xrange(len(self.counters)):
      self.counters[i].Merge([o.counters[i] for o in others], multipliers)

  def Print(self):
    """Print all the attributes in full detail."""
    print 'function: ident %d length %d line_chksum %x cfg_chksum %x' % (
        self.ident, self.length,
        self.line_checksum, self.cfg_checksum)
    if self.file:
      print 'file:     %s' % self.file
      print 'line_number:   %d' % self.line_number
    for c in self.counters:
      c.Print()

  def ArcCounters(self):
    """Return the counter object containing Arcs counts."""
    for c in self.counters:
      if c.tag == DataObjectFactory.TAG_COUNTER_ARCS:
        return c
    return None


class Blocks(DataObject):
  """Block information for a function."""

  def __init__(self, reader, tag, n_words):
    DataObject.__init__(self, tag)
    self.length = n_words
    self.__blocks = reader.ReadWords(n_words)

  def Print(self):
    """Print the list of block IDs."""
    print 'blocks:  ', ' '.join(self.__blocks)


class Arcs(DataObject):
  """List of outgoing control flow edges for a single basic block."""

  def __init__(self, reader, tag, n_words):
    DataObject.__init__(self, tag)

    self.length = (n_words - 1) / 2
    self.block_id = reader.ReadWord()
    self.__arcs = reader.ReadWords(2 * self.length)

  def Print(self):
    """Print all edge information in full detail."""
    print 'arcs: block', self.block_id
    print 'arcs: ',
    for i in xrange(0, len(self.__arcs), 2):
      print '(%d:%x)' % (self.__arcs[i], self.__arcs[i+1]),
      if self.__arcs[i+1] & 0x01: print 'on_tree'
      if self.__arcs[i+1] & 0x02: print 'fake'
      if self.__arcs[i+1] & 0x04: print 'fallthrough'
    print


class Lines(DataObject):
  """Line number information for a block."""

  def __init__(self, reader, tag, n_words):
    DataObject.__init__(self, tag)
    self.length = n_words
    self.block_id = reader.ReadWord()
    self.line_numbers = []
    line_number = reader.ReadWord()
    src_files = reader.ReadStr()
    while src_files:
      line_number = reader.ReadWord()
      src_lines = [src_files]
      while line_number:
        src_lines.append(line_number)
        line_number = reader.ReadWord()
      self.line_numbers.append(src_lines)
      src_files = reader.ReadStr()

  def Print(self):
    """Print all line numbers in full detail."""
    for l in self.line_numbers:
      print 'line_number: block %d' % self.block_id, ' '.join(l)


class Counters(DataObject):
  """List of counter values.

  Attributes:
    counters: Sequence of counter values.
  """

  def __init__(self, reader, tag, n_words):
    DataObject.__init__(self, tag)
    self.counters = reader.ReadCounters(n_words / 2)

  def Write(self, writer):
    """Write."""
    writer.WriteWord(self.tag)
    writer.WriteWord(len(self.counters) * 2)
    writer.WriteCounters(self.counters)

  def IsComparable(self, other):
    """Returns true if two counters are comparable."""
    return (self.tag == other.tag and
            len(self.counters) == len(other.counters))

  def Merge(self, others, multipliers):
    """Merge all counter values from others into self.

    Args:
      others: other counters to merge.
      multipliers: multiplier to apply to each of the other counters.

    The value in self.counters is overwritten and is not included in merging.
    """
    for i in xrange(len(self.counters)):
      self.counters[i] = ReturnMergedCounters(others, i, multipliers)

  def Print(self):
    """Print the counter values."""
    if self.counters and reduce(lambda x, y: x or y, self.counters):
      print '%10s: ' % data_factory.GetTagName(self.tag), self.counters


def FindMaxKeyValuePair(table):
  """Return (key, value) pair of a dictionary that has maximum value."""
  maxkey = 0
  maxval = 0
  for k, v in table.iteritems():
    if v > maxval:
      maxval = v
      maxkey = k
  return maxkey, maxval


class SingleValueCounters(Counters):
  """Single-value counter.

  Each profiled single value is encoded in 3 counters:
  counters[3 * i + 0]: the most frequent value
  counters[3 * i + 1]: the count of the most frequent value
  counters[3 * i + 2]: the total number of the evaluation of the value
  """

  def Merge(self, others, multipliers):
    """Merge single value counters."""
    for i in xrange(0, len(self.counters), 3):
      table = {}
      for j in xrange(len(others)):
        o = others[j]
        key = o.counters[i]
        if key in table:
          table[key] += multipliers[j] * o.counters[i + 1]
        else:
          table[o.counters[i]] = multipliers[j] * o.counters[i + 1]

      (maxkey, maxval) = FindMaxKeyValuePair(table)

      self.counters[i] = maxkey
      self.counters[i + 1] = maxval

      # Accumulate the overal count
      self.counters[i + 2] = ReturnMergedCounters(others, i + 2, multipliers)


class DeltaValueCounters(Counters):
  """Delta counter.

  Each profiled delta value is encoded in four counters:
  counters[4 * i + 0]: the last measured value
  counters[4 * i + 1]: the most common difference
  counters[4 * i + 2]: the count of the most common difference
  counters[4 * i + 3]: the total number of the evaluation of the value
  Merging is similar to SingleValueCounters.
  """

  def Merge(self, others, multipliers):
    """Merge DeltaValue counters."""
    for i in xrange(0, len(self.counters), 4):
      table = {}
      for j in xrange(len(others)):
        o = others[j]
        key = o.counters[i + 1]
        if key in table:
          table[key] += multipliers[j] * o.counters[i + 2]
        else:
          table[key] = multipliers[j] * o.counters[i + 2]

      maxkey, maxval = FindMaxKeyValuePair(table)

      self.counters[i + 1] = maxkey
      self.counters[i + 2] = maxval

      # Accumulate the overal count
      self.counters[i + 3] = ReturnMergedCounters(others, i + 3, multipliers)


class IorCounters(Counters):
  """Bitwise-IOR counters."""

  def Merge(self, others, _):
    """Merge IOR counter."""
    for i in xrange(len(self.counters)):
      self.counters[i] = 0
      for o in others:
        self.counters[i] |= o.counters[i]


class ICallTopNCounters(Counters):
  """Indirect call top-N counter.

  Each profiled indirect call top-N is encoded in nine counters:
  counters[9 * i + 0]: number_of_evictions
  counters[9 * i + 1]: callee global id
  counters[9 * i + 2]: call_count
  counters[9 * i + 3]: callee global id
  counters[9 * i + 4]: call_count
  counters[9 * i + 5]: callee global id
  counters[9 * i + 6]: call_count
  counters[9 * i + 7]: callee global id
  counters[9 * i + 8]: call_count
  The 4 pairs of counters record the 4 most frequent indirect call targets.
  """

  def Merge(self, others, multipliers):
    """Merge ICallTopN counters."""
    for i in xrange(0, len(self.counters), 9):
      table = {}
      for j, o in enumerate(others):
        multiplier = multipliers[j]
        for k in xrange(0, 4):
          key = o.counters[i+2*k+1]
          value = o.counters[i+2*k+2]
          if key in table:
            table[key] += multiplier * value
          else:
            table[key] = multiplier * value
      for j in xrange(0, 4):
        (maxkey, maxval) = FindMaxKeyValuePair(table)
        self.counters[i+2*j+1] = maxkey
        self.counters[i+2*j+2] = maxval
        if maxkey:
          del table[maxkey]


def IsGidInsane(gid):
  """Return if the given global id looks insane."""
  module_id = gid >> 32
  function_id = gid & 0xFFFFFFFF
  return (module_id == 0) or (function_id == 0)


class DCallCounters(Counters):
  """Direct call counter.

  Each profiled direct call is encoded in two counters:
  counters[2 * i + 0]: callee global id
  counters[2 * i + 1]: call count
  """

  def Merge(self, others, multipliers):
    """Merge DCall counters."""
    for i in xrange(0, len(self.counters), 2):
      self.counters[i+1] *= multipliers[0]
      for j, other in enumerate(others[1:]):
        global_id = other.counters[i]
        call_count = multipliers[j] * other.counters[i+1]
        if self.counters[i] != 0 and global_id != 0:
          if IsGidInsane(self.counters[i]):
            self.counters[i] = global_id
          elif IsGidInsane(global_id):
            global_id = self.counters[i]
          assert self.counters[i] == global_id
        elif global_id != 0:
          self.counters[i] = global_id
        self.counters[i+1] += call_count
        if IsGidInsane(self.counters[i]):
          self.counters[i] = 0
          self.counters[i+1] = 0
        if self.counters[i] == 0:
          assert self.counters[i+1] == 0
        if self.counters[i+1] == 0:
          assert self.counters[i] == 0


def WeightedMean2(v1, c1, v2, c2):
  """Weighted arithmetic mean of two values."""
  if c1 + c2 == 0:
    return 0
  return (v1*c1 + v2*c2) / (c1+c2)


class ReuseDistCounters(Counters):
  """ReuseDist counters.

  We merge the counters one by one, which may render earlier counters
  contribute less to the final result due to the truncations. We are doing
  this to match the computation in libgcov, to make the
  result consistent in these two merges.
  """

  def Merge(self, others, multipliers):
    """Merge ReuseDist counters."""
    for i in xrange(0, len(self.counters), 4):
      a_mean_dist = 0
      a_mean_size = 0
      a_count = 0
      a_dist_x_size = 0
      for j, other in enumerate(others):
        mul = multipliers[j]
        f_mean_dist = other.counters[i]
        f_mean_size = other.counters[i+1]
        f_count = other.counters[i+2]
        f_dist_x_size = other.counters[i+3]
        a_mean_dist = WeightedMean2(a_mean_dist, a_count,
                                    f_mean_dist, f_count*mul)
        a_mean_size = WeightedMean2(a_mean_size, a_count,
                                    f_mean_size, f_count*mul)
        a_count += f_count*mul
        a_dist_x_size += f_dist_x_size*mul
      self.counters[i] = a_mean_dist
      self.counters[i+1] = a_mean_size
      self.counters[i+2] = a_count
      self.counters[i+3] = a_dist_x_size


class Summary(DataObject):
  """Program level summary information."""

  class Summable(object):
    """One instance of summable information in the profile."""

    def __init__(self, num, runs, sum_all, run_max, sum_max):
      self.num = num
      self.runs = runs
      self.sum_all = sum_all
      self.run_max = run_max
      self.sum_max = sum_max

    def Write(self, writer):
      """Serialize to the byte stream."""

      writer.WriteWord(self.num)
      writer.WriteWord(self.runs)
      writer.WriteCounter(self.sum_all)
      writer.WriteCounter(self.run_max)
      writer.WriteCounter(self.sum_max)

    def Merge(self, others, multipliers):
      """Merge the summary."""
      sum_all = 0
      run_max = 0
      sum_max = 0
      runs = 0
      for i in xrange(len(others)):
        sum_all += others[i].sum_all * multipliers[i]
        sum_max += others[i].sum_max * multipliers[i]
        run_max = max(run_max, others[i].run_max * multipliers[i])
        runs += others[i].runs
      self.sum_all = sum_all
      self.run_max = run_max
      self.sum_max = sum_max
      self.runs = runs

    def Print(self):
      """Print the program summary value."""
      print '%10d %10d %15d %15d %15d' % (
          self.num, self.runs, self.sum_all, self.run_max, self.sum_max)

  class HistogramBucket(object):
    def __init__(self, num_counters, min_value, cum_value):
      self.num_counters = num_counters
      self.min_value = min_value
      self.cum_value = cum_value

    def Print(self, ix):
      if self.num_counters != 0:
        print 'ix=%d num_count=%d min_count=%d cum_count=%d' % (
            ix, self.num_counters, self.min_value, self.cum_value)

  class Histogram(object):
    """Program level histogram information."""

    def __init__(self):
      self.size = 252
      self.bitvector_size = (self.size + 31) / 32
      self.histogram = [[None]] * self.size
      self.bitvector = [0] * self.bitvector_size

    def ComputeCntandBitvector(self):
      h_cnt = 0
      for h_ix in range(0, self.size):
        if self.histogram[h_ix] != [None]:
          if self.histogram[h_ix].num_counters:
            self.bitvector[h_ix/32] |= (1 << (h_ix %32))
            h_cnt += 1
      self.h_cnt = h_cnt

    def Index(self, value):
      """Return the bucket index of a histogram value."""
      r = 1
      prev2bits = 0

      if value <= 3:
        return value
      v = value
      while v > 3:
        r += 1
        v >>= 1
      v = value
      prev2bits = (v >> (r - 2)) & 0x3
      return (r - 1) * 4 + prev2bits

    def Insert(self, value):
      """Add a count value to histogram."""
      i = self.Index(value)
      if self.histogram[i] != [None]:
        self.histogram[i].num_counters += 1
        self.histogram[i].cum_value += value
        if value < self.histogram[i].min_value:
          self.histogram[i].min_value = value
      else:
        self.histogram[i] = Summary.HistogramBucket(1, value, value)

    def Print(self):
      """Print a histogram."""
      print 'Histogram:'
      for i in range(self.size):
        if self.histogram[i] != [None]:
          self.histogram[i].Print(i)

    def Write(self, writer):
      for bv_ix in range(0, self.bitvector_size):
        writer.WriteWord(self.bitvector[bv_ix])
      for h_ix in range(0, self.size):
        if self.histogram[h_ix] != [None]:
          writer.WriteWord(self.histogram[h_ix].num_counters)
          writer.WriteCounter(self.histogram[h_ix].min_value)
          writer.WriteCounter(self.histogram[h_ix].cum_value)

  def SummaryLength(self, h_cnt):
    """Return the of of summary for a given histogram count."""
    return 1 + (10 + 3 * 2) + h_cnt * 5

  def __init__(self, reader, tag, n_words):
    DataObject.__init__(self, tag)
    self.length = n_words
    self.checksum = reader.ReadWord()
    self.sum_counter = []
    self.histograms = []

    for _ in xrange(DataObjectFactory.N_SUMMABLE):
      num = reader.ReadWord()
      runs = reader.ReadWord()
      sum_all = reader.ReadCounter()
      run_max = reader.ReadCounter()
      sum_max = reader.ReadCounter()

      histogram = self.Histogram()
      histo_bitvector = [[None]] * histogram.bitvector_size
      h_cnt = 0

      for bv_ix in xrange(histogram.bitvector_size):
        val = reader.ReadWord()
        histo_bitvector[bv_ix] = val
        while val != 0:
          h_cnt += 1
          val &= (val-1)
      bv_ix = 0
      h_ix = 0
      cur_bitvector = 0
      for _ in xrange(h_cnt):
        while cur_bitvector == 0:
          h_ix = bv_ix * 32
          cur_bitvector = histo_bitvector[bv_ix]
          bv_ix += 1
        assert bv_ix <= histogram.bitvector_size
        while (cur_bitvector & 0x1) == 0:
          h_ix += 1
          cur_bitvector >>= 1
          assert h_ix < histogram.size
        n_counters = reader.ReadWord()
        minv = reader.ReadCounter()
        maxv = reader.ReadCounter()
        histogram.histogram[h_ix] = self.HistogramBucket(n_counters,
                                                         minv, maxv)
        cur_bitvector >>= 1
        h_ix += 1

      self.histograms.append(histogram)
      self.sum_counter.append(self.Summable(
          num, runs, sum_all, run_max, sum_max))

  def Write(self, writer):
    """Serialize to byte stream."""
    writer.WriteWord(self.tag)
    assert new_histogram
    self.length = self.SummaryLength(new_histogram[0].h_cnt)
    writer.WriteWord(self.length)
    writer.WriteWord(self.checksum)
    for i, s in enumerate(self.sum_counter):
      s.Write(writer)
      new_histogram[i].Write(writer)

  def Merge(self, others, multipliers):
    """Merge with the other counter. Histogram will be recomputed afterwards."""
    for i in xrange(len(self.sum_counter)):
      self.sum_counter[i].Merge([o.sum_counter[i] for o in others], multipliers)

  def Print(self):
    """Print all the summary info for a given module/object summary."""
    print '%s: checksum %X' % (
        data_factory.GetTagName(self.tag), self.checksum)
    print '%10s %10s %15s %15s %15s' % (
        'num', 'runs', 'sum_all', 'run_max', 'sum_max')
    for i in xrange(DataObjectFactory.N_SUMMABLE):
      self.sum_counter[i].Print()
      self.histograms[i].Print()


class ModuleInfo(DataObject):
  """Module information."""

  def __init__(self, reader, tag, n_words):
    DataObject.__init__(self, tag)
    self.length = n_words
    self.module_id = reader.ReadWord()
    self.is_primary = reader.ReadWord()
    self.flags = reader.ReadWord()
    self.ggc_memory = reader.ReadWord()
    self.language = reader.ReadWord()
    self.num_quote_paths = reader.ReadWord()
    self.num_bracket_paths = reader.ReadWord()
    self.num_system_paths = reader.ReadWord()
    self.num_cpp_defines = reader.ReadWord()
    self.num_cpp_includes = reader.ReadWord()
    self.num_cl_args = reader.ReadWord()
    self.filename_len = reader.ReadWord()
    self.filename = []
    for _ in xrange(self.filename_len):
      self.filename.append(reader.ReadWord())
    self.src_filename_len = reader.ReadWord()
    self.src_filename = []
    for _ in xrange(self.src_filename_len):
      self.src_filename.append(reader.ReadWord())
    self.string_lens = []
    self.strings = []
    for _ in xrange(self.num_quote_paths + self.num_bracket_paths +
                    self.num_system_paths +
                    self.num_cpp_defines + self.num_cpp_includes +
                    self.num_cl_args):
      string_len = reader.ReadWord()
      string = []
      self.string_lens.append(string_len)
      for _ in xrange(string_len):
        string.append(reader.ReadWord())
      self.strings.append(string)

  def Write(self, writer):
    """Serialize to byte stream."""
    writer.WriteWord(self.tag)
    writer.WriteWord(self.length)
    writer.WriteWord(self.module_id)
    writer.WriteWord(self.is_primary)
    writer.WriteWord(self.flags)
    writer.WriteWord(self.language)
    writer.WriteWord(self.ggc_memory)
    writer.WriteWord(self.num_quote_paths)
    writer.WriteWord(self.num_bracket_paths)
    writer.WriteWord(self.num_system_paths)
    writer.WriteWord(self.num_cpp_defines)
    writer.WriteWord(self.num_cpp_includes)
    writer.WriteWord(self.num_cl_args)
    writer.WriteWord(self.filename_len)
    for i in xrange(self.filename_len):
      writer.WriteWord(self.filename[i])
    writer.WriteWord(self.src_filename_len)
    for i in xrange(self.src_filename_len):
      writer.WriteWord(self.src_filename[i])
    for i in xrange(len(self.string_lens)):
      writer.WriteWord(self.string_lens[i])
      string = self.strings[i]
      for j in xrange(self.string_lens[i]):
        writer.WriteWord(string[j])

  def Print(self):
    """Print the module info."""
    fn = ''
    for fn4 in self.src_filename:
      fn += chr((fn4) & 0xFF)
      fn += chr((fn4 >> 8) & 0xFF)
      fn += chr((fn4 >> 16) & 0xFF)
      fn += chr((fn4 >> 24) & 0xFF)
    print ('%s: %s [%s, %s, %s]'
           % (data_factory.GetTagName(self.tag),
              fn,
              ('primary', 'auxiliary')[self.is_primary == 0],
              ('exported', 'not-exported')[(self.flags & 0x1) == 0],
              ('include_all', '')[(self.flags & 0x2) == 0]))


class DataObjectFactory(object):
  """A factory of profile data objects."""

  TAG_FUNCTION = 0x01000000
  TAG_BLOCK = 0x01410000
  TAG_ARCS = 0x01430000
  TAG_LINES = 0x01450000
  TAG_COUNTER_ARCS = 0x01a10000 + (0 << 17)
  TAG_COUNTER_INTERVAL = TAG_COUNTER_ARCS + (1 << 17)
  TAG_COUNTER_POW2 = TAG_COUNTER_ARCS + (2 << 17)
  TAG_COUNTER_SINGLE = TAG_COUNTER_ARCS + (3 << 17)
  TAG_COUNTER_DELTA = TAG_COUNTER_ARCS + (4 << 17)
  TAG_COUNTER_INDIRECT_CALL = TAG_COUNTER_ARCS + (5 << 17)
  TAG_COUNTER_AVERAGE = TAG_COUNTER_ARCS + (6 << 17)
  TAG_COUNTER_IOR = TAG_COUNTER_ARCS + (7 << 17)
  TAG_COUNTER_ICALL_TOPN = TAG_COUNTER_ARCS + (8 << 17)
  TAG_COUNTER_DCALL = TAG_COUNTER_ARCS + (9 << 17)
  TAG_COUNTER_REUSE_DIST = TAG_COUNTER_ARCS + (10 << 17)

  TAG_PROGRAM_SUMMARY = 0x0a3000000L
  TAG_MODULE_INFO = 0x0ab000000L

  N_SUMMABLE = 1

  DATA_MAGIC = 0x67636461
  NOTE_MAGIC = 0x67636e6f

  def __init__(self):
    self.__tagname = {}
    self.__tagname[self.TAG_FUNCTION] = ('function', Function)
    self.__tagname[self.TAG_BLOCK] = ('blocks', Blocks)
    self.__tagname[self.TAG_ARCS] = ('cfg_arcs', Arcs)
    self.__tagname[self.TAG_LINES] = ('lines', Lines)
    self.__tagname[self.TAG_PROGRAM_SUMMARY] = ('program_summary', Summary)
    self.__tagname[self.TAG_MODULE_INFO] = ('module_info', ModuleInfo)
    self.__tagname[self.TAG_COUNTER_ARCS] = ('arcs', Counters)
    self.__tagname[self.TAG_COUNTER_INTERVAL] = ('interval', Counters)
    self.__tagname[self.TAG_COUNTER_POW2] = ('pow2', Counters)
    self.__tagname[self.TAG_COUNTER_SINGLE] = ('single', SingleValueCounters)
    self.__tagname[self.TAG_COUNTER_DELTA] = ('delta', DeltaValueCounters)
    self.__tagname[self.TAG_COUNTER_INDIRECT_CALL] = (
        'icall', SingleValueCounters)
    self.__tagname[self.TAG_COUNTER_AVERAGE] = ('average', Counters)
    self.__tagname[self.TAG_COUNTER_IOR] = ('ior', IorCounters)
    self.__tagname[self.TAG_COUNTER_ICALL_TOPN] = ('icall_topn',
                                                   ICallTopNCounters)
    self.__tagname[self.TAG_COUNTER_DCALL] = ('dcall', DCallCounters)
    self.__tagname[self.TAG_COUNTER_REUSE_DIST] = ('reuse_dist',
                                                   ReuseDistCounters)

  def GetTagName(self, tag):
    """Return the name for a given tag."""
    return self.__tagname[tag][0]

  def Create(self, reader, tag, n_words):
    """Read the raw data from reader and return the data object."""
    if tag not in self.__tagname:
      print tag

    assert tag in self.__tagname
    return self.__tagname[tag][1](reader, tag, n_words)


# Singleton factory object.
data_factory = DataObjectFactory()


class ProfileDataFile(object):
  """Structured representation of a gcda/gcno file.

  Attributes:
    buffer: The binary representation of the file.
    pos: The current position in the buffer.
    magic: File type magic number.
    version: Compiler version.
    stamp: Time stamp.
    functions: A sequence of all Function objects.
        The order is preserved from the binary representation.

  One profile data file (gcda or gcno file) is a collection
  of Function data objects and object/program summaries.
  """

  def __init__(self, buf=None):
    """If buf is None, create a skeleton. Otherwise, read from buf."""
    self.pos = 0
    self.functions = []
    self.program_summaries = []
    self.module_infos = []

    if buf:
      self.buffer = buf
      # Convert the entire buffer to ints as store in an array.  This
      # is a bit more convenient and faster.
      self.int_array = array.array('I', self.buffer)
      self.n_ints = len(self.int_array)
      self.magic = self.ReadWord()
      self.version = self.ReadWord()
      self.stamp = self.ReadWord()
      if (self.magic == data_factory.DATA_MAGIC or
          self.magic == data_factory.NOTE_MAGIC):
        self.ReadObjects()
      else:
        print 'error: %X is not a known gcov magic' % self.magic
    else:
      self.buffer = None
      self.magic = 0
      self.version = 0
      self.stamp = 0

  def WriteToBuffer(self):
    """Return a string that contains the binary representation of the file."""
    self.pos = 0
    # When writing, accumulate written values in a list, then flatten
    # into a string.  This is _much_ faster than accumulating within a
    # string.
    self.buffer = []
    self.WriteWord(self.magic)
    self.WriteWord(self.version)
    self.WriteWord(self.stamp)
    for s in self.program_summaries:
      s.Write(self)
    for f in self.functions:
      f.Write(self)
    for m in self.module_infos:
      m.Write(self)
    self.WriteWord(0)  # EOF marker
    # Flatten buffer into a string.
    self.buffer = ''.join(self.buffer)
    return self.buffer

  def WriteWord(self, word):
    """Write one word - 32-bit integer to buffer."""
    self.buffer.append(struct.pack('I', word & 0xFFFFFFFF))

  def WriteWords(self, words):
    """Write a sequence of words to buffer."""
    for w in words:
      self.WriteWord(w)

  def WriteCounter(self, c):
    """Write one counter to buffer."""
    self.WriteWords((int(c), int(c >> 32)))

  def WriteCounters(self, counters):
    """Write a sequence of Counters to buffer."""
    for c in counters:
      self.WriteCounter(c)

  def WriteStr(self, s):
    """Write a string to buffer."""
    l = len(s)
    self.WriteWord((l + 4) / 4)  # Write length
    self.buffer.append(s)
    for _ in xrange(4 * ((l + 4) / 4) - l):
      self.buffer.append('\x00'[0])

  def ReadWord(self):
    """Read a word from buffer."""
    self.pos += 1
    return self.int_array[self.pos - 1]

  def ReadWords(self, n_words):
    """Read the specified number of words (n_words) from buffer."""
    self.pos += n_words
    return self.int_array[self.pos - n_words:self.pos]

  def ReadCounter(self):
    """Read a counter value from buffer."""
    v = self.ReadWord()
    return v | (self.ReadWord() << 32)

  def ReadCounters(self, n_counters):
    """Read the specified number of counter values from buffer."""
    words = self.ReadWords(2 * n_counters)
    return [words[2 * i] | (words[2 * i + 1] << 32) for i in xrange(n_counters)]

  def ReadStr(self):
    """Read a string from buffer."""
    length = self.ReadWord()
    if not length:
      return None
    # Read from the original string buffer to avoid having to convert
    # from int back to string.  The position counter is a count of
    # ints, so we need to multiply it by 4.
    ret = self.buffer[4 * self.pos: 4 * self.pos + 4 * length]
    self.pos += length
    return ret.rstrip('\x00')

  def ReadObjects(self):
    """Read and process all data objects from buffer."""
    function = None
    while self.pos < self.n_ints:
      obj = None
      tag = self.ReadWord()
      if not tag and self.program_summaries:
        break

      length = self.ReadWord()
      obj = data_factory.Create(self, tag, length)
      if obj:
        if tag == data_factory.TAG_FUNCTION:
          function = obj
          self.functions.append(function)
        elif tag == data_factory.TAG_PROGRAM_SUMMARY:
          self.program_summaries.append(obj)
        elif tag == data_factory.TAG_MODULE_INFO:
          self.module_infos.append(obj)
        else:
          # By default, all objects belong to the preceding function,
          # except for program summary or new function.
          function.counters.append(obj)
      else:
        print 'WARNING: unknown tag - 0x%X' % tag

  def PrintBrief(self):
    """Print the list of functions in the file."""
    print 'magic:   0x%X' % self.magic
    print 'version: 0x%X' % self.version
    print 'stamp:   0x%X' % self.stamp
    for function in self.functions:
      print '%d' % function.EntryCount()

  def Print(self):
    """Print the content of the file in full detail."""
    for function in self.functions:
      function.Print()
    for s in self.program_summaries:
      s.Print()
    for m in self.module_infos:
      m.Print()

  def MergeFiles(self, files, multipliers):
    """Merge ProfileDataFiles and return a merged file."""
    for f in files:
      assert self.version == f.version
      assert len(self.functions) == len(f.functions)

    for i in range(len(self.functions)):
      self.functions[i].Merge([f.functions[i] for f in files], multipliers)

    for i in range(len(self.program_summaries)):
      self.program_summaries[i].Merge([f.program_summaries[i] for f in files],
                                      multipliers)

    if self.module_infos:
      primary_module_id = self.module_infos[0].module_id
      module_group_ids = set(m.module_id for m in self.module_infos)
      for f in files:
        assert f.module_infos
        assert primary_module_id == f.module_infos[0].module_id
        assert ((f.module_infos[0].flags & 0x2) ==
                (self.module_infos[0].flags & 0x2))
        f.module_infos[0].flags |= self.module_infos[0].flags
        for m in f.module_infos:
          if m.module_id not in module_group_ids:
            module_group_ids.add(m.module_id)
            self.module_infos.append(m)


class OneImport(object):
  """Representation of one import for a primary module."""

  def __init__(self, src, gcda):
    self.src = src
    self.gcda = gcda
    assert self.gcda.endswith('.gcda\n')

  def GetLines(self):
    """Returns the text lines for the import."""
    lines = [self.src, self.gcda]
    return lines


class ImportsFile(object):
  """Representation of one .gcda.imports file."""

  def __init__(self, profile_archive, import_file):
    self.filename = import_file
    if profile_archive.dir:
      f = open(os.path.join(profile_archive.dir, import_file), 'rb')
      lines = f.readlines()
      f.close()
    else:
      assert profile_archive.zip
      buf = profile_archive.zip.read(import_file)
      lines = []
      if buf:
        lines = buf.rstrip('\n').split('\n')
      for i in xrange(len(lines)):
        lines[i] += '\n'

    self.imports = []
    for i in xrange(0, len(lines), 2):
      src = lines[i]
      gcda = lines[i+1]
      self.imports.append(OneImport(src, gcda))

  def MergeFiles(self, files):
    """Merge ImportsFiles and return a merged file."""
    table = dict((imp.src, 1) for imp in self.imports)

    for o in files:
      for imp in o.imports:
        if not imp.src in table:
          self.imports.append(imp)
          table[imp.src] = 1

  def Write(self, datafile):
    """Write out to datafile as text lines."""
    lines = []
    for imp in self.imports:
      lines.extend(imp.GetLines())
    datafile.writelines(lines)

  def WriteToBuffer(self):
    """Return a string that contains the binary representation of the file."""
    self.pos = 0
    self.buffer = ''

    for imp in self.imports:
      for line in imp.GetLines():
        self.buffer += line

    return self.buffer

  def Print(self):
    """Print method."""
    print 'Imports for %s\n' % (self.filename)
    for imp in self.imports:
      for line in imp.GetLines():
        print line


class ProfileArchive(object):
  """A container for all gcda/gcno files under a directory (recursively).

  Attributes:
    gcda: A dictionary with the gcda file path as key.
          If the value is 0, it means the file exists in the archive
          but not yet read.
    gcno: A dictionary with the gcno file path as key.
    dir: A path to the directory containing the gcda/gcno.
         If set, the archive is a directory.
    zip: A ZipFile instance. If set, the archive is a zip file.

  ProfileArchive can be either a directory containing a directory tree
  containing gcda/gcno files, or a single zip file that contains
  the similar directory hierarchy.
  """

  def __init__(self, path):
    self.gcda = {}
    self.gcno = {}
    self.imports = {}
    if os.path.isdir(path):
      self.dir = path
      self.zip = None
      self.ScanDir(path)
    elif path.endswith('.zip'):
      self.zip = zipfile.ZipFile(path)
      self.dir = None
      self.ScanZip()

  def ReadFile(self, path):
    """Read the content of the file and return it.

    Args:
      path: a relative path of the file inside the archive.

    Returns:
      Sequence of bytes containing the content of the file.

    Raises:
      Error: If file is not found.
    """
    if self.dir:
      return ReadAllAndClose(os.path.join(self.dir, path))
    elif self.zip:
      return self.zip.read(path)
    raise Error('File not found - "%s"' % path)

  def ScanZip(self):
    """Find all .gcda/.gcno/.imports files in the zip."""
    for f in self.zip.namelist():
      if f.endswith('.gcda'):
        self.gcda[f] = 0
      elif f.endswith('.gcno'):
        self.gcno[f] = 0
      elif f.endswith('.imports'):
        self.imports[f] = 0

  def ScanDir(self, direc):
    """Recursively visit all subdirs and find all .gcda/.gcno/.imports files."""

    def ScanFile(_, dirpath, namelist):
      """Record gcda/gcno files."""
      for f in namelist:
        path = os.path.join(dirpath, f)
        if f.endswith('.gcda'):
          self.gcda[path] = 0
        elif f.endswith('.gcno'):
          self.gcno[path] = 0
        elif f.endswith('.imports'):
          self.imports[path] = 0

    # Avoid using abs path to save memory.
    cwd = os.getcwd()
    os.chdir(direc)
    os.path.walk('.', ScanFile, None)
    os.chdir(cwd)

  def ReadAll(self):
    """Read all gcda/gcno/imports files found inside the archive."""
    for f in self.gcda.iterkeys():
      self.gcda[f] = ProfileDataFile(self.ReadFile(f))
    for f in self.gcno.iterkeys():
      self.gcno[f] = ProfileDataFile(self.ReadFile(f))
    for f in self.imports.iterkeys():
      self.imports[f] = ImportsFile(self, f)

  def Print(self):
    """Print all files in full detail - including all counter values."""
    for f in self.gcda.itervalues():
      f.Print()
    for f in self.gcno.itervalues():
      f.Print()
    for f in self.imports.itervalues():
      f.Print()

  def PrintBrief(self):
    """Print only the summary information without the counter values."""
    for f in self.gcda.itervalues():
      f.PrintBrief()
    for f in self.gcno.itervalues():
      f.PrintBrief()
    for f in self.imports.itervalues():
      f.PrintBrief()

  def Write(self, output_path):
    """Write the archive to disk."""

    if output_path.endswith('.zip'):
      zip_out = zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED)
      for f in self.gcda.iterkeys():
        zip_out.writestr(f, self.gcda[f].WriteToBuffer())
      for f in self.imports.iterkeys():
        zip_out.writestr(f, self.imports[f].WriteToBuffer())
      zip_out.close()

    else:
      if not os.path.exists(output_path):
        os.makedirs(output_path)
      for f in self.gcda.iterkeys():
        outfile_path = os.path.join(output_path, f)
        if not os.path.exists(os.path.dirname(outfile_path)):
          os.makedirs(os.path.dirname(outfile_path))
        data_file = open(outfile_path, 'wb')
        data_file.write(self.gcda[f].WriteToBuffer())
        data_file.close()
      for f in self.imports.iterkeys():
        outfile_path = os.path.join(output_path, f)
        if not os.path.exists(os.path.dirname(outfile_path)):
          os.makedirs(os.path.dirname(outfile_path))
        data_file = open(outfile_path, 'wb')
        self.imports[f].Write(data_file)
        data_file.close()

  def Merge(self, archives, multipliers):
    """Merge one file at a time."""

    # Read
    for a in archives:
      a.ReadAll()
    if not self in archives:
      self.ReadAll()

    # First create set of all gcda files
    all_gcda_files = set()
    for a in [self] + archives:
      all_gcda_files = all_gcda_files.union(a.gcda.iterkeys())

    # Iterate over all gcda files and create a merged object
    # containing all profile data which exists for this file
    # among self and archives.
    for gcda_file in all_gcda_files:
      files = []
      mults = []
      for i, a in enumerate(archives):
        if gcda_file in a.gcda:
          files.append(a.gcda[gcda_file])
          mults.append(multipliers[i])
      if gcda_file not in self.gcda:
        self.gcda[gcda_file] = files[0]
      self.gcda[gcda_file].MergeFiles(files, mults)

    # Same process for imports files
    all_imports_files = set()
    for a in [self] + archives:
      all_imports_files = all_imports_files.union(a.imports.iterkeys())

    for imports_file in all_imports_files:
      files = []
      for i, a in enumerate(archives):
        if imports_file in a.imports:
          files.append(a.imports[imports_file])
      if imports_file not in self.imports:
        self.imports[imports_file] = files[0]
      self.imports[imports_file].MergeFiles(files)

  def ComputeHistogram(self):
    """Compute and return the histogram."""

    histogram = [[None]] * DataObjectFactory.N_SUMMABLE
    for n in xrange(DataObjectFactory.N_SUMMABLE):
      histogram[n] = Summary.Histogram()

    for o in self.gcda:
      for f in self.gcda[o].functions:
        for n in xrange(len(f.counters)):
          if n < DataObjectFactory.N_SUMMABLE:
            for c in xrange(len(f.counters[n].counters)):
              histogram[n].Insert(f.counters[n].counters[c])
    for n in xrange(DataObjectFactory.N_SUMMABLE):
      histogram[n].ComputeCntandBitvector()
    return histogram


def main():
  """Merge multiple profile data."""

  global new_histogram

  usage = 'usage: %prog [options] <list of dirs/zip_files to be merged>'
  parser = OptionParser(usage)
  parser.add_option('-w', '--multipliers',
                    dest='multipliers',
                    help='Comma separated list of multipliers to be applied '
                    'for each corresponding profile.')
  parser.add_option('-o', '--output',
                    dest='output_profile',
                    help='Output directory or zip file to dump the '
                    'merged profile. Default output is profile-merged.zip.')

  (options, args) = parser.parse_args()

  if len(args) < 2:
    parser.error('Please provide at least 2 input profiles.')

  input_profiles = [ProfileArchive(path) for path in args]

  if options.multipliers:
    profile_multipliers = [long(i) for i in options.multipliers.split(',')]
    if len(profile_multipliers) != len(input_profiles):
      parser.error('--multipliers has different number of elements from '
                   '--inputs.')
  else:
    profile_multipliers = [1 for i in range(len(input_profiles))]

  if options.output_profile:
    output_profile = options.output_profile
  else:
    output_profile = 'profile-merged.zip'

  input_profiles[0].Merge(input_profiles, profile_multipliers)

  new_histogram = input_profiles[0].ComputeHistogram()

  input_profiles[0].Write(output_profile)

if __name__ == '__main__':
  main()
