function onset_data = SMIDIBT_route_onsets(send_file,params)
% [onset_data] = SMIDIBT_route_onsets(data_file,params)
%   Extracts the onset and offset times of the MIDI triggers for the send 
%   and read Arduino.
%
%   INPUTS:
%   data_file       The name of a ".wav" file ending in "_send" that
%                   contains the MIDI onset and offset triggers of the send
%                   Arduino. This file MUST have a complementary audio file
%                   ending in "_read" to extract the MIDI onset and offset
%                   triggers of the read Arduino.
%
%   params          A file containing the parameters for onset extraction
%                   (e.g., amplitude and timing thresholds). The parameters
%                   used in Schultz (2017, Experiment 2) are set as default.
%
%   OUTPUT:
%   onset_data      The times (in milliseconds) of the MIDI onsets of the
%                   SMIDIB Send Arduino (first column), MIDI offsets of the 
%                   SMIDIB Send Arduino (second column), MIDI onsets of the
%                   SMIDIB Read Arduino (thrid column), and MIDI offsets of
%                   the SMIDIB Read Arduino (fourth column).
%
%   This version has dependencies appended at the bottom.
%
%   2017-09-28 ben.schultz@maastrichtuniversity.nl
%   Copyright (c) 2017, Benjamin Schultz, Maastricht University.

%   This script is described in more detail in the publication:
%   Schultz, B. G. (submitted).The Schultz MIDI Benchmarking toolbox for 
%   MIDI interfaces, percussion pads, and sound cards. Behavior Research 
%   Methods.
%
%   The SMIDIB Toolbox is distributed in the hope that it will be useful, but
%   WITHOUT ANY WARRANTY; without even the implied warranty of
%   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
%   General Public License for more details.
% 
%   You should have received a copy of the GNU General Public License
%   along with the SMIDIB Toolbox; if not, write to the Free Software
%   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
%   02110-1301 USA
% 
%   See the file "COPYING" for the text of the license.


%% Check inputs
% check input parameters and set defaults
if ~exist('params','var'); params = struct; end;
params = check_params(params);

% check file names
send_ind = strfind(send_file,'_send');
read_file = strrep(send_file,'_send.wav','_read.wav');

if isempty(send_ind)
    error('Filename does not have the correct format. Should end in ''_send''');
end

if ~exist(send_file,'file') || ~exist(read_file,'file')
    error('One or more files do not exist in specified location');
end

% load send and read MIDI triggers
[send_data,send_Fs]=audioread(send_file);
send_data = (send_data*-1)/max(send_data*-1);

[read_data,read_Fs]=audioread(read_file);
read_data = (read_data*-1)/max(read_data*-1);
read_data = filt_read_data(read_data,read_Fs);

if send_Fs~=read_Fs
     error('Sample rates do not match. Files may not be synchronous.');
end

% set thresholds
off_time_thresh_samps = round(params.off_time_thresh_ms*(send_Fs/1000));

% preset data
onset_data = ones(params.n_trigs,4);

prev_offset = 1;
prev_read_offset = 1;
cur_pos = 1;
cur_onset = find(send_data(prev_offset:end)>params.on_thresh,1,'first')+prev_offset;

while ~isempty(cur_onset)
    
    cur_onset = find(send_data(prev_offset:end)>params.on_thresh,1,'first')+prev_offset;
    
    if isempty(cur_onset)
        break;
    end
    
    cur_offset = find(abs(send_data(cur_onset+1:end))<params.off_thresh,1,'first')+cur_onset+1;
    
    [cur_max,~] = max(send_data(cur_onset:cur_offset));
    
    cur_trig_offset = find(send_data(cur_onset:cur_offset)==cur_max,1,'first')+cur_onset;
    
    while send_data(cur_onset-1)<send_data(cur_onset)
        cur_onset = cur_onset-1;
    end
    
    cur_read_onset = find(read_data(prev_read_offset:end)>params.read_on_thresh,1,'first')+prev_read_offset;
    cur_read_offset = find(read_data(cur_read_onset+1:end)<params.read_off_thresh,1,'first')+cur_read_onset+1;
    
    prev_offset = cur_offset+off_time_thresh_samps;
    prev_read_offset = cur_read_offset+off_time_thresh_samps;
    
    while read_data(cur_read_offset-1)>read_data(cur_read_offset)
        cur_read_offset = cur_read_offset-1;
    end
    
    while read_data(cur_read_onset-1)<read_data(cur_read_onset)
        cur_read_onset = cur_read_onset-1;
    end
    
    if isempty(cur_read_onset)
        break
    end
    
    % add onsets to data output
    onset_data(cur_pos,1:4) = [cur_onset,cur_trig_offset,cur_read_onset,cur_read_offset];
    
    % move to next row
    cur_pos = cur_pos+1;
    
end

% remove excess NaNs and turn into milliseconds
onset_data = onset_data(~isnan(onset_data(:,1)),:)/(send_Fs/1000);

%% Save onsets (recommended for large datasets)
if isfield(params,'out_dir')
    [~,data_filename,data_ext] = fileparts(send_file);
    out_filename = strrep([data_filename,data_ext],'_send.wav','_onsets.mat');
    Fs = send_Fs;
    save(fullfile(params.out_dir,out_filename),'onset_data','Fs');
end

%% Print plot
if isfield(params,'plot_dir')
    t = (1:length(send_data))/send_Fs;
    cur_fig = figure;
    plot(t,send_data); 
    hold on;
    plot(t,read_data,'r');
    hold off;
    [~,data_filename,~] = fileparts(send_file);
    saveas(cur_fig,fullfile(params.plot_dir,strrep(data_filename,'_send','')),'fig');
    close(cur_fig);    
end

%% FUNCTIONS
function params = check_params(params)
% Sets the default parameters (if missing)

%parameters
if ~isfield(params,'on_thresh');
    params.on_thresh = 0.2;
end
if ~isfield(params,'off_thresh');
    params.off_thresh = 0.2;
end
if ~isfield(params,'read_on_thresh');
    params.read_on_thresh = 0.15;
end
if ~isfield(params,'read_off_thresh');
    params.read_off_thresh = 0.15;
end
if ~isfield(params,'off_time_thresh_ms');
    params.off_time_thresh_ms = 0.1;
end
if ~isfield(params,'time_thresh_ms');
    params.time_thresh_ms = 0.1;
end
if ~isfield(params,'n_trigs');
    params.n_trigs = 8002; % expected number of onsets
end

function read_data = filt_read_data(read_data,read_Fs)
% Filter for the read data (remove drift)

Wn = 50/(read_Fs/2);
[b,a] = butter(1,Wn,'high');           % IIR filter design

read_data_filt = filtfilt(b,a,read_data);
read_data = read_data_filt;
read_data = (read_data)/max(abs(read_data));