Bläddra i källkod

Update KDB Tree creation

maxep 6 år sedan
förälder
incheckning
a4b5fa97e7
4 ändrade filer med 116 tillägg och 55 borttagningar
  1. 25 29
      Sources/KDB/Database.swift
  2. 25 5
      Sources/KDB/Entry.swift
  3. 37 3
      Sources/KDB/Group.swift
  4. 29 18
      Sources/KDB/Row.swift

+ 25 - 29
Sources/KDB/Database.swift

@@ -27,7 +27,7 @@ public class Database {
 
     let header: Header
 
-    public private(set) var root: Group!
+    public let root: Group
 
     public required init(from input: Input, compositeKey: CompositeKey) throws {
         header = try input.read()
@@ -52,7 +52,7 @@ public class Database {
         }
 
         let stream = Input(bytes: content)
-        self.root = try tree(from: stream)
+        self.root = try Root(from: stream, groups: Int(header.groups), entries: Int(header.entries))
     }
 
     public convenience init(from file: URL, compositeKey: CompositeKey) throws {
@@ -79,46 +79,42 @@ extension Database: Writable {
 
 }
 
-extension Database {
+private func Root(from input: Input, groups: Int, entries: Int) throws -> Group {
 
-    private func tree(from input: Input) throws -> Group {
+    let groups: [Group] = try input.read(maxLenght: groups)
+    let entries: [Entry] = try input.read(maxLenght: entries)
 
-        let groups: [Group] = try input.read(maxLenght: Int(header.groups))
-        let entries: [Entry] = try input.read(maxLenght: Int(header.entries))
+    let root = Group()
 
-        let root = Group()
+    for (i, group) in groups.enumerated() {
 
-        for (i, group) in groups.enumerated() {
+        guard let level1: UInt16 = group[.groupLevel] else { throw KDBError.corruptedDatabase }
 
-            guard let level1: UInt16 = group[.groupLevel] else { throw KDBError.corruptedDatabase }
-
-            if level1 == 0 {
-                root.childs.append(group)
-                continue
-            }
-
-            for (j, parent) in groups[0..<i].enumerated().reversed() {
-                guard let level2: UInt16 = parent[.groupLevel] else { throw KDBError.corruptedDatabase }
+        if level1 == 0 {
+            root.childs.append(group)
+            continue
+        }
 
-                if level2 < level1 {
-                    guard (level1 - level2) == 1 else { throw KDBError.corruptedDatabase }
-                    parent.childs.append(group)
-                    break
-                }
+        for (j, parent) in groups[0..<i].enumerated().reversed() {
+            guard let level2: UInt16 = parent[.groupLevel] else { throw KDBError.corruptedDatabase }
 
-                guard j > 0 else { throw KDBError.corruptedDatabase }
+            if level2 < level1 {
+                guard (level1 - level2) == 1 else { throw KDBError.corruptedDatabase }
+                parent.childs.append(group)
+                break
             }
 
+            guard j > 0 else { throw KDBError.corruptedDatabase }
         }
 
-        for entry in entries {
-            guard let groupID: UInt32 = entry[.groupID] else { throw KDBError.corruptedDatabase }
+    }
 
-            let group = try groups.first(where: .groupID, { $0 == groupID }) ?? root
-            group.entries.append(entry)
-        }
+    for entry in entries {
+        guard let groupID: UInt32 = entry[.groupID] else { throw KDBError.corruptedDatabase }
 
-        return root
+        let group = try groups.first(where: .groupID, { $0 == groupID }) ?? root
+        group.add(entry)
     }
 
+    return root
 }

+ 25 - 5
Sources/KDB/Entry.swift

@@ -41,9 +41,9 @@ let MetaEntryKeePassKitUserTemplates         = "KeePassKit User Templates"
 
 public final class Entry: Row, Streamable {
 
-    public static let End = Field.end
+    public static let End = Type.end
 
-    public enum Field: UInt16, Streamable {
+    public enum `Type`: UInt16, Streamable {
         case reserved           = 0x0000
         case uuid               = 0x0001
         case groupID            = 0x0002
@@ -62,12 +62,13 @@ public final class Entry: Row, Streamable {
         case end                = 0xFFFF
     }
 
-    public var fields: [TLV<Field, UInt32>]
+    var parent: Group?
+
+    public var fields: [Field<Type>]
 
     public required init() {
         fields = []
-    }
-    
+    }    
 }
 
 extension Entry {
@@ -75,4 +76,23 @@ extension Entry {
     var isMetaEntry: Bool {
         return false
     }
+
+    public func removeFromParent() {
+        parent?.entries.removeAll(where: { $0 == self })
+        fields.removeAll(.groupID)
+    }
+}
+
+extension Entry: Hashable {
+
+    public static func == (lhs: Entry, rhs: Entry) -> Bool {
+        guard let lhs = lhs[.uuid], let rhs = rhs[.uuid] else { return false }
+        return lhs == rhs
+    }
+
+    public func hash(into hasher: inout Hasher) {
+        if let uuid = self[.uuid] { hasher.combine(uuid) }
+        else if let title = self[.title] { hasher.combine(title) }
+    }
+
 }

+ 37 - 3
Sources/KDB/Group.swift

@@ -21,9 +21,9 @@ import Binary
 
 public final class Group: Row, Streamable {
 
-    public static let End = Field.end
+    public static let End = Type.end
 
-    public enum Field: UInt16, Streamable {
+    public enum `Type`: UInt16, Streamable {
         case reserved           = 0x0000
         case groupID            = 0x0001
         case name               = 0x0002
@@ -37,7 +37,9 @@ public final class Group: Row, Streamable {
         case end                = 0xFFFF
     }
 
-    public var fields: [TLV<Field, UInt32>]
+    var parent: Group?
+
+    public var fields: [Field<Type>]
 
     public var childs: [Group]
 
@@ -49,3 +51,35 @@ public final class Group: Row, Streamable {
         entries = []
     }
 }
+
+extension Group {
+
+    public func removeFromParent() {
+        parent?.childs.removeAll(where: { $0 == self })
+    }
+
+    public func add(_ entry: Entry) {
+        entry.removeFromParent()
+        entries.append(entry)
+        entry[.groupID] = self[.groupID]
+    }
+
+    public func add(_ group: Group) {
+        group.removeFromParent()
+        childs.append(group)
+    }
+}
+
+extension Group: Hashable {
+
+    public static func == (lhs: Group, rhs: Group) -> Bool {
+        guard let lhs = lhs[.groupLevel], let rhs = rhs[.groupLevel] else { return false }
+        return lhs == rhs
+    }
+
+    public func hash(into hasher: inout Hasher) {
+        if let groupLevel = self[.groupLevel] { hasher.combine(groupLevel) }
+        if let groupFlags = self[.groupFlags] { hasher.combine(groupFlags) }
+    }
+
+}

+ 29 - 18
Sources/KDB/Row.swift

@@ -19,34 +19,46 @@
 import Foundation
 import Binary
 
-public protocol Row {
-    associatedtype Field: Streamable, Equatable
+public typealias Field<Type> = TLV<Type, UInt32>
 
-    init()
+public protocol Row: class {
+
+    associatedtype `Type`: Streamable, Equatable
+
+    static var End: Type { get }
 
-    var fields: [TLV<Field, UInt32>] { get set }
+    var fields: [Field<Type>] { get set }
 
-    static var End: Field { get }
+    init()
 }
 
 extension Row {
 
-    public subscript (_ field: Field) -> Bytes? {
-        get { fields.first(where: { $0.type == field })?.value }
+    public subscript(_ type: Type) -> Bytes? {
+        get { fields.first(where: { $0.type == type })?.value }
         set {
-            fields.removeAll(where: { $0.type == field })
+            fields.removeAll(where: { $0.type == type })
             guard let value = newValue else { return }
-            let tlv = TLV<Field, UInt32>(type: field, value: value)
+            let tlv = Field(type: type, value: value)
             fields.insert(tlv, at: 0)
         }
     }
 
-    public subscript <T>(_ field: Field) -> T? where T: BytesRepresentable {
+    public func set(_ field: Field<Type>) {
+        fields.removeAll(where: { $0.type == field.type })
+        fields.insert(field, at: 0)
+    }
+
+    public subscript<T>(_ type: Type) -> T? where T: BytesRepresentable {
         get {
-            guard let bytes = self[field] else { return nil }
+            guard let bytes = self[type] else { return nil }
             return try? T(bytes)
         }
-        set { self[field] = newValue?.bytes }
+        set { self[type] = newValue?.bytes }
+    }
+
+    public func remove(_ type: Type) {
+        fields.removeAll(where: { $0.type == type })
     }
 }
 
@@ -55,7 +67,7 @@ extension Readable where Self: Row {
     public init(from input: Input) throws {
         self.init()
         while true {
-            let field: TLV<Field, UInt32> = try input.read()
+            let field = try input.read() as Field<Type>
             guard field.type != Self.End else { break }
             fields.append(field)
         }
@@ -67,7 +79,7 @@ extension Writable where Self: Row {
 
     public func write(to output: Output) throws {
         try output.write(fields)
-        let end = TLV<Field, UInt32>(type: Self.End, value: [])
+        let end = Field(type: Self.End, value: [])
         try output.write(end)
     }
     
@@ -75,19 +87,18 @@ extension Writable where Self: Row {
 
 extension Sequence where Element: Row {
 
-    public func first<T>(where field: Element.Field, _ predicate: (T) throws -> Bool) throws -> Element? where T: BytesRepresentable {
+    public func first<T>(where type: Element.`Type`, _ predicate: (T) throws -> Bool) throws -> Element? where T: BytesRepresentable {
         return try first(where: {
-            guard let bytes = $0[field] else { return false }
+            guard let bytes = $0[type] else { return false }
             return try predicate(try T(bytes))
         })
     }
 
-    public func sorted<T>(field: Element.Field, by areInIncreasingOrder: (T, T) throws -> Bool) throws -> [Self.Element] where T: BytesRepresentable {
+    public func sorted<T>(field: Element.`Type`, by areInIncreasingOrder: (T, T) throws -> Bool) throws -> [Self.Element] where T: BytesRepresentable {
         return try sorted(by: {
             guard let rhs = $0[field], let lhs = $1[field] else { return false }
             return try areInIncreasingOrder(try T(rhs), try T(lhs))
         })
     }
 
-    
 }