from distutils.version import LooseVersion
import logging
import warnings
from nengo.builder.processes import SimProcess
from nengo.exceptions import SimulationError
from nengo.synapses import Lowpass, LinearFilter
from nengo.utils.filter_design import (cont2discrete, tf2ss, ss2tf,
                                       BadCoefficients)
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import gen_sparse_ops
from nengo_dl import utils
from nengo_dl.builder import Builder, OpBuilder
logger = logging.getLogger(__name__)
[docs]@Builder.register(SimProcess)
class SimProcessBuilder(OpBuilder):
    """Builds a group of :class:`~nengo:nengo.builder.processes.SimProcess`
    operators.
    Calls the appropriate sub-build class for the different process types.
    Attributes
    ----------
    TF_PROCESS_IMPL : list of :class:`~nengo:nengo.Process`
        The process types that have a custom implementation
    """
    TF_PROCESS_IMPL = (Lowpass, LinearFilter)
    def __init__(self, ops, signals):
        super(SimProcessBuilder, self).__init__(ops, signals)
        logger.debug("process %s", [op.process for op in ops])
        logger.debug("input %s", [op.input for op in ops])
        logger.debug("output %s", [op.output for op in ops])
        logger.debug("t %s", [op.t for op in ops])
        # if we have a custom tensorflow implementation for this process type,
        # then we build that. otherwise we'll execute the process step
        # function externally (using `tf.py_func`), so we just need to set up
        # the inputs/outputs for that.
        if isinstance(ops[0].process, self.TF_PROCESS_IMPL):
            # note: we do this two-step check (even though it's redundant) to
            # make sure that TF_PROCESS_IMPL is kept up to date
            if type(ops[0].process) == Lowpass:
                self.built_process = LowpassBuilder(ops, signals)
            elif isinstance(ops[0].process, LinearFilter):
                self.built_process = LinearFilterBuilder(ops, signals)
        else:
            self.built_process = GenericProcessBuilder(ops, signals)
[docs]    def build_step(self, signals):
        self.built_process.build_step(signals) 
[docs]    def build_post(self, ops, signals, sess, rng):
        if isinstance(self.built_process, GenericProcessBuilder):
            self.built_process.build_post(ops, signals, sess, rng)  
[docs]class GenericProcessBuilder(OpBuilder):
    """Builds all process types for which there is no custom TensorFlow
    implementation.
    Notes
    -----
    These will be executed as native Python functions, requiring execution to
    move in and out of TensorFlow.  This can significantly slow down the
    simulation, so any performance-critical processes should consider
    adding a custom TensorFlow implementation for their type instead.
    """
    def __init__(self, ops, signals):
        super(GenericProcessBuilder, self).__init__(ops, signals)
        self.input_data = (None if ops[0].input is None else
                           signals.combine([op.input for op in ops]))
        self.output_data = signals.combine([op.output for op in ops])
        self.output_shape = self.output_data.shape + (signals.minibatch_size,)
        self.mode = "inc" if ops[0].mode == "inc" else "update"
        self.prev_result = []
        # build the step function for each process
        self.step_fs = [[None for _ in range(signals.minibatch_size)]
                        for _ in ops]
        # `merged_func` calls the step function for each process and
        # combines the result
        @utils.align_func(self.output_shape, self.output_data.dtype)
        def merged_func(time, input):  # pragma: no cover
            if any(x is None for a in self.step_fs for x in a):
                raise SimulationError(
                    "build_post has not been called for %s" % self)
            input_offset = 0
            func_output = []
            for i, op in enumerate(ops):
                if op.input is not None:
                    input_shape = op.input.shape[0]
                    func_input = input[input_offset:input_offset + input_shape]
                    input_offset += input_shape
                mini_out = []
                for j in range(signals.minibatch_size):
                    x = [] if op.input is None else [func_input[..., j]]
                    mini_out += [self.step_fs[i][j](*([time] + x))]
                func_output += [np.stack(mini_out, axis=-1)]
            return np.concatenate(func_output, axis=0)
        self.merged_func = merged_func
        self.merged_func.__name__ = utils.sanitize_name(
            "_".join([type(op.process).__name__ for op in ops]))
