nixfiles/libs/aes/lib/default.nix
Sebastian Walz 9f7b02e1cd
Tohu vaBohu
2023-04-03 14:38:02 +02:00

258 lines
9.4 KiB
Nix

{ ... } @ libs:
let
inherit(builtins) bitAnd bitOr bitXor elemAt foldl' genList length;
inherit(import ./serde.nix libs) packDWord unpackDWord;
foot = this: elemAt this ((length this) - 1);
repeat
= rounds:
initial:
convert:
foldl'
convert
initial
(genList (x: x) rounds);
subTE0 = elemAt (import ./te0.nix);
subTE1 = elemAt (import ./te1.nix);
subTE2 = elemAt (import ./te2.nix);
subTE3 = elemAt (import ./te3.nix);
subSBox = elemAt (import ./sbox.nix);
substituteLine
= { byte0, byte1, byte2, byte3 }:
{
byte0 = subSBox byte0;
byte1 = subSBox byte1;
byte2 = subSBox byte2;
byte3 = subSBox byte3;
};
rotateLine
= { byte0, byte1, byte2, byte3 }:
{
byte0 = byte3;
byte1 = byte0;
byte2 = byte1;
byte3 = byte2;
};
getRoundConstant = elemAt (import ./rcon.nix);
expandRoundKey
= length:
key:
round:
let
prevKey = foot key;
finalLine
= bitXor
(getRoundConstant round)
(packDWord (substituteLine (rotateLine (unpackDWord prevKey.final))));
dword0 = bitXor prevKey.dword0 finalLine;
dword1 = bitXor prevKey.dword1 dword0;
dword2 = bitXor prevKey.dword2 dword1;
dword3 = bitXor prevKey.dword3 dword2;
dword3'
= if length > 6
then
packDWord (substituteLine (unpackDWord dword3));
else
dword3;
dword4 = bitXor prevKey.dword4 dword3';
dword5 = bitXor prevKey.dword5 dword4;
dword6 = bitXor prevKey.dword6 dword5;
dword7 = bitXor prevKey.dword7 dword6;
in
key
++ [
{
inherit dword0 dword1 dword2 dword3 dword4 dword5 dword6 dword7;
final
= elemAt
[ dword0 dword1 dword2 dword3 dword4 dword5 dword6 dword7 ]
length;
}
];
expandKey
= { length, rounds }:
key:
let
length' = length / 32;
in
{
__type__ = "AESkey";
inherit length;
roundKeys
= repeat
length'
[
{
dword0 = elemAt key 0;
dword1 = elemAt key 1;
dword2 = elemAt key 2;
dword3 = elemAt key 3;
dword4 = elemAt key 4;
dword5 = elemAt key 5;
dword6 = elemAt key 6;
dword7 = elemAt key 7;
final = elemAt key (length - 1);
}
]
(expandRoundKey length');
};
in
{
inherit expandKey;
expand128bitKey = expandKey { length = 128; rounds = 11; };
expand192bitKey = expandKey { length = 192; rounds = 13; };
expand256bitKey = expandKey { length = 256; rounds = 15; };
encrypt
= key:
message:
let
cipher
= foldl'
(
data:
round:
applyRoundKey round data
)
(unpackDWord message)
(genList (r: r) 4);
b = round 0 a;
c = round 1 b;
d = round 2 c;
e = round 3 d;
f = last 4 e;
result
= [
f.a
f.b
f.c
f.d
];
in
result
}
k = elemAt key';
round
= round:
{ a, b, c, d }:
let
i = 8 * round;
a' = bitXor a (k (i + 0));
b' = bitXor b (k (i + 1));
c' = bitXor c (k (i + 2));
d' = bitXor d (k (i + 3));
t0
= bitXor
(bitXor (te0 (byteAt a' 0)) (te1 (byteAt b' 1)))
(bitXor (te2 (byteAt c' 2)) (te3 (byteAt d' 3)));
t1
= bitXor
(bitXor (te0 (byteAt b' 0)) (te1 (byteAt c' 1)))
(bitXor (te2 (byteAt d' 2)) (te3 (byteAt a' 3)));
t2
= bitXor
(bitXor (te0 (byteAt c' 0)) (te1 (byteAt d' 1)))
(bitXor (te2 (byteAt a' 2)) (te3 (byteAt b' 3)));
t3
= bitXor
(bitXor (te0 (byteAt d' 0)) (te1 (byteAt a' 1)))
(bitXor (te2 (byteAt b' 2)) (te3 (byteAt c' 3)));
t0' = bitXor t0 (k (i + 4));
t1' = bitXor t1 (k (i + 5));
t2' = bitXor t2 (k (i + 6));
t3' = bitXor t3 (k (i + 7));
in
{
a
= bitXor
(bitXor (te0 (byteAt t0' 0)) (te1 (byteAt t1' 1)))
(bitXor (te2 (byteAt t2' 2)) (te3 (byteAt t3' 3)));
b
= bitXor
(bitXor (te0 (byteAt t1' 0)) (te1 (byteAt t2' 1)))
(bitXor (te2 (byteAt t3' 2)) (te3 (byteAt t0' 3)));
c
= bitXor
(bitXor (te0 (byteAt t2' 0)) (te1 (byteAt t3' 1)))
(bitXor (te2 (byteAt t0' 2)) (te3 (byteAt t1' 3)));
d
= bitXor
(bitXor (te0 (byteAt t3' 0)) (te1 (byteAt t0' 1)))
(bitXor (te2 (byteAt t1' 2)) (te3 (byteAt t2' 3)));
};
last
= round:
{ a, b, c, d }:
let
i = 8 * round;
a' = bitXor a (k (i + 0));
b' = bitXor b (k (i + 1));
c' = bitXor c (k (i + 2));
d' = bitXor d (k (i + 3));
t0
= bitXor
(bitXor (te0 (byteAt a' 0)) (te1 (byteAt b' 1)))
(bitXor (te2 (byteAt c' 2)) (te3 (byteAt d' 3)));
t1
= bitXor
(bitXor (te0 (byteAt b' 0)) (te1 (byteAt c' 1)))
(bitXor (te2 (byteAt d' 2)) (te3 (byteAt a' 3)));
t2
= bitXor
(bitXor (te0 (byteAt c' 0)) (te1 (byteAt d' 1)))
(bitXor (te2 (byteAt a' 2)) (te3 (byteAt b' 3)));
t3
= bitXor
(bitXor (te0 (byteAt d' 0)) (te1 (byteAt a' 1)))
(bitXor (te2 (byteAt b' 2)) (te3 (byteAt c' 3)));
t0' = bitXor t0 (k (i + 4));
t1' = bitXor t1 (k (i + 5));
t2' = bitXor t2 (k (i + 6));
t3' = bitXor t3 (k (i + 7));
a''
= bitXor
(bitXor (te4 (byteAt t0' 0)) (te4 (byteAt t1' 1) * shift1))
(bitXor (te4 (byteAt t2' 2) * shift2) (te4 (byteAt t3' 3) * shift3));
b''
= bitXor
(bitXor (te4 (byteAt t1' 0)) (te4 (byteAt t2' 1) * shift1))
(bitXor (te4 (byteAt t3' 2) * shift2) (te4 (byteAt t0' 3) * shift3));
c''
= bitXor
(bitXor (te4 (byteAt t2' 0)) (te4 (byteAt t3' 1) * shift1))
(bitXor (te4 (byteAt t0' 2) * shift2) (te4 (byteAt t1' 3) * shift3));
d''
= bitXor
(bitXor (te4 (byteAt t3' 0)) (te4 (byteAt t0' 1) * shift1))
(bitXor (te4 (byteAt t1' 2) * shift2) (te4 (byteAt t2' 3) * shift3));
in
{
a = bitXor a'' (k (i + 8));
b = bitXor b'' (k (i + 9));
c = bitXor c'' (k (i + 10));
d = bitXor d'' (k (i + 11));
};
in
data: