RMI源码浅析
2023-3-4 20:55:0 Author: xz.aliyun.com(查看原文) 阅读量:16 收藏

RMI

前言

Remote Method Invocation(远程方法调用),它是一种机制,能够让在某个 Java虚拟机上的对象调用另一个 Java 虚拟机中的对象上的方法。可以用此方法调用的任何对象必须实现该远程接口。调用这样一个对象时,其参数为 "marshalled" 并将其从本地虚拟机发送到远程虚拟机(该远程虚拟机的参数为 "unmarshalled")上。该方法终止时,将编组来自远程机的结果并将结果发送到调用方的虚拟机。如果方法调用导致抛出异常,则该异常将指示给调用方。它主要是为java分布式而设计的,但由于在数据传输时采用了序列化,并且没有做一定的过滤所以导致了一系列安全问题。

服务端创建Registry

Registry registry = LocateRegistry.createRegistry(1099);

RMI提供了一个静态方法用来创建和获取Registry,跟进源码可以看到new了一个RegistryImpl对象。

public static Registry createRegistry(int port) throws RemoteException {
    return new RegistryImpl(port);
}

继续跟进

public RegistryImpl(int port)
    throws RemoteException
{
    if (port == Registry.REGISTRY_PORT && System.getSecurityManager() != null) {
        // grant permission for default port only.
        try {
            AccessController.doPrivileged(new PrivilegedExceptionAction<Void>() {
                public Void run() throws RemoteException {
                    LiveRef lref = new LiveRef(id, port);
                    setup(new UnicastServerRef(lref));
                    return null;
                }
            }, null, new SocketPermission("localhost:"+port, "listen,accept"));
        } catch (PrivilegedActionException pae) {
            throw (RemoteException)pae.getException();
        }
    } else {
        LiveRef lref = new LiveRef(id, port);
        setup(new UnicastServerRef(lref));
    }
}

前面If判断主要是安全检查,然后创建了一个LiveRef对象和UnicastServerRef对象,然后转入setup方法。我们首先来看LiveRef对象的创建。

public LiveRef(ObjID objID, int port) {
    this(objID, TCPEndpoint.getLocalEndpoint(port), true);
}
public LiveRef(ObjID objID, Endpoint endpoint, boolean isLocal) {
    ep = endpoint;
    id = objID;
    this.isLocal = isLocal;
}

其中调用了两次构造方法,在第一次通过port创建了Endpoint对象,然后传入后面的构造方法对变量赋值。
继续跟进TCPEndpoint.getLocalEndpoint(port)

public static TCPEndpoint getLocalEndpoint(int port) {
    return getLocalEndpoint(port, null, null);
}

public static TCPEndpoint getLocalEndpoint(int port,
                                            RMIClientSocketFactory csf,
                                            RMIServerSocketFactory ssf)
{
    TCPEndpoint ep = null;

    synchronized (localEndpoints) {
        TCPEndpoint endpointKey = new TCPEndpoint(null, port, csf, ssf);
        LinkedList<TCPEndpoint> epList = localEndpoints.get(endpointKey);
        String localHost = resampleLocalHost();

        if (epList == null) {
            ep = new TCPEndpoint(localHost, port, csf, ssf);
            epList = new LinkedList<TCPEndpoint>();
            epList.add(ep);
            ep.listenPort = port;
            ep.transport = new TCPTransport(epList);
            localEndpoints.put(endpointKey, epList);

            if (TCPTransport.tcpLog.isLoggable(Log.BRIEF)) {
                TCPTransport.tcpLog.log(Log.BRIEF,
                    "created local endpoint for socket factory " + ssf +
                    " on port " + port);
            }
        } else {
            synchronized (epList) {
                ep = epList.getLast();
                String lastHost = ep.host;
                int lastPort =  ep.port;
                TCPTransport lastTransport = ep.transport;
                // assert (localHost == null ^ lastHost != null)
                if (localHost != null && !localHost.equals(lastHost)) {
                    if (lastPort != 0) {
                        epList.clear();
                    }
                    ep = new TCPEndpoint(localHost, lastPort, csf, ssf);
                    ep.listenPort = port;
                    ep.transport = lastTransport;
                    epList.add(ep);
                }
            }
        }
    }
    return ep;
}

这里创建了一个TCPEndpoint对象,然后加入静态变量localEndpoints集合中,它存放了不同端口的TCPEndpoint对象。每次创建TCPEndpoint都会先检查localEndpoints是否存在与之端口一致的对象,如果存在,且绑定的host与当前的localhost相同则直接使用该对象。这就是为什么当我们创建多个远程对象时,他们监听的都是一个端口。
TCPEndpoint中除了host,port等变量外还有一个真正负责网络传输的TCPTransport对象,TCPEndpoint只是一个抽象的网络连接对象,实际的socket相关的工作交给了TCPTransport对象负责。TCPTransport的初始化比较简单,就设置了变量值,就不贴源码了。
再回到LiveRef构造函数中设置了ep,id,isLocal变量值后就结束了。
然后进入UnicastServerRef对象的初始化。

