멀티프로세싱에서 메모리 관리의 이해

멀티프로세싱에서 메모리 관리의 이해
Photo by Fredy Jacob / Unsplash

이 글을 쓰게 된 배경이 되는 사건의 발단은 다음과 같은데, accelerate를 이용해 8개 프로세스를 만들고 num_workers를 8로 설정했더니 100기가바이트에 달하는 엄청난 RAM을 먹는 문제가 관찰되었다.

그 이유는 Dataset 객체를 8개 프로세스당 8개 워커가 만들어, 64개의 Dataset 객체가 메모리에 상주하고 내가 사용하는 데이터셋이 인덱싱 역할하는 파일만 해도 용량이 상당히 큰 편이라 그랬다. 그럼 여기서 드는 의문. 멀티프로세싱 과정에서 객체는 어떻게 메모리에서 관리되는 것일까?

먼저, 프로세스간에 통신할 때 파이썬에서는 pickle을 통해 데이터를 직렬화/역직렬화한다. 간단히 말하면 직렬화는 메모리 여기저기 흩어져 있는 객체의 정보를 모아 바이트스트림 형태로 전송할 수 있도록 해주는 것이다. 이에 관련해서는 https://docs.python.org/3/library/multiprocessing.html#pipes-and-queueshttps://hyperconnect.github.io/2023/05/30/Python-Performance-Tips.html 와 같은 게시글을 참고하면 좋다.

또한 멀티프로세스 라이브러리에서는 크게 두 가지 방법으로 프로세스를 생성한다. 한 가지 방법은 fork, 다른 방법은 spawn으로, python의 multiprocessing 라이브러리에서 기본값은 리눅스에서 fork, 윈도우즈에서 spawn이다. spawn은 실행하는 파이썬 파일을 여러 번 똑같이 실행하는 것과 비슷하게 작동하며 fork는 현재 메모리에 올라가 있는 현재 프로세스의 정보를 복사하는 방식으로 작동한다. 이런 작동 방법의 차이로 인해 윈도우즈에서 멀티프로세싱을 사용하려면 ifmain을 설정해줘야 하는 등 이슈가 있다.

또, fork는 메모리에서 파일을 복사한다고 말했는데, 이 과정에서 프로세스간의 통신이 일어나고, pickle이 사용된다. 또한, 어떤 객체가 pickle당할 때는 getstate, setstate 메서드(물론 앞뒤에 언더바가 붙는다)가 호출된다. (참고: https://stackoverflow.com/questions/1939058/simple-example-of-use-of-setstate-and-getstate ) 어떻게 작동하는지는 아래 예제를 실행해보면 바로 알 수 있다.

import multiprocessing as mp


class Test:
    def __init__(self):
        self.x = 1
        print("init")

    def __call__(self, x):
        return x + self.x

    def __getstate__(self):
        print("get state")
        return self.__dict__

    def __setstate__(self, state):
        print("set state")
        self.__dict__ = state


if __name__ == "__main__":
    pool = mp.Pool()
    test = Test()
    print(test(1))
    print(pool.map(test, [1, 2, 3]))
    pool.close()

멀티프로세싱 라이브러리인 ray에서는 ray.put을 이용해 큰 객체를 object store에 저장해 메모리를 아끼는데 사실 이건 함수 인자로 전해줄 때 이야기다. (마침 ray가 나왔으니 서술하자면 지금까지 멀티프로세싱 라이브러리로 multiprocessing, ray, joblib을 써봤는데 다들 장단점이 있는 듯하다.)

다 좋은데, 조금 주의해야 할 점은 여러 프로세스가 같은 핸들(특히 파일 입출력)을 사용할 때이다. https://github.com/jotaf98/simple-tar-dataset/blob/master/tardataset.py#L64C9-L64C9 을 참고하면 아래와 같이 소스코드가 구성되어 있는데, 워커마다 다른 파일 핸들을 쓰도록 만든다! 메인 워커에서 init을 실행하며 TarDataset을 만들면 이를 워커 개수만큼 fork 방식으로 복사하기 때문에, init에서 하나의 파일 핸들만을 만들면 모든 워커가 하나의 파일 핸들을 공유하게 된다. TarFile이 thread-safe하지 않기에 이런 경우 주의해야 한다.

class TarDataset(Dataset):
  def __init__(self, archive, transform=to_tensor, extensions=('.png', '.jpg', '.jpeg'),
    is_valid_file=None, ignore_unexpected_eof=False):
    if not isinstance(archive, TarDataset):
      # open tar file. in a multiprocessing setting (e.g. DataLoader workers), we
      # have to open one file handle per worker (stored as the tar_obj dict), since
      # when the multiprocessing method is 'fork', the workers share this TarDataset.
      # we want one file handle per worker because TarFile is not thread-safe.
      worker = get_worker_info()
      worker = worker.id if worker else None
      self.tar_obj = {worker: tarfile.open(archive) if ignore_unexpected_eof is False else UnexpectedEOFTarFile.open(archive)}
      self.archive = archive

      # store headers of all files and folders by name
      members = sorted(self.tar_obj[worker].getmembers(), key=lambda m: m.name)
      self.members_by_name = {m.name: m for m in members}