channelz: pass parent pointer instead of parent ID to RegisterSubChannel (#7101)

diff --git a/channelz/service/service_test.go b/channelz/service/service_test.go
index 4e41542..8214f12 100644
--- a/channelz/service/service_test.go
+++ b/channelz/service/service_test.go
@@ -334,7 +334,7 @@
 		},
 	})
 
-	subChan := channelz.RegisterSubChannel(cids[0].ID, refNames[2])
+	subChan := channelz.RegisterSubChannel(cids[0], refNames[2])
 	channelz.AddTraceEvent(logger, subChan, 0, &channelz.TraceEvent{
 		Desc:     "SubChannel Created",
 		Severity: channelz.CtInfo,
@@ -432,7 +432,7 @@
 		Desc:     "Channel Created",
 		Severity: channelz.CtInfo,
 	})
-	subChan := channelz.RegisterSubChannel(chann.ID, refNames[1])
+	subChan := channelz.RegisterSubChannel(chann, refNames[1])
 	defer channelz.RemoveEntry(subChan.ID)
 	channelz.AddTraceEvent(logger, subChan, 0, &channelz.TraceEvent{
 		Desc:     subchanCreated,
diff --git a/clientconn.go b/clientconn.go
index d16d058..c7f2607 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -833,7 +833,7 @@
 		addrs:        copyAddressesWithoutBalancerAttributes(addrs),
 		scopts:       opts,
 		dopts:        cc.dopts,
-		channelz:     channelz.RegisterSubChannel(cc.channelz.ID, ""),
+		channelz:     channelz.RegisterSubChannel(cc.channelz, ""),
 		resetBackoff: make(chan struct{}),
 		stateChan:    make(chan struct{}),
 	}
diff --git a/internal/channelz/funcs.go b/internal/channelz/funcs.go
index f461e9b..03e24e1 100644
--- a/internal/channelz/funcs.go
+++ b/internal/channelz/funcs.go
@@ -143,20 +143,21 @@
 // Returns a unique channelz identifier assigned to this subChannel.
 //
 // If channelz is not turned ON, the channelz database is not mutated.
-func RegisterSubChannel(pid int64, ref string) *SubChannel {
+func RegisterSubChannel(parent *Channel, ref string) *SubChannel {
 	id := IDGen.genID()
-	if !IsOn() {
-		return &SubChannel{ID: id}
+	sc := &SubChannel{
+		ID:      id,
+		RefName: ref,
+		parent:  parent,
 	}
 
-	sc := &SubChannel{
-		RefName: ref,
-		ID:      id,
-		sockets: make(map[int64]string),
-		parent:  db.getChannel(pid),
-		trace:   &ChannelTrace{CreationTime: time.Now(), Events: make([]*traceEvent, 0, getMaxTraceEntry())},
+	if !IsOn() {
+		return sc
 	}
-	db.addSubChannel(id, sc, pid)
+
+	sc.sockets = make(map[int64]string)
+	sc.trace = &ChannelTrace{CreationTime: time.Now(), Events: make([]*traceEvent, 0, getMaxTraceEntry())}
+	db.addSubChannel(id, sc, parent.ID)
 	return sc
 }
 
diff --git a/internal/transport/keepalive_test.go b/internal/transport/keepalive_test.go
index 3fafc38..393a454 100644
--- a/internal/transport/keepalive_test.go
+++ b/internal/transport/keepalive_test.go
@@ -249,6 +249,16 @@
 	}
 }
 
+func channelzSubChannel(t *testing.T) *channelz.SubChannel {
+	ch := channelz.RegisterChannel(nil, "test chan")
+	sc := channelz.RegisterSubChannel(ch, "test subchan")
+	t.Cleanup(func() {
+		channelz.RemoveEntry(sc.ID)
+		channelz.RemoveEntry(ch.ID)
+	})
+	return sc
+}
+
 // TestKeepaliveClientClosesUnresponsiveServer creates a server which does not
 // respond to keepalive pings, and makes sure that the client closes the
 // transport once the keepalive logic kicks in. Here, we set the
@@ -257,14 +267,13 @@
 func (s) TestKeepaliveClientClosesUnresponsiveServer(t *testing.T) {
 	connCh := make(chan net.Conn, 1)
 	copts := ConnectOptions{
-		ChannelzParent: channelz.RegisterSubChannel(-1, "test subchan"),
+		ChannelzParent: channelzSubChannel(t),
 		KeepaliveParams: keepalive.ClientParameters{
 			Time:                10 * time.Millisecond,
 			Timeout:             10 * time.Millisecond,
 			PermitWithoutStream: true,
 		},
 	}
-	defer channelz.RemoveEntry(copts.ChannelzParent.ID)
 	client, cancel := setUpWithNoPingServer(t, copts, connCh)
 	defer cancel()
 	defer client.Close(fmt.Errorf("closed manually by test"))
@@ -288,13 +297,12 @@
 func (s) TestKeepaliveClientOpenWithUnresponsiveServer(t *testing.T) {
 	connCh := make(chan net.Conn, 1)
 	copts := ConnectOptions{
-		ChannelzParent: channelz.RegisterSubChannel(-1, "test subchan"),
+		ChannelzParent: channelzSubChannel(t),
 		KeepaliveParams: keepalive.ClientParameters{
 			Time:    10 * time.Millisecond,
 			Timeout: 10 * time.Millisecond,
 		},
 	}
-	defer channelz.RemoveEntry(copts.ChannelzParent.ID)
 	client, cancel := setUpWithNoPingServer(t, copts, connCh)
 	defer cancel()
 	defer client.Close(fmt.Errorf("closed manually by test"))
@@ -319,13 +327,12 @@
 func (s) TestKeepaliveClientClosesWithActiveStreams(t *testing.T) {
 	connCh := make(chan net.Conn, 1)
 	copts := ConnectOptions{
-		ChannelzParent: channelz.RegisterSubChannel(-1, "test subchan"),
+		ChannelzParent: channelzSubChannel(t),
 		KeepaliveParams: keepalive.ClientParameters{
 			Time:    500 * time.Millisecond,
 			Timeout: 500 * time.Millisecond,
 		},
 	}
-	defer channelz.RemoveEntry(copts.ChannelzParent.ID)
 	// TODO(i/6099): Setup a server which can ping and no-ping based on a flag to
 	// reduce the flakiness in this test.
 	client, cancel := setUpWithNoPingServer(t, copts, connCh)
diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go
index 90ce78f..b0be892 100644
--- a/internal/transport/transport_test.go
+++ b/internal/transport/transport_test.go
@@ -434,8 +434,7 @@
 func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
 	server := setUpServerOnly(t, port, sc, ht)
 	addr := resolver.Address{Addr: "localhost:" + server.port}
-	copts.ChannelzParent = channelz.RegisterSubChannel(-1, "test channel")
-	t.Cleanup(func() { channelz.RemoveEntry(copts.ChannelzParent.ID) })
+	copts.ChannelzParent = channelzSubChannel(t)
 
 	connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
 	ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {})
@@ -1321,9 +1320,8 @@
 	connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	time.AfterFunc(100*time.Millisecond, cancel)
 
-	parent := channelz.RegisterSubChannel(-1, "test channel")
+	parent := channelzSubChannel(t)
 	copts := ConnectOptions{ChannelzParent: parent}
-	defer channelz.RemoveEntry(parent.ID)
 	_, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
 	if err == nil {
 		t.Fatalf("NewClientTransport() returned successfully; wanted error")
@@ -1414,8 +1412,7 @@
 	connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
 	defer cancel()
 
-	parent := channelz.RegisterSubChannel(-1, "test channel")
-	defer channelz.RemoveEntry(parent.ID)
+	parent := channelzSubChannel(t)
 	copts := ConnectOptions{ChannelzParent: parent}
 	ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
 	if err != nil {
@@ -2425,9 +2422,8 @@
 
 	copts := ConnectOptions{
 		TransportCredentials: creds,
-		ChannelzParent:       channelz.RegisterSubChannel(-1, "test subchannel"),
+		ChannelzParent:       channelzSubChannel(t),
 	}
-	defer channelz.RemoveEntry(copts.ChannelzParent.ID)
 	tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
 	if err != nil {
 		t.Fatalf("NewClientTransport(): %v", err)
@@ -2467,9 +2463,8 @@
 
 	copts := ConnectOptions{
 		Dialer:         dialer,
-		ChannelzParent: channelz.RegisterSubChannel(-1, "test subchannel"),
+		ChannelzParent: channelzSubChannel(t),
 	}
-	defer channelz.RemoveEntry(copts.ChannelzParent.ID)
 	tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
 	if err != nil {
 		t.Fatalf("NewClientTransport(): %v", err)
diff --git a/test/channelz_test.go b/test/channelz_test.go
index 1cc1e2e..cd6b77a 100644
--- a/test/channelz_test.go
+++ b/test/channelz_test.go
@@ -554,8 +554,8 @@
 	// Socket1       Socket2
 
 	topChan := channelz.RegisterChannel(nil, "")
-	subChan1 := channelz.RegisterSubChannel(topChan.ID, "")
-	subChan2 := channelz.RegisterSubChannel(topChan.ID, "")
+	subChan1 := channelz.RegisterSubChannel(topChan, "")
+	subChan2 := channelz.RegisterSubChannel(topChan, "")
 	skt1 := channelz.RegisterSocket(&channelz.Socket{SocketType: channelz.SocketTypeNormal, Parent: subChan1})
 	skt2 := channelz.RegisterSocket(&channelz.Socket{SocketType: channelz.SocketTypeNormal, Parent: subChan1})