public UnicastServerRef(LiveRef ref) {
    super(ref);
}
    public UnicastRef(LiveRef liveRef) {
    ref = liveRef;
}

他初始化就赋了个值,然后继续跟进RegistryImpl的setup函数。

private void setup(UnicastServerRef uref)
    throws RemoteException
{
    ref = uref;
    uref.exportObject(this, null, true);
}

可以看到这给把UnicastServerRef赋给了ref变量,现在大致的对象关系是RegistryImpl <- UnicastRef <- LiveRef <- TCPEndpoint <- TCPTransport
然后调用UnicastRef.exportObject()函数

public Remote exportObject(Remote impl, Object data,
                            boolean permanent)
    throws RemoteException
{
    Class<?> implClass = impl.getClass();
    Remote stub;

    try {
        stub = Util.createProxy(implClass, getClientRef(), forceStubUse);
    } catch (IllegalArgumentException e) {
        throw new ExportException(
            "remote object implements illegal remote interface", e);
    }
    if (stub instanceof RemoteStub) {
        setSkeleton(impl);
    }

    Target target =
        new Target(impl, this, stub, ref.getObjID(), permanent);
    ref.exportObject(target);
    hashToMethod_Map = hashToMethod_Maps.get(implClass);
    return stub;
}

这里首先调用Util.createProxy()方法创建了一个Stub。

public static Remote createProxy(Class<?> implClass,
                                  RemoteRef clientRef,
                                  boolean forceStubUse)
    throws StubNotFoundException
{
    Class<?> remoteClass;

    try {
        remoteClass = getRemoteClass(implClass);
    } catch (ClassNotFoundException ex ) {
        throw new StubNotFoundException(
            "object does not implement a remote interface: " +
            implClass.getName());
    }

    if (forceStubUse ||
        !(ignoreStubClasses || !stubClassExists(remoteClass)))
    {
        return createStub(remoteClass, clientRef);
    }

    final ClassLoader loader = implClass.getClassLoader();
    final Class<?>[] interfaces = getRemoteInterfaces(implClass);
    final InvocationHandler handler =
        new RemoteObjectInvocationHandler(clientRef);

    /* REMIND: private remote interfaces? */

    try {
        return AccessController.doPrivileged(new PrivilegedAction<Remote>() {
            public Remote run() {
                return (Remote) Proxy.newProxyInstance(loader,
                                                        interfaces,
                                                        handler);
            }});
    } catch (IllegalArgumentException e) {
        throw new StubNotFoundException("unable to create proxy", e);
    }
}

首先调用getRemoteClass()方法,检查是否继承了Remote接口,若没继承则产生ClassNotFoundException异常,返回实现了Remote接口的类。
然后判断是否存在以_Stub结尾的类,如果存在则调用createStub(),可以发现rmi包中存在RegistryImpl_Stub类,所以调用createStub创建RegistryImpl_Stub类,这个方法比较简单,就是通过反射实例化了RegistryImpl_Stub类,同时将上面创建的LiveRef对作为参数传入。
如果不存在以_Stub结尾的类则继续往下调用Proxy.newProxyInstance()创建一个代理类。创建完Stub后再回到UnicastServerRef中下面判断Stub是否是RemoteStub实例,这里可以看到RegistryImpl_Stub是继承了RemoteStub类的,所以这里继续跟进setSkeleton()方法。

public void setSkeleton(Remote impl) throws RemoteException {
    if (!withoutSkeletons.containsKey(impl.getClass())) {
        try {
            skel = Util.createSkeleton(impl);
        } catch (SkeletonNotFoundException e) {
            withoutSkeletons.put(impl.getClass(), null);
        }
    }
}

判断withoutSkeletons集合中是否包含RegistryImpl。如果对于一个类C,不存在存在C_Skel类则将类C放入该集合,这里可以看到是存在RegistryImpl_Skel类的,所以调用Util.createSkeleton()方法,该方法和前面的Util.createProxy()类似,就不再继续跟进了。
然后再把前面创建的RegistryImpl, UnicastRef,RegistryImpl_Stub类封装到Target对象中,继续调用LiveRef.exportObject() -> TCPEndpoint.exportObject() -> TCPTransport.exportObject(),最后调用了TCPTransport的exportObject方法。

public void exportObject(Target target) throws RemoteException {
    synchronized (this) {
        listen();
        exportCount++;
    }
    boolean ok = false;
    try {
        super.exportObject(target);
        ok = true;
    } finally {
        if (!ok) {
            synchronized (this) {
                decrementExportCount();
            }
        }
    }
}

这里调用了listen()创建socket并监听,下面调用super.exportObject(),将target添加到objTable。

客户端获取Registry

