# NeuroPype Sample Code 

import logging

from ...engine import *
from ...utilities.cloud import storage

logger = logging.getLogger(__name__)


class ImportSET(Node):
    """Import data from an EEGLAB set file."""

    # --- Input/output ports ---
    data = DataPort(Packet, "Output signal.", OUT)

    # --- Properties ---
    filename = StringPort("", """Name of the file to import (set). If a
        relative path is given, a file of this name will be looked up in the
        standard data directories (these include cpe/resources and examples/).
        """, is_filename=True)

    # options for cloud-hosted files
    cloud_host = EnumPort("Default", ["Default", "Azure", "S3", "Google",
                                      "Local", "None"], """Cloud storage host to
        use (if any). You can override this option to select from what kind of
        cloud storage service data should be downloaded. On some environments
        (e.g., on NeuroScale), the value Default will be map to the default
        storage provider on that environment.""")
    cloud_account = StringPort("", """Cloud account name on storage provider
        (use default if omitted). You can override this to choose a non-default
        account name for some storage provider (e.g., Azure or S3.). On some
        environments (e.g., on NeuroScale), this value will be
        default-initialized to your account.""")
    cloud_bucket = StringPort("", """Cloud bucket to read from (use default if
        omitted). This is the bucket or container on the cloud storage provider
        that the file would be read from. On some environments (e.g., on
        NeuroScale), this value will be default-initialized to a bucket
        that has been created for you.""")
    cloud_credentials = StringPort("", """Secure credential to access cloud data
        (use default if omitted). These are the security credentials (e.g.,
        password or access token) for the the cloud storage provider. On some
        environments (e.g., on NeuroScale), this value will be
        default-initialized to the right credentials for you.""")

    def __init__(self, **kwargs):
        """Create a new node. Accepts initial values for the ports."""
        self._has_emitted = False
        super().__init__(**kwargs)

    @classmethod
    def description(cls):
        """Declare descriptive information about the node"""
        return Description(name='Import SET',
                           description="""\
                           Import data from an EEGLAB set file. This format
                           is a commonly-used interchange file format for EEG,
                           for which converters from most other EEG file
                           formats exist. This node will import the
                           respective data, which is assumed to be continuous
                           (non-epoched) EEG/EXG, optionally with event
                           markers. The node outputs the entire data in a
                           single large packet on the first update, so any
                           processing applied to the result will be "offline"
                           or batch" style processing on data that isn't
                           streaming (consequently, the packet is flagged as
                           non-streaming). However, if you intend to simulate
                           online processing, it is possible to chain a
                           Stream Data node after this one, which will take
                           the imported recording and play it back in a
                           streaming fashion (that is, in small chunks, one at a
                           time).

                           Technically, the packet generated by this node is
                           formatted as follows: the first stream is called
                           'eeg' and holds the EEG/EXG data, and the packet
                           contains a single chunk for this stream with a
                           2d array with a space axis (channels) and a time
                           axis (samples). If the data had markers, a second
                           stream named 'markers' is also included in the
                           packet, which has a vector of numeric data (one
                           per marker, initially all NaN) and a single instance
                           axis that indexes the marker instances and whose
                           entry are associated each with a timestamp and a
                           string payload that is the respective event/marker
                           type from the .set file. The numeric data can be
                           overridden based on the event type string using the
                           Assign Targets node, which is required for
                           segmentation and supervised machine learning.""",
                           url='https://en.wikipedia.org/wiki/EEGLAB',
                           version='1.0.0', status=DevStatus.production)

    @Node.update.setter
    def update(self, v):
        if not self.filename:
            self._data = None
            return

        # hack to ensure that the data is only loaded/propagated once; will
        # be replaced by proper management of the "dirty" state
        if self._has_emitted:
            self._data = None
            return
        self._has_emitted = True

        filename = storage.cloud_get(self.filename, host=self.cloud_host,
                                     account=self.cloud_account,
                                     bucket=self.cloud_bucket,
                                     credentials=self.cloud_credentials)
        self._data = self.import_set(filename, filealias=self.filename)
        logger.info("Imported file %s." % self.filename)

    def on_port_assigned(self):
        """Callback to reset internal state when a value was assigned to a
        port (unless the port's setter has been overridden)."""
        self._has_emitted = False
        self.signal_changed(True)

    def is_finished(self):
        """Whether this node is finished producing output."""
        return not self.filename or self._has_emitted

    def import_set(filename, filealias=None):

        """
        :param filename: Name of the file to import (set). If a
            relative path is given, a file of this name will be looked up in the
            standard data directories (these include cpe/resources and examples/).
        :param filealias: Optionally an alias for the file used for e.g., logging
            and the source URL (e.g., if filename points to a temp file)

        :return: packet read from set file
        """

        from scipy.io.matlab.mio import loadmat
        if filealias is None:
            filealias = filename

        logger.info("Importing SET file %s..." % filealias)
        filename = resolve_filename(filename)
        header = loadmat(filename)
        srate = float(header['EEG']['srate'][0][0][0][0])
        xmin = float(header['EEG']['xmin'][0][0][0][0])
        pnts = int(header['EEG']['pnts'][0][0][0][0])
        nbchan = int(header['EEG']['nbchan'][0][0][0][0])
        data = header['EEG']['data'][0][0]

        try:
            sample_times = np.asanyarray(header['EEG']['times'][0][0][0], dtype=float)
            if np.abs(np.median(np.diff(sample_times)) * srate - 1000) < 10:
                logger.info("Fixing time-stamp scale factor (ms).")
                sample_times = sample_times / 1000.0
        except (KeyError, IndexError) as _:
            sample_times = None

        # sanity check
        trials = int(header['EEG']['trials'][0][0][0][0])
        if trials > 1:
            raise RuntimeError('the given .set file %s is segmented; cannot '
                               'import.' % filealias)

        # parse data
        if type(data[0][0]) is str:
            # need to load actual data from .fdt file
            # noinspection PyTypeChecker
            path = os.path.dirname(filename) + os.path.sep + str(data[0])
            data = np.fromfile(path, dtype=np.dtype('<f'), sep='',
                               count=pnts * nbchan).astype('float').reshape((pnts, nbchan)).T

        # parse chaninfo
        try:
            nosedir = str(header['EEG']['chaninfo'][0][0][0]['nosedir'][0][0])
        except Exception:
            logger.info("ImportSET: nosedir property missing from file; assuming +X")
            nosedir = '+X'
        try:
            ls = str(header['EEG']['chaninfo'][0][0][0]['labelscheme'][0][0])
        except Exception:
            ls = 'unknown'

        # parse chanlocs
        try:
            chanlocs = header['EEG']['chanlocs']
            locs = chanlocs[0][0]
        except Exception:
            locs = None
        if locs.shape[0] == 1 and locs.shape[1] > 1:
            locs = locs.T  # for some .set files this is swapped
        chn_labels = [None] * nbchan
        chn_pos = np.zeros((nbchan, 3))
        for c in range(nbchan):
            # label
            label = locs[c]['labels'][0]
            try:
                chn_labels[c] = str(label if not label.ndim else label[0])
            except Exception:
                chn_labels[c] = 'ch' + str(c)

            # position
            try:
                chn_pos[c, 0] = float(locs[c]['X'][0][0])
                chn_pos[c, 1] = float(locs[c]['Y'][0][0])
                chn_pos[c, 2] = float(locs[c]['Z'][0][0])
            except Exception:
                chn_pos[c] = np.nan

        # parse markers
        if header['EEG']['event'][0][0].size:
            event = header['EEG']['event'][0][0][0]
            event_types = [None] * event.size
            event_times = np.zeros(event.size)
            for e in range(event.size):
                try:
                    event_types[e] = str(event[e]['type'][0])
                except Exception:
                    event_types[e] = None
                tmplatency = float(event[e]['latency'][0][0])
                event_times[e] = ((tmplatency - 1) / srate + xmin)
        else:
            # no markers present
            event_types = []
            event_times = np.array([])
        del header

        data = np.asanyarray(data, dtype=float)

        # optionally internalize coordinates
        if nosedir == '+Y':
            # assume internal system
            chn_pos = internalize_coordinates(chn_pos, unit='guess', x='right',
                                              y='front', z='up')
        elif nosedir == '+X':
            chn_pos = internalize_coordinates(chn_pos, unit='guess', x='front',
                                              y='left', z='up')
        elif nosedir == '-X':
            chn_pos = internalize_coordinates(chn_pos, unit='guess', x='back',
                                              y='right', z='up')
        else:
            warn_once("Can not infer the nose direction from input file.",
                      logger=logger)

        if pnts != len(sample_times):
            logger.info("Discarding inconsistently-sized sample times.")
            sample_times = None

        # construct the packet
        if sample_times is None:  # if EEG does not include .times field
            data_axes = (SpaceAxis(names=chn_labels, naming_system=ls,
                                   positions=chn_pos, coordinate_system=nosedir),
                         TimeAxis(nominal_rate=srate, init_time=xmin,
                                  num_times=pnts))
        else:
            data_axes = (SpaceAxis(names=chn_labels, naming_system=ls,
                                   positions=chn_pos, coordinate_system=nosedir),
                         TimeAxis(nominal_rate=srate, init_time=xmin,
                                  num_times=pnts, times=sample_times))
        data_block = Block(data=data, axes=data_axes)
        data_props = {Flags.is_streaming: False, Flags.is_signal: True,
                      Origin.modality: 'EEG',
                      Origin.source_url: 'file://' + filealias}
        data_chunk = Chunk(block=data_block, props=data_props)

        if event_types:
            marker_axes = (InstanceAxis(times=event_times, data=event_types,
                                        instance_type=InstanceType.marker),)
            marker_block = Block(data=np.full(event_times.size, np.nan),
                                 axes=marker_axes)
            marker_props = {Flags.is_streaming: False, Flags.has_markers: True}
            marker_chunk = Chunk(block=marker_block, props=marker_props)

            chunks = OrderedDict([('eeg', data_chunk),
                                  ('markers', marker_chunk)])
        else:
            chunks = OrderedDict([('eeg', data_chunk)])

        return Packet(chunks)

"""Copyright (C) 2014-2020 Intheon. All rights reserved."""