分类预测 | MATLAB实现WOA-CNN鲸鱼算法优化卷积神经网络数据分类预测

分类预测 | MATLAB实现WOA-CNN-LSTM鲸鱼算法优化卷积长短期记忆网络数据分类预测

目录

分类效果

1
2
3
4
5

基本描述

1.Matlab实现WOA-CNN多特征分类预测,多特征输入模型,运行环境Matlab2018b及以上;
2.基于鲸鱼算法(WOA)优化卷积神经网络(CNN)分类预测,优化参数为,学习率,批处理,正则化参数;
3.多特征输入单输出的二分类及多分类模型。程序内注释详细,直接替换数据就可以用;
程序语言为matlab,程序可出分类效果图,迭代优化图,混淆矩阵图;
4.data为数据集,输入12个特征,分四类;main为主程序,其余为函数文件,无需运行,可在下载区获取数据和程序内容。

程序设计

%%  优化算法参数设置
SearchAgents_no = 3;                  % 数量
Max_iteration = 5;                    % 最大迭代次数
dim = 3;                              % 优化参数个数

 
%% 建立模型
lgraph = [
 
 convolution2dLayer([1, 1], 32)  % 卷积核大小 3*1 生成32张特征图
 batchNormalizationLayer         % 批归一化层
 reluLayer                       % Relu激活层

 dropoutLayer(0.2)               % Dropout层
 fullyConnectedLayer(num_class, "Name", "fc")                     % 全连接层
 softmaxLayer("Name", "softmax")                                  % softmax激活层
 classificationLayer("Name", "classification")];                  % 分类层