上面讲了服务端创建Registry的过程,下面再说一下客户端获取Registry的过程。客户端调用如下代码即可获取到服务端的Registry。

LocateRegistry.getRegistry("127.0.0.1", 1099);

然后跟进代码

public static Registry getRegistry(String host, int port)
    throws RemoteException
{
    return getRegistry(host, port, null);
}
public static Registry getRegistry(String host, int port,
                                    RMIClientSocketFactory csf)
    throws RemoteException
{
    Registry registry = null;

    if (port <= 0)
        port = Registry.REGISTRY_PORT;

    if (host == null || host.length() == 0) {
        try {
            host = java.net.InetAddress.getLocalHost().getHostAddress();
        } catch (Exception e) {
            host = "";
        }
    }

    LiveRef liveRef =
        new LiveRef(new ObjID(ObjID.REGISTRY_ID),
                    new TCPEndpoint(host, port, csf, null),
                    false);
    RemoteRef ref =
        (csf == null) ? new UnicastRef(liveRef) : new UnicastRef2(liveRef);

    return (Registry) Util.createProxy(RegistryImpl.class, ref, false);
}

前面判断端口如果小于0则取默认端口1099,host如果为空,则获取本地的host。
然后这里和上面一样创建了一个LiveRef类和UnicastRef类,UnicastRef类是UnicastServerRef的子类。然后实例化了RegistryImpl类,客户端获取Registry没有发起网络请求,只是创建了一个RegistryImpl_Stub对象。

服务端创建远程对象

远程对象必须满足以下条件

  • 实现java.rmi.Remote接口
  • 远程调用的方法必须抛出java.rmi.RemoteException异常

有些文章里面可能还说了要继承UnicastRemoteObject,是因为这个类实现了处理远程对象的方法,比如导出对象以及底层传输的协议(JRMP)等,如果不想实现程序运行也不会报错,也可以自己手动调用UnicastRemoteObject的静态导出对象方法。或者自己实现导出方法以及底层传输协议等操作。
下面以实现UnicastRemoteObject类为例分析源码。
首先创建一个远程对象接口RemoteObject并继承Remote接口,在该接口中定义远程方法同时抛出RemoteException异常。创建远程对象实现类继承UnicastRemoteObject和实现RemoteObject接口中的方法。其构造方法可以使用无参构造,创建对象时会自动调用UnicastRemoteObject的无参构造方法。

protected UnicastRemoteObject() throws RemoteException
{
    this(0);
}
protected UnicastRemoteObject(int port) throws RemoteException
{
    this.port = port;
    exportObject((Remote) this, port);
}
public static Remote exportObject(Remote obj, int port)
    throws RemoteException
{
    return exportObject(obj, new UnicastServerRef(port));
}
private static Remote exportObject(Remote obj, UnicastServerRef sref)
    throws RemoteException
{
    // if obj extends UnicastRemoteObject, set its ref.
    if (obj instanceof UnicastRemoteObject) {
        ((UnicastRemoteObject) obj).ref = sref;
    }
    return sref.exportObject(obj, null, false);
}

前面都是类似的,创建UnicastServerRef对象,然后里面封装LiveRef对象等。由于这里设置的端口是0,在最后创建Serversocket的时候会自动分配一个随机端口。
然后UnicastServerRef.exportObject()方法。

public Remote exportObject(Remote impl, Object data,
                            boolean permanent)
    throws RemoteException
{
    Class<?> implClass = impl.getClass();
    Remote stub;

    try {
        stub = Util.createProxy(implClass, getClientRef(), forceStubUse);
    } catch (IllegalArgumentException e) {
        throw new ExportException(
            "remote object implements illegal remote interface", e);
    }
    if (stub instanceof RemoteStub) {
        setSkeleton(impl);
    }

    Target target =
        new Target(impl, this, stub, ref.getObjID(), permanent);
    ref.exportObject(target);
    hashToMethod_Map = hashToMethod_Maps.get(implClass);
    return stub;
}

又是这个熟悉的方法,这里调用Util.createProxy()创建远程对象的代理。

final ClassLoader loader = implClass.getClassLoader();
final Class<?>[] interfaces = getRemoteInterfaces(implClass);
final InvocationHandler handler =
    new RemoteObjectInvocationHandler(clientRef);
try {
    return AccessController.doPrivileged(new PrivilegedAction<Remote>() {
        public Remote run() {
            return (Remote) Proxy.newProxyInstance(loader,
                                                    interfaces,
                                                    handler);
        }});
} catch (IllegalArgumentException e) {
    throw new StubNotFoundException("unable to create proxy", e);
}

这里我就只把创建代理对象这一块拿过来了,上面的步骤和之前一样检查是否实现Remote接口,然后检查是否存在以_Stub结尾的类并创建。
这里可以就是正常的创建代理对象三要素(类加载器,代理对象接口,实现了InvocationHandler的代理类),这里跟进获取接口的方法中。

