@@ -643,12 +643,12 @@ func TestEngine_Sync(t *testing.T) {
643
643
644
644
func TestEngine_UpdateNetworkMapWithRoutes (t * testing.T ) {
645
645
testCases := []struct {
646
- name string
647
- inputErr error
648
- networkMap * mgmtProto.NetworkMap
649
- expectedLen int
650
- expectedRoutes [] * route.Route
651
- expectedSerial uint64
646
+ name string
647
+ inputErr error
648
+ networkMap * mgmtProto.NetworkMap
649
+ expectedLen int
650
+ expectedClientRoutes route.HAMap
651
+ expectedSerial uint64
652
652
}{
653
653
{
654
654
name : "Routes Config Should Be Passed To Manager" ,
@@ -676,22 +676,26 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
676
676
},
677
677
},
678
678
expectedLen : 2 ,
679
- expectedRoutes : []* route.Route {
680
- {
681
- ID : "a" ,
682
- Network : netip .MustParsePrefix ("192.168.0.0/24" ),
683
- NetID : "n1" ,
684
- Peer : "p1" ,
685
- NetworkType : 1 ,
686
- Masquerade : false ,
679
+ expectedClientRoutes : route.HAMap {
680
+ "n1|192.168.0.0/24" : []* route.Route {
681
+ {
682
+ ID : "a" ,
683
+ Network : netip .MustParsePrefix ("192.168.0.0/24" ),
684
+ NetID : "n1" ,
685
+ Peer : "p1" ,
686
+ NetworkType : 1 ,
687
+ Masquerade : false ,
688
+ },
687
689
},
688
- {
689
- ID : "b" ,
690
- Network : netip .MustParsePrefix ("192.168.1.0/24" ),
691
- NetID : "n2" ,
692
- Peer : "p1" ,
693
- NetworkType : 1 ,
694
- Masquerade : false ,
690
+ "n2|192.168.1.0/24" : []* route.Route {
691
+ {
692
+ ID : "b" ,
693
+ Network : netip .MustParsePrefix ("192.168.1.0/24" ),
694
+ NetID : "n2" ,
695
+ Peer : "p1" ,
696
+ NetworkType : 1 ,
697
+ Masquerade : false ,
698
+ },
695
699
},
696
700
},
697
701
expectedSerial : 1 ,
@@ -704,9 +708,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
704
708
RemotePeersIsEmpty : false ,
705
709
Routes : nil ,
706
710
},
707
- expectedLen : 0 ,
708
- expectedRoutes : [] * route. Route {} ,
709
- expectedSerial : 1 ,
711
+ expectedLen : 0 ,
712
+ expectedClientRoutes : nil ,
713
+ expectedSerial : 1 ,
710
714
},
711
715
{
712
716
name : "Error Shouldn't Break Engine" ,
@@ -717,9 +721,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
717
721
RemotePeersIsEmpty : false ,
718
722
Routes : nil ,
719
723
},
720
- expectedLen : 0 ,
721
- expectedRoutes : [] * route. Route {} ,
722
- expectedSerial : 1 ,
724
+ expectedLen : 0 ,
725
+ expectedClientRoutes : nil ,
726
+ expectedSerial : 1 ,
723
727
},
724
728
}
725
729
@@ -762,16 +766,29 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
762
766
engine .wgInterface , err = iface .NewWGIFace (opts )
763
767
assert .NoError (t , err , "shouldn't return error" )
764
768
input := struct {
765
- inputSerial uint64
766
- inputRoutes [] * route.Route
769
+ inputSerial uint64
770
+ clientRoutes route.HAMap
767
771
}{}
768
772
769
773
mockRouteManager := & routemanager.MockManager {
770
- UpdateRoutesFunc : func (updateSerial uint64 , newRoutes [ ]* route.Route ) error {
774
+ UpdateRoutesFunc : func (updateSerial uint64 , serverRoutes map [route. ID ]* route.Route , clientRoutes route. HAMap , useNewDNSRoute bool ) error {
771
775
input .inputSerial = updateSerial
772
- input .inputRoutes = newRoutes
776
+ input .clientRoutes = clientRoutes
773
777
return testCase .inputErr
774
778
},
779
+ ClassifyRoutesFunc : func (newRoutes []* route.Route ) (map [route.ID ]* route.Route , route.HAMap ) {
780
+ if len (newRoutes ) == 0 {
781
+ return nil , nil
782
+ }
783
+
784
+ // Classify all routes as client routes (not matching our public key)
785
+ clientRoutes := make (route.HAMap )
786
+ for _ , r := range newRoutes {
787
+ haID := r .GetHAUniqueID ()
788
+ clientRoutes [haID ] = append (clientRoutes [haID ], r )
789
+ }
790
+ return nil , clientRoutes
791
+ },
775
792
}
776
793
777
794
engine .routeManager = mockRouteManager
@@ -789,8 +806,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
789
806
err = engine .updateNetworkMap (testCase .networkMap )
790
807
assert .NoError (t , err , "shouldn't return error" )
791
808
assert .Equal (t , testCase .expectedSerial , input .inputSerial , "serial should match" )
792
- assert .Len (t , input .inputRoutes , testCase .expectedLen , "clientRoutes len should match" )
793
- assert .Equal (t , testCase .expectedRoutes , input .inputRoutes , "clientRoutes should match" )
809
+ assert .Len (t , input .clientRoutes , testCase .expectedLen , "clientRoutes len should match" )
810
+ assert .Equal (t , testCase .expectedClientRoutes , input .clientRoutes , "clientRoutes should match" )
794
811
})
795
812
}
796
813
}
@@ -951,7 +968,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
951
968
assert .NoError (t , err , "shouldn't return error" )
952
969
953
970
mockRouteManager := & routemanager.MockManager {
954
- UpdateRoutesFunc : func (updateSerial uint64 , newRoutes [ ]* route.Route ) error {
971
+ UpdateRoutesFunc : func (updateSerial uint64 , serverRoutes map [route. ID ]* route.Route , clientRoutes route. HAMap , useNewDNSRoute bool ) error {
955
972
return nil
956
973
},
957
974
}
0 commit comments