%% 参数设置
options = trainingOptions('adam', ...     % Adam 梯度下降算法
    'MaxEpochs', 10,...                 % 最大训练次数 
    'MiniBatchSize',best_hd, ...
    'InitialLearnRate', best_lr,...          % 初始学习率为0.001
    'L2Regularization', best_l2,...         % L2正则化参数
    'LearnRateSchedule', 'piecewise',...  % 学习率下降
    'LearnRateDropFactor', 0.1,...        % 学习率下降因子 0.1
    'LearnRateDropPeriod', 400,...        % 经过800次训练后 学习率
%% 训练
net = trainNetwork(p_train, t_train, lgraph, options);

%% 预测
t_sim1 = predict(net, p_train); 
t_sim2 = predict(net, p_test ); 
%_________________________________________________________________________%
%  Whale Optimization Algorithm (WOA) source codes demo 1.0               
% The Whale Optimization Algorithm
function [Best_Cost,Best_pos,curve]=WOA(pop,Max_iter,lb,ub,dim,fobj)

% initialize position vector and score for the leader
Best_pos=zeros(1,dim);
Best_Cost=inf; %change this to -inf for maximization problems


%Initialize the positions of search agents
Positions=initialization(pop,dim,ub,lb);

curve=zeros(1,Max_iter);

t=0;% Loop counter

% Main loop
while t<Max_iter
    for i=1:size(Positions,1)
        
        % Return back the search agents that go beyond the boundaries of the search space
        Flag4ub=Positions(i,:)>ub;
        Flag4lb=Positions(i,:)<lb;
        Positions(i,:)=(Positions(i,:).*(~(Flag4ub+Flag4lb)))+ub.*Flag4ub+lb.*Flag4lb;
        
        % Calculate objective function for each search agent
        fitness=fobj(Positions(i,:));
        
        % Update the leader
        if fitness<Best_Cost % Change this to > for maximization problem
            Best_Cost=fitness; % Update alpha
            Best_pos=Positions(i,:);
        end
        
    end
    
    a=2-t*((2)/Max_iter); % a decreases linearly fron 2 to 0 in Eq. (2.3)
    
    % a2 linearly dicreases from -1 to -2 to calculate t in Eq. (3.12)
    a2=-1+t*((-1)/Max_iter);
    
    % Update the Position of search agents 
    for i=1:size(Positions,1)
        r1=rand(); % r1 is a random number in [0,1]
        r2=rand(); % r2 is a random number in [0,1]
        
        A=2*a*r1-a;  % Eq. (2.3) in the paper
        C=2*r2;      % Eq. (2.4) in the paper
        
        
        b=1;               %  parameters in Eq. (2.5)
        l=(a2-1)*rand+1;   %  parameters in Eq. (2.5)
        
        p = rand();        % p in Eq. (2.6)
        
        for j=1:size(Positions,2)
            
            if p<0.5   
                if abs(A)>=1
                    rand_leader_index = floor(pop*rand()+1);
                    X_rand = Positions(rand_leader_index, :);
                    D_X_rand=abs(C*X_rand(j)-Positions(i,j)); % Eq. (2.7)
                    Positions(i,j)=X_rand(j)-A*D_X_rand;      % Eq. (2.8)
                    
                elseif abs(A)<1
                    D_Leader=abs(C*Best_pos(j)-Positions(i,j)); % Eq. (2.1)
                    Positions(i,j)=Best_pos(j)-A*D_Leader;      % Eq. (2.2)
                end
                
            elseif p>=0.5
              
                distance2Leader=abs(Best_pos(j)-Positions(i,j));
                % Eq. (2.5)
                Positions(i,j)=distance2Leader*exp(b.*l).*cos(l.*2*pi)+Best_pos(j);
                
            end
            
        end
    end
    t=t+1;
    curve(t)=Best_Cost;
    [t Best_Cost]
end

参考资料

[1] https://blog.csdn.net/kjm13182345320/article/details/129036772?spm=1001.2014.3001.5502
[2] https://blog.csdn.net/kjm13182345320/article/details/128690229


http://www.niftyadmin.cn/n/273348.html

相关文章

opencv-python加载pytorch训练好的onnx格式线性回归模型

opencv是一个开源的图形库&#xff0c;有针对java,c,python的库依赖&#xff0c;它本身对模型训练支持的不好&#xff0c;但是可以加载其他框架训练的模型来进行预测。 这里举一个最简单的线性回归的例子&#xff0c;使用深度学习框架pytorch训练模型&#xff0c;最后保存模型为…

【Linux】进程信号 --- 信号产生 信号递达和阻塞 信号捕捉

&#x1f34e;作者&#xff1a;阿润菜菜 &#x1f4d6;专栏&#xff1a;Linux系统编程 文章目录 一、预备知识二、信号产生1. 通过终端按键产生信号1.1 signal()1.2 core dump标志位、核心存储文件 2.通过系统调用向进程发送信号3.由软件条件产生信号3.1 alarm函数和SIGALRM信号…

Linux云服务器的使用,以及运行Python程序、相关Linux指令

目录 1、使用Linux云服务器的软件 2、Linux系统运行Python程序 3、Linux系统查看包、虚拟环境、安装包等 以下几个深度学习服务器都不错&#xff1a;智星云、AutoDL、恒源云 1、使用Linux云服务器的软件 MobaXterm_Personal 推荐MobaXterm_Personal mobaxterm是一款方便网站…

【2023/04/21-04/28】回溯算法

学习链接&#xff1a; 回溯算法解题套路框架回溯算法秒杀所有排列-组合-子集问题一文秒杀所有岛屿题目 1.分割回文串 题目来源&#xff1a;131.分割回文串 题解&#xff1a; class Solution { public:vector<vector<int>> f;vector<vector<string>&g…

C#基础(转义字符)

什么是转义字符 它是字符串的一部分 用来表示一些特殊含义的字符 比如&#xff1a;在字符串中表现 单引号 引号 空行等等 写法 固定写法&#xff1a; \字符 常用的转义字符单引号\双引号\"换行\n斜杠 \\ 计算机文件路径是要用到\符号的 不常用的转义字符&…

外卖项目优化-01-redis缓存短信验证码、菜品数据

文章目录 外卖项目优化-01课程内容前言1. 环境搭建1.1 版本控制解决branch和tag命名冲突 1.2 环境准备 2. 缓存短信验证码2.1 思路分析2.2 代码改造2.3 功能测试 3. 缓存菜品信息3.1 实现思路3.2 代码改造3.2.1 查询菜品缓存3.2.2 清理菜品缓存 3.3 功能测试3.4 提交并推送代码…

abc 283E 经典dp

题意&#xff1a;https://www.luogu.com.cn/problem/AT_abc283_e 思路&#xff1a;非常经典的dp&#xff0c;设为前i行第i行是否反转和第i1行是否反转。 /*keep on going and never give up*/ #include<cstdio> #include<iostream> #include<queue> #inclu…

【Linux - Shell常用命令】- 判断文件是否存在、去掉文件后缀

目录 一、判断文件是否存在1.1 判断目录是否存在1.2 判断文件是否存在1.3 其他文件类型判断 二、字符串截取&#xff08;去掉文件后缀&#xff09;2.1 获取文件后缀2.2 获取文件前缀 一、判断文件是否存在 1.1 判断目录是否存在 将下面代码保存为dirExist.sh &#xff0c;运行…