private static Class<?>[] getRemoteInterfaces(Class<?> remoteClass) {
    ArrayList<Class<?>> list = new ArrayList<>();
    getRemoteInterfaces(list, remoteClass);
    return list.toArray(new Class<?>[list.size()]);
}

先创建了一个Arraylist存储所有的接口。然后调用getRemoteInterfaces()方法获取remoteClass中Remote及其子类的接口,最后把ArrayList转为数组,继续跟进其方法。

private static void getRemoteInterfaces(ArrayList<Class<?>> list, Class<?> cl) {
        Class<?> superclass = cl.getSuperclass();
        if (superclass != null) {
            getRemoteInterfaces(list, superclass);
        }

        Class<?>[] interfaces = cl.getInterfaces();
        for (int i = 0; i < interfaces.length; i++) {
            Class<?> intf = interfaces[i];
            if (Remote.class.isAssignableFrom(intf)) {
                if (!(list.contains(intf))) {
                    Method[] methods = intf.getMethods();
                    for (int j = 0; j < methods.length; j++) {
                        checkMethod(methods[j]);
                    }
                    list.add(intf);
                }
            }
        }
    }

这里面可以看到先获取其父类,然后如果存在则继续调用该方法,一直循环调用,知道最后不存在父类,然后检查其接口是否是Remote及其子类,如果是则继续获取其接口的Method,然后调用checkMethod()对每个方法进行检查,主要就是检查该方法是否抛出了RemoteException异常,这里就不再跟进了。现在就解释了为什么远程对象必须要继承Remote接口,并且远程方法必须抛出RemoteException异常。
然后再回到UnicastServerRef中的exportObject方法,后面就是创建Target对象,然后继续导出,步骤就和前面一样了。

服务端绑定对象

服务端可以调用bind()方法绑定远程对象。

public void bind(String name, Remote obj)
    throws RemoteException, AlreadyBoundException, AccessException
{
    checkAccess("Registry.bind");
    synchronized (bindings) {
        Remote curr = bindings.get(name);
        if (curr != null)
            throw new AlreadyBoundException(name);
        bindings.put(name, obj);
    }
}

前面创建了各种各样的对象,感觉有点乱了,现在来简单理一下
LocateRegistry.createRegistry(1099)返回一个RegistryImpl对象,然后其属性ref <- UnicastServerRef <- LiveRef <- TCPEndpoint <- TCPTransport ,skel <- RegistryImpl_Skel,还创建一个RegistryImpl_Stub被封装到Target对象中最后存入静态变量ObjectTable.objTable。

然后在绑定的时候把远程对象实现类rmiObjctImpl保存在了RegistryImpl的bindings集合中,它的代理对象也被封装到target中然后保存在ObjectTable.objTable中。

客户端查询注册端Registry

现在整个服务端创建RegistryImpl和远程对象及其绑定所有过程都结束了,都还没有发起网络请求,都只是监听了端口。
下面当客户端执行到lookup方法才真正发起网络请求。因为后面的分析涉及到客户端和服务端都会有响应,分析可能会有点乱,我尽量把他说清楚。

public Remote lookup(String var1) throws AccessException, NotBoundException, RemoteException {
    try {
        RemoteCall var2 = super.ref.newCall(this, operations, 2, 4905912898345647071L);

        try {
            ObjectOutput var3 = var2.getOutputStream();
            var3.writeObject(var1);
        } catch (IOException var18) {
            throw new MarshalException("error marshalling arguments", var18);
        }

        super.ref.invoke(var2);

        Remote var23;
        try {
            ObjectInput var6 = var2.getInputStream();
            var23 = (Remote)var6.readObject();
        } 
        ...
        return var23;
    } 
    ...
}

大概看一下就整个函数里面主要就执行了两个函数super.ref.newCall()和super.ref.invoke(),newCall的返回值作为了invoke方法的参数,然后还有个readObject,其他都是try catch相关的东西。
先跟进newCall函数中。

public RemoteCall newCall(RemoteObject obj, Operation[] ops, int opnum,long hash) throws RemoteException
{
    Connection conn = ref.getChannel().newConnection();
    try {
        ...
        RemoteCall call =
            new StreamRemoteCall(conn, ref.getObjID(), opnum, hash);
        try {
            marshalCustomCallData(call.getOutputStream());
        } catch (IOException e) {
            throw new MarshalException("error marshaling " +
                                        "custom call data");
        }
        return call;
    } catch (RemoteException e) {
        ref.getChannel().free(conn, false);
        throw e;
    }
}

在newConnection函数中从freeList中找一个可使用connection,如果没有最后调用createConnection方法创建一个。创建后发送一个握手包,里面包含了版本信息之类的,然后服务端返回一个确认包,最后回到newCall方法中,将上面创建的conn封装到StreamRemoteCall对象中返回。

