...

Source file src/net/rpc/server_test.go

Documentation: net/rpc

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rpc
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http/httptest"
    14  	"reflect"
    15  	"runtime"
    16  	"strings"
    17  	"sync"
    18  	"sync/atomic"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  var (
    24  	newServer                 *Server
    25  	serverAddr, newServerAddr string
    26  	httpServerAddr            string
    27  	once, newOnce, httpOnce   sync.Once
    28  )
    29  
    30  const (
    31  	newHttpPath = "/foo"
    32  )
    33  
    34  type Args struct {
    35  	A, B int
    36  }
    37  
    38  type Reply struct {
    39  	C int
    40  }
    41  
    42  type Arith int
    43  
    44  // Some of Arith's methods have value args, some have pointer args. That's deliberate.
    45  
    46  func (t *Arith) Add(args Args, reply *Reply) error {
    47  	reply.C = args.A + args.B
    48  	return nil
    49  }
    50  
    51  func (t *Arith) Mul(args *Args, reply *Reply) error {
    52  	reply.C = args.A * args.B
    53  	return nil
    54  }
    55  
    56  func (t *Arith) Div(args Args, reply *Reply) error {
    57  	if args.B == 0 {
    58  		return errors.New("divide by zero")
    59  	}
    60  	reply.C = args.A / args.B
    61  	return nil
    62  }
    63  
    64  func (t *Arith) String(args *Args, reply *string) error {
    65  	*reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
    66  	return nil
    67  }
    68  
    69  func (t *Arith) Scan(args string, reply *Reply) (err error) {
    70  	_, err = fmt.Sscan(args, &reply.C)
    71  	return
    72  }
    73  
    74  func (t *Arith) Error(args *Args, reply *Reply) error {
    75  	panic("ERROR")
    76  }
    77  
    78  func (t *Arith) SleepMilli(args *Args, reply *Reply) error {
    79  	time.Sleep(time.Duration(args.A) * time.Millisecond)
    80  	return nil
    81  }
    82  
    83  type hidden int
    84  
    85  func (t *hidden) Exported(args Args, reply *Reply) error {
    86  	reply.C = args.A + args.B
    87  	return nil
    88  }
    89  
    90  type Embed struct {
    91  	hidden
    92  }
    93  
    94  type BuiltinTypes struct{}
    95  
    96  func (BuiltinTypes) Map(args *Args, reply *map[int]int) error {
    97  	(*reply)[args.A] = args.B
    98  	return nil
    99  }
   100  
   101  func (BuiltinTypes) Slice(args *Args, reply *[]int) error {
   102  	*reply = append(*reply, args.A, args.B)
   103  	return nil
   104  }
   105  
   106  func (BuiltinTypes) Array(args *Args, reply *[2]int) error {
   107  	(*reply)[0] = args.A
   108  	(*reply)[1] = args.B
   109  	return nil
   110  }
   111  
   112  func listenTCP() (net.Listener, string) {
   113  	l, err := net.Listen("tcp", "127.0.0.1:0") // any available address
   114  	if err != nil {
   115  		log.Fatalf("net.Listen tcp :0: %v", err)
   116  	}
   117  	return l, l.Addr().String()
   118  }
   119  
   120  func startServer() {
   121  	Register(new(Arith))
   122  	Register(new(Embed))
   123  	RegisterName("net.rpc.Arith", new(Arith))
   124  	Register(BuiltinTypes{})
   125  
   126  	var l net.Listener
   127  	l, serverAddr = listenTCP()
   128  	log.Println("Test RPC server listening on", serverAddr)
   129  	go Accept(l)
   130  
   131  	HandleHTTP()
   132  	httpOnce.Do(startHttpServer)
   133  }
   134  
   135  func startNewServer() {
   136  	newServer = NewServer()
   137  	newServer.Register(new(Arith))
   138  	newServer.Register(new(Embed))
   139  	newServer.RegisterName("net.rpc.Arith", new(Arith))
   140  	newServer.RegisterName("newServer.Arith", new(Arith))
   141  
   142  	var l net.Listener
   143  	l, newServerAddr = listenTCP()
   144  	log.Println("NewServer test RPC server listening on", newServerAddr)
   145  	go newServer.Accept(l)
   146  
   147  	newServer.HandleHTTP(newHttpPath, "/bar")
   148  	httpOnce.Do(startHttpServer)
   149  }
   150  
   151  func startHttpServer() {
   152  	server := httptest.NewServer(nil)
   153  	httpServerAddr = server.Listener.Addr().String()
   154  	log.Println("Test HTTP RPC server listening on", httpServerAddr)
   155  }
   156  
   157  func TestRPC(t *testing.T) {
   158  	once.Do(startServer)
   159  	testRPC(t, serverAddr)
   160  	newOnce.Do(startNewServer)
   161  	testRPC(t, newServerAddr)
   162  	testNewServerRPC(t, newServerAddr)
   163  }
   164  
   165  func testRPC(t *testing.T, addr string) {
   166  	client, err := Dial("tcp", addr)
   167  	if err != nil {
   168  		t.Fatal("dialing", err)
   169  	}
   170  	defer client.Close()
   171  
   172  	// Synchronous calls
   173  	args := &Args{7, 8}
   174  	reply := new(Reply)
   175  	err = client.Call("Arith.Add", args, reply)
   176  	if err != nil {
   177  		t.Errorf("Add: expected no error but got string %q", err.Error())
   178  	}
   179  	if reply.C != args.A+args.B {
   180  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   181  	}
   182  
   183  	// Methods exported from unexported embedded structs
   184  	args = &Args{7, 0}
   185  	reply = new(Reply)
   186  	err = client.Call("Embed.Exported", args, reply)
   187  	if err != nil {
   188  		t.Errorf("Add: expected no error but got string %q", err.Error())
   189  	}
   190  	if reply.C != args.A+args.B {
   191  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   192  	}
   193  
   194  	// Nonexistent method
   195  	args = &Args{7, 0}
   196  	reply = new(Reply)
   197  	err = client.Call("Arith.BadOperation", args, reply)
   198  	// expect an error
   199  	if err == nil {
   200  		t.Error("BadOperation: expected error")
   201  	} else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
   202  		t.Errorf("BadOperation: expected can't find method error; got %q", err)
   203  	}
   204  
   205  	// Unknown service
   206  	args = &Args{7, 8}
   207  	reply = new(Reply)
   208  	err = client.Call("Arith.Unknown", args, reply)
   209  	if err == nil {
   210  		t.Error("expected error calling unknown service")
   211  	} else if !strings.Contains(err.Error(), "method") {
   212  		t.Error("expected error about method; got", err)
   213  	}
   214  
   215  	// Out of order.
   216  	args = &Args{7, 8}
   217  	mulReply := new(Reply)
   218  	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
   219  	addReply := new(Reply)
   220  	addCall := client.Go("Arith.Add", args, addReply, nil)
   221  
   222  	addCall = <-addCall.Done
   223  	if addCall.Error != nil {
   224  		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
   225  	}
   226  	if addReply.C != args.A+args.B {
   227  		t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
   228  	}
   229  
   230  	mulCall = <-mulCall.Done
   231  	if mulCall.Error != nil {
   232  		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
   233  	}
   234  	if mulReply.C != args.A*args.B {
   235  		t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
   236  	}
   237  
   238  	// Error test
   239  	args = &Args{7, 0}
   240  	reply = new(Reply)
   241  	err = client.Call("Arith.Div", args, reply)
   242  	// expect an error: zero divide
   243  	if err == nil {
   244  		t.Error("Div: expected error")
   245  	} else if err.Error() != "divide by zero" {
   246  		t.Error("Div: expected divide by zero error; got", err)
   247  	}
   248  
   249  	// Bad type.
   250  	reply = new(Reply)
   251  	err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
   252  	if err == nil {
   253  		t.Error("expected error calling Arith.Add with wrong arg type")
   254  	} else if !strings.Contains(err.Error(), "type") {
   255  		t.Error("expected error about type; got", err)
   256  	}
   257  
   258  	// Non-struct argument
   259  	const Val = 12345
   260  	str := fmt.Sprint(Val)
   261  	reply = new(Reply)
   262  	err = client.Call("Arith.Scan", &str, reply)
   263  	if err != nil {
   264  		t.Errorf("Scan: expected no error but got string %q", err.Error())
   265  	} else if reply.C != Val {
   266  		t.Errorf("Scan: expected %d got %d", Val, reply.C)
   267  	}
   268  
   269  	// Non-struct reply
   270  	args = &Args{27, 35}
   271  	str = ""
   272  	err = client.Call("Arith.String", args, &str)
   273  	if err != nil {
   274  		t.Errorf("String: expected no error but got string %q", err.Error())
   275  	}
   276  	expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
   277  	if str != expect {
   278  		t.Errorf("String: expected %s got %s", expect, str)
   279  	}
   280  
   281  	args = &Args{7, 8}
   282  	reply = new(Reply)
   283  	err = client.Call("Arith.Mul", args, reply)
   284  	if err != nil {
   285  		t.Errorf("Mul: expected no error but got string %q", err.Error())
   286  	}
   287  	if reply.C != args.A*args.B {
   288  		t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
   289  	}
   290  
   291  	// ServiceName contain "." character
   292  	args = &Args{7, 8}
   293  	reply = new(Reply)
   294  	err = client.Call("net.rpc.Arith.Add", args, reply)
   295  	if err != nil {
   296  		t.Errorf("Add: expected no error but got string %q", err.Error())
   297  	}
   298  	if reply.C != args.A+args.B {
   299  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   300  	}
   301  }
   302  
   303  func testNewServerRPC(t *testing.T, addr string) {
   304  	client, err := Dial("tcp", addr)
   305  	if err != nil {
   306  		t.Fatal("dialing", err)
   307  	}
   308  	defer client.Close()
   309  
   310  	// Synchronous calls
   311  	args := &Args{7, 8}
   312  	reply := new(Reply)
   313  	err = client.Call("newServer.Arith.Add", args, reply)
   314  	if err != nil {
   315  		t.Errorf("Add: expected no error but got string %q", err.Error())
   316  	}
   317  	if reply.C != args.A+args.B {
   318  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   319  	}
   320  }
   321  
   322  func TestHTTP(t *testing.T) {
   323  	once.Do(startServer)
   324  	testHTTPRPC(t, "")
   325  	newOnce.Do(startNewServer)
   326  	testHTTPRPC(t, newHttpPath)
   327  }
   328  
   329  func testHTTPRPC(t *testing.T, path string) {
   330  	var client *Client
   331  	var err error
   332  	if path == "" {
   333  		client, err = DialHTTP("tcp", httpServerAddr)
   334  	} else {
   335  		client, err = DialHTTPPath("tcp", httpServerAddr, path)
   336  	}
   337  	if err != nil {
   338  		t.Fatal("dialing", err)
   339  	}
   340  	defer client.Close()
   341  
   342  	// Synchronous calls
   343  	args := &Args{7, 8}
   344  	reply := new(Reply)
   345  	err = client.Call("Arith.Add", args, reply)
   346  	if err != nil {
   347  		t.Errorf("Add: expected no error but got string %q", err.Error())
   348  	}
   349  	if reply.C != args.A+args.B {
   350  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   351  	}
   352  }
   353  
   354  func TestBuiltinTypes(t *testing.T) {
   355  	once.Do(startServer)
   356  
   357  	client, err := DialHTTP("tcp", httpServerAddr)
   358  	if err != nil {
   359  		t.Fatal("dialing", err)
   360  	}
   361  	defer client.Close()
   362  
   363  	// Map
   364  	args := &Args{7, 8}
   365  	replyMap := map[int]int{}
   366  	err = client.Call("BuiltinTypes.Map", args, &replyMap)
   367  	if err != nil {
   368  		t.Errorf("Map: expected no error but got string %q", err.Error())
   369  	}
   370  	if replyMap[args.A] != args.B {
   371  		t.Errorf("Map: expected %d got %d", args.B, replyMap[args.A])
   372  	}
   373  
   374  	// Slice
   375  	args = &Args{7, 8}
   376  	replySlice := []int{}
   377  	err = client.Call("BuiltinTypes.Slice", args, &replySlice)
   378  	if err != nil {
   379  		t.Errorf("Slice: expected no error but got string %q", err.Error())
   380  	}
   381  	if e := []int{args.A, args.B}; !reflect.DeepEqual(replySlice, e) {
   382  		t.Errorf("Slice: expected %v got %v", e, replySlice)
   383  	}
   384  
   385  	// Array
   386  	args = &Args{7, 8}
   387  	replyArray := [2]int{}
   388  	err = client.Call("BuiltinTypes.Array", args, &replyArray)
   389  	if err != nil {
   390  		t.Errorf("Array: expected no error but got string %q", err.Error())
   391  	}
   392  	if e := [2]int{args.A, args.B}; !reflect.DeepEqual(replyArray, e) {
   393  		t.Errorf("Array: expected %v got %v", e, replyArray)
   394  	}
   395  }
   396  
   397  // CodecEmulator provides a client-like api and a ServerCodec interface.
   398  // Can be used to test ServeRequest.
   399  type CodecEmulator struct {
   400  	server        *Server
   401  	serviceMethod string
   402  	args          *Args
   403  	reply         *Reply
   404  	err           error
   405  }
   406  
   407  func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
   408  	codec.serviceMethod = serviceMethod
   409  	codec.args = args
   410  	codec.reply = reply
   411  	codec.err = nil
   412  	var serverError error
   413  	if codec.server == nil {
   414  		serverError = ServeRequest(codec)
   415  	} else {
   416  		serverError = codec.server.ServeRequest(codec)
   417  	}
   418  	if codec.err == nil && serverError != nil {
   419  		codec.err = serverError
   420  	}
   421  	return codec.err
   422  }
   423  
   424  func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
   425  	req.ServiceMethod = codec.serviceMethod
   426  	req.Seq = 0
   427  	return nil
   428  }
   429  
   430  func (codec *CodecEmulator) ReadRequestBody(argv any) error {
   431  	if codec.args == nil {
   432  		return io.ErrUnexpectedEOF
   433  	}
   434  	*(argv.(*Args)) = *codec.args
   435  	return nil
   436  }
   437  
   438  func (codec *CodecEmulator) WriteResponse(resp *Response, reply any) error {
   439  	if resp.Error != "" {
   440  		codec.err = errors.New(resp.Error)
   441  	} else {
   442  		*codec.reply = *(reply.(*Reply))
   443  	}
   444  	return nil
   445  }
   446  
   447  func (codec *CodecEmulator) Close() error {
   448  	return nil
   449  }
   450  
   451  func TestServeRequest(t *testing.T) {
   452  	once.Do(startServer)
   453  	testServeRequest(t, nil)
   454  	newOnce.Do(startNewServer)
   455  	testServeRequest(t, newServer)
   456  }
   457  
   458  func testServeRequest(t *testing.T, server *Server) {
   459  	client := CodecEmulator{server: server}
   460  	defer client.Close()
   461  
   462  	args := &Args{7, 8}
   463  	reply := new(Reply)
   464  	err := client.Call("Arith.Add", args, reply)
   465  	if err != nil {
   466  		t.Errorf("Add: expected no error but got string %q", err.Error())
   467  	}
   468  	if reply.C != args.A+args.B {
   469  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   470  	}
   471  
   472  	err = client.Call("Arith.Add", nil, reply)
   473  	if err == nil {
   474  		t.Errorf("expected error calling Arith.Add with nil arg")
   475  	}
   476  }
   477  
   478  type ReplyNotPointer int
   479  type ArgNotPublic int
   480  type ReplyNotPublic int
   481  type NeedsPtrType int
   482  type local struct{}
   483  
   484  func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
   485  	return nil
   486  }
   487  
   488  func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
   489  	return nil
   490  }
   491  
   492  func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
   493  	return nil
   494  }
   495  
   496  func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
   497  	return nil
   498  }
   499  
   500  // Check that registration handles lots of bad methods and a type with no suitable methods.
   501  func TestRegistrationError(t *testing.T) {
   502  	err := Register(new(ReplyNotPointer))
   503  	if err == nil {
   504  		t.Error("expected error registering ReplyNotPointer")
   505  	}
   506  	err = Register(new(ArgNotPublic))
   507  	if err == nil {
   508  		t.Error("expected error registering ArgNotPublic")
   509  	}
   510  	err = Register(new(ReplyNotPublic))
   511  	if err == nil {
   512  		t.Error("expected error registering ReplyNotPublic")
   513  	}
   514  	err = Register(NeedsPtrType(0))
   515  	if err == nil {
   516  		t.Error("expected error registering NeedsPtrType")
   517  	} else if !strings.Contains(err.Error(), "pointer") {
   518  		t.Error("expected hint when registering NeedsPtrType")
   519  	}
   520  }
   521  
   522  type WriteFailCodec int
   523  
   524  func (WriteFailCodec) WriteRequest(*Request, any) error {
   525  	// the panic caused by this error used to not unlock a lock.
   526  	return errors.New("fail")
   527  }
   528  
   529  func (WriteFailCodec) ReadResponseHeader(*Response) error {
   530  	select {}
   531  }
   532  
   533  func (WriteFailCodec) ReadResponseBody(any) error {
   534  	select {}
   535  }
   536  
   537  func (WriteFailCodec) Close() error {
   538  	return nil
   539  }
   540  
   541  func TestSendDeadlock(t *testing.T) {
   542  	client := NewClientWithCodec(WriteFailCodec(0))
   543  	defer client.Close()
   544  
   545  	done := make(chan bool)
   546  	go func() {
   547  		testSendDeadlock(client)
   548  		testSendDeadlock(client)
   549  		done <- true
   550  	}()
   551  	select {
   552  	case <-done:
   553  		return
   554  	case <-time.After(5 * time.Second):
   555  		t.Fatal("deadlock")
   556  	}
   557  }
   558  
   559  func testSendDeadlock(client *Client) {
   560  	defer func() {
   561  		recover()
   562  	}()
   563  	args := &Args{7, 8}
   564  	reply := new(Reply)
   565  	client.Call("Arith.Add", args, reply)
   566  }
   567  
   568  func dialDirect() (*Client, error) {
   569  	return Dial("tcp", serverAddr)
   570  }
   571  
   572  func dialHTTP() (*Client, error) {
   573  	return DialHTTP("tcp", httpServerAddr)
   574  }
   575  
   576  func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
   577  	once.Do(startServer)
   578  	client, err := dial()
   579  	if err != nil {
   580  		t.Fatal("error dialing", err)
   581  	}
   582  	defer client.Close()
   583  
   584  	args := &Args{7, 8}
   585  	reply := new(Reply)
   586  	return testing.AllocsPerRun(100, func() {
   587  		err := client.Call("Arith.Add", args, reply)
   588  		if err != nil {
   589  			t.Errorf("Add: expected no error but got string %q", err.Error())
   590  		}
   591  		if reply.C != args.A+args.B {
   592  			t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   593  		}
   594  	})
   595  }
   596  
   597  func TestCountMallocs(t *testing.T) {
   598  	if testing.Short() {
   599  		t.Skip("skipping malloc count in short mode")
   600  	}
   601  	if runtime.GOMAXPROCS(0) > 1 {
   602  		t.Skip("skipping; GOMAXPROCS>1")
   603  	}
   604  	fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
   605  }
   606  
   607  func TestCountMallocsOverHTTP(t *testing.T) {
   608  	if testing.Short() {
   609  		t.Skip("skipping malloc count in short mode")
   610  	}
   611  	if runtime.GOMAXPROCS(0) > 1 {
   612  		t.Skip("skipping; GOMAXPROCS>1")
   613  	}
   614  	fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
   615  }
   616  
   617  type writeCrasher struct {
   618  	done chan bool
   619  }
   620  
   621  func (writeCrasher) Close() error {
   622  	return nil
   623  }
   624  
   625  func (w *writeCrasher) Read(p []byte) (int, error) {
   626  	<-w.done
   627  	return 0, io.EOF
   628  }
   629  
   630  func (writeCrasher) Write(p []byte) (int, error) {
   631  	return 0, errors.New("fake write failure")
   632  }
   633  
   634  func TestClientWriteError(t *testing.T) {
   635  	w := &writeCrasher{done: make(chan bool)}
   636  	c := NewClient(w)
   637  	defer c.Close()
   638  
   639  	res := false
   640  	err := c.Call("foo", 1, &res)
   641  	if err == nil {
   642  		t.Fatal("expected error")
   643  	}
   644  	if err.Error() != "fake write failure" {
   645  		t.Error("unexpected value of error:", err)
   646  	}
   647  	w.done <- true
   648  }
   649  
   650  func TestTCPClose(t *testing.T) {
   651  	once.Do(startServer)
   652  
   653  	client, err := dialHTTP()
   654  	if err != nil {
   655  		t.Fatalf("dialing: %v", err)
   656  	}
   657  	defer client.Close()
   658  
   659  	args := Args{17, 8}
   660  	var reply Reply
   661  	err = client.Call("Arith.Mul", args, &reply)
   662  	if err != nil {
   663  		t.Fatal("arith error:", err)
   664  	}
   665  	t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
   666  	if reply.C != args.A*args.B {
   667  		t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
   668  	}
   669  }
   670  
   671  func TestErrorAfterClientClose(t *testing.T) {
   672  	once.Do(startServer)
   673  
   674  	client, err := dialHTTP()
   675  	if err != nil {
   676  		t.Fatalf("dialing: %v", err)
   677  	}
   678  	err = client.Close()
   679  	if err != nil {
   680  		t.Fatal("close error:", err)
   681  	}
   682  	err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
   683  	if err != ErrShutdown {
   684  		t.Errorf("Forever: expected ErrShutdown got %v", err)
   685  	}
   686  }
   687  
   688  // Tests the fix to issue 11221. Without the fix, this loops forever or crashes.
   689  func TestAcceptExitAfterListenerClose(t *testing.T) {
   690  	newServer := NewServer()
   691  	newServer.Register(new(Arith))
   692  	newServer.RegisterName("net.rpc.Arith", new(Arith))
   693  	newServer.RegisterName("newServer.Arith", new(Arith))
   694  
   695  	var l net.Listener
   696  	l, _ = listenTCP()
   697  	l.Close()
   698  	newServer.Accept(l)
   699  }
   700  
   701  func TestShutdown(t *testing.T) {
   702  	var l net.Listener
   703  	l, _ = listenTCP()
   704  	ch := make(chan net.Conn, 1)
   705  	go func() {
   706  		defer l.Close()
   707  		c, err := l.Accept()
   708  		if err != nil {
   709  			t.Error(err)
   710  		}
   711  		ch <- c
   712  	}()
   713  	c, err := net.Dial("tcp", l.Addr().String())
   714  	if err != nil {
   715  		t.Fatal(err)
   716  	}
   717  	c1 := <-ch
   718  	if c1 == nil {
   719  		t.Fatal(err)
   720  	}
   721  
   722  	newServer := NewServer()
   723  	newServer.Register(new(Arith))
   724  	go newServer.ServeConn(c1)
   725  
   726  	args := &Args{7, 8}
   727  	reply := new(Reply)
   728  	client := NewClient(c)
   729  	err = client.Call("Arith.Add", args, reply)
   730  	if err != nil {
   731  		t.Fatal(err)
   732  	}
   733  
   734  	// On an unloaded system 10ms is usually enough to fail 100% of the time
   735  	// with a broken server. On a loaded system, a broken server might incorrectly
   736  	// be reported as passing, but we're OK with that kind of flakiness.
   737  	// If the code is correct, this test will never fail, regardless of timeout.
   738  	args.A = 10 // 10 ms
   739  	done := make(chan *Call, 1)
   740  	call := client.Go("Arith.SleepMilli", args, reply, done)
   741  	c.(*net.TCPConn).CloseWrite()
   742  	<-done
   743  	if call.Error != nil {
   744  		t.Fatal(err)
   745  	}
   746  }
   747  
   748  func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
   749  	once.Do(startServer)
   750  	client, err := dial()
   751  	if err != nil {
   752  		b.Fatal("error dialing:", err)
   753  	}
   754  	defer client.Close()
   755  
   756  	// Synchronous calls
   757  	args := &Args{7, 8}
   758  	b.ResetTimer()
   759  
   760  	b.RunParallel(func(pb *testing.PB) {
   761  		reply := new(Reply)
   762  		for pb.Next() {
   763  			err := client.Call("Arith.Add", args, reply)
   764  			if err != nil {
   765  				b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
   766  			}
   767  			if reply.C != args.A+args.B {
   768  				b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
   769  			}
   770  		}
   771  	})
   772  }
   773  
   774  func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
   775  	if b.N == 0 {
   776  		return
   777  	}
   778  	const MaxConcurrentCalls = 100
   779  	once.Do(startServer)
   780  	client, err := dial()
   781  	if err != nil {
   782  		b.Fatal("error dialing:", err)
   783  	}
   784  	defer client.Close()
   785  
   786  	// Asynchronous calls
   787  	args := &Args{7, 8}
   788  	procs := 4 * runtime.GOMAXPROCS(-1)
   789  	send := int32(b.N)
   790  	recv := int32(b.N)
   791  	var wg sync.WaitGroup
   792  	wg.Add(procs)
   793  	gate := make(chan bool, MaxConcurrentCalls)
   794  	res := make(chan *Call, MaxConcurrentCalls)
   795  	b.ResetTimer()
   796  
   797  	for p := 0; p < procs; p++ {
   798  		go func() {
   799  			for atomic.AddInt32(&send, -1) >= 0 {
   800  				gate <- true
   801  				reply := new(Reply)
   802  				client.Go("Arith.Add", args, reply, res)
   803  			}
   804  		}()
   805  		go func() {
   806  			for call := range res {
   807  				A := call.Args.(*Args).A
   808  				B := call.Args.(*Args).B
   809  				C := call.Reply.(*Reply).C
   810  				if A+B != C {
   811  					b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C)
   812  					return
   813  				}
   814  				<-gate
   815  				if atomic.AddInt32(&recv, -1) == 0 {
   816  					close(res)
   817  				}
   818  			}
   819  			wg.Done()
   820  		}()
   821  	}
   822  	wg.Wait()
   823  }
   824  
   825  func BenchmarkEndToEnd(b *testing.B) {
   826  	benchmarkEndToEnd(dialDirect, b)
   827  }
   828  
   829  func BenchmarkEndToEndHTTP(b *testing.B) {
   830  	benchmarkEndToEnd(dialHTTP, b)
   831  }
   832  
   833  func BenchmarkEndToEndAsync(b *testing.B) {
   834  	benchmarkEndToEndAsync(dialDirect, b)
   835  }
   836  
   837  func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
   838  	benchmarkEndToEndAsync(dialHTTP, b)
   839  }
   840  

View as plain text