[docs]    def build_step(self, signals):
        input = ([] if self.input_data is None
                 else signals.gather(self.input_data))
        # note: we need to make sure that the previous call to this function
        # has completed before the next starts, since we don't know that the
        # functions are thread safe
        with tf.control_dependencies(self.prev_result), tf.device("/cpu:0"):
            result = tf.py_func(
                self.merged_func, [signals.time, input],
                self.output_data.dtype, name=self.merged_func.__name__)
        result.set_shape(self.output_shape)
        self.prev_result = [result]
        signals.scatter(self.output_data, result, mode=self.mode) 
[docs]    def build_post(self, ops, signals, sess, rng):
        for i, op in enumerate(ops):
            for j in range(signals.minibatch_size):
                self.step_fs[i][j] = op.process.make_step(
                    op.input.shape if op.input is not None else (0,),
                    op.output.shape, signals.dt_val, op.process.get_rng(rng))  
[docs]class LowpassBuilder(OpBuilder):
    """Build a group of :class:`~nengo:nengo.Lowpass` synapse operators."""
    def __init__(self, ops, signals):
        super(LowpassBuilder, self).__init__(ops, signals)
        self.input_data = signals.combine([op.input for op in ops])
        self.output_data = signals.combine([op.output for op in ops])
        nums = []
        dens = []
        for op in ops:
            if op.process.tau <= 0.03 * signals.dt_val:
                num = 1
                den = 0
            else:
                num, den, _ = cont2discrete((op.process.num, op.process.den),
                                            signals.dt_val, method="zoh")
                num = num.flatten()
                num = num[1:] if num[0] == 0 else num
                assert len(num) == 1
                num = num[0]
                assert len(den) == 2
                den = den[1]
            nums += [num] * op.input.shape[0]
            dens += [den] * op.input.shape[0]
        nums = np.asarray(nums)
        while nums.ndim < len(self.input_data.full_shape):
            nums = np.expand_dims(nums, -1)
        # note: applying the negative here
        dens = -np.asarray(dens)
        while dens.ndim < len(self.input_data.full_shape):
            dens = np.expand_dims(dens, -1)
        # need to manually broadcast for scatter_mul
        # dens = np.tile(dens, (1, signals.minibatch_size))
        self.nums = signals.constant(nums, dtype=self.output_data.dtype)
        self.dens = signals.constant(dens, dtype=self.output_data.dtype)
        # create a variable to represent the internal state of the filter
        # self.state_sig = signals.make_internal(
        #     "state", self.output_data.shape)
[docs]    def build_step(self, signals):
        # signals.scatter(self.output_data, self.dens, mode="mul")
        # input = signals.gather(self.input_data)
        # signals.scatter(self.output_data, self.nums * input, mode="inc")
        input = signals.gather(self.input_data)
        output = signals.gather(self.output_data)
        signals.scatter(self.output_data,
                        self.dens * output + self.nums * input)  
        # method using _step
        # note: this build_step function doesn't use _step for efficiency
        # reasons (we can avoid an extra scatter by reusing the output signal
        # as the state signal)
        # input = signals.gather(self.input_data)
        # prev_state = signals.gather(self.state_sig)
        # new_state = self.dens * prev_state + self.nums * input
        # signals.scatter(self.state_sig, new_state)
        # signals.scatter(self.output_data, new_state)