回到最开始looup方法中,将查询的对象名字调用writeObject序列化写入输出流中,然后调用super.ref.invoke()方法。该方法里面又调了call.executeCall()方法。注意这里调用的对象call是前面创建newCall创建的。

public void executeCall() throws Exception {
    byte returnType;
    try {
      ...
        releaseOutputStream();
        DataInputStream rd = new DataInputStream(conn.getInputStream());
        byte op = rd.readByte();
    ...
        getInputStream();
        returnType = in.readByte();
        in.readID();        // id for DGC acknowledgement
    } catch (UnmarshalException e) {
      ...
    }

    // read return value
    switch (returnType) {
    case TransportConstants.NormalReturn:
        break;

    case TransportConstants.ExceptionalReturn:
        Object ex;
        try {
            ex = in.readObject();
        } catch (Exception e) {
            throw new UnmarshalException("Error unmarshaling return", e);
        }
        ...
    default:
    ...
    }
}

在函数前面先调用了一个releaseOutputStream()方法释放了输出流,就是将刚刚序列化写入输出流的远程对象名发送了出去。然后解析返回包,如果成功发送则returnType是1,即TransportConstants.NormalReturn就直接break结束了。如果出现了TransportConstants.ExceptionalReturn异常,这里会调用一个readObject读取该异常对象,所以这里就是一个反序列化点。

这里结束后,如果没有任何异常回到lookup方法,最后就从返回的字节流中读取反序列化的对象了。这个对象实际上是前面服务端创建的代理对象。

下面分析一下在服务端的情况,首先定位到服务端的监听代码。前面注册中心最后调用的TCPTransport.exportObject中有一个listen()函数,这里面实现了注册中的socket监听。
下面跟进listen()

private void listen() throws RemoteException {
    assert Thread.holdsLock(this);
    TCPEndpoint ep = getEndpoint();
    int port = ep.getPort();

    if (server == null) {
        try {
            server = ep.newServerSocket();

            Thread t = AccessController.doPrivileged(
                new NewThreadAction(new AcceptLoop(server),
                                    "TCP Accept-" + port, true));
            t.start();
        } 
        ...

    } else {
      ...
    }
}

先调用newServerSocket方法,创建一个Serversocket,若传入的port是0,则系统会自动分配一个随机端口。然后开启了一个新线程并用上面创建的socket初始化了AcceptLoop,我们跟进AcceptLoop的run()方法.

public void run() {
    try {
        executeAcceptLoop();
    } finally {
        try {
            serverSocket.close();
        } catch (IOException e) {
        }
    }
}
private void executeAcceptLoop() {
            while (true) {
                Socket socket = null;
                try {
                    socket = serverSocket.accept();
                    InetAddress clientAddr = socket.getInetAddress();
                    String clientHost = (clientAddr != null
                                         ? clientAddr.getHostAddress()
                                         : "0.0.0.0");
                    try {
                        connectionThreadPool.execute(
                            new ConnectionHandler(socket, clientHost));
                    } catch (RejectedExecutionException e) {
                        ...
                    }

                } catch (Throwable t) {
                    ...
                    }
                    ...
                }
            }
        }

它里面又执行了executeAcceptLoop()方法,这里面终于看到了accept()函数,serversockt会停在这等待连接,当客户端连接后它又创建了一个新的线程,我们继续跟进ConnectionHandler。

public void run() {
    Thread t = Thread.currentThread();
    String name = t.getName();
    try {
      ...
        AccessController.doPrivileged((PrivilegedAction<Void>)() -> {
            run0();
            return null;
        }, NOPERMS_ACC);
    } finally {
        t.setName(name);
    }
}

run()方法中又调用了run0(),其实前面说了那么多真正处理socket请求的方法就是run0(),其实这也是JRMP的实现。JRMP就是rmi底层网络传输的协议,这个方法太长了,下面会分成几段来说。

