...
  
  
     1  
     2  
     3  
     4  
     5  
     6  
     7  package httpresponse
     8  
     9  import (
    10  	"go/ast"
    11  	"go/types"
    12  
    13  	"golang.org/x/tools/go/analysis"
    14  	"golang.org/x/tools/go/analysis/passes/inspect"
    15  	"golang.org/x/tools/go/ast/inspector"
    16  	"golang.org/x/tools/internal/analysisinternal"
    17  	"golang.org/x/tools/internal/typesinternal"
    18  )
    19  
    20  const Doc = `check for mistakes using HTTP responses
    21  
    22  A common mistake when using the net/http package is to defer a function
    23  call to close the http.Response Body before checking the error that
    24  determines whether the response is valid:
    25  
    26  	resp, err := http.Head(url)
    27  	defer resp.Body.Close()
    28  	if err != nil {
    29  		log.Fatal(err)
    30  	}
    31  	// (defer statement belongs here)
    32  
    33  This checker helps uncover latent nil dereference bugs by reporting a
    34  diagnostic for such mistakes.`
    35  
    36  var Analyzer = &analysis.Analyzer{
    37  	Name:     "httpresponse",
    38  	Doc:      Doc,
    39  	URL:      "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/httpresponse",
    40  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    41  	Run:      run,
    42  }
    43  
    44  func run(pass *analysis.Pass) (any, error) {
    45  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    46  
    47  	
    48  	
    49  	if !analysisinternal.Imports(pass.Pkg, "net/http") {
    50  		return nil, nil
    51  	}
    52  
    53  	nodeFilter := []ast.Node{
    54  		(*ast.CallExpr)(nil),
    55  	}
    56  	inspect.WithStack(nodeFilter, func(n ast.Node, push bool, stack []ast.Node) bool {
    57  		if !push {
    58  			return true
    59  		}
    60  		call := n.(*ast.CallExpr)
    61  		if !isHTTPFuncOrMethodOnClient(pass.TypesInfo, call) {
    62  			return true 
    63  		}
    64  
    65  		
    66  		
    67  		stmts, ncalls := restOfBlock(stack)
    68  		if len(stmts) < 2 {
    69  			
    70  			return true
    71  		}
    72  
    73  		
    74  		
    75  		if ncalls > 1 {
    76  			return true
    77  		}
    78  
    79  		asg, ok := stmts[0].(*ast.AssignStmt)
    80  		if !ok {
    81  			return true 
    82  		}
    83  
    84  		resp := rootIdent(asg.Lhs[0])
    85  		if resp == nil {
    86  			return true 
    87  		}
    88  
    89  		def, ok := stmts[1].(*ast.DeferStmt)
    90  		if !ok {
    91  			return true 
    92  		}
    93  		root := rootIdent(def.Call.Fun)
    94  		if root == nil {
    95  			return true 
    96  		}
    97  
    98  		if resp.Obj == root.Obj {
    99  			pass.ReportRangef(root, "using %s before checking for errors", resp.Name)
   100  		}
   101  		return true
   102  	})
   103  	return nil, nil
   104  }
   105  
   106  
   107  
   108  
   109  func isHTTPFuncOrMethodOnClient(info *types.Info, expr *ast.CallExpr) bool {
   110  	fun, _ := expr.Fun.(*ast.SelectorExpr)
   111  	sig, _ := info.Types[fun].Type.(*types.Signature)
   112  	if sig == nil {
   113  		return false 
   114  	}
   115  
   116  	res := sig.Results()
   117  	if res.Len() != 2 {
   118  		return false 
   119  	}
   120  	isPtr, named := typesinternal.ReceiverNamed(res.At(0))
   121  	if !isPtr || named == nil || !analysisinternal.IsTypeNamed(named, "net/http", "Response") {
   122  		return false 
   123  	}
   124  
   125  	errorType := types.Universe.Lookup("error").Type()
   126  	if !types.Identical(res.At(1).Type(), errorType) {
   127  		return false 
   128  	}
   129  
   130  	typ := info.Types[fun.X].Type
   131  	if typ == nil {
   132  		id, ok := fun.X.(*ast.Ident)
   133  		return ok && id.Name == "http" 
   134  	}
   135  
   136  	if analysisinternal.IsTypeNamed(typ, "net/http", "Client") {
   137  		return true 
   138  	}
   139  	ptr, ok := types.Unalias(typ).(*types.Pointer)
   140  	return ok && analysisinternal.IsTypeNamed(ptr.Elem(), "net/http", "Client") 
   141  }
   142  
   143  
   144  
   145  
   146  func restOfBlock(stack []ast.Node) ([]ast.Stmt, int) {
   147  	var ncalls int
   148  	for i := len(stack) - 1; i >= 0; i-- {
   149  		if b, ok := stack[i].(*ast.BlockStmt); ok {
   150  			for j, v := range b.List {
   151  				if v == stack[i+1] {
   152  					return b.List[j:], ncalls
   153  				}
   154  			}
   155  			break
   156  		}
   157  
   158  		if _, ok := stack[i].(*ast.CallExpr); ok {
   159  			ncalls++
   160  		}
   161  	}
   162  	return nil, 0
   163  }
   164  
   165  
   166  func rootIdent(n ast.Node) *ast.Ident {
   167  	switch n := n.(type) {
   168  	case *ast.SelectorExpr:
   169  		return rootIdent(n.X)
   170  	case *ast.Ident:
   171  		return n
   172  	default:
   173  		return nil
   174  	}
   175  }
   176  
View as plain text