In-place Radix sort O(k) space overhead

The following code implements an in-place Radix sort with O(k) space overhead. It currently doesn’t deal with signed values however that should be relatively easy to add this, the high bit just needs to be sorted in reverse order.

Unlike comparison sorts Radix sort only operates on integers with a complexity linear in the terms of the number of elements in the list (n) as a multiple of the number of digits in the integer. The implementation below operates in base 2.

This implementation sorts each bit in turn starting with the most significant bit. It operates by maintaining pointers to the top and bottom (I term them left and right below) of the array shuffling those elements with one at the current bit position to the left and zero to the right. It does this by swapping elements each time it encounters a pair in the wrong place and moving the left and right pointers toward each other until they meet.

While this sorts a single bit position, in order to sort the entire array the algorithm proceeds recursively. Once the one’s and zero’s have been sorted in the current position the Radix sort proceeds to sort all the one’s for the next lowest bit, and all the zero’s for the next lowest bit separately. While most easily described recursively, the implementation keeps track of the partitions in a vector and sorts these in a loop.

The implementation below keeps track of all the bin partitions, and thus space overhead is O(2^k) where k is the number of digits. A recursive solution would possibly be more efficient, as the algorithm would not need to keep track of all partitions down to the final bit.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include <iostream>
#include <vector>
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
 
using namespace std;
 
int bits = 32;
 
void dump_array(vector<int32_t> array);
void bin_dump(int32_t v);
void inplace_radix_sort_bit(vector<int32_t> &a,int start,int end,int bit,int &breakpoint);
 
void swap(vector<int32_t> &a,int l,int r) {
 
  int32_t x = a[l];
  a[l] = a[r];
  a[r] = x;
}
 
bool bit_is_set(int32_t v,int bit) {
  if((v & (1 << bit)) > 0) return true;
                      else return false;
}
 
 
void inplace_radix_sort(vector<int32_t> &a) {
 
  int breakpoint;
  vector<int> breakpoints;
 
  inplace_radix_sort_bit(a,0,a.size(),bits-1,breakpoint);
  breakpoints.push_back(0);
  breakpoints.push_back(breakpoint);
  breakpoints.push_back(a.size());
 
  for(int bit=bits-2;bit>=0;bit--) {
 
    cout << "breakpoints: ";
    for(int n=0;n<breakpoints.size();n++) cout << breakpoints[n] << " ";
    cout << endl;
 
    vector<int> newbreakpoints;
    for(int n=0;n<breakpoints.size()-1;n++) {
      if(breakpoints[n] != breakpoints[n+1])
     // inplace_radix_sort_bit(a,breakpoints[n]+1,breakpoints[n+1],bit,breakpoint);
      inplace_radix_sort_bit(a,breakpoints[n],breakpoints[n+1]-1,bit,breakpoint);
      newbreakpoints.push_back(breakpoint);
    }
 
    // create new breakpoint list (equiv for recursion)
    vector<int> mergebreakpoints;
    for(int n=0;n<breakpoints.size();n++) {
      mergebreakpoints.push_back(breakpoints[n]);
      if(n!=(breakpoints.size()-1)) mergebreakpoints.push_back(newbreakpoints[n]);
    }
    breakpoints = mergebreakpoints;
 
    // remove duplicates
    vector<int> cleanedbreakpoints;
    cleanedbreakpoints.push_back(breakpoints[0]);
    for(int n=1;n<breakpoints.size();n++) {
      if(breakpoints[n] != breakpoints[n-1]) cleanedbreakpoints.push_back(breakpoints[n]);
    }
    breakpoints = cleanedbreakpoints;
  }
 
}
 
void inplace_radix_sort_bit(vector<int32_t> &a,int start,int end,int bit,int &breakpoint) {
 
  // sort each bit posiiton
 
  cout << "sorting bit: " << bit << "  pos: " << start << " " << end << endl;
  int l_pos = start;
  int r_pos = end;
 
  for(;l_pos < r_pos;) {
    cout << "l_pos: " << l_pos << " r_pos: " << r_pos << endl;
    cout << "comparing: " << a[l_pos] << " " << a[r_pos] << " ";
    bin_dump(a[l_pos]);
    cout << " ";
    bin_dump(a[r_pos]);
    cout << endl;
       
    bool l_bit = bit_is_set(a[l_pos],bit);
    bool r_bit = bit_is_set(a[r_pos],bit);
    if(l_bit) cout << "l_bit: 1" << endl; else cout << "l_bit: 0" << endl;
    if(r_bit) cout << "r_bit: 1" << endl; else cout << "r_bit: 0" << endl;
 
    if(!l_bit &&  r_bit) { swap(a,l_pos,r_pos); l_pos++; r_pos--; cout << "swp"          << endl; if(l_pos == r_pos) if(bit_is_set(a[l_pos],bit)) {l_pos++; break;}} else
    if( l_bit && !r_bit) {                      l_pos++; r_pos--; cout << "10 linc rdec" << endl; } else
    if( l_bit &&  r_bit) {                      l_pos++;          cout << "11 linc"      << endl; if(l_pos == r_pos) {l_pos++; break;}} else
    if(!l_bit && !r_bit) {                               r_pos--; cout << "00 rdec"      << endl; if(l_pos == r_pos) { break;}}
  }
  breakpoint = l_pos;
  dump_array(a);
}
 
void bin_dump(int32_t v) {
 
  for(int n=bits-1;n>=0;n--) {
    if(v & (1 << n)) cout << "1"; else cout << "0";
  }
 
}
 
void dump_array(vector<int32_t> array) {
   
  for(int n=0;n<array.size();n++) {
    printf("%20d ",array[n]);
    bin_dump(array[n]);
    cout << endl;
  }
 
}
 
int main() {
 
  vector<int32_t> array;
  for(int n=0;n<20;n++) {
    array.push_back(rand());
  }
 
  cout << "random array" << endl;
  dump_array(array);
 
  inplace_radix_sort(array);
 
  cout << "sorted array" << endl;
  dump_array(array);
}

Leave a Reply