[docs]class LinearFilterBuilder(OpBuilder):
    """Build a group of :class:`~nengo:nengo.LinearFilter` synapse
    operators."""
    def __init__(self, ops, signals):
        super(LinearFilterBuilder, self).__init__(ops, signals)
        self.input_data = signals.combine([op.input for op in ops])
        self.output_data = signals.combine([op.output for op in ops])
        self.n_ops = len(ops)
        self.signal_d = ops[0].input.shape[0]
        As = []
        Cs = []
        Ds = []
        # compute the A/C/D matrices for each operator
        for op in ops:
            A, B, C, D = tf2ss(op.process.num, op.process.den)
            if op.process.analog:
                # convert to discrete system
                A, B, C, D, _ = cont2discrete((A, B, C, D), signals.dt_val,
                                              method="zoh")
            # convert to controllable form
            num, den = ss2tf(A, B, C, D)
            if op.process.analog:
                # add shift
                num = np.concatenate((num, [[0]]), axis=1)
            with warnings.catch_warnings():
                # ignore the warning about B, since we aren't using it anyway
                warnings.simplefilter("ignore", BadCoefficients)
                A, _, C, D = tf2ss(num, den)
            As.append(A)
            Cs.append(C[0])
            Ds.append(D.item())
        self.state_d = sum(x.shape[0] for x in Cs)
        # build a sparse matrix containing the A matrices as blocks
        # along the diagonal
        sparse_indices = []
        corner = np.zeros(2, dtype=np.int64)
        for A in As:
            idxs = np.reshape(np.dstack(np.meshgrid(
                np.arange(A.shape[0]), np.arange(A.shape[1]),
                indexing="ij")), (-1, 2))
            idxs += corner
            corner += A.shape
            sparse_indices += [idxs]
        sparse_indices = np.concatenate(sparse_indices, axis=0)
        self.A = signals.constant(np.concatenate(As, axis=0).flatten(),
                                  dtype=signals.dtype)
        self.A_indices = signals.constant(sparse_indices, dtype=(
            tf.int32 if np.all(sparse_indices < np.iinfo(np.int32).max)
            else tf.int64))
        self.A_shape = tf.constant(corner, dtype=tf.int64)
        if np.allclose(Cs, 0):
            self.C = None
        else:
            # add empty dimension for broadcasting
            self.C = signals.constant(np.concatenate(Cs)[:, None],
                                      dtype=signals.dtype)
        if np.allclose(Ds, 0):
            self.D = None
        else:
            # add empty dimension for broadcasting
            self.D = signals.constant(np.asarray(Ds)[:, None],
                                      dtype=signals.dtype)
        self.offsets = tf.expand_dims(
            tf.range(0, len(ops) * As[0].shape[0], As[0].shape[0]),
            axis=1)
        # create a variable to represent the internal state of the filter
        self.state_sig = signals.make_internal(
            "state", (self.state_d, signals.minibatch_size * self.signal_d),
            minibatched=False)
[docs]    def build_step(self, signals):
        input = signals.gather(self.input_data)
        input = tf.reshape(input, (self.n_ops, -1))
        state = signals.gather(self.state_sig)
        # compute output
        if self.C is None:
            output = tf.zeros_like(input)
        else:
            output = state * self.C
            output = tf.reshape(
                output,
                (self.n_ops, -1, signals.minibatch_size * self.signal_d))
            output = tf.reduce_sum(output, axis=1)
        if self.D is not None:
            output += self.D * input
        signals.scatter(self.output_data, output)
        # update state
        if LooseVersion(tf.__version__) < LooseVersion("1.7.0"):
            mat_mul = gen_sparse_ops._sparse_tensor_dense_mat_mul
        else:
            mat_mul = gen_sparse_ops.sparse_tensor_dense_mat_mul
        r = mat_mul(self.A_indices, self.A, self.A_shape, state)
        with tf.control_dependencies([output]):
            state = r + tf.scatter_nd(self.offsets, input,
                                      self.state_sig.shape)
            # TODO: tensorflow does not yet support sparse_tensor_dense_add
            # on the GPU
            # state = gen_sparse_ops._sparse_tensor_dense_add(
            #     self.offsets, input, self.state_sig.shape, r)
        state.set_shape(self.state_sig.shape)
        signals.mark_gather(self.input_data)
        signals.mark_gather(self.state_sig)
        signals.scatter(self.state_sig, state)