PS/Algorithm
행렬곱 알고리즘
hyelie
2022. 6. 21. 14:48
1. 일반적인 행렬곱 : $O(n^{3})$
vector<vector<int>> solution(vector<vector<int>> arr1, vector<vector<int>> arr2) {
int row1 = arr1.size(), col1 = arr1[0].size(); // = row2
int col2 = arr2[0].size();
vector<vector<int>> answer(row1, vector<int>(col2, 0));
for(int i =0; i<row1; i++){
for(int j = 0; j<col2; j++){
int temp = 0;
for(int k = 0; k<col1; k++){
temp += arr1[i][k] * arr2[k][j];
}
answer[i][j] = temp;
}
}
return answer;
}
일반적인 행렬의 곱은 위와 같이 구현할 것이다. 이 경우 시간복잡도는 3중포문을 돌기 때문에 $O(n^{3})$이다.
그러나 제일 안쪽의 for문을 보자. arr1[i][k] * arr2[k][j]의 형태를 띄고 있다. arr1의 경우에는 첫 번째 index인 i가 변하지 않아 cache에 arr을 저장하고 읽어오는 데 큰 문제가 없지만 arr2의 경우에는 매번 k가 바뀔 때마다 j*sizeof(element)만큼의 메모리를 jump해야하기 때문에 큰 행렬에 대해서 매번 cache miss가 날 것이다.
2. cache를 고려한 행렬곱 - KIJ : $O(n^{3})$
vector<vector<int>> solution(vector<vector<int>> arr1, vector<vector<int>> arr2) {
int row1 = arr1.size(), col1 = arr1[0].size(); // = row2
int col2 = arr2[0].size();
vector<vector<int>> answer(row1, vector<int>(col2, 0));
for(int k = 0; k<col1; k++){
for(int i = 0; i<row1; i++){
int temp = arr1[i][k];
for(int j = 0; j<col2; j++){
answer[i][j] += temp * arr2[k][j];
}
}
}
return answer;
}
위 코드에서
i : arr1의 row idx
j : arr2의 col idx
k : arr1의 col, arr2의 row idx
라고 생각하면 되고, 직관적인 IJK 순서 for loop보다는 KIJ 순서가 cache를 더 신경쓴 방법이기 때문에 더 빠르다. 3번째 for loop에서 answer[i][j], arr[k][j]를 보면 j가 바뀔 때 메모리 한 칸만큼 이동하기 때문에 cache hit rate가 높아질 것이다.
비록 이 코드가 제일 최적화된 코드는 아니지만 짧으면서도 효과적이기 때문에 너무 큰 행렬이 아니라면 KIJ 형식으로 행렬곱을 구현하도록 하자.
3. 슈트라센 알고리즘(divide and conquer)
행렬을 상하좌우 4개의 파트로 나눠 계산하는 방법이다. 단순히 행렬 2개를 곱하는 것보다 4개의 파트로 나누면, 중복되는 행렬이 생기기 때문에 행렬곱에서 시간적인 이득을 볼 수 있다. 시간복잡도는 $O(n^{2.81})$이다.
이 알고리즘을 쓸 일은 아마 없을 것 같기 때문에 내용을 작성하지는 않겠다. 필요 시 찾아보는 걸로~