《算法导论》——矩阵乘法的Strassen算法
前言:
很多朋友看到我写的《算法导论》系列,可能会觉得云里雾里,不知所云。这里我再次说明,本系列博文时配合《算法导论》一书,给出该书涉及的算法的c++实现。请结合《算法导论》一书阅读该系列博文。我这里有该书的电子版,有需要的朋友可以留言。
正题:
今天讨论的算法是矩阵乘法的Strassen算法,该算法的精髓在于减少n/2矩阵*n/2矩阵的次数。首先,作一些写该算法的基础工作:
/*
* 矩阵的加法运算
*/
void Add(int** matrixA, int** matrixB, int** matrixResult,int length)
{
for(int i = ; i < length; i++) {
for(int j = ; j < length; j++) {
matrixResult[i][j] = matrixA[i][j] + matrixB[i][j];
}
}
} /*
* 矩阵的减法运算
*/
void Sub(int** matrixA, int** matrixB, int** matrixResult,int length)
{
for(int i = ; i < length; i++) {
for(int j = ; j < length; j++) {
matrixResult[i][j] = matrixA[i][j] - matrixB[i][j];
}
}
} /*
* 矩阵乘法
*/
void Mul(int** matrixA, int** matrixB, int** matrixResult){
for(int i = ; i < ; ++i) {
for(int j = ; j < ; ++j) {
matrixResult[i][j] = ;
for(int k = ; k < ; ++k) {
matrixResult[i][j] += matrixA[i][k] * matrixB[k][j];
}
}
}
}
接着进入核心部分:
void Strassen(int** matrixA, int** matrixB, int** matrixResult,int length)
{
int halfLength=length/;
int** a11=new int*[halfLength];
int** a12=new int*[halfLength];
int** a21=new int*[halfLength];
int** a22=new int*[halfLength]; int** b11=new int*[halfLength];
int** b12=new int*[halfLength];
int** b21=new int*[halfLength];
int** b22=new int*[halfLength]; int** s1=new int*[halfLength];
int** s2=new int*[halfLength];
int** s3=new int*[halfLength];
int** s4=new int*[halfLength];
int** s5=new int*[halfLength];
int** s6=new int*[halfLength];
int** s7=new int*[halfLength]; int** matrixResult11=new int*[halfLength];
int** matrixResult12=new int*[halfLength];
int** matrixResult21=new int*[halfLength];
int** matrixResult22=new int*[halfLength]; int** temp=new int*[halfLength];
int** temp1=new int*[halfLength];
if(halfLength==){
Mul(matrixA, matrixB, matrixResult);
}else{
//首先将矩阵A,B 分为4块
for(int i = ; i < halfLength; i++) {
a11[i]=new int[halfLength];
a12[i]=new int[halfLength];
a21[i]=new int[halfLength];
a22[i]=new int[halfLength]; b11[i]=new int[halfLength];
b12[i]=new int[halfLength];
b21[i]=new int[halfLength];
b22[i]=new int[halfLength]; s1[i]=new int[halfLength];
s2[i]=new int[halfLength];
s3[i]=new int[halfLength];
s4[i]=new int[halfLength];
s5[i]=new int[halfLength];
s6[i]=new int[halfLength];
s7[i]=new int[halfLength]; matrixResult11[i]=new int[halfLength];
matrixResult12[i]=new int[halfLength];
matrixResult21[i]=new int[halfLength];
matrixResult22[i]=new int[halfLength]; temp[i]=new int[halfLength];
temp1[i]=new int[halfLength];
for(int j = ; j < halfLength; j++) {
a11[i][j]=matrixA[i][j];
a12[i][j]=matrixA[i][j+halfLength];
a21[i][j]=matrixA[i+halfLength][j];
a22[i][j]=matrixA[i+halfLength][j+halfLength];
b11[i][j]=matrixB[i][j];
b12[i][j]=matrixB[i][j+halfLength];
b21[i][j]=matrixB[i+halfLength][j];
b22[i][j]=matrixB[i+halfLength][j+halfLength];
}
} //计算s1
Sub(b12, b22, temp,halfLength);
Strassen(a11, temp, s1,halfLength);
//计算s2
Add(a11, a12, temp,halfLength);
Strassen(temp, b22, s2,halfLength);
//计算s3
Add(a21, a22, temp,halfLength);
Strassen(temp, b11, s3,halfLength);
//计算s4
Sub(b21, b11, temp,halfLength);
Strassen(a22, temp, s4,halfLength);
//计算s5
Add(a11, a22, temp1,halfLength);
Add(b11, b22, temp,halfLength);
Strassen(temp1, temp, s5,halfLength);
//计算s6
Sub(a12, a22, temp1,halfLength);
Add(b21, b22, temp,halfLength);
Strassen(temp1, temp, s6,halfLength);
//计算s7
Sub(a11, a21, temp1,halfLength);
Add(b11, b12, temp,halfLength);
Strassen(temp1, temp, s7,halfLength); //计算matrixResult11
Add(s5, s4, temp1,halfLength);
Sub(temp1, s2, temp,halfLength);
Add(temp, s6, matrixResult11,halfLength);
//计算matrixResult12
Add(s1, s2, matrixResult12,halfLength);
//计算matrixResult21
Add(s3, s4, matrixResult21,halfLength);
//计算matrixResult22
Add(s5, s1, temp1,halfLength);
Sub(temp1, s3, temp,halfLength);
Sub(temp, s7, matrixResult22,halfLength); //结果送回matrixResult中
for(int i = ; i < halfLength; i++) {
for(int j = ; j < halfLength; j++) {
matrixResult[i][j]=matrixResult11[i][j];
matrixResult[i][j+halfLength]=matrixResult12[i][j];
matrixResult[i+halfLength][j]=matrixResult21[i][j];
matrixResult[i+halfLength][j+halfLength]=matrixResult22[i][j];
}
delete(a11[i]);
delete(a12[i]);
delete(a21[i]);
delete(a22[i]); delete(b11[i]);
delete(b12[i]);
delete(b21[i]);
delete(b22[i]); delete(s1[i]);
delete(s2[i]);
delete(s3[i]);
delete(s4[i]);
delete(s5[i]);
delete(s6[i]);
delete(s7[i]); delete(matrixResult11[i]);
delete(matrixResult12[i]);
delete(matrixResult21[i]);
delete(matrixResult22[i]); delete(temp[i]);
delete(temp1[i]);
}
delete(a11);
delete(a12);
delete(a21);
delete(a22); delete(b11);
delete(b12);
delete(b21);
delete(b22); delete(s1);
delete(s2);
delete(s3);
delete(s4);
delete(s5);
delete(s6);
delete(s7); delete(matrixResult11);
delete(matrixResult12);
delete(matrixResult21);
delete(matrixResult22); delete(temp);
delete(temp1);
}
}
该算法看着或许有些冗长,几乎一半都在进行动态指针的初始化和删除。利用该算法计算矩阵乘的时间复杂度为θ(n^lg7)。
测试一下吧:
#include "stdafx.h"
#include <iostream>
#include "SquareMatrix.h" using namespace std;
using namespace dksl; //STRASSEN矩阵乘法算法 const int N=; //常量N用来定义矩阵的大小
int _tmain(int argc, _TCHAR* argv[])
{
int **a=new int*[];
int **b=new int*[];
int **c=new int*[];
for(int i=;i<;i++)
{
a[i]=new int[];
b[i]=new int[];
c[i]=new int[];
for(int j=;j<;j++)
{
a[i][j]=;
b[i][j]=;
}
}
Strassen(a,b,c,);
for(int i=;i<;i++)
{
for(int j=;j<;j++)
cout<<c[i][j]<<" ";
cout<<endl;
}
system("PAUSE");
return ;
}
《算法导论》——矩阵乘法的Strassen算法的更多相关文章
- 4-2.矩阵乘法的Strassen算法详解
题目描述 请编程实现矩阵乘法,并考虑当矩阵规模较大时的优化方法. 思路分析 根据wikipedia上的介绍:两个矩阵的乘法仅当第一个矩阵B的列数和另一个矩阵A的行数相等时才能定义.如A是m×n矩阵和B ...
- 【算法导论C++代码】Strassen算法
简单方阵矩乘法 SQUARE-MATRIX-MULTIPLY(A,B) n = A.rows let C be a new n*n natrix to n to n cij = to n cij=ci ...
- 算法导论-矩阵乘法-strassen算法
目录 1.矩阵相乘的朴素算法 2.矩阵相乘的strassen算法 3.完整测试代码c++ 4.性能分析 5.参考资料 内容 1.矩阵相乘的朴素算法 T(n) = Θ(n3) 朴素矩阵相乘算法,思想明了 ...
- 第四章 分治策略 4.2 矩阵乘法的Strassen算法
package chap04_Divide_And_Conquer; import static org.junit.Assert.*; import java.util.Arrays; import ...
- 【算法导论】--分治策略Strassen算法(运用下标运算)【c++】
由于偷懒不想用泛型,所以直接用了整型来写了一份 ①首先你得有一个矩阵的class Matrix ②Matrix为了方便用下标进行运算, Matrix的结构如图:(我知道我的字丑...) Matrix. ...
- 算法笔记_081:蓝桥杯练习 算法提高 矩阵乘法(Java)
目录 1 问题描述 2 解决方案 1 问题描述 问题描述 有n个矩阵,大小分别为a0*a1, a1*a2, a2*a3, ..., a[n-1]*a[n],现要将它们依次相乘,只能使用结合率,求最 ...
- 蓝桥 ADV-232 算法提高 矩阵乘法 【区间DP】
算法提高 矩阵乘法 时间限制:3.0s 内存限制:256.0MB 问题描述 有n个矩阵,大小分别为a0*a1, a1*a2, a2*a3, ..., a[n-1]*a[n],现要 ...
- Java实现 蓝桥杯 算法提高 矩阵乘法(暴力)
试题 算法提高 矩阵乘法 问题描述 小明最近刚刚学习了矩阵乘法,但是他计算的速度太慢,于是他希望你能帮他写一个矩阵乘法的运算器. 输入格式 输入的第一行包含三个正整数N,M,K,表示一个NM的矩阵乘以 ...
- Java实现 蓝桥杯 算法训练 矩阵乘法
算法训练 矩阵乘法 时间限制:1.0s 内存限制:512.0MB 提交此题 问题描述 输入两个矩阵,分别是ms,sn大小.输出两个矩阵相乘的结果. 输入格式 第一行,空格隔开的三个正整数m,s,n(均 ...
随机推荐
- VideoPlayer播放
播放网络视频.本地视频:可以暂停.前后拖动.快进.快退.音量调节.下一个视频 环境:Unity5.6以上 Unity正式发布了5.6版本后,作为5.x版本的最后一版还是有不少给力的更新的.其中新加入了 ...
- Python基础+模块、异常
date:2018414+2018415 day1+2 一.python基础 #coding=utf-8 #注释 #算数运算 +(加) -(减) *(乘) /(除) //(取整) %(取余) ...
- echarts环形图自动定位radius
根据后台返回数据条数进行pie图radius定位: var a = 100; var b = 0; var c = 0; var radius = []; for (var i in data ...
- L2-018. 多项式A除以B*
L2-018. 多项式A除以B 参考博客 #include <iostream> #include <map> #include <cmath> #include ...
- SpringBoot的学习【4.快速实现一个SpringBoo的应用】
1.引子 正常创建一个 Spring Boot 应用的顺序: 创建 Maven 项目 pom 文件导入依赖(参照 Spring 官方文档) 编写主程序 编写业务逻辑 但其实IDE( idea 和 Sp ...
- 关于idea的debug
idea的debug真的是超级好用哎.分享几个今天学会的新方式: 1.右键会发现此选项 ,点击出现 在输入框中输入,可以通过某些公式单独计算. 2.点击属性值,右键点击set values 会出现一个 ...
- 微信小程序跳转(当我们不知道是普通页面还是tabbar)
页面跳转一般我们都用wx.navigateTo 或者wx.redirectTo等,当页面为tabbar的某一个页面时, 我们盖如何兼容呢我处理的方式为在navigateTo的fail方法中执行wx.s ...
- ubuntu16.04+caffe+GPU+cuda+cudnn安装教程
步骤简述: 1.安装GPU驱动(系统适配,不采取手动安装的方式) 2.安装依赖(cuda依赖库,caffe依赖) 3.安装cuda 4.安装cudnn(只是复制文件加链接,不需要编译安装的过程) 5. ...
- License控制解决方案
当我们写完一个软件以后一般都会牵扯到软件控制,那么控制版本的原理是什么呢?其实就是在程序中添加了一段经过自己编写算法(这个算法可以是简单的公式运算,也可以是复杂的结合硬件的绑定方式),将形成的序列号注 ...
- .net4.0调用非托管DLL的异常捕获
转发: 由于有些非托管的DLL内部异常未有效处理,当托管程序调用到这样的DLL时,就引起托管程序意外退出. 托管程序使用通常的捕获try……catch块不起作用.原因是.NET 4.0里新的异常处理机 ...