private void run0() {
    TCPEndpoint endpoint = getEndpoint();
    int port = endpoint.getPort();
      ...
    try {
        InputStream sockIn = socket.getInputStream();
        InputStream bufIn = sockIn.markSupported()
                ? sockIn
                : new BufferedInputStream(sockIn);

        // Read magic (or HTTP wrapper)
        bufIn.mark(4);
        DataInputStream in = new DataInputStream(bufIn);
        int magic = in.readInt();

        if (magic == POST) {
          ...//一些http请求的处理
        }
        short version = in.readShort();
        if (magic != TransportConstants.Magic ||
            version != TransportConstants.Version) {
            closeSocket(socket);
            return;
        }

        OutputStream sockOut = socket.getOutputStream();
        BufferedOutputStream bufOut = new BufferedOutputStream(sockOut);
        DataOutputStream out = new DataOutputStream(bufOut);

        int remotePort = socket.getPort();
        ...

        TCPEndpoint ep;
        TCPChannel ch;
        TCPConnection conn;

        // send ack (or nack) for protocol
        byte protocol = in.readByte();

这是第一部分是前面客户端发送的第一个数据包的处理部分,可以结合wireshark抓包看一下。

可以看到上面代码中读取了三次输入流

int magic = in.readInt();
short version = in.readShort();
byte protocol = in.readByte();

就是对应的数据包的三个参数值。然后下面就根据protocal进入对应的case分支语句。

switch (protocol) {
    case TransportConstants.SingleOpProtocol:
        ...
    case TransportConstants.StreamProtocol:
        // send ack
        out.writeByte(TransportConstants.ProtocolAck);
        ...
        out.writeUTF(remoteHost);
        out.writeInt(remotePort);
        out.flush();

        String clientHost = in.readUTF();
        int    clientPort = in.readInt();

        ep = new TCPEndpoint(remoteHost, socket.getLocalPort(),
                              endpoint.getClientSocketFactory(),
                              endpoint.getServerSocketFactory());
        ch = new TCPChannel(TCPTransport.this, ep);
        conn = new TCPConnection(ch, socket, bufIn, bufOut);

        // read input messages
        handleMessages(conn, true);
        break;

    case TransportConstants.MultiplexProtocol:
        ...

    default:
        // protocol not understood, send nack and close socket
        out.writeByte(TransportConstants.ProtocolNack);
        out.flush();
        break;
    }

因为上面抓包中protocal是0x4b(75),对应的第二个case分支,所以我就把其他代码删了,在这个case里面,可以看到前面先发送了ack,包含host,port和一个ack标志。
然后下面它又重新封装了一个TCPConnection对象传入handleMessage函数中。

void handleMessages(Connection conn, boolean persistent) {
    int port = getEndpoint().getPort();

    try {
        DataInputStream in = new DataInputStream(conn.getInputStream());
        do {
            int op = in.read();     // transport op
            ...
            switch (op) {
            case TransportConstants.Call:
                // service incoming RMI call
                RemoteCall call = new StreamRemoteCall(conn);
                if (serviceCall(call) == false)
                    return;
                break;

            case TransportConstants.Ping:
                // send ack for ping
                DataOutputStream out = new DataOutputStream(conn.getOutputStream());
                out.writeByte(TransportConstants.PingAck);
                conn.releaseOutputStream();
                break;

            case TransportConstants.DGCAck:
                DGCAckHandler.received(UID.read(in));
                break;

            default:
                throw new IOException("unknown transport op " + op);
            }
        } while (persistent);

    } 
    ...
}

这个函数有点类似与一个请求分发器,读取客户端请求的操作码,然后进入对应分支,这里的请求是Call,所以进入第一个分支。先是创建了一个StreamRemoteCall对象,还记得这个对象在前面哪提到过吗,就是在客户端请求的lookup方法里面的调用的newCall函数里面也创建这个对象。然后继续跟进

public boolean serviceCall(final RemoteCall call) {
    try {
        /* read object id */
        final Remote impl;
        ObjID id;
    ...
        Transport transport = id.equals(dgcID) ? null : this;
        Target target = ObjectTable.getTarget(new ObjectEndpoint(id, transport));
    ...
        final Dispatcher disp = target.getDispatcher();
        target.incrementCallCount();
        try {
            /* call the dispatcher */
            transportLog.log(Log.VERBOSE, "call dispatcher");

            final AccessControlContext acc = target.getAccessControlContext();
            ClassLoader ccl = target.getContextClassLoader();

            ClassLoader savedCcl = Thread.currentThread().getContextClassLoader();

            try {
                setContextClassLoader(ccl);
                currentTransport.set(this);
                try {
                    java.security.AccessController.doPrivileged(
                        new java.security.PrivilegedExceptionAction<Void>() {
                        public Void run() throws IOException {
                            checkAcceptPermission(acc);
                            disp.dispatch(impl, call);
                            return null;
                        }
                    }, acc);
                } catch (java.security.PrivilegedActionException pae) {
                    throw (IOException) pae.getException();
                }
            } finally {
                setContextClassLoader(savedCcl);
                currentTransport.set(null);
            }
        } catch (IOException ex) {
            transportLog.log(Log.BRIEF,"exception thrown by dispatcher: ", ex);
            return false;
        } finally {
            target.decrementCallCount();
        }
    } catch (RemoteException e) {
        ...
    }
    return true;
}

前面我们在创建代理对象的时候每个stub最后都被封装到了Target对象中最后保存到了静态对象ObjectTable.objTable中。这个函数里面开始根据id,transport获取了RegistryImpl_Stub对应的target,然后下面获取dispatcher,实际上就是UnicastServerRef对象。下面设置了一些值和异常处理,然后调用了disp.dispatch(impl, call),impl是从target中获取的,call是前面传递的函数参数。

public void dispatch(Remote obj, RemoteCall call) throws IOException {
        int num;
        long op;

        try {
            // read remote call header
            ObjectInput in;e
            try {
                in = call.getInputStream();
                num = in.readInt();
                if (num >= 0) {
                    if (skel != null) {
                        oldDispatch(obj, call, num);
                        return;
                    } else {
                        throw new UnmarshalException("skeleton class not found but required " + "for client version");
                    }
                }
                op = in.readLong();
            } catch (Exception readEx) {
                throw new UnmarshalException("error unmarshalling call header",
                                             readEx);
            }
            ....

首先读取了call数据包中的操作码,判断客户端是查询对象还是绑定远程对象或者解绑等操作。这了显然是skel是不为空的,它是RegistryImpl_Skel对象,所以继续调用oldDispatch()。

public void oldDispatch(Remote obj, RemoteCall call, int op)
        throws IOException
{
    long hash;              // hash for matching stub with skeleton

    try {
        // read remote call header
        ObjectInput in;
        try {
            in = call.getInputStream();
            try {
                Class<?> clazz = Class.forName("sun.rmi.transport.DGCImpl_Skel");
                if (clazz.isAssignableFrom(skel.getClass())) {
                    ((MarshalInputStream)in).useCodebaseOnly();
                }
            } catch (ClassNotFoundException ignore) { }
            hash = in.readLong();
        } catch (Exception readEx) {
            throw new UnmarshalException("error unmarshalling call header",
                                          readEx);
        }

        logCall(obj, skel.getOperations()[op]);
        unmarshalCustomCallData(in);
        // dispatch to skeleton for remote object
        skel.dispatch(obj, call, op, hash);

    } catch (Throwable e) {
        ...
    } finally {
        call.releaseInputStream(); // in case skeleton doesn't
        call.releaseOutputStream();
    }
}

这个函数里面先做了一些判断skel等,然后打印日志等。后面又调用了skel.dispatch(obj, call, op, hash)。

解释一下各个参数值,obj就是前面从target中获取的RegistryImpl_Stub,call是前面创建的客户端连接,op是前面读取的操作数,hash是读取的序列化对象的hash值,用于在反序列化前判断。

public void dispatch(Remote var1, RemoteCall var2, int var3, long var4) throws Exception {
    if (var4 != 4905912898345647071L) {
        throw new SkeletonMismatchException("interface hash mismatch");
    } else {
        RegistryImpl var6 = (RegistryImpl)var1;
        String var7;
        Remote var8;
        ObjectInput var10;
        ObjectInput var11;
        switch(var3) {
        case 0:
            ...
        case 1:
            ...
        case 2:
            try {
                var10 = var2.getInputStream();
                var7 = (String)var10.readObject();
            } catch (IOException var89) {
                throw new UnmarshalException("error unmarshalling arguments", var89);
            } catch (ClassNotFoundException var90) {
                throw new UnmarshalException("error unmarshalling arguments", var90);
            } finally {
                var2.releaseInputStream();
            }

            var8 = var6.lookup(var7);

            try {
                ObjectOutput var9 = var2.getResultStream(true);
                var9.writeObject(var8);
                break;
            } catch (IOException var88) {
                throw new MarshalException("error marshalling return", var88);
            }
        case 3:
            ...
        case 4:
            ...
        default:
            throw new UnmarshalException("invalid method number");
        }

    }
}

这里才是真正获取到远程对象的地方,这个类没有源码,只有class文件反编译的代码,也不能调试,所以代码不太好看。因为查询对象主要是第二个case分支,所以我就把其他代码删了。看case2的代码逻辑,它先从输入流中读取了一个对象,其实就是客户端序列化写入的远程对象名字的字符串,var6就是RegistryImpl,然后调用它的lookup从bindings中获取到远程对象。然后下面写入输出流中,最后回到上面的oldDispatch中的finnally语句中将输出流中的数据发送出去。

看到这大家应该就大概明白了客户端是怎么获取到远程对象的了,但有人细心调试后可能会发现一点猫腻,我们服务端绑定到bindings的是一个远程对象,我们这读取的到的也是远程对象,最后客户端获取到的对象怎么变成了它的代理对象。

我们知道我们服务端在创建远程对象的过程中还会调用Util.createProxy()创建了一个代理对象,这个代理对象最后被封装到了target对象中,然后存入ObjectTable.objTable静态变量中。我们继续跟进上面的writeObject方法,看看里面是怎么写入对象的。

可以看到这里它又调用了readObject0(),我们继续跟进。

最后我们发现当它执行到这里调用了replaceObject(obj),然后返回了其对应的代理对象。我们猜测这个方法可能被重写了,我们可以看到下面调试框中显示的当前对象this实际上是ConnectionOutputStream,我们跟进这个对象中找到这个方法。

最后我们在它的子类中找到了这个重写的方法,发现它这里使用远程对象在ObjectTable中查找了其对应的代理对象,这我们就知道了为什么我们客户端获取到的是代理对象。

客户端执行代理对象

上面说了客户端获取远程对象时客户端和服务端的行为,下面继续说一下客户端在获取到代理对象后执行函数时的代码。
上面我们所有的交互都是在和注册中心1099端口交互,现在我们获取到了远程对象的ip和端口,如果我们知道远程对象的ip和端口我们也可以不访问注册中心,直接访问远程对象。
我们前面说了远程对象的创建过程,知道它执行的invoke方法在RemoteObjectInvocationHandler类中,我们可以跟进看一下。

public Object invoke(Object proxy, Method method, Object[] args) throws Throwable
{
    if (! Proxy.isProxyClass(proxy.getClass())) {
        throw new IllegalArgumentException("not a proxy");
    }

    if (Proxy.getInvocationHandler(proxy) != this) {
        throw new IllegalArgumentException("handler mismatch");
    }

    if (method.getDeclaringClass() == Object.class) {
        return invokeObjectMethod(proxy, method, args);
    } else if ("finalize".equals(method.getName()) && method.getParameterCount() == 0 &&
        !allowFinalizeInvocation) {
        return null; // ignore
    } else {
        return invokeRemoteMethod(proxy, method, args);
    }
}

可以看到前面做了一些判断,然后判断调用的方法是否存在Object对象中(如hashcode,toString等),这些方法可以就在本地调用。其他的方法就调用invokeRemoteMethod(proxy, method, args)实现远程调用。跟进invokeRemoteMethod()方法可以看到它主要调用了UnicastRef.invoke()方法,继续跟进。

public Object invoke(Remote obj,Method method,Object[] params,long opnum) throws Exception
{
    ...
    Connection conn = ref.getChannel().newConnection();
    RemoteCall call = null;
    boolean reuse = true;
    boolean alreadyFreed = false;

    try {
      ...
        // create call context
        call = new StreamRemoteCall(conn, ref.getObjID(), -1, opnum);

        // marshal parameters
        try {
            ObjectOutput out = call.getOutputStream();
            marshalCustomCallData(out);
            Class<?>[] types = method.getParameterTypes();
            for (int i = 0; i < types.length; i++) {
                marshalValue(types[i], params[i], out);
            }
        } 
        ...
        // unmarshal return
        call.executeCall();

        try {
            Class<?> rtype = method.getReturnType();
            if (rtype == void.class)
                return null;
            ObjectInput in = call.getInputStream();
            Object returnValue = unmarshalValue(rtype, in);
            alreadyFreed = true;
            clientRefLog.log(Log.BRIEF, "free connection (reuse = true)");
            ref.getChannel().free(conn, true);
            return returnValue;
        } 
        ...
}

在这个函数中使用JRMP协议调用远程对象的方法,协议交互过程和lookup类似,先创建一个连接,自定义握手过程,然后将param通过marshalValue()方法序列化写入输出流,然后调用call.executeCall()将参数发送给服务端,然后判断返回值,如果是void就直接返回null,本次调用结束,否则调用unmarshalValue()获取返回值最后释放连接返回结果。

服务端的过程和前面lookup大致相同,只是在UnicastServerRef.dispatch()方法调用的过程中,判断的是否存在skel,如果存在,则调用oldDispatch,这就是上面lookup的逻辑,当方法调用时是远程对象的不存在其对应的以_Skel结尾的对象,所以这里判断结果为假,然后继续向下执行。下面就是方法的调用过程,基本上也就是上面客户端的逆过程,就不再分析代码了。

总结

总的来说,RMI整个过程主要涉及的知识点有Socket和动态代理。在服务端每个远程对象(RemoteObject)都会监听一个端口,同时创建一个代理对象(Stub)。注册中心可以说是一个特殊的远程对象。因为其他远程对象的端口实在创建过程中系统随机分配的,客户端只知道注册中心的端口,然后先请求注册中心,他们通过JRMP协议交互进行交互,获取对应远程对象的代理对象。当客户端获取到代理对象后就不会再和注册中心交互了,因为获取到的代理对象中包括了远程对象的监听端口等属性值,所以后面就可以通过代理对象访问远程对象最后将执行结果返回给客户端。在这个过程中的数据处理主要以玩序列化方式传输,所以可能导致反序列化漏洞。

对于客户端而言反序列化漏洞的利用点主要有下面几个地方

  • StreamRemoteCall.executeCall()中异常读取(line:245)
  • RegistryImpl_Stub.lookup()读取服务端返回的代理对象(line:104)
  • 客户端读取服务端执行结果返回值(UnicastRef.unmarshalValue:302)

对于服务端而言可能产生反序列化漏洞的利用点主要有下面几个地方

  • 服务端读取客户端(查询/绑定/解绑)对象字符串(RegistryImpl_Skel.dispatch)
  • 服务端读取客户端远程方法的参数值(UnicastRef.unmarshalValue:302)

参考链接

JAVA安全基础(四)-- RMI机制


文章来源: https://xz.aliyun.com/t/12254
如有侵权请联系:admin#unsafe.sh